Skip to content

Commit 2b2ce50

Browse files
[MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (#88712)
This commit adds an API (`tileAndFuseConsumerOfSlice`) to fuse consumer to a producer within scf.for/scf.forall loop. To support this two new methods are added to the `TilingInterface` - `getIterationDomainTileFromOperandTile` - `getTiledImplementationFromOperandTile`. Consumer operations that implement this method can be used to be fused with tiled producer operands in a manner similar to (but essentially the inverse of) the fusion of an untiled producer with a tiled consumer. Note that this only does one `tiled producer` -> `consumer` fusion. This could be called repeatedly for fusing multiple consumers. The current implementation also is conservative in when this kicks in (like single use of the value returned by the inter-tile loops that surround the tiled producer, etc.) These can be relaxed over time. Signed-off-by: Abhishek Varma <[email protected]> --------- Signed-off-by: Abhishek Varma <[email protected]> Signed-off-by: Abhishek Varma <[email protected]> Co-authored-by: cxy <[email protected]>
1 parent 9e22c7a commit 2b2ce50

File tree

10 files changed

+1085
-29
lines changed

10 files changed

+1085
-29
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/PatternMatch.h"
1515
#include "mlir/Interfaces/LoopLikeInterface.h"
1616
#include "mlir/Interfaces/TilingInterface.h"
17+
#include "mlir/Interfaces/ViewLikeInterface.h"
1718

1819
#include <deque>
1920

@@ -239,6 +240,19 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
239240
TilingInterface consumer,
240241
const SCFTileAndFuseOptions &options);
241242

243+
/// Fuse the consumer of the source of `candidateSliceOp` by computing the
244+
/// required slice of the consumer in-place. Note that the method
245+
/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
246+
/// value but does not delete the slice operation.
247+
struct SCFFuseConsumerOfSliceResult {
248+
OpOperand *origConsumerOperand; // Original untiled consumer's operand.
249+
OpOperand
250+
*tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
251+
SmallVector<Operation *> tiledOps;
252+
};
253+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
254+
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
255+
242256
/// Method to lower an `op` that implements the `TilingInterface` to
243257
/// loops/scalars.
244258
FailureOr<SmallVector<scf::ForOp>>

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1313
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Interfaces/ViewLikeInterface.h"
1415

1516
namespace mlir {
1617

@@ -22,14 +23,21 @@ namespace tensor {
2223
// Patterns
2324
//===----------------------------------------------------------------------===//
2425

25-
/// Pattern to swap an `tensor.extract_slice` with its producer when the
26+
/// Method to swap an `tensor.extract_slice` with its producer when the
2627
/// producer implements the `TilingInterface`. The pattern itself does not
2728
/// provide a mechanism to control where the application happens. With use of
2829
/// transform dialect that control is done within the transform dialect. Other
2930
/// use cases can inherit from this pattern and add necessary controls.
3031
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
3132
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
3233

34+
/// Method to swap an `tensor.insert_slice` with its consumer when the
35+
/// consumer implements the `TilingInterface`.
36+
FailureOr<TilingResult>
37+
replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
38+
OffsetSizeAndStrideOpInterface sliceOp,
39+
OpOperand &consumerOp);
40+
3341
//===----------------------------------------------------------------------===//
3442
// Populate functions.
3543
//===----------------------------------------------------------------------===//

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
6363
The method returns the operation that is the tiled
6464
implementation.
6565
}],
66-
/*retType=*/"FailureOr<TilingResult>",
66+
/*retType=*/"FailureOr<::mlir::TilingResult>",
6767
/*methodName=*/"getTiledImplementation",
6868
/*args=*/(ins
6969
"OpBuilder &":$b,
@@ -82,7 +82,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
8282
by the tiled implementation. Expects the same `offsets` and `sizes` as
8383
used to obtain the tiled implementation of the operation.
8484
}],
85-
/*retType=*/"LogicalResult",
85+
/*retType=*/"::mlir::LogicalResult",
8686
/*methodName=*/"getResultTilePosition",
8787
/*args=*/(ins
8888
"OpBuilder &":$b,
@@ -96,6 +96,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
9696
return failure();
9797
}]
9898
>,
99+
InterfaceMethod<
100+
/*desc=*/[{
101+
Method to return the tile of the iteration domain where
102+
values from the given tile of the operand are used.
103+
}],
104+
/*retType=*/"::mlir::LogicalResult",
105+
/*methodName=*/"getIterationDomainTileFromOperandTile",
106+
/*args=*/(ins
107+
"OpBuilder &":$b,
108+
"unsigned":$operandNumber,
109+
"ArrayRef<OpFoldResult> ":$offsets,
110+
"ArrayRef<OpFoldResult> ":$sizes,
111+
"SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
112+
"SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
113+
/*methodBody=*/"",
114+
/*defaultImplementation=*/[{
115+
return failure();
116+
}]
117+
>,
99118
InterfaceMethod<
100119
/*desc=*/[{
101120
Method to generate the code that produces a tile of the result.
@@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
119138
iteration space).
120139
- `sizes` provides the size of the tile.
121140
}],
122-
/*retType=*/"FailureOr<TilingResult>",
141+
/*retType=*/"FailureOr<::mlir::TilingResult>",
123142
/*methodName=*/"generateResultTileValue",
124143
/*args=*/(ins
125144
"OpBuilder &":$b,
@@ -131,6 +150,45 @@ def TilingInterface : OpInterface<"TilingInterface"> {
131150
return failure();
132151
}]
133152
>,
153+
InterfaceMethod<
154+
/*desc=*/[{
155+
Method to generate the tiled implementation of an operation from
156+
operand tile position.
157+
158+
NOTE: For most operations, this should be a trivial composition of
159+
getIterationDomainTileFromOperandTile and getTiledImplementation.
160+
161+
Generates the IR that computes the tiled implementation of an
162+
operation from operand tile. The `offsets` and `sizes`
163+
describe the tile of the operand required. This is different from
164+
`getTiledImplementation` which generates the tiled
165+
implementation of the operation given a tile of the
166+
iteration space. This method generates a tiled
167+
implementation of the operation based on the tile of the
168+
operand required. This method enables consumer fusion by using
169+
tile and fuse. The method returns failure if the operation
170+
can't be tiled to generate the operand tile. In practical terms
171+
this implies it cannot be tiled and fused with its producers.
172+
173+
- `offsets` provides the offset of the tile in the coordinate system
174+
of the original iteration space, i.e., if an iteration space
175+
dimension had non-zero offset, it must be included in the offset
176+
provided here (as opposed to zero-based offset "relative" to the
177+
iteration space).
178+
- `sizes` provides the size of the tile.
179+
}],
180+
/*retType=*/"FailureOr<::mlir::TilingResult>",
181+
/*methodName=*/"getTiledImplementationFromOperandTile",
182+
/*args=*/(ins
183+
"OpBuilder &":$b,
184+
"unsigned":$operandNumber,
185+
"ArrayRef<OpFoldResult>":$offsets,
186+
"ArrayRef<OpFoldResult>":$sizes),
187+
/*methodBody=*/"",
188+
/*defaultImplementation=*/[{
189+
return failure();
190+
}]
191+
>,
134192
InterfaceMethod<
135193
/*desc=*/[{
136194
Generates the scalar implementation of the operation.
@@ -142,7 +200,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
142200
transformations are done, this method can be used to lower to scalar
143201
code that can then be lowered to LLVM or SPIR-V dialects.
144202
}],
145-
/*retType=*/"LogicalResult",
203+
/*retType=*/"::mlir::LogicalResult",
146204
/*methodName=*/"generateScalarImplementation",
147205
/*args=*/(ins
148206
"OpBuilder &":$b,

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 80 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
110110
}));
111111
}
112112

