Skip to content

Commit 611cb4e

Browse files
Refactor code v1
1 parent d8f69ff commit 611cb4e

File tree

9 files changed

+104
-43
lines changed

9 files changed

+104
-43
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
100100
["getIterationDomain",
101101
"getLoopIteratorTypes",
102102
"getResultTilePosition",
103+
"getOperandTilesForIterationDomainTile",
103104
"getTiledImplementation"]>]> {
104105
let summary = "Softmax operator";
105106
let description = [{

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
4747
>,
4848
InterfaceMethod<
4949
/*desc=*/[{
50-
Method to generate the tiled implementation of an operation.
50+
Method to generate the tiled implementation of an operation given the
51+
slices of its operands.
5152

5253
The iteration space of the operation is returned by
5354
`getIterationDomain`. The caller provides the information of the
@@ -67,6 +68,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
6768
/*methodName=*/"getTiledImplementation",
6869
/*args=*/(ins
6970
"OpBuilder &":$b,
71+
"ArrayRef<Value> ": $tiledOperands,
7072
"ArrayRef<OpFoldResult> ":$offsets,
7173
"ArrayRef<OpFoldResult> ":$sizes),
7274
/*methodBody=*/"",
@@ -150,6 +152,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
150152
return failure();
151153
}]
152154
>,
155+
InterfaceMethod<
156+
/*desc=*/[{
157+
Method to generate slices of the operands.
158+
159+
- `offsets` provides the offset of the tile in the coordinate system
160+
of the original coordinate space, i.e., if a dimension from the
161+
coordinate space of the operand has a non-zero offset, it must be
162+
included in the offset provided here (as opposed to zero-based offset
163+
"relative" to the coordinate space).
164+
- `sizes` provides the size of the tile.
165+
}],
166+
/*retType=*/"SmallVector<Value>",
167+
/*methodName=*/"getOperandTilesForIterationDomainTile",
168+
/*args=*/(ins
169+
"OpBuilder &":$b,
170+
"ArrayRef<OpFoldResult>":$offsets,
171+
"ArrayRef<OpFoldResult>":$sizes),
172+
/*methodBody=*/"",
173+
/*defaultImplementation=*/[{
174+
return {};
175+
}]
176+
>,
153177
InterfaceMethod<
154178
/*desc=*/[{
155179
Method to generate the tiled implementation of an operation from
@@ -168,10 +192,10 @@ def TilingInterface : OpInterface<"TilingInterface"> {
168192
this implies it cannot be tiled and fused with its producers.
169193

170194
- `offsets` provides the offset of the tile in the coordinate system
171-
of the original iteration space, i.e., if an iteration space
172-
dimension had non-zero offset, it must be included in the offset
173-
provided here (as opposed to zero-based offset "relative" to the
174-
iteration space).
195+
of the original coordinate space, i.e., if a dimension from the
196+
coordinate space of the operand has a non-zero offset, it must be
197+
included in the offset provided here (as opposed to zero-based offset
198+
"relative" to the coordinate space).
175199
- `sizes` provides the size of the tile.
176200
}],
177201
/*retType=*/"FailureOr<::mlir::TilingResult>",
@@ -184,12 +208,15 @@ def TilingInterface : OpInterface<"TilingInterface"> {
184208
/*methodBody=*/"",
185209
/*defaultImplementation=*/[{
186210
::llvm::SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
187-
auto tilingInterfaceOp = cast<::mlir::TilingInterface>($_op.getOperation());
211+
Operation* op = $_op.getOperation();
212+
auto tilingInterfaceOp = cast<::mlir::TilingInterface>(op);
188213
if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
189214
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
190215
return failure();
191216
}
192-
return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
217+
SmallVector<Value> tiledOperands = tilingInterfaceOp.getOperandTilesForIterationDomainTile(b, mappedOffsets, mappedSizes);
218+
tiledOperands[operandNumber] = op->getOperand(operandNumber);
219+
return tilingInterfaceOp.getTiledImplementation(b, tiledOperands, mappedOffsets, mappedSizes);
193220
}]
194221
>,
195222
InterfaceMethod<

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2430,10 +2430,10 @@ SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
24302430
return iteratorTypes;
24312431
}
24322432

2433-
FailureOr<TilingResult>
2434-
SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2435-
ArrayRef<OpFoldResult> offsets,
2436-
ArrayRef<OpFoldResult> sizes) {
2433+
SmallVector<Value>
2434+
SoftmaxOp::getOperandTilesForIterationDomainTile(OpBuilder &builder,
2435+
ArrayRef<OpFoldResult> offsets,
2436+
ArrayRef<OpFoldResult> sizes) {
24372437
int64_t rank = getInputOperandRank();
24382438
auto oneAttr = builder.getI64IntegerAttr(1);
24392439
SmallVector<OpFoldResult> strides(rank, oneAttr);
@@ -2442,7 +2442,12 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
24422442
getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
24432443
tiledOperands.emplace_back(
24442444
getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
2445+
return tiledOperands;
2446+
}
24452447

2448+
FailureOr<TilingResult> SoftmaxOp::getTiledImplementation(
2449+
OpBuilder &builder, ArrayRef<Value> tiledOperands,
2450+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) {
24462451
SmallVector<Type, 4> resultTypes;
24472452
if (hasPureTensorSemantics())
24482453
resultTypes.push_back(tiledOperands[1].getType());

mlir/lib/Dialect/Linalg/Transforms/Split.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ createSplitPart(RewriterBase &b, Location loc, TilingInterface op,
4040
sizesCopy[dimension] = size;
4141
offsetsCopy[dimension] = offset;
4242

43-
// Create the part as it it were a single tile.
43+
// Create the part as it it were a single tile by fetching the operand tiles.
44+
SmallVector<Value> tiledOperands =
45+
op.getOperandTilesForIterationDomainTile(b, offsetsCopy, sizesCopy);
4446
FailureOr<TilingResult> tilingResult =
45-
op.getTiledImplementation(b, offsetsCopy, sizesCopy);
47+
op.getTiledImplementation(b, tiledOperands, offsetsCopy, sizesCopy);
4648

4749
// Insert the results back and populate the `results` list.
4850
for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) {

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -415,9 +415,13 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
415415
}
416416

417417
// 4. Tile the cloned op and delete the clone.
418+
auto tilingInterfaceOp = cast<TilingInterface>(clonedOp);
419+
SmallVector<Value> tiledOperands =
420+
tilingInterfaceOp.getOperandTilesForIterationDomainTile(b, tiledOffsets,
421+
tiledSizes);
418422
FailureOr<TilingResult> tilingResult =
419-
cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
420-
tiledSizes);
423+
tilingInterfaceOp.getTiledImplementation(b, tiledOperands, tiledOffsets,
424+
tiledSizes);
421425
if (failed(tilingResult))
422426
return clonedOp->emitError("Failed to tile op: ");
423427
if (tilingResult->tiledOps.size() != 1) {
@@ -766,9 +770,13 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
766770

767771
// 5. Tile the cloned op and delete the clone.
768772
if (tileSizes.empty()) {
769-
FailureOr<TilingResult> tilingResult =
770-
cast<TilingInterface>(clonedOp).getTiledImplementation(
773+
auto tilingInterfaceOp = cast<TilingInterface>(clonedOp);
774+
SmallVector<Value> tiledOperands =
775+
tilingInterfaceOp.getOperandTilesForIterationDomainTile(
771776
b, tiledOffsets, tiledSizes);
777+
FailureOr<TilingResult> tilingResult =
778+
tilingInterfaceOp.getTiledImplementation(b, tiledOperands,
779+
tiledOffsets, tiledSizes);
772780
if (failed(tilingResult))
773781
return clonedOp->emitError("Failed to tile op: ");
774782
if (tilingResult->tiledOps.size() != 1) {

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,18 +110,24 @@ struct LinalgOpTilingInterface
110110
}));
111111
}
112112

113+
/// Method to generate slices of the operands.
114+
SmallVector<Value>
115+
getOperandTilesForIterationDomainTile(Operation *op, OpBuilder &b,
116+
ArrayRef<OpFoldResult> offsets,
117+
ArrayRef<OpFoldResult> sizes) const {
118+
Location loc = op->getLoc();
119+
LinalgOp linalgOp = cast<LinalgOp>(op);
120+
return makeTiledShapes(b, loc, linalgOp, op->getOperands(), offsets, sizes,
121+
{}, true);
122+
}
123+
113124
/// Instantiate the tiled implementation of the operation.
114-
FailureOr<TilingResult>
115-
getTiledImplementation(Operation *op, OpBuilder &b,
116-
ArrayRef<OpFoldResult> offsets,
117-
ArrayRef<OpFoldResult> sizes) const {
125+
FailureOr<TilingResult> getTiledImplementation(
126+
Operation *op, OpBuilder &b, ArrayRef<Value> tiledOperands,
127+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
118128
// Leave the `sizeBounds` value empty. That is only needed when the `sizes`
119129
// specified could lead to out of bounds accesses.
120-
Location loc = op->getLoc();
121130
LinalgOp linalgOp = cast<LinalgOp>(op);
122-
SmallVector<Value> valuesToTile = linalgOp->getOperands();
123-
SmallVector<Value, 4> tiledOperands = makeTiledShapes(
124-
b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
125131

126132
SmallVector<Type> resultTensorTypes =
127133
getTensorOutputTypes(linalgOp, tiledOperands);
@@ -235,9 +241,17 @@ struct LinalgOpTilingInterface
235241
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
236242
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
237243
mappedOffsets, mappedSizes);
244+
245+
// Fetch the tiled slices of all operands.
238246
auto tilingInterfaceOp = cast<TilingInterface>(op);
247+
SmallVector<Value> tiledOperands =
248+
tilingInterfaceOp.getOperandTilesForIterationDomainTile(
249+
b, mappedOffsets, mappedSizes);
250+
251+
// Fetch the tiled implementation of the op.
239252
FailureOr<TilingResult> tilingResult =
240-
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
253+
tilingInterfaceOp.getTiledImplementation(b, tiledOperands,
254+
mappedOffsets, mappedSizes);
241255

242256
if (failed(tilingResult))
243257
return failure();

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,11 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
622622
}
623623

624624
// 5c. Tile the cloned operation.
625-
tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
625+
SmallVector<Value> tiledOperands =
626+
clonedOp.getOperandTilesForIterationDomainTile(rewriter, offsets,
627+
sizes);
628+
tilingResult = clonedOp.getTiledImplementation(rewriter, tiledOperands,
629+
offsets, sizes);
626630
if (failed(tilingResult)) {
627631
rewriter.eraseOp(clonedOp);
628632
return op.emitOpError("faild to tile operation");

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,9 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
4646
return loopRanges;
4747
}
4848

49-
FailureOr<TilingResult>
50-
getTiledImplementation(Operation *op, OpBuilder &b,
51-
ArrayRef<OpFoldResult> offsets,
52-
ArrayRef<OpFoldResult> sizes) const {
49+
FailureOr<TilingResult> getTiledImplementation(
50+
Operation *op, OpBuilder &b, ArrayRef<Value> tiledOperands,
51+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
5352
FailureOr<TilingResult> result =
5453
tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
5554
if (failed(result))
@@ -116,10 +115,9 @@ struct PackOpTiling
116115
return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
117116
}
118117

119-
FailureOr<TilingResult>
120-
getTiledImplementation(Operation *op, OpBuilder &b,
121-
ArrayRef<OpFoldResult> offsets,
122-
ArrayRef<OpFoldResult> sizes) const {
118+
FailureOr<TilingResult> getTiledImplementation(
119+
Operation *op, OpBuilder &b, ArrayRef<Value> _tiledOperands,
120+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
123121
auto packOp = cast<PackOp>(op);
124122
Location loc = packOp.getLoc();
125123

@@ -241,7 +239,8 @@ struct PackOpTiling
241239
return failure();
242240

243241
FailureOr<TilingResult> tilingResult = getTiledImplementation(
244-
op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
242+
op, b, /*tiledOperands=*/{}, offsets.drop_back(numTiles),
243+
sizes.drop_back(numTiles));
245244
if (failed(tilingResult))
246245
return failure();
247246
return tilingResult.value();
@@ -380,10 +379,9 @@ struct UnPackOpTiling
380379
/// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
381380
/// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
382381
/// can get the actual result.
383-
FailureOr<TilingResult>
384-
getTiledImplementation(Operation *op, OpBuilder &b,
385-
ArrayRef<OpFoldResult> offsets,
386-
ArrayRef<OpFoldResult> sizes) const {
382+
FailureOr<TilingResult> getTiledImplementation(
383+
Operation *op, OpBuilder &b, ArrayRef<Value> _tiledOperands,
384+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
387385
auto unpackOp = cast<UnPackOp>(op);
388386
int64_t srcRank = unpackOp.getSourceRank();
389387
int64_t destRank = unpackOp.getDestRank();
@@ -431,6 +429,8 @@ struct UnPackOpTiling
431429
unpackOp.getDestType().getElementType());
432430
}
433431

432+
// TODO: Refactor this as a follow-up and use the passed `_tiledOperands`
433+
// instead.
434434
SmallVector<Value> tiledOperands = {sliceSource, sliceDest};
435435
for (auto tile : unpackOp.getInnerTiles())
436436
tiledOperands.push_back(tile);
@@ -464,7 +464,7 @@ struct UnPackOpTiling
464464
ArrayRef<OpFoldResult> offsets,
465465
ArrayRef<OpFoldResult> sizes) const {
466466
FailureOr<TilingResult> tilingResult =
467-
getTiledImplementation(op, b, offsets, sizes);
467+
getTiledImplementation(op, b, /*tiledOperands=*/{}, offsets, sizes);
468468
if (failed(tilingResult))
469469
return failure();
470470
return tilingResult.value();

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ module attributes {transform.with_named_sequence} {
4545
// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
4646
// CHECK: %[[MAT_OUT:.*]] = linalg.generic
4747
// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
48-
// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
4948
// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
5049
// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
5150
// CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
5251
// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
5352
// CHECK-SAME: outs(%[[SLICE_OUT]] :
53+
// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
5454
// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
5555
// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
5656
// CHECK: }
@@ -170,13 +170,13 @@ module attributes {transform.with_named_sequence} {
170170
// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
171171
// CHECK: %[[MAT_OUT:.*]] = linalg.generic
172172
// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
173-
// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
174173
// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
175174
// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
176175
// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
177176
// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic
178177
// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
179178
// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
179+
// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
180180
// CHECK: %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
181181
// CHECK: %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
182182
// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] :

0 commit comments

Comments
 (0)