Skip to content

Commit e4ecd3c

Browse files
Remove use of getLoopBounds to avoid unnecessary lit test churn.
1 parent 6dea0cf commit e4ecd3c

File tree

1 file changed

+24
-23
lines changed

1 file changed

+24
-23
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,12 @@ verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
9494
return success();
9595
}
9696

97-
/// Compute the tile sizes and num threads values passed in.
97+
/// Method to instantiate the tile sizes and/or number of threads specified
98+
/// by the user.
9899
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
99-
getTileSizes(RewriterBase &rewriter, TilingInterface op,
100-
ArrayRef<Range> iterationDomain,
101-
const scf::SCFTilingOptions &options) {
102-
assert(
103-
llvm::all_of(iterationDomain,
104-
[](Range r) { return isConstantIntValue(r.stride, 1); }) &&
105-
"tile size computation assumes that all dimensions of the iteration "
106-
"domain have stride 1");
100+
getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
101+
ArrayRef<Range> iterationDomain,
102+
const scf::SCFTilingOptions &options) {
107103
OpFoldResult zero = rewriter.getIndexAttr(0);
108104
SmallVector<OpFoldResult> tileSizes, numThreads;
109105
size_t numLoops = iterationDomain.size();
@@ -240,7 +236,9 @@ static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
240236
return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
241237
}
242238

243-
/// Compute the tile offsets and sizes.
239+
/// Compute the `OpFoldResult`s that represents the multi-dimensional
240+
/// `offset`s and `size`s of the tile of the iteration space that the
241+
/// innermost loop body of the generated tiled loops corresponds to.
244242
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
245243
getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
246244
ArrayRef<Range> iterationDomain,
@@ -249,12 +247,6 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
249247
SmallVector<OpFoldResult> offsets, sizes;
250248
int materializedLoopNum = 0;
251249

252-
assert(
253-
llvm::all_of(iterationDomain,
254-
[](Range r) { return isConstantIntValue(r.stride, 1); }) &&
255-
"the offset and tile size computation assumes stride 1 for all "
256-
"dimensions of the iteration domain");
257-
258250
if (!numThreads.empty()) {
259251
AffineExpr d0, d1, s0, s1;
260252
AffineExpr offsetExpr, residualTileSizeExpr;
@@ -266,7 +258,9 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
266258
for (auto [nt, tileSize, loopRange] :
267259
llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
268260

269-
if (isConstantIntValue(nt, 0) || isConstantIntValue(nt, 1)) {
261+
// Non-tiled cases, set the offset and size to the
262+
// `loopRange.offset/size`.
263+
if (isConstantIntValue(nt, 0)) {
270264
offsets.push_back(loopRange.offset);
271265
sizes.push_back(loopRange.size);
272266
continue;
@@ -290,6 +284,16 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
290284
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
291285
{sizeMinusOffsetPerThread, tileSize});
292286
}
287+
288+
// Consider the case where the original loop was `[0, 100)`.
289+
// If number of threads are `7`, the tile size would be computed as
290+
// `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
291+
// - `offset = 0 + 6 * 15 = 105`
292+
// - `tileSize = min(15, 100 - 105) = -5`
293+
// To avoid negative tile sizes, we need to do a further
294+
// `nonNegativeTileSize = affine.max(0, tileSize)`.
295+
// This `max` can be avoided if
296+
// `offset + tileSize * (numThreads - 1) < (ub - lb)`
293297
if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
294298
AffineMap maxMap =
295299
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
@@ -305,6 +309,8 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
305309
for (auto [tileSize, loopRange] :
306310
llvm::zip_equal(tileSizes, iterationDomain)) {
307311

312+
// Non-tiled cases, set the offset and size to the
313+
// `loopRange.offset/size`.
308314
if (isConstantIntValue(tileSize, 0)) {
309315
offsets.push_back(loopRange.offset);
310316
sizes.push_back(loopRange.size);
@@ -787,16 +793,11 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
787793

788794
// 1. Get the range of the loops that are represented by the operation.
789795
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
790-
if (llvm::any_of(iterationDomain,
791-
[](Range r) { return !isConstantIntValue(r.stride, 1); })) {
792-
return rewriter.notifyMatchFailure(
793-
op, "unhandled tiling of iteration domain with non-unit stride");
794-
}
795796

796797
// 2. Materialize the tile sizes and/or number of threads;
797798
SmallVector<OpFoldResult> tileSizes, numThreads;
798799
std::tie(tileSizes, numThreads) =
799-
getTileSizes(rewriter, op, iterationDomain, options);
800+
getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
800801

801802
// Check if it is safe to tile. This is hold over from previous iterations
802803
// of tile to for-all. Consider dropping it.

0 commit comments

Comments
 (0)