@@ -1081,11 +1081,13 @@ BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1081
1081
// ===----------------------------------------------------------------------===//
1082
1082
1083
1083
void LaunchFuncOp::build (OpBuilder &builder, OperationState &result,
1084
- GPUFuncOp kernelFunc , KernelDim3 gridSize,
1084
+ SymbolRefAttr kernelSymbol , KernelDim3 gridSize,
1085
1085
KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1086
1086
ValueRange kernelOperands, Type asyncTokenType,
1087
1087
ValueRange asyncDependencies,
1088
1088
std::optional<KernelDim3> clusterSize) {
1089
+ assert (kernelSymbol.getNestedReferences ().size () == 1 &&
1090
+ " expected a symbol reference with a single nested reference" );
1089
1091
result.addOperands (asyncDependencies);
1090
1092
if (asyncTokenType)
1091
1093
result.types .push_back (builder.getType <AsyncTokenType>());
@@ -1098,10 +1100,6 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1098
1100
if (dynamicSharedMemorySize)
1099
1101
result.addOperands (dynamicSharedMemorySize);
1100
1102
result.addOperands (kernelOperands);
1101
- auto kernelModule = kernelFunc->getParentOfType <GPUModuleOp>();
1102
- auto kernelSymbol =
1103
- SymbolRefAttr::get (kernelModule.getNameAttr (),
1104
- {SymbolRefAttr::get (kernelFunc.getNameAttr ())});
1105
1103
1106
1104
Properties &prop = result.getOrAddProperties <Properties>();
1107
1105
prop.kernel = kernelSymbol;
@@ -1122,6 +1120,21 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1122
1120
prop.operandSegmentSizes [segmentSizesLen - 1 ] = 0 ;
1123
1121
}
1124
1122
1123
+ void LaunchFuncOp::build (OpBuilder &builder, OperationState &result,
1124
+ GPUFuncOp kernelFunc, KernelDim3 gridSize,
1125
+ KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1126
+ ValueRange kernelOperands, Type asyncTokenType,
1127
+ ValueRange asyncDependencies,
1128
+ std::optional<KernelDim3> clusterSize) {
1129
+ auto kernelModule = kernelFunc->getParentOfType <GPUModuleOp>();
1130
+ auto kernelSymbol =
1131
+ SymbolRefAttr::get (kernelModule.getNameAttr (),
1132
+ {SymbolRefAttr::get (kernelFunc.getNameAttr ())});
1133
+ build (builder, result, kernelSymbol, gridSize, getBlockSize,
1134
+ dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1135
+ asyncDependencies, clusterSize);
1136
+ }
1137
+
1125
1138
void LaunchFuncOp::build (OpBuilder &builder, OperationState &result,
1126
1139
SymbolRefAttr kernel, KernelDim3 gridSize,
1127
1140
KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
0 commit comments