17
17
#include " mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
18
18
#include " mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
19
19
#include " mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
20
+ #include " mlir/Dialect/Func/IR/FuncOps.h"
20
21
#include " mlir/Dialect/GPU/IR/GPUDialect.h"
21
22
#include " mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
22
23
#include " mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -54,22 +55,47 @@ void GPUToSPIRVPass::runOnOperation() {
54
55
55
56
SmallVector<Operation *, 1 > gpuModules;
56
57
OpBuilder builder (context);
58
+
59
+ auto targetEnvSupportsKernelCapability = [](gpu::GPUModuleOp moduleOp) {
60
+ Operation *gpuModule = moduleOp.getOperation ();
61
+ auto targetAttr = spirv::lookupTargetEnvOrDefault (gpuModule);
62
+ spirv::TargetEnv targetEnv (targetAttr);
63
+ return targetEnv.allows (spirv::Capability::Kernel);
64
+ };
65
+
57
66
module.walk ([&](gpu::GPUModuleOp moduleOp) {
58
67
// Clone each GPU kernel module for conversion, given that the GPU
59
68
// launch op still needs the original GPU kernel module.
60
- builder.setInsertionPoint (moduleOp.getOperation ());
69
+ // For Vulkan Shader capabilities, we insert the newly converted SPIR-V
70
+ // module right after the original GPU module, as that's the expectation of
71
+ // the in-tree Vulkan runner.
72
+ // For OpenCL Kernel capabilities, we insert the newly converted SPIR-V
73
+ // module inside the original GPU module, as that's the expectaion of the
74
+ // normal GPU compilation pipeline.
75
+ if (targetEnvSupportsKernelCapability (moduleOp)) {
76
+ builder.setInsertionPoint (moduleOp.getBody (),
77
+ moduleOp.getBody ()->begin ());
78
+ } else {
79
+ builder.setInsertionPoint (moduleOp.getOperation ());
80
+ }
61
81
gpuModules.push_back (builder.clone (*moduleOp.getOperation ()));
62
82
});
63
83
64
84
// Run conversion for each module independently as they can have different
65
85
// TargetEnv attributes.
66
86
for (Operation *gpuModule : gpuModules) {
87
+ spirv::TargetEnvAttr targetAttr =
88
+ spirv::lookupTargetEnvOrDefault (gpuModule);
89
+
67
90
// Map MemRef memory space to SPIR-V storage class first if requested.
68
91
if (mapMemorySpace) {
69
92
std::unique_ptr<ConversionTarget> target =
70
93
spirv::getMemorySpaceToStorageClassTarget (*context);
71
94
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
72
- spirv::mapMemorySpaceToVulkanStorageClass;
95
+ targetEnvSupportsKernelCapability (
96
+ dyn_cast<gpu::GPUModuleOp>(gpuModule))
97
+ ? spirv::mapMemorySpaceToOpenCLStorageClass
98
+ : spirv::mapMemorySpaceToVulkanStorageClass;
73
99
spirv::MemorySpaceToStorageClassConverter converter (memorySpaceMap);
74
100
75
101
RewritePatternSet patterns (context);
@@ -79,7 +105,6 @@ void GPUToSPIRVPass::runOnOperation() {
79
105
return signalPassFailure ();
80
106
}
81
107
82
- auto targetAttr = spirv::lookupTargetEnvOrDefault (gpuModule);
83
108
std::unique_ptr<ConversionTarget> target =
84
109
SPIRVConversionTarget::get (targetAttr);
85
110
@@ -108,6 +133,25 @@ void GPUToSPIRVPass::runOnOperation() {
108
133
if (failed (applyFullConversion (gpuModule, *target, std::move (patterns))))
109
134
return signalPassFailure ();
110
135
}
136
+
137
+ // For OpenCL, the gpu.func op in the original gpu.module op needs to be
138
+ // replaced with an empty func.func op with the same arguments as the gpu.func
139
+ // op. The func.func op needs gpu.kernel attribute set.
140
+ module.walk ([&](gpu::GPUModuleOp moduleOp) {
141
+ if (targetEnvSupportsKernelCapability (moduleOp)) {
142
+ moduleOp.walk ([&](gpu::GPUFuncOp funcOp) {
143
+ builder.setInsertionPoint (funcOp);
144
+ auto newFuncOp = builder.create <func::FuncOp>(
145
+ funcOp.getLoc (), funcOp.getName (), funcOp.getFunctionType ());
146
+ auto entryBlock = newFuncOp.addEntryBlock ();
147
+ builder.setInsertionPointToEnd (entryBlock);
148
+ builder.create <func::ReturnOp>(funcOp.getLoc ());
149
+ newFuncOp->setAttr (gpu::GPUDialect::getKernelFuncAttrName (),
150
+ builder.getUnitAttr ());
151
+ funcOp.erase ();
152
+ });
153
+ }
154
+ });
111
155
}
112
156
113
157
} // namespace
0 commit comments