Skip to content

Commit 170a25a

Browse files
[mlir][TilingInterface] Make the tiling set tile sizes function use OpFoldResult. (#66566)
1 parent 75fdf2e commit 170a25a

File tree

5 files changed

+47
-56
lines changed

5 files changed

+47
-56
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace mlir {
2626
namespace scf {
2727

2828
using SCFTileSizeComputationFunction =
29-
std::function<SmallVector<Value>(OpBuilder &, Operation *)>;
29+
std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
3030

3131
/// Options to use to control tiling.
3232
struct SCFTilingOptions {
@@ -40,17 +40,10 @@ struct SCFTilingOptions {
4040
tileSizeComputationFunction = std::move(fun);
4141
return *this;
4242
}
43-
/// Set the `tileSizeComputationFunction` to return the values `ts`. The
44-
/// values must not fold away when tiling. Otherwise, use a more robust
45-
/// `tileSizeComputationFunction`.
46-
SCFTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
47-
tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
48-
return *this;
49-
}
5043
/// Convenience function to set the `tileSizeComputationFunction` to a
5144
/// function that computes tile sizes at the point they are needed. Allows
5245
/// proper interaction with folding.
53-
SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
46+
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
5447

5548
/// The interchange vector to reorder the tiled loops.
5649
SmallVector<int64_t> interchangeVector = {};

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,9 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
473473

474474
scf::SCFTilingOptions tilingOptions;
475475
tilingOptions.interchangeVector = tileInterchange;
476-
tilingOptions = tilingOptions.setTileSizes(tileSizes);
476+
SmallVector<OpFoldResult> tileSizesOfr =
477+
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
478+
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
477479
scf::SCFTileAndFuseOptions tileAndFuseOptions;
478480
tileAndFuseOptions.tilingOptions = tilingOptions;
479481
LogicalResult result = applyTilingToAll(
@@ -923,7 +925,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
923925
auto nextProducer = getNextProducer();
924926
if (failed(nextProducer)) {
925927
auto diag = mlir::emitSilenceableFailure(getLoc())
926-
<< "could not find next producer to fuse into container";
928+
<< "could not find next producer to fuse into container";
927929
diag.attachNote(containingOp->getLoc()) << "containing op";
928930
return diag;
929931
}
@@ -1999,7 +2001,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
19992001
transform::TransformState &state) {
20002002
scf::SCFTilingOptions tilingOptions;
20012003
tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2002-
SmallVector<Value, 4> tileSizes;
2004+
SmallVector<OpFoldResult> tileSizes;
20032005
Location loc = target.getLoc();
20042006
SmallVector<OpFoldResult> allShapeSizes =
20052007
target.createFlatListOfOperandDims(b, loc);
@@ -2012,9 +2014,8 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
20122014
// If the shape size is dynamic, tile by 1.
20132015
// Otherwise, do not tile (i.e. tile size 0).
20142016
for (OpFoldResult shapeSize : shapeSizes) {
2015-
tileSizes.push_back(getConstantIntValue(shapeSize)
2016-
? b.create<arith::ConstantIndexOp>(loc, 0)
2017-
: b.create<arith::ConstantIndexOp>(loc, 1));
2017+
tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2018+
: b.getIndexAttr(1));
20182019
}
20192020
return tileSizes;
20202021
});
@@ -2549,7 +2550,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
25492550
if (!tileSizes.empty()) {
25502551
tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
25512552
Operation *) {
2552-
SmallVector<Value, 4> sizes;
2553+
SmallVector<OpFoldResult> sizes;
25532554
sizes.reserve(tileSizes.size());
25542555
unsigned dynamicIdx = 0;
25552556

@@ -2560,10 +2561,10 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
25602561
getLoc(), attr.cast<IntegerAttr>().getInt());
25612562
Value vscale =
25622563
b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
2563-
sizes.push_back(b.create<arith::MulIOp>(getLoc(), val, vscale));
2564+
sizes.push_back(
2565+
b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
25642566
} else {
2565-
sizes.push_back(b.create<arith::ConstantIndexOp>(
2566-
getLoc(), cast<IntegerAttr>(attr).getInt()));
2567+
sizes.push_back(attr);
25672568
}
25682569
continue;
25692570
}
@@ -2573,8 +2574,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
25732574
assert((dynamicSizes.empty() ^ params.empty()) &&
25742575
"expected either dynamic sizes or parameters");
25752576
if (!params.empty()) {
2576-
sizes.push_back(
2577-
b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
2577+
sizes.push_back(b.getIndexAttr(params[index]));
25782578
} else {
25792579
sizes.push_back(dynamicSizes[index]->getResult(0));
25802580
}

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,11 @@
3131
using namespace mlir;
3232

3333
scf::SCFTilingOptions &
34-
scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
34+
scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
3535
assert(!tileSizeComputationFunction && "tile sizes already set");
36-
SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
36+
auto tileSizes = llvm::to_vector(ts);
3737
tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
38-
OpBuilder::InsertionGuard guard(b);
39-
b.setInsertionPointToStart(
40-
&op->getParentWithTrait<OpTrait::IsIsolatedFromAbove>()
41-
->getRegion(0)
42-
.front());
43-
return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
44-
Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
45-
return v;
46-
}));
38+
return tileSizes;
4739
};
4840
return *this;
4941
}
@@ -108,17 +100,16 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
108100

