Skip to content

Commit 8e044b2

Browse files
[MLIR][SCF] Add an API to fuse consumer to a producer within scf loop
-- This commit adds an API to fuse consumer to a producer within scf.for/scf.forall loop. Signed-off-by: Abhishek Varma <[email protected]>
1 parent c006b90 commit 8e044b2

File tree

7 files changed

+804
-21
lines changed

7 files changed

+804
-21
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,19 @@ struct SCFTileAndFuseOptions {
126126
}
127127
};
128128

129+
/// Fuse the consumer of the source of `candidateSliceOp` by computing the
130+
/// required slice of the consumer in-place. Note that the method
131+
/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
132+
/// value but does not delete the slice operation.
133+
struct SCFFuseConsumerOfSliceResult {
134+
Operation *origConsumer; // Original untiled consumer.
135+
Value tiledAndFusedConsumer; // Tile and fused consumer value.
136+
SmallVector<Operation *> tiledOps;
137+
};
138+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
139+
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
140+
bool useSCFFor);
141+
129142
/// Fuse the producer of the source of `candidateSliceOp` by computing the
130143
/// required slice of the producer in-place. Note that the method
131144
/// replaces the uses of `candidateSliceOp` with the tiled and fused producer

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
7474
return {};
7575
}]
7676
>,
77+
InterfaceMethod<
78+
/*desc=*/[{
79+
Method to return iterator domain position computed by the
80+
input operand position.
81+
}],
82+
/*retType=*/"LogicalResult",
83+
/*methodName=*/"getIterDomainTilePositionFromOperandPosition",
84+
/*args=*/(ins
85+
"OpBuilder &":$b,
86+
"unsigned":$operandNumber,
87+
"ArrayRef<OpFoldResult> ":$offsets,
88+
"ArrayRef<OpFoldResult> ":$sizes,
89+
"SmallVector<OpFoldResult> &":$iterDomainOffsets,
90+
"SmallVector<OpFoldResult> &":$iterDomainSizes),
91+
/*methodBody=*/"",
92+
/*defaultImplementation=*/[{
93+
return failure();
94+
}]
95+
>,
7796
InterfaceMethod<
7897
/*desc=*/[{
7998
Method to return the position of the result tile computed by the tiled operation.
@@ -96,6 +115,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
96115
return failure();
97116
}]
98117
>,
118+
InterfaceMethod<
119+
/*desc=*/[{
120+
Method to generate the tiled implementation of an operation from
121+
operand position.
122+
123+
Generates the IR that generate the tiled implementation of an
124+
operation from operand position. The `offsets` and `sizes`
125+
describe the tile of the operand required. This is different from
126+
`getTiledImplementation` which generates the tiled
127+
implementation of the operation given a tile of the
128+
iteration space. This method generates a tiled
129+
implementation of the operation based on the position of the
130+
operand required. This method enables fusion consumer by using
131+
tile and fuse. The method returns failure if the operation
132+
can't be tiled to generate the operand tile. In practical terms
133+
this implies it cannot be tiled and fused with its producers.
134+
135+
- `offsets` provides the offset of the tile in the coordinate system
136+
of the original iteration space, i.e., if an iteration space
137+
dimension had non-zero offset, it must be included in the offset
138+
provided here (as opposed to zero-based offset "relative" to the
139+
iteration space).
140+
- `sizes` provides the size of the tile.
141+
}],
142+
/*retType=*/"FailureOr<TilingResult>",
143+
/*methodName=*/"getTiledImplementationFromOperandPosition",
144+
/*args=*/(ins
145+
"OpBuilder &":$b,
146+
"unsigned":$operandNumber,
147+
"ArrayRef<OpFoldResult>":$offsets,
148+
"ArrayRef<OpFoldResult>":$sizes),
149+
/*methodBody=*/"",
150+
/*defaultImplementation=*/[{
151+
return failure();
152+
}]
153+
>,
99154
InterfaceMethod<
100155
/*desc=*/[{
101156
Method to generate the code that produces a tile of the result.

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 75 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,59 @@ struct LinalgOpTilingInterface
132132
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
133133
}
134134

135+
void getMappedOffsetAndSize(Operation *op, OpBuilder &b,
136+
AffineMap indexingMap,
137+
ArrayRef<OpFoldResult> offsets,
138+
ArrayRef<OpFoldResult> sizes,
139+
SmallVector<OpFoldResult> &mappedOffsets,
140+
SmallVector<OpFoldResult> &mappedSizes) const {
141+
auto linalgOp = cast<LinalgOp>(op);
142+
auto numLoops = linalgOp.getNumLoops();
143+
auto tilingInterfaceOp = cast<TilingInterface>(op);
144+
mappedOffsets.resize(numLoops);
145+
mappedSizes.resize(numLoops);
146+
if (!indexingMap.isPermutation()) {
147+
SmallVector<Range> iterationDomain =
148+
tilingInterfaceOp.getIterationDomain(b);
149+
for (const auto &range : llvm::enumerate(iterationDomain)) {
150+
mappedOffsets[range.index()] = range.value().offset;
151+
mappedSizes[range.index()] = range.value().size;
152+
}
153+
}
154+
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
155+
unsigned dimPosition =
156+
cast<AffineDimExpr>(resultExpr.value()).getPosition();
157+
mappedOffsets[dimPosition] = offsets[resultExpr.index()];
158+
mappedSizes[dimPosition] = sizes[resultExpr.index()];
159+
}
160+
}
161+
162+
// Return the details of the output tile generated by the tiled
163+
// implementation.
164+
LogicalResult getIterDomainTilePositionFromOperandPosition(
165+
Operation *op, OpBuilder &b, unsigned operandNumber,
166+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
167+
SmallVector<OpFoldResult> &iterDomainOffsets,
168+
SmallVector<OpFoldResult> &iterDomainSizes) const {
169+
auto linalgOp = cast<LinalgOp>(op);
170+
171+
// Check that the indexing map used for the operand is a projected
172+
// permutation. This could be relaxed with a more general approach that can
173+
// map the offsets and sizes from the operand to iteration space tiles
174+
// (filling in full extent for dimensions not used to access the result).
175+
AffineMap indexingMap =
176+
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
177+
if (!indexingMap.isProjectedPermutation()) {
178+
return op->emitOpError(
179+
"unhandled get iter domain position when operand is not "
180+
"accessed using a permuted projection");
181+
}
182+
183+
getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes,
184+
iterDomainOffsets, iterDomainSizes);
185+
return success();
186+
}
187+
135188
// Return the details of the output tile generated by the tiled
136189
// implementation.
137190
LogicalResult
@@ -160,6 +213,20 @@ struct LinalgOpTilingInterface
160213
return success();
161214
}
162215

