Skip to content

Commit 2a47ee0

Browse files
authored
[MLIR][Linalg] Enable fuse consumer (#85528)
This patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp. - Add interface method 'getIterDomainTilePositionFromOperandPosition' to tiling interface which get iteration domain position from operand position. - Add interface method 'getTiledImplementationFromOperandPosition' to tiling interface which generate tiled implementation according to operand position. - Implemented the above two methods and supported consumer fusion for FuseIntoContainingOp. Signed-off-by: Donald Chen
1 parent ec062f5 commit 2a47ee0

File tree

4 files changed

+149
-40
lines changed

4 files changed

+149
-40
lines changed

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
6363
The method returns the operation that is the tiled
6464
implementation.
6565
}],
66-
/*retType=*/"FailureOr<TilingResult>",
66+
/*retType=*/"FailureOr<::mlir::TilingResult>",
6767
/*methodName=*/"getTiledImplementation",
6868
/*args=*/(ins
6969
"OpBuilder &":$b,
@@ -82,15 +82,34 @@ def TilingInterface : OpInterface<"TilingInterface"> {
8282
by the tiled implementation. Expects the same `offsets` and `sizes` as
8383
used to obtain the tiled implementation of the operation.
8484
}],
85-
/*retType=*/"LogicalResult",
85+
/*retType=*/"::mlir::LogicalResult",
8686
/*methodName=*/"getResultTilePosition",
8787
/*args=*/(ins
8888
"OpBuilder &":$b,
8989
"unsigned":$resultNumber,
9090
"ArrayRef<OpFoldResult> ":$offsets,
9191
"ArrayRef<OpFoldResult> ":$sizes,
92-
"SmallVector<OpFoldResult> &":$resultOffsets,
93-
"SmallVector<OpFoldResult> &":$resultSizes),
92+
"SmallVectorImpl<OpFoldResult> &":$resultOffsets,
93+
"SmallVectorImpl<OpFoldResult> &":$resultSizes),
94+
/*methodBody=*/"",
95+
/*defaultImplementation=*/[{
96+
return failure();
97+
}]
98+
>,
99+
InterfaceMethod<
100+
/*desc=*/[{
101+
Method to return the position of iteration domain tile computed by the
102+
tiled operation.
103+
}],
104+
/*retType=*/"::mlir::LogicalResult",
105+
/*methodName=*/"getIterationDomainTileFromOperandTile",
106+
/*args=*/(ins
107+
"OpBuilder &":$b,
108+
"unsigned":$operandNumber,
109+
"ArrayRef<OpFoldResult> ":$offsets,
110+
"ArrayRef<OpFoldResult> ":$sizes,
111+
"SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
112+
"SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
94113
/*methodBody=*/"",
95114
/*defaultImplementation=*/[{
96115
return failure();
@@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
119138
iteration space).
120139
- `sizes` provides the size of the tile.
121140
}],
122-
/*retType=*/"FailureOr<TilingResult>",
141+
/*retType=*/"FailureOr<::mlir::TilingResult>",
123142
/*methodName=*/"generateResultTileValue",
124143
/*args=*/(ins
125144
"OpBuilder &":$b,
@@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
131150
return failure();
132151
}]
133152
>,
153+
InterfaceMethod<
154+
/*desc=*/[{
155+
Method to generate the tiled implementation of an operation from
156+
operand tile position.
157+
158+
Generates the IR that computes the tiled implementation of an
159+
operation from operand tile. The `offsets` and `sizes`
160+
describe the tile of the operand required. This is different from
161+
`getTiledImplementation` which generates the tiled
162+
implementation of the operation given a tile of the
163+
iteration space. This method generates a tiled
164+
implementation of the operation based on the tile of the
165+
operand required. This method enables consumer fusion by using
166+
tile and fuse. The method returns failure if the operation
167+
can't be tiled to generate the operand tile. In practical terms
168+
this implies it cannot be tiled and fused with its producers.
169+
170+
- `offsets` provides the offset of the tile in the coordinate system
171+
of the original iteration space, i.e., if an iteration space
172+
dimension had non-zero offset, it must be included in the offset
173+
provided here (as opposed to zero-based offset "relative" to the
174+
iteration space).
175+
- `sizes` provides the size of the tile.
176+
}],
177+
/*retType=*/"FailureOr<::mlir::TilingResult>",
178+
/*methodName=*/"getTiledImplementationFromOperandTile",
179+
/*args=*/(ins
180+
"OpBuilder &":$b,
181+
"unsigned":$operandNumber,
182+
"ArrayRef<OpFoldResult>":$offsets,
183+
"ArrayRef<OpFoldResult>":$sizes),
184+
/*methodBody=*/"",
185+
/*defaultImplementation=*/[{
186+
return failure();
187+
}]
188+
>,
134189
InterfaceMethod<
135190
/*desc=*/[{
136191
Generates the scalar implementation of the operation.
@@ -142,7 +197,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
142197
transformations are done, this method can be used to lower to scalar
143198
code that can then be lowered to LLVM or SPIR-V dialects.
144199
}],
145-
/*retType=*/"LogicalResult",
200+
/*retType=*/"::mlir::LogicalResult",
146201
/*methodName=*/"generateScalarImplementation",
147202
/*args=*/(ins
148203
"OpBuilder &":$b,

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,8 +2425,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
24252425

24262426
LogicalResult SoftmaxOp::getResultTilePosition(
24272427
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2428-
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2429-
SmallVector<OpFoldResult> &resultSizes) {
2428+
ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
2429+
SmallVectorImpl<OpFoldResult> &resultSizes) {
24302430
if (resultNumber == 0) {
24312431
resultOffsets.assign(offsets.begin(), offsets.end());
24322432
resultSizes.assign(sizes.begin(), sizes.end());

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
110110
}));
111111
}
112112

