@@ -99,6 +99,11 @@ static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
99
99
getTileSizes (RewriterBase &rewriter, TilingInterface op,
100
100
ArrayRef<Range> iterationDomain,
101
101
const scf::SCFTilingOptions &options) {
102
+ assert (
103
+ llvm::all_of (iterationDomain,
104
+ [](Range r) { return isConstantIntValue (r.stride , 1 ); }) &&
105
+ " tile size computation assumes that all dimensions of the iteration "
106
+ " domain have stride 1" );
102
107
OpFoldResult zero = rewriter.getIndexAttr (0 );
103
108
SmallVector<OpFoldResult> tileSizes, numThreads;
104
109
size_t numLoops = iterationDomain.size ();
@@ -119,19 +124,19 @@ getTileSizes(RewriterBase &rewriter, TilingInterface op,
119
124
// of tiles as follows
120
125
// - niters = ceilDiv(ub - lb, step)
121
126
// - tileSize = ceilDiv(niters, numThreads)
122
- AffineExpr s0, s1, s2, s3;
123
- bindSymbols (rewriter.getContext (), s0, s1, s2, s3);
124
- AffineExpr numItersExpr = (s1 - s0).ceilDiv (s2);
125
- AffineExpr tileSizeExpr = numItersExpr.ceilDiv (s3);
127
+ AffineExpr s0, s1, s2;
128
+ bindSymbols (rewriter.getContext (), s0, s1, s2);
129
+ // TODO: The step here is assumed to be 1.
130
+ AffineExpr numItersExpr = (s1 - s0);
131
+ AffineExpr tileSizeExpr = numItersExpr.ceilDiv (s2);
126
132
tileSizes.resize (numLoops, zero);
127
133
for (auto [index, range, nt] :
128
134
llvm::enumerate (iterationDomain, numThreads)) {
129
135
if (isConstantIntValue (nt, 0 ))
130
136
continue ;
131
137
132
138
tileSizes[index] = affine::makeComposedFoldedAffineApply (
133
- rewriter, op.getLoc (), tileSizeExpr,
134
- {range.offset , range.size , range.stride , nt});
139
+ rewriter, op.getLoc (), tileSizeExpr, {range.offset , range.size , nt});
135
140
}
136
141
tileSizes.resize (numLoops, zero);
137
142
return {tileSizes, numThreads};
@@ -244,13 +249,19 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
244
249
SmallVector<OpFoldResult> offsets, sizes;
245
250
int materializedLoopNum = 0 ;
246
251
252
+ assert (
253
+ llvm::all_of (iterationDomain,
254
+ [](Range r) { return isConstantIntValue (r.stride , 1 ); }) &&
255
+ " the offset and tile size computation assumes stride 1 for all "
256
+ " dimensions of the iteration domain" );
257
+
247
258
if (!numThreads.empty ()) {
248
- AffineExpr d0, d1, s0, s1, s2 ;
259
+ AffineExpr d0, d1, s0, s1;
249
260
AffineExpr offsetExpr, residualTileSizeExpr;
250
261
bindDims (rewriter.getContext (), d0, d1);
251
- bindSymbols (rewriter.getContext (), s0, s1, s2 );
252
- offsetExpr = d0 + d1 * s0 * s1 ;
253
- residualTileSizeExpr = s2 - (d0 + d1 * s0 * s1 );
262
+ bindSymbols (rewriter.getContext (), s0, s1);
263
+ offsetExpr = d0 + d1 * s0;
264
+ residualTileSizeExpr = s1 - (d0 + d1 * s0);
254
265
255
266
for (auto [nt, tileSize, loopRange] :
256
267
llvm::zip_equal (numThreads, tileSizes, iterationDomain)) {
@@ -264,11 +275,11 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
264
275
Value iv = ivs[materializedLoopNum++];
265
276
OpFoldResult offset = affine::makeComposedFoldedAffineApply (
266
277
rewriter, loc, offsetExpr,
267
- ArrayRef<OpFoldResult>{loopRange.offset , iv, loopRange.stride ,
268
- tileSize});
278
+ ArrayRef<OpFoldResult>{loopRange.offset , iv, tileSize});
269
279
OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply (
270
280
rewriter, loc, residualTileSizeExpr,
271
- {loopRange.offset , nt, loopRange.stride , tileSize, loopRange.size });
281
+ {loopRange.offset , nt, tileSize, loopRange.size });
282
+
272
283
OpFoldResult size = tileSize;
273
284
if (!isConstantIntValue (residualTileSize, 0 )) {
274
285
OpFoldResult sizeMinusOffsetPerThread =
@@ -776,6 +787,11 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
776
787
777
788
// 1. Get the range of the loops that are represented by the operation.
778
789
SmallVector<Range> iterationDomain = op.getIterationDomain (rewriter);
790
+ if (llvm::any_of (iterationDomain,
791
+ [](Range r) { return !isConstantIntValue (r.stride , 1 ); })) {
792
+ return rewriter.notifyMatchFailure (
793
+ op, " unhandled tiling of iteration domain with non-unit stride" );
794
+ }
779
795
780
796
// 2. Materialize the tile sizes and/or number of threads;
781
797
SmallVector<OpFoldResult> tileSizes, numThreads;
0 commit comments