216+
FailureOr<TilingResult> getTiledImplementationFromOperandPosition(
217+
Operation *op, OpBuilder &b, unsigned operandNumber,
218+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
219+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
220+
auto tilingInterfaceOp = cast<TilingInterface>(op);
221+
if (failed(tilingInterfaceOp.getIterDomainTilePositionFromOperandPosition(
222+
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
223+
return op->emitOpError(
224+
"unable to obtain the iter domain position of the operation.");
225+
}
226+
return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
227+
mappedSizes);
228+
}
229+
163230
FailureOr<TilingResult>
164231
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
165232
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
177244
"unhandled tiled implementation generation when result is not "
178245
"accessed using a permuted projection");
179246
}
180-
181-
auto numLoops = linalgOp.getNumLoops();
247+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
248+
getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes, mappedOffsets,
249+
mappedSizes);
182250
auto tilingInterfaceOp = cast<TilingInterface>(op);
183-
SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
184-
iterationTileSizes(numLoops);
185-
if (!indexingMap.isPermutation()) {
186-
SmallVector<Range> iterationDomain =
187-
tilingInterfaceOp.getIterationDomain(b);
188-
for (const auto &range : llvm::enumerate(iterationDomain)) {
189-
iterationTileOffsets[range.index()] = range.value().offset;
190-
iterationTileSizes[range.index()] = range.value().size;
191-
}
192-
}
193-
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
194-
unsigned dimPosition =
195-
cast<AffineDimExpr>(resultExpr.value()).getPosition();
196-
iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
197-
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
198-
}
199-
200251
FailureOr<TilingResult> tilingResult =
201-
tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
202-
iterationTileSizes);
252+
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
253+
254+
if (failed(tilingResult))
255+
return failure();
256+
203257
if (tilingResult->tiledOps.size() != 1)
204258
return op->emitOpError("failed to generate tiled implementation");
205259

0 commit comments

Comments
 (0)