Skip to content

Commit 98b49d4

Browse files
Address remaining comments (round 3)
1 parent 09dab87 commit 98b49d4

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
112112
/// - `resultSizes` is of the same size as `tiledValues` and represents
113113
/// the size of the corresponding element from `tiledValues` inserted into
114114
/// the element from `newBbArgs`.
115+
/// In case the method needs to return `failure()` the method is expected
116+
/// to clean up any inserted operations.
115117
using YieldTiledValuesFn = std::function<LogicalResult(
116118
RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
117119
SmallVector<Value> &tiledValues,
@@ -354,17 +356,10 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
354356
if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
355357
newRegionIterArgs, tiledValues, resultOffsets,
356358
resultSizes))) {
359+
rewriter.eraseOp(newLoop);
357360
return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
358361
}
359362

360-
if (tiledValues.size() != resultOffsets.size() ||
361-
tiledValues.size() != resultSizes.size()) {
362-
return rewriter.notifyMatchFailure(
363-
loopOp,
364-
"expected number of tiled values returned, the number of offset "
365-
"vectors and number of size vectors to be the same");
366-
}
367-
368363
SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
369364
for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
370365
llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
@@ -414,6 +409,7 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
414409
if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
415410
regionIterArgs, tiledValues, resultOffsets,
416411
resultSizes))) {
412+
rewriter.eraseOp(newLoop);
417413
return rewriter.notifyMatchFailure(loopOp,
418414
"failed to get yielded tiled values");
419415
}
@@ -625,8 +621,10 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
625621

626622
// 5c. Tile the cloned operation.
627623
tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
628-
if (failed(tilingResult))
629-
return op.emitOpError("failed to tile operation");
624+
if (failed(tilingResult)) {
625+
rewriter.eraseOp(clonedOp);
626+
return op.emitOpError("faild to tile operation");
627+
}
630628

631629
// 5d. Delete the cloned operation.
632630
rewriter.eraseOp(clonedOp);
@@ -639,6 +637,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
639637
SmallVector<OpFoldResult> resultOffset, resultSize;
640638
if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
641639
resultOffset, resultSize))) {
640+
for (auto op : tilingResult->tiledOps) {
641+
rewriter.eraseOp(op);
642+
}
642643
return rewriter.notifyMatchFailure(
643644
op, "failed to get slice of result produced");
644645
}

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
3838
(ins TransformHandleTypeInterface:$target,
3939
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
4040
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
41-
DefaultValuedAttr<BoolAttr, "{false}">:$use_forall);
41+
DefaultValuedAttr<BoolAttr, "false">:$use_forall);
4242
let results = (outs TransformHandleTypeInterface:$transfomed,
4343
Variadic<TransformHandleTypeInterface>:$loops);
4444

0 commit comments

Comments
 (0)