@@ -101,7 +101,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
101
101
FunctionCallBuilder moduleLoadCallBuilder = {
102
102
" mgpuModuleLoad" ,
103
103
llvmPointerType /* void *module */ ,
104
- {llvmPointerType /* void *cubin */ }};
104
+ {llvmPointerType /* void *cubin */ , llvmInt64Type /* size_t size */ }};
105
105
FunctionCallBuilder moduleUnloadCallBuilder = {
106
106
" mgpuModuleUnload" , llvmVoidType, {llvmPointerType /* void *module */ }};
107
107
FunctionCallBuilder moduleGetFunctionCallBuilder = {
@@ -125,7 +125,8 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
125
125
llvmInt32Type, /* unsigned int sharedMemBytes */
126
126
llvmPointerType, /* void *hstream */
127
127
llvmPointerPointerType, /* void **kernelParams */
128
- llvmPointerPointerType /* void **extra */
128
+ llvmPointerPointerType, /* void **extra */
129
+ llvmInt64Type /* size_t paramsCount */
129
130
}};
130
131
FunctionCallBuilder streamCreateCallBuilder = {
131
132
" mgpuStreamCreate" , llvmPointerType /* void *stream */ , {}};
@@ -1134,7 +1135,23 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
1134
1135
loc, rewriter, nameBuffer.str (), binaryAttr.getValue (),
1135
1136
LLVM::Linkage::Internal, getTypeConverter ()->useOpaquePointers ());
1136
1137
1137
- auto module = moduleLoadCallBuilder.create (loc, rewriter, data);
1138
+ // Pass the binary size. SPIRV requires binary size.
1139
+ auto gpuBlob = binaryAttr.getValue ();
1140
+ auto gpuBlobSize = rewriter.create <mlir::LLVM::ConstantOp>(
1141
+ loc, llvmInt64Type,
1142
+ mlir::IntegerAttr::get (llvmInt64Type,
1143
+ static_cast <int64_t >(gpuBlob.size ())));
1144
+
1145
+ auto module =
1146
+ moduleLoadCallBuilder.create (loc, rewriter, {data, gpuBlobSize});
1147
+
1148
+ // Pass the count of the parameters to runtime wrappers
1149
+ auto paramsCount = rewriter.create <mlir::LLVM::ConstantOp>(
1150
+ loc, llvmInt64Type,
1151
+ mlir::IntegerAttr::get (
1152
+ llvmInt64Type,
1153
+ static_cast <int64_t >(launchOp.getNumKernelOperands ())));
1154
+
1138
1155
// Get the function from the module. The name corresponds to the name of
1139
1156
// the kernel function.
1140
1157
auto kernelName = generateKernelNameConstant (
@@ -1158,7 +1175,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
1158
1175
{function.getResult (), adaptor.getGridSizeX (), adaptor.getGridSizeY (),
1159
1176
adaptor.getGridSizeZ (), adaptor.getBlockSizeX (), adaptor.getBlockSizeY (),
1160
1177
adaptor.getBlockSizeZ (), dynamicSharedMemorySize, stream, kernelParams,
1161
- /* extra=*/ nullpointer});
1178
+ /* extra=*/ nullpointer, paramsCount });
1162
1179
1163
1180
if (launchOp.getAsyncToken ()) {
1164
1181
// Async launch: make dependent ops use the same stream.
0 commit comments