Skip to content

Extend TilingInterface to allow more flexible tiling #95422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ struct SCFTilingResult {
/// Values to use as replacements for the untiled op. Is the same size as the
/// number of results of the untiled op.
SmallVector<Value> replacements;
SmallVector<Operation *> extractSliceOps;
};

/// Method to tile an op that implements the `TilingInterface` using
Expand Down Expand Up @@ -135,6 +136,7 @@ struct SCFFuseProducerOfSliceResult {
OpResult origProducer; // Original untiled producer.
Value tiledAndFusedProducer; // Tile and fused producer value.
SmallVector<Operation *> tiledOps;
SmallVector<Operation *> extractSliceOps;
};
std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Interfaces/TilingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ namespace mlir {
/// are returned to the caller for further transformations.
/// - `tiledValues` contains the tiled value corresponding to the result of the
/// untiled operation.
/// - `extractSliceOps` contains all the `tensor.extract_slice` ops used in
/// generating the `tiledOps`. Usually these are operands to the `tiledOps`
/// but they can be embedded in regions owned by `tiledOps`.
struct TilingResult {
SmallVector<Operation *> tiledOps;
SmallVector<Value> tiledValues;
SmallVector<Operation *> extractSliceOps;
};

} // namespace mlir
Expand Down
8 changes: 7 additions & 1 deletion mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2501,7 +2501,13 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);

return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
SmallVector<Operation *> sliceOps;
for (Value operand : tiledOperands)
if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
sliceOps.push_back(sliceOp);

return TilingResult{
{tiledOp}, SmallVector<Value>(tiledOp->getResults()), sliceOps};
}

LogicalResult SoftmaxOp::getResultTilePosition(
Expand Down
11 changes: 9 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,13 @@ struct LinalgOpTilingInterface
Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);

return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
SmallVector<Operation *> sliceOps;
for (Value operand : tiledOperands)
if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
sliceOps.push_back(sliceOp);

return TilingResult{
{tiledOp}, SmallVector<Value>(tiledOp->getResults()), sliceOps};
}

/// Utility to fetch the offsets and sizes when applied as per the indexing
Expand Down Expand Up @@ -247,7 +253,8 @@ struct LinalgOpTilingInterface

return TilingResult{
tilingResult->tiledOps,
SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
tilingResult->extractSliceOps};
}

/// Method to generate the tiled implementation of an operation from the tile
Expand Down
37 changes: 19 additions & 18 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
if (llvm::all_of(tileSizes, isZeroIndex)) {
tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
tilingResult =
TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
/*extractSliceOps=*/{}};
return success();
}

Expand Down Expand Up @@ -675,12 +676,14 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// op.
if (loops.empty()) {
return scf::SCFTilingResult{tilingResult->tiledOps, loops,
tilingResult->tiledValues};
tilingResult->tiledValues,
tilingResult->extractSliceOps};
}

SmallVector<Value> replacements = llvm::map_to_vector(
loops.front()->getResults(), [](OpResult r) -> Value { return r; });
return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
tilingResult->extractSliceOps};
}

FailureOr<scf::SCFReductionTilingResult>
Expand Down Expand Up @@ -931,9 +934,9 @@ mlir::scf::tileAndFuseProducerOfSlice(
->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
.set(origDestinationTensors[resultNumber]);
}
return scf::SCFFuseProducerOfSliceResult{fusableProducer,
tileAndFuseResult->tiledValues[0],
tileAndFuseResult->tiledOps};
return scf::SCFFuseProducerOfSliceResult{
fusableProducer, tileAndFuseResult->tiledValues[0],
tileAndFuseResult->tiledOps, tileAndFuseResult->extractSliceOps};
}

/// Reconstruct the fused producer from within the tiled-and-fused code.
Expand Down Expand Up @@ -962,13 +965,12 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
.getDefiningOp<DestinationStyleOpInterface>()) {
rewriter.setInsertionPoint(tiledDestStyleOp);
Value newRegionArg = newRegionIterArgs.back();
auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
unsigned resultNumber = fusableProducer.getResultNumber();
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
});
auto origSlice = tiledDestStyleOp.getDpsInits()[resultNumber]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were previously creating a new tensor::ExtractSliceOp to set as the destination argument of the tiled op. This worked before because we would be looking at the operands of the tiled op to build the worklist of candidate tensor::ExtractSliceOps in tileConsumerAndFuseProducers. However, we now directly use the slice ops returned by the TilingInterfaceOp before we call into yieldReplacementForFusedProducer.

We therefore modify the slice op argument of the fused producer in place instead of creating a new tensor::ExtractSliceOp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to try this out more... I cloned the slice cause that made it easier. I think without this I kept running into invalid IR creation, but maybe it was an artifact of something else not working. I will have to try this out in IREE to really stress test it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR indeed needs this change to ensure the slice ops returned by the TilingInterfaceOp would not expire after yieldReplacementForFusedProducer. Otherwise, it needs another solution to update the slice ops cached in SCFFuseProducerOfSliceResult before to latest one(created in yieldReplacementForFusedProducer).

BTW, this change will also affect another PR involving yieldReplacementForFusedProducer as well , hopefully merged before this PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up @Yun-Fly

.getDefiningOp<tensor::ExtractSliceOp>();
if (origSlice) {
origSlice.getSourceMutable().set(newRegionArg);
}
}
Block *block = rewriter.getInsertionPoint()->getBlock();
rewriter.setInsertionPoint(block->getTerminator());
Expand Down Expand Up @@ -1036,15 +1038,14 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
// operations. If the producers of the source of the `tensor.extract_slice`
// can be tiled such that the tiled value is generated in-place, that
// effectively tiles + fuses the operations.
auto addCandidateSlices = [](Operation *fusedOp,
auto addCandidateSlices = [](const SmallVector<Operation *> &newSliceOps,
std::deque<tensor::ExtractSliceOp> &candidates) {
for (Value operand : fusedOp->getOperands())
if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
candidates.push_back(sliceOp);
for (auto *op : newSliceOps)
candidates.push_back(llvm::cast<tensor::ExtractSliceOp>(op));
};

std::deque<tensor::ExtractSliceOp> candidates;
addCandidateSlices(tiledAndFusedOps.back(), candidates);
addCandidateSlices(tilingResult->extractSliceOps, candidates);
OpBuilder::InsertionGuard g(rewriter);
while (!candidates.empty()) {
// Traverse the slices in BFS fashion.
Expand Down Expand Up @@ -1086,7 +1087,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
tiledAndFusedOps.insert(tiledAndFusedOp);
addCandidateSlices(tiledAndFusedOp, candidates);
addCandidateSlices(fusedResult->extractSliceOps, candidates);
}
}

Expand Down
Loading