Skip to content

Commit 8d74612

Browse files
[mlir][SCF] Allow tiling by specifying maximum number of tiles.
1 parent 2b38688 commit 8d74612

File tree

8 files changed

+270
-302
lines changed

8 files changed

+270
-302
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class GenericOp;
3030
class LinalgOp;
3131
} // namespace linalg
3232

33+
namespace scf {
34+
struct SCFTilingResult;
35+
} // namespace scf
36+
3337
namespace tensor {
3438
class InsertSliceOp;
3539
class PackOp;
@@ -60,7 +64,7 @@ tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
6064
ArrayRef<OpFoldResult> mixedNumThreads,
6165
ArrayRef<OpFoldResult> mixedTileSizes,
6266
std::optional<ArrayAttr> mapping,
63-
linalg::ForallTilingResult &tilingResult);
67+
scf::SCFTilingResult &tilingResult);
6468

6569
} // namespace transform
6670
} // namespace mlir

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -846,30 +846,6 @@ FailureOr<StaticMultiSizeSpecification>
846846
computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
847847
int64_t divisor);
848848

849-
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
850-
/// tiling by `numThreads`.
851-
/// If non-empty, the `mapping` is added as an attribute to the
852-
/// resulting `scf.forall`.
853-
/// Zero tile sizes indicate that the dimension is not tiled, and can be
854-
/// thought of as tiling by the full size of data. It is the user's
855-
/// responsibility to ensure that `numThreads` is a valid tiling specification
856-
/// (i.e. that only tiles parallel dimensions, e.g. in the Linalg case).
857-
struct ForallTilingResult {
858-
Operation *tileOp;
859-
Operation *tiledOp;
860-
};
861-
FailureOr<ForallTilingResult> tileToForallOp(RewriterBase &builder,
862-
TilingInterface op,
863-
ArrayRef<OpFoldResult> numThreads,
864-
std::optional<ArrayAttr> mapping);
865-
866-
/// Same as `tileToForallOp`, but calculate the number of threads
867-
/// required using the given tileSizes.
868-
FailureOr<ForallTilingResult>
869-
tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
870-
ArrayRef<OpFoldResult> tileSizes,
871-
std::optional<ArrayAttr> mapping);
872-
873849
/// Transformation information returned after reduction tiling.
874850
struct ForallReductionTilingResult {
875851
/// The partial reduction tiled op generated.

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,13 @@ using SCFTileSizeComputationFunction =
3131

3232
/// Options to use to control tiling.
3333
struct SCFTilingOptions {
34-
/// Computation function that returns the tile sizes for each operation.
35-
/// Delayed construction of constant tile sizes should occur to interoperate
36-
/// with folding.
34+
/// Computation function that returns the tile sizes to use for each loop.
35+
/// Returning a tile size of zero implies no tiling for that loop. If the
36+
/// size of the returned vector is smaller than the number of loops, the inner
37+
/// loops are not tiled. If the size of the returned vector is larger, then
38+
/// the vector is truncated to number of loops. Only one of
39+
/// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
40+
/// be used.
3741
SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
3842

3943
SCFTilingOptions &
@@ -44,7 +48,25 @@ struct SCFTilingOptions {
4448
/// Convenience function to set the `tileSizeComputationFunction` to a
4549
/// function that computes tile sizes at the point they are needed. Allows
4650
/// proper interaction with folding.
47-
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
51+
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
52+
53+
/// Computation function that returns the maximum number of tile to use for
54+
/// each loop. Returning a tile size of zero implies no tiling for that loop.
55+
/// If the size of the returned vector is smaller than the number of loops,
56+
/// the inner loops are not tiled. If the size of the returned vector is
57+
/// larger, then the vector is truncated to number of loops. Only one of
58+
/// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
59+
/// be used.
60+
SCFTileSizeComputationFunction maxNumTilesComputationFunction = nullptr;
61+
62+
SCFTilingOptions &
63+
setMaxNumTilesComputationFunction(SCFTileSizeComputationFunction fun) {
64+
maxNumTilesComputationFunction = std::move(fun);
65+
return *this;
66+
}
67+
/// Convenience function to set the `tileSizeComputationFunction` to a
68+
/// function that computes tile sizes at the point they are needed.
69+
SCFTilingOptions &setMaxNumTiles(ArrayRef<OpFoldResult> numTiles);
4870

4971
/// The interchange vector to reorder the tiled loops.
5072
SmallVector<int64_t> interchangeVector = {};
@@ -66,9 +88,8 @@ struct SCFTilingOptions {
6688
/// when using loop constructs that dont support such a mapping (like
6789
/// `scf.for`)
6890
SmallVector<Attribute> mappingVector = {};
69-
SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
70-
mappingVector = llvm::map_to_vector(
71-
mapping, [](auto attr) -> Attribute { return attr; });
91+
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
92+
mappingVector = llvm::to_vector(mapping);
7293
return *this;
7394
}
7495
};

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2917,7 +2917,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
29172917
TransformOpInterface transformOp, Operation *target,
29182918
ArrayRef<OpFoldResult> mixedNumThreads,
29192919
ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
2920-
linalg::ForallTilingResult &tilingResult) {
2920+
scf::SCFTilingResult &tilingResult) {
29212921
// Transform all targets one by one.
29222922
auto tileableOp = dyn_cast<TilingInterface>(target);
29232923
if (!tileableOp) {
@@ -2928,18 +2928,22 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
29282928
return diag;
29292929
}
29302930
rewriter.setInsertionPoint(tileableOp);
2931-
FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
2931+
scf::SCFTilingOptions options;
2932+
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
29322933
if (!mixedNumThreads.empty()) {
2933-
maybeTilingResult =
2934-
linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
2934+
options.setMaxNumTiles(mixedNumThreads);
29352935
} else {
2936-
maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
2937-
rewriter, tileableOp, mixedTileSizes, mapping);
2936+
options.setTileSizes(mixedTileSizes);
29382937
}
2938+
if (mapping) {
2939+
options.setMapping(mapping.value().getValue());
2940+
}
2941+
FailureOr<scf::SCFTilingResult> maybeTilingResult =
2942+
scf::tileUsingSCF(rewriter, tileableOp, options);
29392943

29402944
if (failed(maybeTilingResult))
29412945
return transformOp.emitDefaultSilenceableFailure(tileableOp);
2942-
rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
2946+
rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
29432947

29442948
tilingResult = *maybeTilingResult;
29452949
return DiagnosedSilenceableFailure::success();
@@ -2975,14 +2979,14 @@ DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
29752979
return status;
29762980

29772981
for (Operation *target : state.getPayloadOps(getTarget())) {
2978-
linalg::ForallTilingResult tilingResult;
2982+
scf::SCFTilingResult tilingResult;
29792983
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
29802984
rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
29812985
getMapping(), tilingResult);
29822986
if (!diag.succeeded())
29832987
return diag;
2984-
tileOps.push_back(tilingResult.tileOp);
2985-
tiledOps.push_back(tilingResult.tiledOp);
2988+
tileOps.push_back(tilingResult.loops.front());
2989+
tiledOps.append(tilingResult.tiledOps);
29862990
}
29872991

