Skip to content

Commit a8a301a

Browse files
committed
[mlir][spirv] Support MemRef in convert-to-spirv pass
1 parent f9f0ae1 commit a8a301a

File tree

4 files changed

+51
-0
lines changed

4 files changed

+51
-0
lines changed

mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
1717
MLIRFuncToSPIRV
1818
MLIRIndexToSPIRV
1919
MLIRIR
20+
MLIRMemRefToSPIRV
2021
MLIRPass
2122
MLIRRewrite
2223
MLIRSCFToSPIRV

mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
1111
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
1212
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
13+
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
1314
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
1415
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
1516
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
@@ -62,12 +63,24 @@ struct ConvertToSPIRVPass final
6263
RewritePatternSet patterns(context);
6364
ScfToSPIRVContext scfToSPIRVContext;
6465

66+
// Map MemRef memory space to SPIR-V storage class.
67+
spirv::TargetEnv targetEnv(targetAttr);
68+
bool targetEnvSupportsKernelCapability =
69+
targetEnv.allows(spirv::Capability::Kernel);
70+
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
71+
targetEnvSupportsKernelCapability
72+
? spirv::mapMemorySpaceToOpenCLStorageClass
73+
: spirv::mapMemorySpaceToVulkanStorageClass;
74+
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
75+
spirv::convertMemRefTypesAndAttrs(op, converter);
76+
6577
// Populate patterns for each dialect.
6678
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
6779
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
6880
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
6981
populateFuncToSPIRVPatterns(typeConverter, patterns);
7082
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
83+
populateMemRefToSPIRVPatterns(typeConverter, patterns);
7184
populateVectorToSPIRVPatterns(typeConverter, patterns);
7285
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
7386
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -cse -split-input-file %s | FileCheck %s
2+
3+
module attributes {
4+
spirv.target_env = #spirv.target_env<
5+
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
6+
} {
7+
8+
// CHECK-LABEL: @load_store_float_rank_zero
9+
// CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
10+
// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
11+
// CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[CST0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
12+
// CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : f32
13+
// CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[CST0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
14+
// CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : f32
15+
// CHECK: spirv.Return
16+
func.func @load_store_float_rank_zero(%arg0: memref<f32>, %arg1: memref<f32>) {
17+
%0 = memref.load %arg0[] : memref<f32>
18+
memref.store %0, %arg1[] : memref<f32>
19+
return
20+
}
21+
22+
// CHECK-LABEL: @load_store_int_rank_one
23+
// CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, %[[ARG2:.*]]: i32
24+
// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
25+
// CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[ARG2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
26+
// CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : i32
27+
// CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[ARG2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
28+
// CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : i32
29+
// CHECK: spirv.Return
30+
func.func @load_store_int_rank_one(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %arg2 : index) {
31+
%0 = memref.load %arg0[%arg2] : memref<4xi32>
32+
memref.store %0, %arg1[%arg2] : memref<4xi32>
33+
return
34+
}
35+
36+
} // end module

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8316,6 +8316,7 @@ cc_library(
83168316
":FuncToSPIRV",
83178317
":IR",
83188318
":IndexToSPIRV",
8319+
":MemRefToSPIRV",
83198320
":Pass",
83208321
":Rewrite",
83218322
":SCFToSPIRV",

0 commit comments

Comments
 (0)