Skip to content

[MLIR][SCF] Add an API to fuse consumer to a producer within scf loop #88712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

#include <deque>

Expand Down Expand Up @@ -239,6 +240,19 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
TilingInterface consumer,
const SCFTileAndFuseOptions &options);

/// Fuse the consumer of the source of `candidateSliceOp` by computing the
/// required slice of the consumer in-place. Note that the method
/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
/// value but does not delete the slice operation.
struct SCFFuseConsumerOfSliceResult {
OpOperand *origConsumerOperand; // Original untiled consumer's operand.
OpOperand
*tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
SmallVector<Operation *> tiledOps;
};
FailureOr<scf::SCFFuseConsumerOfSliceResult>
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);

/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
FailureOr<SmallVector<scf::ForOp>>
Expand Down
10 changes: 9 additions & 1 deletion mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

namespace mlir {

Expand All @@ -22,14 +23,21 @@ namespace tensor {
// Patterns
//===----------------------------------------------------------------------===//

/// Pattern to swap an `tensor.extract_slice` with its producer when the
/// Method to swap an `tensor.extract_slice` with its producer when the
/// producer implements the `TilingInterface`. The pattern itself does not
/// provide a mechanism to control where the application happens. With use of
/// transform dialect that control is done within the transform dialect. Other
/// use cases can inherit from this pattern and add necessary controls.
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);

/// Method to swap an `tensor.insert_slice` with its consumer when the
/// consumer implements the `TilingInterface`.
FailureOr<TilingResult>
replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
OffsetSizeAndStrideOpInterface sliceOp,
OpOperand &consumerOp);

//===----------------------------------------------------------------------===//
// Populate functions.
//===----------------------------------------------------------------------===//
Expand Down
69 changes: 65 additions & 4 deletions mlir/include/mlir/Interfaces/TilingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
The method returns the operation that is the tiled
implementation.
}],
/*retType=*/"FailureOr<TilingResult>",
/*retType=*/"FailureOr<::mlir::TilingResult>",
/*methodName=*/"getTiledImplementation",
/*args=*/(ins
"OpBuilder &":$b,
Expand All @@ -82,7 +82,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
by the tiled implementation. Expects the same `offsets` and `sizes` as
used to obtain the tiled implementation of the operation.
}],
/*retType=*/"LogicalResult",
/*retType=*/"::mlir::LogicalResult",
/*methodName=*/"getResultTilePosition",
/*args=*/(ins
"OpBuilder &":$b,
Expand All @@ -96,6 +96,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
InterfaceMethod<
/*desc=*/[{
Method to return the the tile of the iteration domain where
values from the given tile of the operand are used
}],
/*retType=*/"::mlir::LogicalResult",
/*methodName=*/"getIterationDomainTileFromOperandTile",
/*args=*/(ins
"OpBuilder &":$b,
"unsigned":$operandNumber,
"ArrayRef<OpFoldResult> ":$offsets,
"ArrayRef<OpFoldResult> ":$sizes,
"SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
"SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
}]
>,
InterfaceMethod<
/*desc=*/[{
Method to generate the code that produces a tile of the result.
Expand All @@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
iteration space).
- `sizes` provides the size of the tile.
}],
/*retType=*/"FailureOr<TilingResult>",
/*retType=*/"FailureOr<::mlir::TilingResult>",
/*methodName=*/"generateResultTileValue",
/*args=*/(ins
"OpBuilder &":$b,
Expand All @@ -131,6 +150,48 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
InterfaceMethod<
/*desc=*/[{
Method to generate the tiled implementation of an operation from
operand tile position.

Generates the IR that computes the tiled implementation of an
operation from operand tile. The `offsets` and `sizes`
describe the tile of the operand required. This is different from
`getTiledImplementation` which generates the tiled
implementation of the operation given a tile of the
iteration space. This method generates a tiled
implementation of the operation based on the tile of the
operand required. This method enables consumer fusion by using
tile and fuse. The method returns failure if the operation
can't be tiled to generate the operand tile. In practical terms
this implies it cannot be tiled and fused with its producers.

- `offsets` provides the offset of the tile in the coordinate system
of the original iteration space, i.e., if an iteration space
dimension had non-zero offset, it must be included in the offset
provided here (as opposed to zero-based offset "relative" to the
iteration space).
- `sizes` provides the size of the tile.
}],
/*retType=*/"FailureOr<::mlir::TilingResult>",
/*methodName=*/"getTiledImplementationFromOperandTile",
/*args=*/(ins
"OpBuilder &":$b,
"unsigned":$operandNumber,
"ArrayRef<OpFoldResult>":$offsets,
"ArrayRef<OpFoldResult>":$sizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
::llvm::SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
auto tilingInterfaceOp = cast<::mlir::TilingInterface>($_op.getOperation());
if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
return failure();
}
return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
}]
>,
InterfaceMethod<
/*desc=*/[{
Generates the scalar implementation of the operation.
Expand All @@ -142,7 +203,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
transformations are done, this method can be used to lower to scalar
code that can then be lowered to LLVM or SPIR-V dialects.
}],
/*retType=*/"LogicalResult",
/*retType=*/"::mlir::LogicalResult",
/*methodName=*/"generateScalarImplementation",
/*args=*/(ins
"OpBuilder &":$b,
Expand Down
90 changes: 66 additions & 24 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
}));
}