109101
/// Generate an empty loop nest that represents the tiled loop nest shell.
110102
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
111-
/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
103+
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
112104
/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
113105
/// the
114106
/// tile processed within the inner most loop.
115-
static SmallVector<scf::ForOp>
116-
generateTileLoopNest(OpBuilder &builder, Location loc,
117-
ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
118-
SmallVector<OpFoldResult> &offsets,
119-
SmallVector<OpFoldResult> &sizes) {
107+
static SmallVector<scf::ForOp> generateTileLoopNest(
108+
OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
109+
ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
110+
SmallVector<OpFoldResult> &sizes) {
120111
assert(!loopRanges.empty() && "expected at least one loop range");
121-
assert(loopRanges.size() == tileSizeVals.size() &&
112+
assert(loopRanges.size() == tileSizes.size() &&
122113
"expected as many tile sizes as loop ranges");
123114
OpBuilder::InsertionGuard guard(builder);
124115
SmallVector<scf::ForOp> loops;
@@ -130,7 +121,8 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
130121
getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
131122
Value size =
132123
getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
133-
Value tileSize = tileSizeVals[loopRange.index()];
124+
Value tileSize = getValueOrCreateConstantIndexOp(
125+
builder, loc, tileSizes[loopRange.index()]);
134126
// No loops if tile size is zero. Set offset and size to the loop
135127
// offset and size.
136128
if (matchPattern(tileSize, m_Zero())) {
@@ -296,10 +288,10 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
296288
// skips tiling a particular dimension. This convention is significantly
297289
// simpler to handle instead of adjusting affine maps to account for missing
298290
// dimensions.
299-
SmallVector<Value> tileSizeVector =
291+
SmallVector<OpFoldResult> tileSizeVector =
300292
options.tileSizeComputationFunction(rewriter, op);
301293
if (tileSizeVector.size() < iterationDomain.size()) {
302-
auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
294+
auto zero = rewriter.getIndexAttr(0);
303295
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
304296
}
305297

@@ -402,17 +394,17 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
402394
FailureOr<scf::SCFReductionTilingResult>
403395
mlir::scf::tileReductionUsingScf(RewriterBase &b,
404396
PartialReductionOpInterface op,
405-
ArrayRef<OpFoldResult> tileSize) {
397+
ArrayRef<OpFoldResult> tileSizes) {
406398
Location loc = op.getLoc();
407399
// Ops implementing PartialReductionOpInterface are expected to implement
408400
// TilingInterface.
409401
auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
410402
SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
411-
SmallVector<Value> tileSizeVector =
412-
getValueOrCreateConstantIndexOp(b, loc, tileSize);
413-
if (tileSizeVector.size() < iterationDomain.size()) {
414-
auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
415-
tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
403+
auto tileSizesVector = llvm::to_vector(tileSizes);
404+
if (tileSizesVector.size() < iterationDomain.size()) {
405+
auto zero = b.getIndexAttr(0);
406+
tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
407+
zero);
416408
}
417409
if (op->getNumResults() != 1)
418410
return b.notifyMatchFailure(
@@ -429,15 +421,15 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
429421

430422
// 1. create the inital tensor value.
431423
FailureOr<Operation *> identityTensor =
432-
op.generateInitialTensorForPartialReduction(b, loc, tileSize,
424+
op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
433425
reductionDims);
434426
if (failed(identityTensor))
435427
return b.notifyMatchFailure(op,
436428
"cannot create a tensor of identity value.");
437429
// 2. Create the nested loops.
438430
SmallVector<OpFoldResult> offsets, sizes;
439431
SmallVector<scf::ForOp> loops = generateTileLoopNest(
440-
b, loc, iterationDomain, tileSizeVector, offsets, sizes);
432+
b, loc, iterationDomain, tileSizesVector, offsets, sizes);
441433

442434
// 3. Generate the tiled implementation within the inner most loop.
443435
b.setInsertionPoint(loops.back().getBody()->getTerminator());

mlir/test/Dialect/Linalg/transform-op-tile.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,16 @@ transform.sequence failures(propagate) {
190190
// -----
191191

192192
// CHECK-LABEL: func.func @scalable_and_fixed_length_tile
193-
// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index
194-
// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index
195193
// CHECK: %[[C4:.*]] = arith.constant 4 : index
196194
// CHECK: %[[VS:.*]] = vector.vscale
197195
// CHECK: %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
198196
// CHECK: %[[C0:.*]] = arith.constant 0 : index
199197
// CHECK: %[[C128:.*]] = arith.constant 128 : index
198+
// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index
200199
// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]]
201200
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
202201
// CHECK: %[[C128_1:.*]] = arith.constant 128 : index
202+
// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index
203203
// CHECK: scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]]
204204
// CHECK: %[[C0_2:.*]] = arith.constant 0 : index
205205
// CHECK: %[[C128_2:.*]] = arith.constant 128 : index

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,9 @@ static void addPatternForTiling(MLIRContext *context,
450450
ArrayRef<int64_t> tileSizes,
451451
ArrayRef<int64_t> interchange = {}) {
452452
scf::SCFTilingOptions tilingOptions;
453-
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
453+
SmallVector<OpFoldResult> tileSizesOfr =
454+
getAsIndexOpFoldResult(context, tileSizes);
455+
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
454456
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
455457
StringAttr::get(context, "tiled"));
456458
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
@@ -462,7 +464,9 @@ static void addPatternForTileFuseAndYield(MLIRContext *context,
462464
ArrayRef<int64_t> tileSizes,
463465
ArrayRef<int64_t> interchange = {}) {
464466
scf::SCFTilingOptions tilingOptions;
465-
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
467+
SmallVector<OpFoldResult> tileSizesOfr =
468+
getAsIndexOpFoldResult(context, tileSizes);
469+
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
466470
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
467471
StringAttr::get(context, "tiled"));
468472
patterns.add<TestTileConsumerFuseAndYieldProducerUsingSCFForOp>(
@@ -475,8 +479,10 @@ static void addPatternForTileAndFuse(MLIRContext *context,
475479
ArrayRef<int64_t> tileSizes,
476480
ArrayRef<int64_t> interchange = {}) {
477481
scf::SCFTileAndFuseOptions tileAndFuseOptions;
478-
tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange(
479-
interchange);
482+
SmallVector<OpFoldResult> tileSizesOfr =
483+
getAsIndexOpFoldResult(context, tileSizes);
484+
tileAndFuseOptions.tilingOptions.setTileSizes(tileSizesOfr)
485+
.setInterchange(interchange);
480486
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
481487
StringAttr::get(context, "tiled"));
482488
patterns.add<TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp>(

0 commit comments

Comments
 (0)