113-
// Instantiate the tiled implementation of the operation.
113+
/// Instantiate the tiled implementation of the operation.
114114
FailureOr<TilingResult>
115115
getTiledImplementation(Operation *op, OpBuilder &b,
116116
ArrayRef<OpFoldResult> offsets,
@@ -132,14 +132,66 @@ struct LinalgOpTilingInterface
132132
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
133133
}
134134

135-
// Return the details of the output tile generated by the tiled
136-
// implementation.
135+
void
136+
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
137+
ArrayRef<OpFoldResult> offsets,
138+
ArrayRef<OpFoldResult> sizes,
139+
SmallVectorImpl<OpFoldResult> &mappedOffsets,
140+
SmallVectorImpl<OpFoldResult> &mappedSizes) const {
141+
unsigned numLoops = linalgOp.getNumLoops();
142+
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
143+
mappedOffsets.resize(numLoops);
144+
mappedSizes.resize(numLoops);
145+
if (!indexingMap.isPermutation()) {
146+
SmallVector<Range> iterationDomain =
147+
tilingInterfaceOp.getIterationDomain(b);
148+
for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
149+
mappedOffsets[index] = value.offset;
150+
mappedSizes[index] = value.size;
151+
}
152+
}
153+
for (const auto &&[index, value] :
154+
llvm::enumerate(indexingMap.getResults())) {
155+
unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
156+
mappedOffsets[dimPosition] = offsets[index];
157+
mappedSizes[dimPosition] = sizes[index];
158+
}
159+
}
160+
161+
/// Return the details of the output tile generated by the tiled
162+
/// implementation.
163+
LogicalResult getIterationDomainTileFromOperandTile(
164+
Operation *op, OpBuilder &b, unsigned operandNumber,
165+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
166+
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
167+
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
168+
auto linalgOp = cast<LinalgOp>(op);
169+
170+
// Check that the indexing map used for the operand is a projected
171+
// permutation. This could be relaxed with a more general approach that can
172+
// map the offsets and sizes from the operand to iteration space tiles
173+
// (filling in full extent for dimensions not used to access the result).
174+
AffineMap indexingMap =
175+
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
176+
if (!indexingMap.isProjectedPermutation()) {
177+
return emitError(op->getLoc(),
178+
"unhandled get iter domain position when operand is not "
179+
"accessed using a permuted projection");
180+
}
181+
182+
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
183+
iterDomainOffsets, iterDomainSizes);
184+
return success();
185+
}
186+
187+
/// Return the details of the output tile generated by the tiled
188+
/// implementation.
137189
LogicalResult
138190
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
139191
ArrayRef<OpFoldResult> offsets,
140192
ArrayRef<OpFoldResult> sizes,
141-
SmallVector<OpFoldResult> &resultOffsets,
142-
SmallVector<OpFoldResult> &resultSizes) const {
193+
SmallVectorImpl<OpFoldResult> &resultOffsets,
194+
SmallVectorImpl<OpFoldResult> &resultSizes) const {
143195
Location loc = op->getLoc();
144196
LinalgOp linalgOp = cast<LinalgOp>(op);
145197

@@ -160,6 +212,21 @@ struct LinalgOpTilingInterface
160212
return success();
161213
}
162214

