Skip to content

Commit 6dea0cf

Browse files
Drop support for non-unit strides, and assert that strides of iteration domain are 1.
1 parent c828576 commit 6dea0cf

File tree

1 file changed

+29
-13
lines changed

1 file changed

+29
-13
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
9999
getTileSizes(RewriterBase &rewriter, TilingInterface op,
100100
ArrayRef<Range> iterationDomain,
101101
const scf::SCFTilingOptions &options) {
102+
assert(
103+
llvm::all_of(iterationDomain,
104+
[](Range r) { return isConstantIntValue(r.stride, 1); }) &&
105+
"tile size computation assumes that all dimensions of the iteration "
106+
"domain have stride 1");
102107
OpFoldResult zero = rewriter.getIndexAttr(0);
103108
SmallVector<OpFoldResult> tileSizes, numThreads;
104109
size_t numLoops = iterationDomain.size();
@@ -119,19 +124,19 @@ getTileSizes(RewriterBase &rewriter, TilingInterface op,
119124
// of tiles as follows
120125
// - niters = ceilDiv(ub - lb, step)
121126
// - tileSize = ceilDiv(niters, numThreads)
122-
AffineExpr s0, s1, s2, s3;
123-
bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
124-
AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2);
125-
AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s3);
127+
AffineExpr s0, s1, s2;
128+
bindSymbols(rewriter.getContext(), s0, s1, s2);
129+
// TODO: The step here is assumed to be 1.
130+
AffineExpr numItersExpr = (s1 - s0);
131+
AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
126132
tileSizes.resize(numLoops, zero);
127133
for (auto [index, range, nt] :
128134
llvm::enumerate(iterationDomain, numThreads)) {
129135
if (isConstantIntValue(nt, 0))
130136
continue;
131137

132138
tileSizes[index] = affine::makeComposedFoldedAffineApply(
133-
rewriter, op.getLoc(), tileSizeExpr,
134-
{range.offset, range.size, range.stride, nt});
139+
rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
135140
}
136141
tileSizes.resize(numLoops, zero);
137142
return {tileSizes, numThreads};
@@ -244,13 +249,19 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
244249
SmallVector<OpFoldResult> offsets, sizes;
245250
int materializedLoopNum = 0;
246251

252+
assert(
253+
llvm::all_of(iterationDomain,
254+
[](Range r) { return isConstantIntValue(r.stride, 1); }) &&
255+
"the offset and tile size computation assumes stride 1 for all "
256+
"dimensions of the iteration domain");
257+
247258
if (!numThreads.empty()) {
248-
AffineExpr d0, d1, s0, s1, s2;
259+
AffineExpr d0, d1, s0, s1;
249260
AffineExpr offsetExpr, residualTileSizeExpr;
250261
bindDims(rewriter.getContext(), d0, d1);
251-
bindSymbols(rewriter.getContext(), s0, s1, s2);
252-
offsetExpr = d0 + d1 * s0 * s1;
253-
residualTileSizeExpr = s2 - (d0 + d1 * s0 * s1);
262+
bindSymbols(rewriter.getContext(), s0, s1);
263+
offsetExpr = d0 + d1 * s0;
264+
residualTileSizeExpr = s1 - (d0 + d1 * s0);
254265

255266
for (auto [nt, tileSize, loopRange] :
256267
llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
@@ -264,11 +275,11 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
264275
Value iv = ivs[materializedLoopNum++];
265276
OpFoldResult offset = affine::makeComposedFoldedAffineApply(
266277
rewriter, loc, offsetExpr,
267-
ArrayRef<OpFoldResult>{loopRange.offset, iv, loopRange.stride,
268-
tileSize});
278+
ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
269279
OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
270280
rewriter, loc, residualTileSizeExpr,
271-
{loopRange.offset, nt, loopRange.stride, tileSize, loopRange.size});
281+
{loopRange.offset, nt, tileSize, loopRange.size});
282+
272283
OpFoldResult size = tileSize;
273284
if (!isConstantIntValue(residualTileSize, 0)) {
274285
OpFoldResult sizeMinusOffsetPerThread =
@@ -776,6 +787,11 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
776787

777788
// 1. Get the range of the loops that are represented by the operation.
778789
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
790+
if (llvm::any_of(iterationDomain,
791+
[](Range r) { return !isConstantIntValue(r.stride, 1); })) {
792+
return rewriter.notifyMatchFailure(
793+
op, "unhandled tiling of iteration domain with non-unit stride");
794+
}
779795

780796
// 2. Materialize the tile sizes and/or number of threads;
781797
SmallVector<OpFoldResult> tileSizes, numThreads;

0 commit comments

Comments
 (0)