-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][SCF] Add an API to fuse consumer to a producer within scf loop #88712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][SCF] Add an API to fuse consumer to a producer within scf loop #88712
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir-scf Author: Abhishek Varma (Abhishek-Varma) Changes-- This commit adds an API to fuse consumer to a producer within scf.for/scf.forall loop. Signed-off-by: Abhishek Varma <[email protected]> Patch is 41.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88712.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be28..75b48a2cdd8dc3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -126,6 +126,19 @@ struct SCFTileAndFuseOptions {
}
};
+/// Fuse the consumer of the source of `candidateSliceOp` by computing the
+/// required slice of the consumer in-place. Note that the method
+/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
+/// value but does not delete the slice operation.
+struct SCFFuseConsumerOfSliceResult {
+ Operation *origConsumer; // Original untiled consumer.
+ Value tiledAndFusedConsumer; // Tile and fused consumer value.
+ SmallVector<Operation *> tiledOps;
+};
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
+ bool useSCFFor);
+
/// Fuse the producer of the source of `candidateSliceOp` by computing the
/// required slice of the producer in-place. Note that the method
/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..4c62d45822ad44 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return {};
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to return iterator domain position computed by the
+ input operand position.
+ }],
+ /*retType=*/"LogicalResult",
+ /*methodName=*/"getIterDomainTilePositionFromOperandPosition",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult> ":$offsets,
+ "ArrayRef<OpFoldResult> ":$sizes,
+ "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+ "SmallVector<OpFoldResult> &":$iterDomainSizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Method to return the position of the result tile computed by the tiled operation.
@@ -96,6 +115,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to generate the tiled implementation of an operation from
+ operand position.
+
+ Generates the IR that generate the tiled implementation of an
+ operation from operand position. The `offsets` and `sizes`
+ describe the tile of the operand required. This is different from
+ `getTiledImplementation` which generates the tiled
+ implementation of the operation given a tile of the
+ iteration space. This method generates a tiled
+ implementation of the operation based on the position of the
+ operand required. This method enables fusion consumer by using
+ tile and fuse. The method returns failure if the operation
+ can't be tiled to generate the operand tile. In practical terms
+ this implies it cannot be tiled and fused with its producers.
+
+ - `offsets` provides the offset of the tile in the coordinate system
+ of the original iteration space, i.e., if an iteration space
+ dimension had non-zero offset, it must be included in the offset
+ provided here (as opposed to zero-based offset "relative" to the
+ iteration space).
+ - `sizes` provides the size of the tile.
+ }],
+ /*retType=*/"FailureOr<TilingResult>",
+ /*methodName=*/"getTiledImplementationFromOperandPosition",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult>":$offsets,
+ "ArrayRef<OpFoldResult>":$sizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..01bf19764b0938 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -132,6 +132,59 @@ struct LinalgOpTilingInterface
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
+ void getMappedOffsetAndSize(Operation *op, OpBuilder &b,
+ AffineMap indexingMap,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &mappedOffsets,
+ SmallVector<OpFoldResult> &mappedSizes) const {
+ auto linalgOp = cast<LinalgOp>(op);
+ auto numLoops = linalgOp.getNumLoops();
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
+ mappedOffsets.resize(numLoops);
+ mappedSizes.resize(numLoops);
+ if (!indexingMap.isPermutation()) {
+ SmallVector<Range> iterationDomain =
+ tilingInterfaceOp.getIterationDomain(b);
+ for (const auto &range : llvm::enumerate(iterationDomain)) {
+ mappedOffsets[range.index()] = range.value().offset;
+ mappedSizes[range.index()] = range.value().size;
+ }
+ }
+ for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
+ unsigned dimPosition =
+ cast<AffineDimExpr>(resultExpr.value()).getPosition();
+ mappedOffsets[dimPosition] = offsets[resultExpr.index()];
+ mappedSizes[dimPosition] = sizes[resultExpr.index()];
+ }
+ }
+
+ // Return the details of the output tile generated by the tiled
+ // implementation.
+ LogicalResult getIterDomainTilePositionFromOperandPosition(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &iterDomainOffsets,
+ SmallVector<OpFoldResult> &iterDomainSizes) const {
+ auto linalgOp = cast<LinalgOp>(op);
+
+ // Check that the indexing map used for the operand is a projected
+ // permutation. This could be relaxed with a more general approach that can
+ // map the offsets and sizes from the operand to iteration space tiles
+ // (filling in full extent for dimensions not used to access the result).
+ AffineMap indexingMap =
+ linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+ if (!indexingMap.isProjectedPermutation()) {
+ return op->emitOpError(
+ "unhandled get iter domain position when operand is not "
+ "accessed using a permuted projection");
+ }
+
+ getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes,
+ iterDomainOffsets, iterDomainSizes);
+ return success();
+ }
+
// Return the details of the output tile generated by the tiled
// implementation.
LogicalResult
@@ -160,6 +213,20 @@ struct LinalgOpTilingInterface
return success();
}
+ FailureOr<TilingResult> getTiledImplementationFromOperandPosition(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
+ if (failed(tilingInterfaceOp.getIterDomainTilePositionFromOperandPosition(
+ b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+ return op->emitOpError(
+ "unable to obtain the iter domain position of the operation.");
+ }
+ return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+ mappedSizes);
+ }
+
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
"unhandled tiled implementation generation when result is not "
"accessed using a permuted projection");
}
-
- auto numLoops = linalgOp.getNumLoops();
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes, mappedOffsets,
+ mappedSizes);
auto tilingInterfaceOp = cast<TilingInterface>(op);
- SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
- iterationTileSizes(numLoops);
- if (!indexingMap.isPermutation()) {
- SmallVector<Range> iterationDomain =
- tilingInterfaceOp.getIterationDomain(b);
- for (const auto &range : llvm::enumerate(iterationDomain)) {
- iterationTileOffsets[range.index()] = range.value().offset;
- iterationTileSizes[range.index()] = range.value().size;
- }
- }
- for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
- unsigned dimPosition =
- cast<AffineDimExpr>(resultExpr.value()).getPosition();
- iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
- iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
- }
-
FailureOr<TilingResult> tilingResult =
- tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
- iterationTileSizes);
+ tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+ if (failed(tilingResult))
+ return failure();
+
if (tilingResult->tiledOps.size() != 1)
return op->emitOpError("failed to generate tiled implementation");
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69df..45c8f8362ad581 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
@@ -798,6 +799,59 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
// tileConsumerAndFuseProducersUsingSCF implementation.
//===----------------------------------------------------------------------===//
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFFor(OpOperand &source,
+ unsigned &operandNumber) {
+ // Step 1. Fetch the corresponding output
+ // TODO(avarma): Make it generic for multiple values yielding scf.for.
+ unsigned yieldOperandNumber = source.getOperandNumber();
+ Value resultingValue =
+ source.getOwner()->getParentOp()->getResult(yieldOperandNumber);
+
+ // Step 3. Get users.
+ std::optional<OpOperand *> destinationIterArg;
+ Operation *untiledConsumer;
+ for (Operation *user : resultingValue.getUsers()) {
+ // TODO(avarma): Address the case where the consumer op itself can return
+ // more than one result.
+ for (Value operand : user->getOperands()) {
+ if (operand == resultingValue) {
+ untiledConsumer = user;
+ break;
+ }
+ operandNumber++;
+ }
+ break;
+ }
+ return {untiledConsumer, destinationIterArg};
+}
+
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFForall(OpOperand *source,
+ unsigned &operandNumber) {
+ // Step 1. Fetch the corresponding output
+ // TODO(avarma): Make it generic for multiple values yielding scf.forall.
+ auto iterArg = dyn_cast<BlockArgument>(source->get());
+ Value resultingValue = iterArg.getOwner()->getParentOp()->getResult(0);
+
+ // Step 3. Get users.
+ std::optional<OpOperand *> destinationIterArg;
+ Operation *untiledConsumer;
+ for (Operation *user : resultingValue.getUsers()) {
+ // TODO(avarma): Address the case where the consumer op itself can return
+ // more than one result.
+ for (Value operand : user->getOperands()) {
+ if (operand == resultingValue) {
+ untiledConsumer = user;
+ break;
+ }
+ operandNumber++;
+ }
+ break;
+ }
+ return {untiledConsumer, destinationIterArg};
+}
+
/// Return the untiled producer whose slice is used in a tiled consumer. The
/// method traverses the tile loop nest (`loops`) if needed, and returns the
/// `iter_args` of the outer most that is encountered. Traversing the iter_args
@@ -820,6 +874,463 @@ getUntiledProducerFromSliceSource(OpOperand *source,
return {dyn_cast<OpResult>(source->get()), destinationIterArg};
}
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+ RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // 1. Get the consumer of the source.
+ unsigned operandNumber = 0;
+ auto [consumerOp, destinationInitArg] =
+ getUntiledConsumerFromSliceDestSCFForall(
+ &candidateSliceOp.getDestMutable(), operandNumber);
+ if (!consumerOp)
+ return failure();
+ OpBuilder::InsertionGuard g(rewriter);
+ // Using candidateSliceOp->getParentOp() because we have the following case :-
+ // scf.forall.in_parallel {
+ // tensor.parallel_insert_slice ...
+ // }
+ rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+ Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+ // Check consumer has tiling interface.
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ llvm::outs() << "consumer is not a TileableInterface: " << *consumerOp
+ << "\n";
+ return failure();
+ }
+
+ // Check containing op is "scf::ForallOp".
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp) {
+ llvm::outs() << "containing op is not a scf.forall: " << containingOp
+ << "\n";
+ return failure();
+ }
+
+ // Check consumer don't use more than one result of containingOp.
+ Value bridge(nullptr);
+ SmallVector<unsigned> operandNums;
+ for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+ if (opd.getDefiningOp() == containingOp) {
+ operandNums.push_back(idx);
+ if (!bridge) {
+ bridge = opd;
+ } else if (bridge != opd) {
+ llvm::outs()
+ << "consumer's operand use more than one containingOp's result"
+ << "\n";
+ return failure();
+ }
+ }
+ }
+
+ // TODO: We have to init result of consumer before scf.forall, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ // Check consumer has DestinationStyleOpInterface.
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp) {
+ llvm::outs() << "consumer op should have destination style op interface"
+ << "\n";
+ return failure();
+ }
+
+ // Check consumer doon't use scf.forall's output as init.
+ SmallVector<Value> dpsInits = llvm::to_vector<4>(
+ llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+ if (llvm::is_contained(dpsInits, bridge)) {
+ llvm::outs() << "consumer op take result of scf.forall as init"
+ << "\n";
+ return failure();
+ }
+
+ // Check result was inserted only once.
+ int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+ auto bridgeBlockArg = forallOp.getRegionOutArgs()[bridgeResultIdx];
+ scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+ tensor::ParallelInsertSliceOp targetInsertOp(nullptr);
+ for (Operation &op : terminatorOp.getRegion().front().getOperations()) {
+ auto parallelInsertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+ if (parallelInsertSliceOp.getDest() == bridgeBlockArg) {
+ if (!targetInsertOp) {
+ targetInsertOp = parallelInsertSliceOp;
+ } else {
+ llvm::outs() << "containingOp's result inserted multi time"
+ << "\n";
+ return failure();
+ }
+ }
+ }
+
+ if (!targetInsertOp) {
+ llvm::outs() << "containingOp's result was not inserted"
+ << "\n";
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> offsets = targetInsertOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = targetInsertOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = targetInsertOp.getMixedStrides();
+
+ // Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+ if (auto attr = foldRes.dyn_cast<Attribute>()) {
+ return cast<IntegerAttr>(attr).getInt() != 1;
+ }
+ return true;
+ })) {
+ llvm::outs() << "containingOp's result yield with stride"
+ << "\n";
+ return failure();
+ }
+
+ Location loc = forallOp.getLoc();
+ rewriter.setInsertionPoint(terminatorOp);
+
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+ // Try to get iter domain position from input position.
+ if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+ rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ llvm::outs() << "can't get iter domain position from input position"
+ << "\n";
+ return failure();
+ }
+
+ // Try to get all containing op result's position from iter domain position.
+ llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+ llvm::SmallVector<OpFoldResult>>>
+ resultPositions(consumerOp->getNumResults());
+ for (auto [idx, v] : llvm::enumerate(consumerOp->getResults())) {
+ if (failed(tileableConsumer.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultPositions[idx].first, resultPositions[idx].second))) {
+ llvm::outs()
+ << "can't get result domain position from iter domain position"
+ << "\n";
+ return failure();
+ }
+ }
+
+ // All check passed, try to fuse consumer.
+ // Create tiled implementation of containing op.
+ FailureOr<TilingResult> tileAndFuseResult =
+ tileableConsumer.getTiledImplementationFromOperandPosition(
+ rewriter, operandNums.front(), offsets, sizes);
+ if (failed(tileAndFuseResult)) {
+ llvm::outs() << "get tiled implementation failed"
+ << "\n";
+ return failure();
+ }
+
+ auto tiledOps = tileAndFuseResult->tiledOps;
+ if (failed(tileAndFuseResult) || tiledOps.size() != 1) {
+ llvm::outs() << "failed to tile consumer op: " << *tileableConsumer << "\n";
+ return failure();
+ }
+
+ // Replace tiled op's operand.
+ for (auto operandNum : operandNums) {
+ tiledOps[0]->setOperand(operandNum, targetInsertOp.getSource());
+ }
+ rewriter.replaceUsesWithIf(bridge, forallOp.getOutputs()[bridgeResultIdx],
+ [&](OpOperand &use) {
+ Operation *op = use.getOwner();
+ return forallOp->isProperAncestor(op);
+ });
+
+ SmallVector<Value> newOuts(forallOp.getOutputs());
+ newOuts.append(dpsInits);
+
+ // Create new scf.forall op.
+ rewriter.setInsertionPoint(consumerOp);
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+ rewriter.eraseBlock(newforallOp.getBody());
+ newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+ for (auto v : dpsInits) ...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Abhishek Varma (Abhishek-Varma) Changes-- This commit adds an API to fuse consumer to a producer within scf.for/scf.forall loop. Signed-off-by: Abhishek Varma <[email protected]> Patch is 41.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88712.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be28..75b48a2cdd8dc3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -126,6 +126,19 @@ struct SCFTileAndFuseOptions {
}
};
+/// Fuse the consumer of the source of `candidateSliceOp` by computing the
+/// required slice of the consumer in-place. Note that the method
+/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
+/// value but does not delete the slice operation.
+struct SCFFuseConsumerOfSliceResult {
+ Operation *origConsumer; // Original untiled consumer.
+ Value tiledAndFusedConsumer; // Tile and fused consumer value.
+ SmallVector<Operation *> tiledOps;
+};
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
+ bool useSCFFor);
+
/// Fuse the producer of the source of `candidateSliceOp` by computing the
/// required slice of the producer in-place. Note that the method
/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..4c62d45822ad44 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return {};
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to return iterator domain position computed by the
+ input operand position.
+ }],
+ /*retType=*/"LogicalResult",
+ /*methodName=*/"getIterDomainTilePositionFromOperandPosition",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult> ":$offsets,
+ "ArrayRef<OpFoldResult> ":$sizes,
+ "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+ "SmallVector<OpFoldResult> &":$iterDomainSizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Method to return the position of the result tile computed by the tiled operation.
@@ -96,6 +115,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to generate the tiled implementation of an operation from
+ operand position.
+
+ Generates the IR that generate the tiled implementation of an
+ operation from operand position. The `offsets` and `sizes`
+ describe the tile of the operand required. This is different from
+ `getTiledImplementation` which generates the tiled
+ implementation of the operation given a tile of the
+ iteration space. This method generates a tiled
+ implementation of the operation based on the position of the
+ operand required. This method enables fusion consumer by using
+ tile and fuse. The method returns failure if the operation
+ can't be tiled to generate the operand tile. In practical terms
+ this implies it cannot be tiled and fused with its producers.
+
+ - `offsets` provides the offset of the tile in the coordinate system
+ of the original iteration space, i.e., if an iteration space
+ dimension had non-zero offset, it must be included in the offset
+ provided here (as opposed to zero-based offset "relative" to the
+ iteration space).
+ - `sizes` provides the size of the tile.
+ }],
+ /*retType=*/"FailureOr<TilingResult>",
+ /*methodName=*/"getTiledImplementationFromOperandPosition",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult>":$offsets,
+ "ArrayRef<OpFoldResult>":$sizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..01bf19764b0938 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -132,6 +132,59 @@ struct LinalgOpTilingInterface
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
+ void getMappedOffsetAndSize(Operation *op, OpBuilder &b,
+ AffineMap indexingMap,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &mappedOffsets,
+ SmallVector<OpFoldResult> &mappedSizes) const {
+ auto linalgOp = cast<LinalgOp>(op);
+ auto numLoops = linalgOp.getNumLoops();
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
+ mappedOffsets.resize(numLoops);
+ mappedSizes.resize(numLoops);
+ if (!indexingMap.isPermutation()) {
+ SmallVector<Range> iterationDomain =
+ tilingInterfaceOp.getIterationDomain(b);
+ for (const auto &range : llvm::enumerate(iterationDomain)) {
+ mappedOffsets[range.index()] = range.value().offset;
+ mappedSizes[range.index()] = range.value().size;
+ }
+ }
+ for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
+ unsigned dimPosition =
+ cast<AffineDimExpr>(resultExpr.value()).getPosition();
+ mappedOffsets[dimPosition] = offsets[resultExpr.index()];
+ mappedSizes[dimPosition] = sizes[resultExpr.index()];
+ }
+ }
+
+ // Return the details of the output tile generated by the tiled
+ // implementation.
+ LogicalResult getIterDomainTilePositionFromOperandPosition(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &iterDomainOffsets,
+ SmallVector<OpFoldResult> &iterDomainSizes) const {
+ auto linalgOp = cast<LinalgOp>(op);
+
+ // Check that the indexing map used for the operand is a projected
+ // permutation. This could be relaxed with a more general approach that can
+ // map the offsets and sizes from the operand to iteration space tiles
+ // (filling in full extent for dimensions not used to access the result).
+ AffineMap indexingMap =
+ linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+ if (!indexingMap.isProjectedPermutation()) {
+ return op->emitOpError(
+ "unhandled get iter domain position when operand is not "
+ "accessed using a permuted projection");
+ }
+
+ getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes,
+ iterDomainOffsets, iterDomainSizes);
+ return success();
+ }
+
// Return the details of the output tile generated by the tiled
// implementation.
LogicalResult
@@ -160,6 +213,20 @@ struct LinalgOpTilingInterface
return success();
}
+ FailureOr<TilingResult> getTiledImplementationFromOperandPosition(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
+ if (failed(tilingInterfaceOp.getIterDomainTilePositionFromOperandPosition(
+ b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+ return op->emitOpError(
+ "unable to obtain the iter domain position of the operation.");
+ }
+ return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+ mappedSizes);
+ }
+
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
"unhandled tiled implementation generation when result is not "
"accessed using a permuted projection");
}
-
- auto numLoops = linalgOp.getNumLoops();
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes, mappedOffsets,
+ mappedSizes);
auto tilingInterfaceOp = cast<TilingInterface>(op);
- SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
- iterationTileSizes(numLoops);
- if (!indexingMap.isPermutation()) {
- SmallVector<Range> iterationDomain =
- tilingInterfaceOp.getIterationDomain(b);
- for (const auto &range : llvm::enumerate(iterationDomain)) {
- iterationTileOffsets[range.index()] = range.value().offset;
- iterationTileSizes[range.index()] = range.value().size;
- }
- }
- for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
- unsigned dimPosition =
- cast<AffineDimExpr>(resultExpr.value()).getPosition();
- iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
- iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
- }
-
FailureOr<TilingResult> tilingResult =
- tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
- iterationTileSizes);
+ tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+ if (failed(tilingResult))
+ return failure();
+
if (tilingResult->tiledOps.size() != 1)
return op->emitOpError("failed to generate tiled implementation");
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69df..45c8f8362ad581 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
@@ -798,6 +799,59 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
// tileConsumerAndFuseProducersUsingSCF implementation.
//===----------------------------------------------------------------------===//
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFFor(OpOperand &source,
+ unsigned &operandNumber) {
+ // Step 1. Fetch the corresponding output
+ // TODO(avarma): Make it generic for multiple values yielding scf.for.
+ unsigned yieldOperandNumber = source.getOperandNumber();
+ Value resultingValue =
+ source.getOwner()->getParentOp()->getResult(yieldOperandNumber);
+
+ // Step 3. Get users.
+ std::optional<OpOperand *> destinationIterArg;
+ Operation *untiledConsumer;
+ for (Operation *user : resultingValue.getUsers()) {
+ // TODO(avarma): Address the case where the consumer op itself can return
+ // more than one result.
+ for (Value operand : user->getOperands()) {
+ if (operand == resultingValue) {
+ untiledConsumer = user;
+ break;
+ }
+ operandNumber++;
+ }
+ break;
+ }
+ return {untiledConsumer, destinationIterArg};
+}
+
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFForall(OpOperand *source,
+ unsigned &operandNumber) {
+ // Step 1. Fetch the corresponding output
+ // TODO(avarma): Make it generic for multiple values yielding scf.forall.
+ auto iterArg = dyn_cast<BlockArgument>(source->get());
+ Value resultingValue = iterArg.getOwner()->getParentOp()->getResult(0);
+
+ // Step 3. Get users.
+ std::optional<OpOperand *> destinationIterArg;
+ Operation *untiledConsumer;
+ for (Operation *user : resultingValue.getUsers()) {
+ // TODO(avarma): Address the case where the consumer op itself can return
+ // more than one result.
+ for (Value operand : user->getOperands()) {
+ if (operand == resultingValue) {
+ untiledConsumer = user;
+ break;
+ }
+ operandNumber++;
+ }
+ break;
+ }
+ return {untiledConsumer, destinationIterArg};
+}
+
/// Return the untiled producer whose slice is used in a tiled consumer. The
/// method traverses the tile loop nest (`loops`) if needed, and returns the
/// `iter_args` of the outer most that is encountered. Traversing the iter_args
@@ -820,6 +874,463 @@ getUntiledProducerFromSliceSource(OpOperand *source,
return {dyn_cast<OpResult>(source->get()), destinationIterArg};
}
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+ RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+ // 1. Get the consumer of the source.
+ unsigned operandNumber = 0;
+ auto [consumerOp, destinationInitArg] =
+ getUntiledConsumerFromSliceDestSCFForall(
+ &candidateSliceOp.getDestMutable(), operandNumber);
+ if (!consumerOp)
+ return failure();
+ OpBuilder::InsertionGuard g(rewriter);
+ // Using candidateSliceOp->getParentOp() because we have the following case :-
+ // scf.forall.in_parallel {
+ // tensor.parallel_insert_slice ...
+ // }
+ rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+ Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+ // Check consumer has tiling interface.
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ llvm::outs() << "consumer is not a TileableInterface: " << *consumerOp
+ << "\n";
+ return failure();
+ }
+
+ // Check containing op is "scf::ForallOp".
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp) {
+ llvm::outs() << "containing op is not a scf.forall: " << containingOp
+ << "\n";
+ return failure();
+ }
+
+ // Check consumer don't use more than one result of containingOp.
+ Value bridge(nullptr);
+ SmallVector<unsigned> operandNums;
+ for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+ if (opd.getDefiningOp() == containingOp) {
+ operandNums.push_back(idx);
+ if (!bridge) {
+ bridge = opd;
+ } else if (bridge != opd) {
+ llvm::outs()
+ << "consumer's operand use more than one containingOp's result"
+ << "\n";
+ return failure();
+ }
+ }
+ }
+
+ // TODO: We have to init result of consumer before scf.forall, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ // Check consumer has DestinationStyleOpInterface.
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp) {
+ llvm::outs() << "consumer op should have destination style op interface"
+ << "\n";
+ return failure();
+ }
+
+ // Check consumer doon't use scf.forall's output as init.
+ SmallVector<Value> dpsInits = llvm::to_vector<4>(
+ llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+ if (llvm::is_contained(dpsInits, bridge)) {
+ llvm::outs() << "consumer op take result of scf.forall as init"
+ << "\n";
+ return failure();
+ }
+
+ // Check result was inserted only once.
+ int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+ auto bridgeBlockArg = forallOp.getRegionOutArgs()[bridgeResultIdx];
+ scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+ tensor::ParallelInsertSliceOp targetInsertOp(nullptr);
+ for (Operation &op : terminatorOp.getRegion().front().getOperations()) {
+ auto parallelInsertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+ if (parallelInsertSliceOp.getDest() == bridgeBlockArg) {
+ if (!targetInsertOp) {
+ targetInsertOp = parallelInsertSliceOp;
+ } else {
+ llvm::outs() << "containingOp's result inserted multi time"
+ << "\n";
+ return failure();
+ }
+ }
+ }
+
+ if (!targetInsertOp) {
+ llvm::outs() << "containingOp's result was not inserted"
+ << "\n";
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> offsets = targetInsertOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = targetInsertOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = targetInsertOp.getMixedStrides();
+
+ // Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+ if (auto attr = foldRes.dyn_cast<Attribute>()) {
+ return cast<IntegerAttr>(attr).getInt() != 1;
+ }
+ return true;
+ })) {
+ llvm::outs() << "containingOp's result yield with stride"
+ << "\n";
+ return failure();
+ }
+
+ Location loc = forallOp.getLoc();
+ rewriter.setInsertionPoint(terminatorOp);
+
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+ // Try to get iter domain position from input position.
+ if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+ rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ llvm::outs() << "can't get iter domain position from input position"
+ << "\n";
+ return failure();
+ }
+
+ // Try to get all containing op result's position from iter domain position.
+ llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+ llvm::SmallVector<OpFoldResult>>>
+ resultPositions(consumerOp->getNumResults());
+ for (auto [idx, v] : llvm::enumerate(consumerOp->getResults())) {
+ if (failed(tileableConsumer.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultPositions[idx].first, resultPositions[idx].second))) {
+ llvm::outs()
+ << "can't get result domain position from iter domain position"
+ << "\n";
+ return failure();
+ }
+ }
+
+ // All check passed, try to fuse consumer.
+ // Create tiled implementation of containing op.
+ FailureOr<TilingResult> tileAndFuseResult =
+ tileableConsumer.getTiledImplementationFromOperandPosition(
+ rewriter, operandNums.front(), offsets, sizes);
+ if (failed(tileAndFuseResult)) {
+ llvm::outs() << "get tiled implementation failed"
+ << "\n";
+ return failure();
+ }
+
+ auto tiledOps = tileAndFuseResult->tiledOps;
+ if (failed(tileAndFuseResult) || tiledOps.size() != 1) {
+ llvm::outs() << "failed to tile consumer op: " << *tileableConsumer << "\n";
+ return failure();
+ }
+
+ // Replace tiled op's operand.
+ for (auto operandNum : operandNums) {
+ tiledOps[0]->setOperand(operandNum, targetInsertOp.getSource());
+ }
+ rewriter.replaceUsesWithIf(bridge, forallOp.getOutputs()[bridgeResultIdx],
+ [&](OpOperand &use) {
+ Operation *op = use.getOwner();
+ return forallOp->isProperAncestor(op);
+ });
+
+ SmallVector<Value> newOuts(forallOp.getOutputs());
+ newOuts.append(dpsInits);
+
+ // Create new scf.forall op.
+ rewriter.setInsertionPoint(consumerOp);
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+ rewriter.eraseBlock(newforallOp.getBody());
+ newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+ for (auto v : dpsInits) ...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
99723e4
to
8e044b2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great start! I reviewed this to some extent. I left one main comment about how to restructure this to make this close to tileConsumerAndFuseProducer
that makes it easy to handle some of the destination args correctly. Let me know if you want to chat about this more in a setting with higher-bandwidth.
8e044b2
to
75be119
Compare
Thank you so much @MaheshRavishankar for such detailed implementation pointers! It was super helpful! I have addressed other comments as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Abhishek-Varma ! This is looking much closer. I have another round of comments, but I think this is looking pretty close to being landable....
75be119
to
f2da670
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @Abhishek-Varma . I left a comment, if it is unclear we can do some quick pair programming to get this working, or I can take this PR and get a better idea of the state of the code at each point you are seeing here.
f2da670
to
4fd26b6
Compare
Thank you @MaheshRavishankar @ftynse for your review comments! The one involving I've addressed all your comments except the ones related to The major changes in the latest push in this PR are :-
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I think I have seen this enough and I am comfortable with this landing. Over time might be better to collapse the scf.for implementation and scf.forall implementation into one using loop interface and maybe adding appropriate interface methods.
Thanks @Abhishek-Varma
Thanks @Abhishek-Varma lets land this PR after #85528 lands. |
4fd26b6
to
53e607e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First of all, thanks for your contribution @Abhishek-Varma
Now, I am afraid I cannot subscribe to any large code changes to tiling transforms until #77874 is addressed to my satisfaction. I already gave a pass 6 months ago and another one 3 months ago when things were supposed to be addressed "in short order".
We're now past tech-debt reduction time:
- first let's address [mlir][TilingInterface] Use
LoopLikeOpInterface
in tiling using SCF to unify tiling withscf.for
andscf.forall
. #77874, - then let's address the reuse as also pointed out by @ftynse
All of this code looks very similar to the scf.for case. Is it possible to hoist it out into a helper function, potentially with templates?
More generally, the entire 1,2,3,4, ... thing here could be a sequence of calls to appropriately named and documented functions. But not blocking on this one, only hoisting out long blocks of common code.
Thanks for helping move this forward!
At this time, more copy-pasta of code is unwelcome. Let's first reduce the duplication
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for pushing on this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please bear with me. This is a rather big and complex patch, I need time to properly re-review this, in particular design comments and tests.
operandToReplace.set(sliceOp.getResult()); | ||
} else if (auto sliceOp = | ||
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) { | ||
operandToReplace.set(sliceOp.getSource()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems off. Here the source of the sliceOp
is the tiled operation, but this is used to replace the operand of the cloned consumer which is untiled. This is also potentially related to the extract_slice
issue pointed out by @qedawkins earlier. This is what I think this should be doing
- Create a new
tensor.insert_slice
that represent the insertion. In thescf.for
case this is a clone of thecandidateSliceOp
. In thescf.forall
case this is created from thetensor.parallel_insert_slice
except now it has a result. - Clone the consumer op (already done here). Replace the operand in the consumer along which fusion is happening with the result of the
tensor.insert_slice
. - Call the method
tensor::replaceInsertSliceWithTiledConsumer
with the cloned insert slice op and the cloned consumer op - Now replace all
tensor.extract_slice
uses of the clonedtensor.insert_slice
with the source of the clonedtensor.insert_slice
. By construction all the slice uses of thetensor.insert_slice
are exactly the same shape as the source.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much @MaheshRavishankar ! This made so much sense and has really helped clean up the issue with tensor.extract_slice
!
I've addressed it in the latest push.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By construction all the slice uses of the tensor.insert_slice are exactly the same shape as the source.
Where does this guarantee come from? Looking at the implementation of replaceInsertSliceWithTiledConsumer
it is just calling getTiledImplementationFromOperandTile
without passing in the source of the insert_slice
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding here an explanation in case it helps :-
State of the IR after step 3 of the algo in the above thread :
%clone_insert_slice = cloned tensor.insert_slice :
%source<32> into %dest<64> | OFFSET | STRIDES | SIZES
%tiled_operand = tensor.extract_slice:
%clone_insert_slice<64> to <32> | OFFSET | STRIDES | SIZES
%res = tiled_consumer_op %tiled_operand
NOTE:
- OFFSETS | STRIDES | SIZES of both
tensor.insert_slice
andtensor.extract_slice
is SAME, therefore it is a valid intermediate IR yielded by thegetTiledImplementationFromOperandTile
API which is in fact taking those OFFSETS/STRIDES info. And this basically means thattensor.insert_slice
andtensor.extract_slice
is just inverse of each other. - Right before calling the API to generate tiled consumer op, we have
%res = consumer_op %clone_insert_slice<64>
.
Now, perhaps what @MaheshRavishankar meant above for step 4 is that by design the use of %tiled_operand
is going to be exactly the same shape as the "source" of %clone_insert_slice
. Thus we can simply replace all uses of %tiled_operand
with %source
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My main question is where in the tiling interface do we require that the tiled implementation generates this exact %tiled_operand = tensor.extract_slice | OFFSET | STRIDES | SIZES
extract slice? The only requirement I see in getTiledImplementation
is that it produces a result tile that corresponds to the given offsets/sizes/strides. I don't see any requirements given about how the operands are used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh okay.
Well, not sure about the doc comments there, but implementation-wise getTiledImplementation
does in fact operate on the operands too via makeTiledShapes here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's for linalg. Other operations can do whatever they want (not trying to be difficult, just trying to make sure we at least clarify any implicit interface API requirements like this).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By construction all the slice uses of the tensor.insert_slice are exactly the same shape as the source.
Where does this guarantee come from? Looking at the implementation of
replaceInsertSliceWithTiledConsumer
it is just callinggetTiledImplementationFromOperandTile
without passing in the source of theinsert_slice
.
You are using the slice of the operand to compute the slice of the iteration space that computes that operand. That is only possible if that is a bijection. Then the tile of the operand computed from the iteration space is going to be the same. Things have to be consistent.
Yes, we are indexing on tensor.extract_slice
and tensor.insert_slice
, which we can maybe eventually switch to an interface that allows different kind of "slices", but this is the current state right now. If we miss it, then its a "missed optimization" and not a correctness issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I'll defer to you then, but I think it would be good to document this implicit dependence on extract_slice
on the SCF api.
60a1b7b
to
7e9f0b5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this patch LGTM now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now when we come to tile and fuse of producer, the flip of this is all we should need. IIRC this is what this patch initially started with but through many many rounds of reviews has deflected from initial intent (which is now lost in all the comment threads).
As @MaheshRavishankar said, from this view, this PR LGTM overall.
Also, I have same concerned about what you are talking related to tiling interface change. However, personally, I prefer that we can open a RFC and another PR to formally solve it,
611cb4e
to
17037e9
Compare
Hi all. Please let me know if the current state of the PR is okay to land and then carry on the RFC discussion about. The growing thread of discussion is making this quite difficult to follow/manage - I tried sifting through the comments to the best of my capabilities/understanding and have accordingly reverted the state to what it was previously (with lesser conflicts) and I believe we should be able to unblock landing this PR now and shift the discussion thread as an RFC. CC: @MaheshRavishankar @ftynse @qedawkins @nicolasvasilache @Yun-Fly @cxy-1993 |
Which specific concerns? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The growing thread of discussion is making this quite difficult to follow/manage - I tried sifting through the comments to the best of my capabilities/understanding and have accordingly reverted the state to what it was previously (with lesser conflicts) and I believe we should be able to unblock landing this PR now and shift the discussion thread as an RFC.
Please add a note in the documentation of getTiledImplementationFromResultTile
that, for most operations, it should be a trivial composition of two other methods. This will significantly simplify the life of any downstream (other than yours) that will want to adopt your change.
And make sure you are comfortable with the level of understanding of the changes you are about to submit....
I am unblocking this. My arguments against the design still stand, but they date back to https://reviews.llvm.org/D127809, where I should have requested an RFC for changing something in lib/Interfaces. I won't make that mistake again, but @Abhishek-Varma is not the one who introduced that, so it's not up to him to resolve it.
The direction of how to refactor
I guess you mean that preparing tiled operand by caller in advance and than calling
In another side, I guess @MaheshRavishankar insist that how to take slice is also one component of how to do tiling on on an operation, which should not be hand over to the caller to deal with tiled operand. BTW, I think anyway we can create a dummy
It is really hard to stand by which one(or even any other new proposal) is general solution before we have much clear definition of responsibility taken by |
A follow up here is indeed what would be good to add to the interface... I will update the documentation for
I consider myself the defacto maintainer of the TilingInterface (and its use in the tiling algorithm under SCF dialect using the interface). I have spent enough time here understand what is implemented here, and I am comfortable with this landing. Please redirect any concerns/issues you see to me, and I will try to address/contextualize them.
Since this was originally introduced by me, I am happy to chat more. Value your concerns a lot here, but we might be approaching this from different angles, but there is no reason why we cant find a meet point. I am happy to setup some time to discuss your concerns offline and reach a shared understanding. |
I think #91878 addresses the concerns in this comment. That is also landing soon. @nicolasvasilache please unblock if you agree. |
Will look tomorrow, thanks for flagging! |
Assuming #91878 gets finished and landed, my main concern does not hold anymore.
CI failures seem unrelated. Merging this change. |
-- This commit adds an API to fuse consumer to a producer within scf.for/scf.forall loop.
Signed-off-by: Abhishek Varma [email protected]