215+
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
216+
Operation *op, OpBuilder &b, unsigned operandNumber,
217+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
218+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
219+
auto tilingInterfaceOp = cast<TilingInterface>(op);
220+
if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
221+
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
222+
return emitError(
223+
op->getLoc(),
224+
"unable to obtain the iter domain position of the operation.");
225+
}
226+
return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
227+
mappedSizes);
228+
}
229+
163230
FailureOr<TilingResult>
164231
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
165232
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
177244
"unhandled tiled implementation generation when result is not "
178245
"accessed using a permuted projection");
179246
}
180-
181-
auto numLoops = linalgOp.getNumLoops();
247+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
248+
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
249+
mappedOffsets, mappedSizes);
182250
auto tilingInterfaceOp = cast<TilingInterface>(op);
183-
SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
184-
iterationTileSizes(numLoops);
185-
if (!indexingMap.isPermutation()) {
186-
SmallVector<Range> iterationDomain =
187-
tilingInterfaceOp.getIterationDomain(b);
188-
for (const auto &range : llvm::enumerate(iterationDomain)) {
189-
iterationTileOffsets[range.index()] = range.value().offset;
190-
iterationTileSizes[range.index()] = range.value().size;
191-
}
192-
}
193-
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
194-
unsigned dimPosition =
195-
cast<AffineDimExpr>(resultExpr.value()).getPosition();
196-
iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
197-
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
198-
}
199-
200251
FailureOr<TilingResult> tilingResult =
201-
tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
202-
iterationTileSizes);
252+
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
253+
254+
if (failed(tilingResult))
255+
return failure();
256+
203257
if (tilingResult->tiledOps.size() != 1)
204258
return op->emitOpError("failed to generate tiled implementation");
205259

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
6161
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
6262
ArrayRef<OpFoldResult> offsets,
6363
ArrayRef<OpFoldResult> sizes,
64-
SmallVector<OpFoldResult> &resultOffsets,
65-
SmallVector<OpFoldResult> &resultSizes) const {
64+
SmallVectorImpl<OpFoldResult> &resultOffsets,
65+
SmallVectorImpl<OpFoldResult> &resultSizes) const {
6666
resultOffsets.assign(offsets.begin(), offsets.end());
6767
resultSizes.assign(sizes.begin(), sizes.end());
6868
return success();
@@ -199,8 +199,8 @@ struct PackOpTiling
199199
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
200200
ArrayRef<OpFoldResult> offsets,
201201
ArrayRef<OpFoldResult> sizes,
202-
SmallVector<OpFoldResult> &resultOffsets,
203-
SmallVector<OpFoldResult> &resultSizes) const {
202+
SmallVectorImpl<OpFoldResult> &resultOffsets,
203+
SmallVectorImpl<OpFoldResult> &resultSizes) const {
204204
// The iteration domain is over outer dimensions of packed layout. In this
205205
// context, the outer dimensions of `resultOffsets` are `offsets`. The
206206
// inner dimensions of `resultOffsets` are zeros because tiling is not
@@ -452,8 +452,8 @@ struct UnPackOpTiling
452452
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
453453
ArrayRef<OpFoldResult> offsets,
454454
ArrayRef<OpFoldResult> sizes,
455-
SmallVector<OpFoldResult> &resultOffsets,
456-
SmallVector<OpFoldResult> &resultSizes) const {
455+
SmallVectorImpl<OpFoldResult> &resultOffsets,
456+
SmallVectorImpl<OpFoldResult> &resultSizes) const {
457457
resultOffsets = llvm::to_vector(offsets);
458458
resultSizes = llvm::to_vector(sizes);
459459
return success();

0 commit comments

Comments
 (0)