// Instantiate the tiled implementation of the operation.
/// Instantiate the tiled implementation of the operation.
FailureOr<TilingResult>
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
Expand All @@ -132,8 +132,63 @@ struct LinalgOpTilingInterface
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}

// Return the details of the output tile generated by the tiled
// implementation.
/// Utility to fetch the offsets and sizes when applied as per the indexing
/// map of the linalg op. This helps in fusing the linalg op as a consumer of
/// a given slice op.
void
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
SmallVectorImpl<OpFoldResult> &mappedOffsets,
SmallVectorImpl<OpFoldResult> &mappedSizes) const {
unsigned numLoops = linalgOp.getNumLoops();
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
mappedOffsets.resize(numLoops);
mappedSizes.resize(numLoops);
if (!indexingMap.isPermutation()) {
SmallVector<Range> iterationDomain =
tilingInterfaceOp.getIterationDomain(b);
for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
mappedOffsets[index] = value.offset;
mappedSizes[index] = value.size;
}
}
for (const auto &&[index, value] :
llvm::enumerate(indexingMap.getResults())) {
unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
mappedOffsets[dimPosition] = offsets[index];
mappedSizes[dimPosition] = sizes[index];
}
}

/// Method to return the position of the result tile computed by the tiled
/// operation.
LogicalResult getIterationDomainTileFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
auto linalgOp = cast<LinalgOp>(op);

// Check that the indexing map used for the operand is a projected
// permutation. This could be relaxed with a more general approach that can
// map the offsets and sizes from the operand to iteration space tiles
// (filling in full extent for dimensions not used to access the result).
AffineMap indexingMap =
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
if (!indexingMap.isProjectedPermutation()) {
return op->emitError()
<< "unhandled get iter domain position when operand is not "
"accessed using a permuted projection";
}

getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
iterDomainOffsets, iterDomainSizes);
return success();
}

/// Return the details of the output tile generated by the tiled
/// implementation.
LogicalResult
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
Expand Down Expand Up @@ -177,29 +232,16 @@ struct LinalgOpTilingInterface
"unhandled tiled implementation generation when result is not "
"accessed using a permuted projection");
}

auto numLoops = linalgOp.getNumLoops();
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
mappedOffsets, mappedSizes);
auto tilingInterfaceOp = cast<TilingInterface>(op);
SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
iterationTileSizes(numLoops);
if (!indexingMap.isPermutation()) {
SmallVector<Range> iterationDomain =
tilingInterfaceOp.getIterationDomain(b);
for (const auto &range : llvm::enumerate(iterationDomain)) {
iterationTileOffsets[range.index()] = range.value().offset;
iterationTileSizes[range.index()] = range.value().size;
}
}
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
unsigned dimPosition =
cast<AffineDimExpr>(resultExpr.value()).getPosition();
iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
}

FailureOr<TilingResult> tilingResult =
tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
iterationTileSizes);
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);

if (failed(tilingResult))
return failure();

if (tilingResult->tiledOps.size() != 1)
return op->emitOpError("failed to generate tiled implementation");

Expand Down
Loading
Loading