113-
// Instantiate the tiled implementation of the operation.
113+
/// Instantiate the tiled implementation of the operation.
114114
FailureOr<TilingResult>
115115
getTiledImplementation(Operation *op, OpBuilder &b,
116116
ArrayRef<OpFoldResult> offsets,
@@ -132,8 +132,63 @@ struct LinalgOpTilingInterface
132132
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
133133
}
134134

135-
// Return the details of the output tile generated by the tiled
136-
// implementation.
135+
/// Utility to fetch the offsets and sizes when applied as per the indexing
136+
/// map of the linalg op. This helps in fusing the linalg op as a consumer of
137+
/// a given slice op.
138+
void
139+
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
140+
ArrayRef<OpFoldResult> offsets,
141+
ArrayRef<OpFoldResult> sizes,
142+
SmallVectorImpl<OpFoldResult> &mappedOffsets,
143+
SmallVectorImpl<OpFoldResult> &mappedSizes) const {
144+
unsigned numLoops = linalgOp.getNumLoops();
145+
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
146+
mappedOffsets.resize(numLoops);
147+
mappedSizes.resize(numLoops);
148+
if (!indexingMap.isPermutation()) {
149+
SmallVector<Range> iterationDomain =
150+
tilingInterfaceOp.getIterationDomain(b);
151+
for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
152+
mappedOffsets[index] = value.offset;
153+
mappedSizes[index] = value.size;
154+
}
155+
}
156+
for (const auto &&[index, value] :
157+
llvm::enumerate(indexingMap.getResults())) {
158+
unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
159+
mappedOffsets[dimPosition] = offsets[index];
160+
mappedSizes[dimPosition] = sizes[index];
161+
}
162+
}
163+
164+
/// Method to return the position of the result tile computed by the tiled
165+
/// operation.
166+
LogicalResult getIterationDomainTileFromOperandTile(
167+
Operation *op, OpBuilder &b, unsigned operandNumber,
168+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
169+
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
170+
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
171+
auto linalgOp = cast<LinalgOp>(op);
172+
173+
// Check that the indexing map used for the operand is a projected
174+
// permutation. This could be relaxed with a more general approach that can
175+
// map the offsets and sizes from the operand to iteration space tiles
176+
// (filling in full extent for dimensions not used to access the result).
177+
AffineMap indexingMap =
178+
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
179+
if (!indexingMap.isProjectedPermutation()) {
180+
return op->emitError()
181+
<< "unhandled get iter domain position when operand is not "
182+
"accessed using a permuted projection";
183+
}
184+
185+
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
186+
iterDomainOffsets, iterDomainSizes);
187+
return success();
188+
}
189+
190+
/// Return the details of the output tile generated by the tiled
191+
/// implementation.
137192
LogicalResult
138193
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
139194
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +232,16 @@ struct LinalgOpTilingInterface
177232
"unhandled tiled implementation generation when result is not "
178233
"accessed using a permuted projection");
179234
}
180-
181-
auto numLoops = linalgOp.getNumLoops();
235+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
236+
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
237+
mappedOffsets, mappedSizes);
182238
auto tilingInterfaceOp = cast<TilingInterface>(op);
183-
SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
184-
iterationTileSizes(numLoops);
185-
if (!indexingMap.isPermutation()) {
186-
SmallVector<Range> iterationDomain =
187-
tilingInterfaceOp.getIterationDomain(b);
188-
for (const auto &range : llvm::enumerate(iterationDomain)) {
189-
iterationTileOffsets[range.index()] = range.value().offset;
190-
iterationTileSizes[range.index()] = range.value().size;
191-
}
192-
}
193-
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
194-
unsigned dimPosition =
195-
cast<AffineDimExpr>(resultExpr.value()).getPosition();
196-
iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
197-
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
198-
}
199-
200239
FailureOr<TilingResult> tilingResult =
201-
tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
202-
iterationTileSizes);
240+
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
241+
242+
if (failed(tilingResult))
243+
return failure();
244+
203245
if (tilingResult->tiledOps.size() != 1)
204246
return op->emitOpError("failed to generate tiled implementation");
205247

@@ -208,6 +250,20 @@ struct LinalgOpTilingInterface
208250
SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
209251
}
210252

253+
/// Method to generate the tiled implementation of an operation from the tile
254+
/// of the operand.
255+
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
256+
Operation *op, OpBuilder &b, unsigned operandNumber,
257+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
258+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
259+
if (failed(getIterationDomainTileFromOperandTile(
260+
op, b, operandNumber, offsets, sizes, mappedOffsets,
261+
mappedSizes))) {
262+
return failure();
263+
}
264+
return getTiledImplementation(op, b, mappedOffsets, mappedSizes);
265+
}
266+
211267
LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
212268
Location loc,
213269
ValueRange ivs) const {

0 commit comments

Comments
 (0)