29882992
transformResults.set(cast<OpResult>(getForallOp()), tileOps);
@@ -3460,7 +3464,7 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
34603464

34613465
// OpBuilder only used to compute attributes.
34623466
OpBuilder b(getContext());
3463-
linalg::ForallTilingResult tilingResult;
3467+
scf::SCFTilingResult tilingResult;
34643468
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
34653469
/*rewriter=*/rewriter,
34663470
/*state=*/state,
@@ -3473,8 +3477,9 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
34733477
if (!diag.succeeded())
34743478
return diag;
34753479

3476-
results.push_back(tilingResult.tileOp);
3477-
results.push_back(tilingResult.tiledOp);
3480+
results.push_back(tilingResult.loops.front());
3481+
for (auto op : tilingResult.tiledOps)
3482+
results.push_back(op);
34783483
return DiagnosedSilenceableFailure::success();
34793484
}
34803485

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 0 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -304,188 +304,6 @@ static void calculateTileOffsetsAndSizes(
304304
}
305305
}
306306

307-
/// Returns a vector of bools representing if, for each axis, `op` can be tiled
308-
/// without incurring in a race condition and thus it is thread-safe to do the
309-
/// tiling. This is checked by iterating over numThreads and ensuring that the
310-
/// corresponding iterator type is "parallel". If it is not, then we know that
311-
/// such dimension is unsafe to tile.
312-
SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
313-
ArrayRef<OpFoldResult> numThreads) {
314-
auto iterators = linalgOp.getIteratorTypesArray();
315-
SmallVector<bool> safeToTile(numThreads.size(), true);
316-
317-
for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
318-
if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
319-
if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
320-
safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
321-
}
322-
} else {
323-
safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
324-
}
325-
}
326-
return safeToTile;
327-
}
328-
329-
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
330-
/// tiling is specified by the number of tiles/threads `numThreads` and the
331-
/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
332-
/// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i],
333-
/// numThreads[i])`. If non-empty, the `mapping` is added as an
334-
/// attribute to the resulting `scf.forall`. A zero tile sizes indicate
335-
/// that the dimension is not tiled, and can be thought of as tiling by the full
336-
/// size of data.
337-
/// It is the user's responsibility to ensure that `numThreads` is a valid
338-
/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
339-
/// Linalg case). If the dimension is not parallelizable, a warning is issued to
340-
/// notify the user that the generated code is not safe to parallelize. If
341-
/// `omitTileOffsetBoundsCheck` is true, then the function will assume that
342-
/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
343-
static FailureOr<ForallTilingResult> tileToForallOpImpl(
344-
RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
345-
std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
346-
std::optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
347-
Location loc = op->getLoc();
348-
OpBuilder::InsertionGuard g(b);
349-
350-
SmallVector<Range> loopRanges = op.getIterationDomain(b);
351-
if (loopRanges.empty())
352-
return op->emitOpError("expected non-empty loop ranges");
353-
auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
354-
if (llvm::any_of(loopRanges, hasStrideOne))
355-
return op->emitOpError("only stride-1 supported atm");
356-
357-
// Gather destination tensors.
358-
SmallVector<Value> dest;
359-
if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
360-
return op->emitOpError("failed to get destination tensors");
361-
362-
SmallVector<OpFoldResult> nonZeroNumThreads =
363-
llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
364-
return !isConstantIntValue(ofr, 0);
365-
}));
366-
SmallVector<Value> materializedNonZeroNumThreads =
367-
llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
368-
return getValueOrCreateConstantIndexOp(b, loc, ofr);
369-
}));
370-
371-
LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
372-
if (linalgOp) {
373-
// Check if tiling is thread safe and print a warning if not.
374-
SmallVector<bool> tilingSafety =
375-
safeToTileToForall(b.getContext(), linalgOp, numThreads);
376-
for (size_t i = 0; i < tilingSafety.size(); i++)
377-
if (!tilingSafety[i])
378-
op.emitWarning() << "tiling is not thread safe at axis #" << i;
379-
}
380-
381-
// 1. Create the ForallOp. We don't use the lambda body-builder
382-
// version because we require the use of RewriterBase in the body, so we
383-
// manually move the insertion point to the body below.
384-
scf::ForallOp forallOp = b.create<scf::ForallOp>(
385-
loc, getAsOpFoldResult((materializedNonZeroNumThreads)), dest, mapping);
386-
387-
// 2. Fill out the ForallOp body.
388-
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
389-
calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, loopRanges,
390-
omitTileOffsetBoundsCheck, nominalTileSizes,
391-
tiledOffsets, tiledSizes);
392-
393-
// 3. Clone the tileable op and update its destination operands to use the
394-
// output bbArgs of the ForallOp.
395-
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
396-
Operation *tiledOp = nullptr;
397-
SmallVector<Value> tiledValues;
398-
{
399-
// 3.a. RAII guard, inserting within forallOp, before terminator.
400-
OpBuilder::InsertionGuard g(b);
401-
b.setInsertionPoint(forallOp.getTerminator());
402-
Operation *clonedOp = b.clone(*op.getOperation());
403-
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
404-
if (destinationStyleOp) {
405-
for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) {
406-
// Swap tensor inits with the corresponding block argument of the
407-
// scf.forall op. Memref inits remain as is.
408-
if (isa<TensorType>(outOperand.get().getType())) {
409-
auto *it = llvm::find(dest, outOperand.get());
410-
assert(it != dest.end() && "could not find destination tensor");
411-
unsigned destNum = std::distance(dest.begin(), it);
412-
outOperand.set(destBbArgs[destNum]);
413-
}
414-
}
415-
}
416-
417-
// 4. Tile the cloned op and delete the clone.
418-
FailureOr<TilingResult> tilingResult =
419-
cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
420-
tiledSizes);
421-
if (failed(tilingResult))
422-
return clonedOp->emitError("Failed to tile op: ");
423-
if (tilingResult->tiledOps.size() != 1) {
424-
return clonedOp->emitError("expected a single produced tiled op, got ")
425-
<< tilingResult->tiledOps.size();
426-
}
427-
428-
b.eraseOp(clonedOp);
429-
tiledOp = tilingResult->tiledOps.front();
430-
tiledValues = tilingResult->tiledValues;
431-
}
432-
433-
// 5. Parallel insert back into the result tensor.
434-
for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
435-
tiledValues, destBbArgs)) {
436-
// 5.a. Partial subset information is inserted just before the terminator.
437-
OpBuilder::InsertionGuard g(b);
438-
b.setInsertionPoint(forallOp.getTerminator());
439-
440-
SmallVector<OpFoldResult> resultOffsets, resultSizes;
441-
if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
442-
tiledSizes, resultOffsets,
443-
resultSizes)))
444-
return op->emitOpError("output offsets couldn't be calculated");
445-
SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
446-
447-
// 5.b. Parallel insertions are inserted at the end of the combining
448-
// terminator.
449-
b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
450-
b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
451-
std::get<2>(it), resultOffsets,
452-
resultSizes, strides);
453-
}
454-
return ForallTilingResult{forallOp, tiledOp};
455-
}
456-
457-
FailureOr<ForallTilingResult>
458-
linalg::tileToForallOp(RewriterBase &b, TilingInterface op,
459-
ArrayRef<OpFoldResult> numThreads,
460-
std::optional<ArrayAttr> mapping) {
461-
return tileToForallOpImpl(b, op, numThreads,
462-
/*nominalTileSizes=*/std::nullopt, mapping,
463-
/*omitTileOffsetBoundsCheck=*/false);
464-
}
465-
466-
FailureOr<ForallTilingResult>
467-
linalg::tileToForallOpUsingTileSizes(RewriterBase &b, TilingInterface op,
468-
ArrayRef<OpFoldResult> tileSizes,
469-
std::optional<ArrayAttr> mapping) {
470-
SmallVector<Range> loopRanges = op.getIterationDomain(b);
471-
unsigned nLoops = loopRanges.size();
472-
SmallVector<OpFoldResult> numThreads;
473-
numThreads.reserve(nLoops);
474-
AffineExpr s0, s1;
475-
bindSymbols(b.getContext(), s0, s1);
476-
AffineExpr divExpr = s0.ceilDiv(s1);
477-
for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
478-
OpFoldResult numTiles = std::get<0>(it);
479-
if (!isConstantIntValue(numTiles, 0))
480-
numTiles = makeComposedFoldedAffineApply(
481-
b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
482-
numThreads.push_back(numTiles);
483-
}
484-
return tileToForallOpImpl(b, op, numThreads,
485-
/*nominalTileSizes=*/tileSizes, mapping,
486-
/*omitTileOffsetBoundsCheck=*/true);
487-
}
488-
489307
template <typename LoopTy>
490308
static FailureOr<TiledLinalgOp>
491309
tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,

0 commit comments

Comments
 (0)