Skip to content

Commit 2b0bf95

Browse files
committed
[MLIR] Improve KernelOutlining to avoid introducing an extra block
This fixes a TODO in the code.
1 parent de375fb commit 2b0bf95

File tree

2 files changed

+11
-20
lines changed

2 files changed

+11
-20
lines changed

mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -241,24 +241,19 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
241241
map.map(operand.value(), entryBlock.getArgument(operand.index()));
242242

243243
// Clone the region of the gpu.launch operation into the gpu.func operation.
244-
// TODO: If cloneInto can be modified such that if a mapping for
245-
// a block exists, that block will be used to clone operations into (at the
246-
// end of the block), instead of creating a new block, this would be much
247-
// cleaner.
248244
launchOpBody.cloneInto(&outlinedFuncBody, map);
249245

250-
// Branch from entry of the gpu.func operation to the block that is cloned
251-
// from the entry block of the gpu.launch operation.
252-
Block &launchOpEntry = launchOpBody.front();
253-
Block *clonedLaunchOpEntry = map.lookup(&launchOpEntry);
254-
builder.setInsertionPointToEnd(&entryBlock);
255-
builder.create<cf::BranchOp>(loc, clonedLaunchOpEntry);
256-
257-
outlinedFunc.walk([](gpu::TerminatorOp op) {
258-
OpBuilder replacer(op);
259-
replacer.create<gpu::ReturnOp>(op.getLoc());
260-
op.erase();
261-
});
246+
// Splice now the entry block of the gpu.launch operation at the end of the
247+
// gpu.func entry block and erase the redundant block.
248+
Block *clonedLaunchOpEntry = map.lookup(&launchOpBody.front());
249+
Operation *terminator = clonedLaunchOpEntry->getTerminator();
250+
OpBuilder replacer(terminator);
251+
replacer.create<gpu::ReturnOp>(terminator->getLoc());
252+
terminator->erase();
253+
entryBlock.getOperations().splice(entryBlock.getOperations().end(),
254+
clonedLaunchOpEntry->getOperations());
255+
clonedLaunchOpEntry->erase();
256+
262257
return outlinedFunc;
263258
}
264259

mlir/test/Dialect/GPU/outlining.mlir

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ func.func @launch() {
5454
// CHECK-NEXT: %[[BDIM:.*]] = gpu.block_dim x
5555
// CHECK-NEXT: = gpu.block_dim y
5656
// CHECK-NEXT: = gpu.block_dim z
57-
// CHECK-NEXT: cf.br ^[[BLOCK:.*]]
58-
// CHECK-NEXT: ^[[BLOCK]]:
5957
// CHECK-NEXT: "use"(%[[KERNEL_ARG0]]) : (f32) -> ()
6058
// CHECK-NEXT: "some_op"(%[[BID]], %[[BDIM]]) : (index, index) -> ()
6159
// CHECK-NEXT: = memref.load %[[KERNEL_ARG1]][%[[TID]]] : memref<?xf32, 1>
@@ -475,8 +473,6 @@ func.func @launch_cluster() {
475473
// CHECK-NEXT: %[[CDIM:.*]] = gpu.cluster_dim x
476474
// CHECK-NEXT: = gpu.cluster_dim y
477475
// CHECK-NEXT: = gpu.cluster_dim z
478-
// CHECK-NEXT: cf.br ^[[BLOCK:.*]]
479-
// CHECK-NEXT: ^[[BLOCK]]:
480476
// CHECK-NEXT: "use"(%[[KERNEL_ARG0]]) : (f32) -> ()
481477
// CHECK-NEXT: "some_op"(%[[CID]], %[[BID]], %[[BDIM]]) : (index, index, index) -> ()
482478
// CHECK-NEXT: = memref.load %[[KERNEL_ARG1]][%[[TID]]] : memref<?xf32, 1>

0 commit comments

Comments
 (0)