Skip to content

Commit d566a5c

Browse files
authored
[MLIR] Improve KernelOutlining to avoid introducing an extra block (#90359)
This fixes a TODO in the code.
1 parent cd68d7b commit d566a5c

File tree

2 files changed

+49
-20
lines changed

2 files changed

+49
-20
lines changed

mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -241,24 +241,26 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
241241
map.map(operand.value(), entryBlock.getArgument(operand.index()));
242242

243243
// Clone the region of the gpu.launch operation into the gpu.func operation.
244-
// TODO: If cloneInto can be modified such that if a mapping for
245-
// a block exists, that block will be used to clone operations into (at the
246-
// end of the block), instead of creating a new block, this would be much
247-
// cleaner.
248244
launchOpBody.cloneInto(&outlinedFuncBody, map);
249245

250-
// Branch from entry of the gpu.func operation to the block that is cloned
251-
// from the entry block of the gpu.launch operation.
252-
Block &launchOpEntry = launchOpBody.front();
253-
Block *clonedLaunchOpEntry = map.lookup(&launchOpEntry);
254-
builder.setInsertionPointToEnd(&entryBlock);
255-
builder.create<cf::BranchOp>(loc, clonedLaunchOpEntry);
256-
257-
outlinedFunc.walk([](gpu::TerminatorOp op) {
258-
OpBuilder replacer(op);
259-
replacer.create<gpu::ReturnOp>(op.getLoc());
260-
op.erase();
261-
});
246+
// Replace the terminator op with returns.
247+
for (Block &block : launchOpBody) {
248+
Block *clonedBlock = map.lookup(&block);
249+
auto terminator = dyn_cast<gpu::TerminatorOp>(clonedBlock->getTerminator());
250+
if (!terminator)
251+
continue;
252+
OpBuilder replacer(terminator);
253+
replacer.create<gpu::ReturnOp>(terminator->getLoc());
254+
terminator->erase();
255+
}
256+
257+
// Splice now the entry block of the gpu.launch operation at the end of the
258+
// gpu.func entry block and erase the redundant block.
259+
Block *clonedLaunchOpEntry = map.lookup(&launchOpBody.front());
260+
entryBlock.getOperations().splice(entryBlock.getOperations().end(),
261+
clonedLaunchOpEntry->getOperations());
262+
clonedLaunchOpEntry->erase();
263+
262264
return outlinedFunc;
263265
}
264266

mlir/test/Dialect/GPU/outlining.mlir

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,41 @@ func.func @launch() {
5454
// CHECK-NEXT: %[[BDIM:.*]] = gpu.block_dim x
5555
// CHECK-NEXT: = gpu.block_dim y
5656
// CHECK-NEXT: = gpu.block_dim z
57-
// CHECK-NEXT: cf.br ^[[BLOCK:.*]]
58-
// CHECK-NEXT: ^[[BLOCK]]:
5957
// CHECK-NEXT: "use"(%[[KERNEL_ARG0]]) : (f32) -> ()
6058
// CHECK-NEXT: "some_op"(%[[BID]], %[[BDIM]]) : (index, index) -> ()
6159
// CHECK-NEXT: = memref.load %[[KERNEL_ARG1]][%[[TID]]] : memref<?xf32, 1>
6260

61+
// -----
62+
63+
// Verify that we can outline a CFG
64+
// CHECK-LABEL: gpu.func @launchCFG_kernel(
65+
// CHECK: cf.br
66+
// CHECK: gpu.return
67+
func.func @launchCFG() {
68+
%0 = "op"() : () -> (f32)
69+
%1 = "op"() : () -> (memref<?xf32, 1>)
70+
%gDimX = arith.constant 8 : index
71+
%gDimY = arith.constant 12 : index
72+
%gDimZ = arith.constant 16 : index
73+
%bDimX = arith.constant 20 : index
74+
%bDimY = arith.constant 24 : index
75+
%bDimZ = arith.constant 28 : index
76+
77+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY,
78+
%grid_z = %gDimZ)
79+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY,
80+
%block_z = %bDimZ) {
81+
"use"(%0): (f32) -> ()
82+
cf.br ^bb1
83+
^bb1:
84+
"some_op"(%bx, %block_x) : (index, index) -> ()
85+
%42 = memref.load %1[%tx] : memref<?xf32, 1>
86+
gpu.terminator
87+
}
88+
return
89+
}
90+
91+
6392
// -----
6493

6594
// This test checks gpu-out-lining can handle gpu.launch kernel from an llvm.func
@@ -475,8 +504,6 @@ func.func @launch_cluster() {
475504
// CHECK-NEXT: %[[CDIM:.*]] = gpu.cluster_dim x
476505
// CHECK-NEXT: = gpu.cluster_dim y
477506
// CHECK-NEXT: = gpu.cluster_dim z
478-
// CHECK-NEXT: cf.br ^[[BLOCK:.*]]
479-
// CHECK-NEXT: ^[[BLOCK]]:
480507
// CHECK-NEXT: "use"(%[[KERNEL_ARG0]]) : (f32) -> ()
481508
// CHECK-NEXT: "some_op"(%[[CID]], %[[BID]], %[[BDIM]]) : (index, index, index) -> ()
482509
// CHECK-NEXT: = memref.load %[[KERNEL_ARG1]][%[[TID]]] : memref<?xf32, 1>

0 commit comments

Comments
 (0)