@@ -94,16 +94,12 @@ verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
94
94
return success ();
95
95
}
96
96
97
- // / Compute the tile sizes and num threads values passed in.
97
+ // / Method to instantiate the tile sizes and/or number of threads specified
98
+ // / by the user.
98
99
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
99
- getTileSizes (RewriterBase &rewriter, TilingInterface op,
100
- ArrayRef<Range> iterationDomain,
101
- const scf::SCFTilingOptions &options) {
102
- assert (
103
- llvm::all_of (iterationDomain,
104
- [](Range r) { return isConstantIntValue (r.stride , 1 ); }) &&
105
- " tile size computation assumes that all dimensions of the iteration "
106
- " domain have stride 1" );
100
+ getUserTileSizesAndNumThreads (RewriterBase &rewriter, TilingInterface op,
101
+ ArrayRef<Range> iterationDomain,
102
+ const scf::SCFTilingOptions &options) {
107
103
OpFoldResult zero = rewriter.getIndexAttr (0 );
108
104
SmallVector<OpFoldResult> tileSizes, numThreads;
109
105
size_t numLoops = iterationDomain.size ();
@@ -240,7 +236,9 @@ static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
240
236
return *tileSizeConst * (*numThreadsConst - 1 ) < *iterSizeConst;
241
237
}
242
238
243
- // / Compute the tile offsets and sizes.
239
+ // / Compute the `OpFoldResult`s that represents the multi-dimensional
240
+ // / `offset`s and `size`s of the tile of the iteration space that the
241
+ // / innermost loop body of the generated tiled loops corresponds to.
244
242
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
245
243
getTileOffsetAndSizes (RewriterBase &rewriter, Location loc, ValueRange ivs,
246
244
ArrayRef<Range> iterationDomain,
@@ -249,12 +247,6 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
249
247
SmallVector<OpFoldResult> offsets, sizes;
250
248
int materializedLoopNum = 0 ;
251
249
252
- assert (
253
- llvm::all_of (iterationDomain,
254
- [](Range r) { return isConstantIntValue (r.stride , 1 ); }) &&
255
- " the offset and tile size computation assumes stride 1 for all "
256
- " dimensions of the iteration domain" );
257
-
258
250
if (!numThreads.empty ()) {
259
251
AffineExpr d0, d1, s0, s1;
260
252
AffineExpr offsetExpr, residualTileSizeExpr;
@@ -266,7 +258,9 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
266
258
for (auto [nt, tileSize, loopRange] :
267
259
llvm::zip_equal (numThreads, tileSizes, iterationDomain)) {
268
260
269
- if (isConstantIntValue (nt, 0 ) || isConstantIntValue (nt, 1 )) {
261
+ // Non-tiled cases, set the offset and size to the
262
+ // `loopRange.offset/size`.
263
+ if (isConstantIntValue (nt, 0 )) {
270
264
offsets.push_back (loopRange.offset );
271
265
sizes.push_back (loopRange.size );
272
266
continue ;
@@ -290,6 +284,16 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
290
284
AffineMap::getMultiDimIdentityMap (2 , rewriter.getContext ()),
291
285
{sizeMinusOffsetPerThread, tileSize});
292
286
}
287
+
288
+ // Consider the case where the original loop was `[0, 100)`.
289
+ // If number of threads are `7`, the tile size would be computed as
290
+ // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
291
+ // - `offset = 0 + 6 * 15 = 105`
292
+ // - `tileSize = min(15, 100 - 105) = -5`
293
+ // To avoid negative tile sizes, we need to do a further
294
+ // `nonNegativeTileSize = affine.max(0, tileSize)`.
295
+ // This `max` can be avoided if
296
+ // `offset + tileSize * (numThreads - 1) < (ub - lb)`
293
297
if (!canOmitTileOffsetInBoundsCheck (tileSize, nt, loopRange.size )) {
294
298
AffineMap maxMap =
295
299
AffineMap::getMultiDimIdentityMap (2 , rewriter.getContext ());
@@ -305,6 +309,8 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
305
309
for (auto [tileSize, loopRange] :
306
310
llvm::zip_equal (tileSizes, iterationDomain)) {
307
311
312
+ // Non-tiled cases, set the offset and size to the
313
+ // `loopRange.offset/size`.
308
314
if (isConstantIntValue (tileSize, 0 )) {
309
315
offsets.push_back (loopRange.offset );
310
316
sizes.push_back (loopRange.size );
@@ -787,16 +793,11 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
787
793
788
794
// 1. Get the range of the loops that are represented by the operation.
789
795
SmallVector<Range> iterationDomain = op.getIterationDomain (rewriter);
790
- if (llvm::any_of (iterationDomain,
791
- [](Range r) { return !isConstantIntValue (r.stride , 1 ); })) {
792
- return rewriter.notifyMatchFailure (
793
- op, " unhandled tiling of iteration domain with non-unit stride" );
794
- }
795
796
796
797
// 2. Materialize the tile sizes and/or number of threads;
797
798
SmallVector<OpFoldResult> tileSizes, numThreads;
798
799
std::tie (tileSizes, numThreads) =
799
- getTileSizes (rewriter, op, iterationDomain, options);
800
+ getUserTileSizesAndNumThreads (rewriter, op, iterationDomain, options);
800
801
801
802
// Check if it is safe to tile. This is hold over from previous iterations
802
803
// of tile to for-all. Consider dropping it.
0 commit comments