Skip to content

Commit 6740d70

Browse files
[mlir][Linalg] Deprecate linalg::tileToForallOp and linalg::tileToForallOpUsingTileSizes (#91878)
The implementation of these methods are legacy and they are removed in favor of using the `scf::tileUsingSCF` methods as replacements. To get the latter on par with requirements of the deprecated methods, the tiling allows one to specify the maximum number of tiles to use instead of specifying the tile sizes. When tiling to `scf.forall` this specification is used to generate the `num_threads` version of the operation. A slight deviation from previous implementation is that the deprecated method always generated the `num_threads` variant of the `scf.forall` operation. Instead now this is driven by the tiling options specified. This reduces the indexing math generated when the tile sizes are specified. **Moving from `linalg::tileToForallOp` to `scf::tileUsingSCF`** ``` OpBuilder b; TilingInterface op; ArrayRef<OpFoldResult> numThreads; ArrayAttr mapping; FailureOr<ForallTilingResult> result =linalg::tileToForallOp(b, op, numThreads, mapping); ``` can be replaced by ``` scf::SCFTilingOptions options; options.setNumThreads(numThreads); options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); options.setMapping(mapping.getValue()); /*note the difference that setMapping takes an ArrayRef<Attribute> */ FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(b, op, options); ``` This generates the `numThreads` version of the `scf.forall` for the inter-tile loops, i.e. ``` ... = scf.forall (%arg0, %arg1) in (%nt0, %nt1) shared_outs(...) ``` **Moving from `linalg::tileToForallOpUsingTileSizes` to `scf::tileUsingSCF`** ``` OpBuilder b; TilingInterface op; ArrayRef<OpFoldResult> tileSizes; ArrayAttr mapping; FailureOr<ForallTilingResult> result =linalg::tileToForallOpUsingTileSizes(b, op, tileSizes, mapping); ``` can be replaced by ``` scf::SCFTilingOptions options; options.setTileSizes(tileSizes); options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); options.setMapping(mapping.getValue()); /*note the difference that setMapping takes an ArrayRef<Attribute> */ FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(b, op, options); ``` Also note that `linalg::tileToForallOpUsingTileSizes` would effectively call the `linalg::tileToForallOp` by computing the `numThreads` from the `op` and `tileSizes` and generate the `numThreads` version of the `scf.forall`. That is not the case anymore. Instead this will directly generate the `tileSizes` version of the `scf.forall` op ``` ... = scf.forall(%arg0, %arg1) = (%lb0, %lb1) to (%ub0, %ub1) step(%step0, %step1) shared_outs(...) ``` If you actually want to use the `numThreads` version, it is upto the caller to compute the `numThreads` and set `options.setNumThreads` instead of `options.setTileSizes`. Note that there is a slight difference in the num threads version and tile size version. The former requires an additional `affine.max` on the tile size to ensure non-negative tile sizes. When lowering to `numThreads` version this `affine.max` is not needed since by construction the tile sizes are non-negative. In previous implementations, the `numThreads` version generated when using the `linalg::tileToForallOpUsingTileSizes` method would avoid generating the `affine.max` operation. To get the same state, downstream users will have to additionally normalize the `scf.forall` operation. **Changes to `transform.structured.tile_using_forall`** The transform dialect op that called into `linalg::tileToForallOp` and `linalg::tileToForallOpUsingTileSizes` have been modified to call `scf::tileUsingSCF`. The transform dialect op always generates the `numThreads` version of the `scf.forall` op. So when `tile_sizes` are specified for the transform dialect op, first the `tile_sizes` version of the `scf.forall` is generated by the `scf::tileUsingSCF` method which is then further normalized to get back to the same state. So there is no functional change to `transform.structured.tile_using_forall`. It always generates the `numThreads` version of the `scf.forall` op (as it did before this change). --------- Signed-off-by: MaheshRavishankar <[email protected]>
1 parent ef67664 commit 6740d70

File tree

17 files changed

+600
-409
lines changed

17 files changed

+600
-409
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class GenericOp;
3030
class LinalgOp;
3131
} // namespace linalg
3232

33+
namespace scf {
34+
struct SCFTilingResult;
35+
} // namespace scf
36+
3337
namespace tensor {
3438
class InsertSliceOp;
3539
class PackOp;
@@ -60,7 +64,7 @@ tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
6064
ArrayRef<OpFoldResult> mixedNumThreads,
6165
ArrayRef<OpFoldResult> mixedTileSizes,
6266
std::optional<ArrayAttr> mapping,
63-
linalg::ForallTilingResult &tilingResult);
67+
scf::SCFTilingResult &tilingResult);
6468

6569
} // namespace transform
6670
} // namespace mlir

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -866,29 +866,6 @@ FailureOr<ContinuousTileSizeSpecification>
866866
computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
867867
unsigned dimension, OpFoldResult targetSize,
868868
bool emitAssertions);
869-
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
870-
/// tiling by `numThreads`.
871-
/// If non-empty, the `mapping` is added as an attribute to the
872-
/// resulting `scf.forall`.
873-
/// Zero tile sizes indicate that the dimension is not tiled, and can be
874-
/// thought of as tiling by the full size of data. It is the user's
875-
/// responsibility to ensure that `numThreads` is a valid tiling specification
876-
/// (i.e. that only tiles parallel dimensions, e.g. in the Linalg case).
877-
struct ForallTilingResult {
878-
Operation *tileOp;
879-
Operation *tiledOp;
880-
};
881-
FailureOr<ForallTilingResult> tileToForallOp(RewriterBase &builder,
882-
TilingInterface op,
883-
ArrayRef<OpFoldResult> numThreads,
884-
std::optional<ArrayAttr> mapping);
885-
886-
/// Same as `tileToForallOp`, but calculate the number of threads
887-
/// required using the given tileSizes.
888-
FailureOr<ForallTilingResult>
889-
tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
890-
ArrayRef<OpFoldResult> tileSizes,
891-
std::optional<ArrayAttr> mapping);
892869

893870
/// Transformation information returned after reduction tiling.
894871
struct ForallReductionTilingResult {
@@ -1750,10 +1727,12 @@ void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
17501727
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
17511728

17521729
/// Adds patterns that reduce the rank of named contraction ops that have
1753-
/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
1754-
/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a
1755-
/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
1756-
/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
1730+
/// unit dimensions in the operand(s) by converting to a sequence of
1731+
/// `collapse_shape`,
1732+
/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For
1733+
/// example a `linalg.batch_matmul` with unit batch size will convert to
1734+
/// `linalg.matmul` and a `linalg.matvec` with with unit spatial dim in lhs will
1735+
/// convert to a `linalg.dot`.
17571736
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
17581737

17591738
} // namespace linalg

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ using SCFTileSizeComputationFunction =
3232

3333
/// Options to use to control tiling.
3434
struct SCFTilingOptions {
35-
/// Computation function that returns the tile sizes for each operation.
36-
/// Delayed construction of constant tile sizes should occur to interoperate
37-
/// with folding.
35+
/// Computation function that returns the tile sizes to use for each loop.
36+
/// Returning a tile size of zero implies no tiling for that loop. If the
37+
/// size of the returned vector is smaller than the number of loops, the inner
38+
/// loops are not tiled. If the size of the returned vector is larger, then
39+
/// the vector is truncated to number of loops.
3840
SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
3941

4042
SCFTilingOptions &
@@ -45,7 +47,27 @@ struct SCFTilingOptions {
4547
/// Convenience function to set the `tileSizeComputationFunction` to a
4648
/// function that computes tile sizes at the point they are needed. Allows
4749
/// proper interaction with folding.
48-
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
50+
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
51+
52+
/// Computation function that returns the number of threads to use for
53+
/// each loop. Returning a num threads of zero implies no tiling for that
54+
/// loop. If the size of the returned vector is smaller than the number of
55+
/// loops, the inner loops are not tiled. If the size of the returned vector
56+
/// is larger, then the vector is truncated to number of loops. Note: This
57+
/// option is only supported with loopType set to `LoopType::ForallOp`. If the
58+
/// tile size function is not specified while the num threads computation is,
59+
/// then the tile size is determined automatically to map at most one tile per
60+
/// thread.
61+
SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr;
62+
63+
SCFTilingOptions &
64+
setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
65+
numThreadsComputationFunction = std::move(fun);
66+
return *this;
67+
}
68+
/// Convenience function to set the `numThreadsComputationFunction` to a
69+
/// function that computes num threads at the point they are needed.
70+
SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
4971

5072
/// The interchange vector to reorder the tiled loops.
5173
SmallVector<int64_t> interchangeVector = {};
@@ -67,9 +89,8 @@ struct SCFTilingOptions {
6789
/// when using loop constructs that dont support such a mapping (like
6890
/// `scf.for`)
6991
SmallVector<Attribute> mappingVector = {};
70-
SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
71-
mappingVector = llvm::map_to_vector(
72-
mapping, [](auto attr) -> Attribute { return attr; });
92+
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
93+
mappingVector = llvm::to_vector(mapping);
7394
return *this;
7495
}
7596
};

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,14 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
195195
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
196196
RewriterBase &rewriter);
197197

198+
/// Normalize an `scf.forall` operation. Returns `failure()`if normalization
199+
/// fails.
200+
// On `success()` returns the
201+
/// newly created operation with all uses of the original operation replaced
202+
/// with results of the new operation.
203+
FailureOr<scf::ForallOp> normalizeForallOp(RewriterBase &rewriter,
204+
scf::ForallOp forallOp);
205+
198206
} // namespace mlir
199207

200208
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 118 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1414
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1516
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1617
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
1718
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -3151,12 +3152,100 @@ void transform::TileUsingForallOp::build(OpBuilder &builder,
31513152
/*mapping=*/mapping);
31523153
}
31533154

3155+
/// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3156+
/// normalized upper bound.
3157+
static SmallVector<OpFoldResult>
3158+
normalizeUpperBounds(RewriterBase &rewriter, Location loc,
3159+
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
3160+
ArrayRef<OpFoldResult> steps) {
3161+
AffineExpr s0, s1, s2;
3162+
bindSymbols(rewriter.getContext(), s0, s1, s2);
3163+
AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3164+
SmallVector<OpFoldResult> normalizedUbs;
3165+
for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3166+
OpFoldResult normalizedUb = affine::makeComposedFoldedAffineApply(
3167+
rewriter, loc, normalizedUbExpr, {lb, ub, step});
3168+
normalizedUbs.push_back(normalizedUb);
3169+
}
3170+
return normalizedUbs;
3171+
}
3172+
3173+
/// When a loop is normalized, the uses of the induction variable within the
3174+
/// loop need to replaced with `original_lb + old_iv * original_step`.
3175+
static SmallVector<Value> denormalizeIndVar(RewriterBase &rewriter,
3176+
Location loc, ValueRange ivs,
3177+
ArrayRef<OpFoldResult> lbs,
3178+
ArrayRef<OpFoldResult> steps) {
3179+
AffineExpr s0, s1;
3180+
AffineExpr d0;
3181+
bindSymbols(rewriter.getContext(), s0, s1);
3182+
bindDims(rewriter.getContext(), d0);
3183+
AffineExpr denormExpr = s0 + d0 * s1;
3184+
SmallVector<Value> denormalizedIvs;
3185+
3186+
for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3187+
OpFoldResult denormValue = affine::makeComposedFoldedAffineApply(
3188+
rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3189+
denormalizedIvs.push_back(
3190+
getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3191+
}
3192+
return denormalizedIvs;
3193+
}
3194+
3195+
/// Given a `scf.forall` loop return a loop op with the loop bounds
3196+
/// normalized.
3197+
/// TODO: Replace this with a general utility to normalize `scf.forall`.
3198+
/// At the time of writing, this wasnt done since adding this to `scf`
3199+
/// dialect would disallow using of `affine.apply` operations due
3200+
/// to cyclic dependencies. To avoid churn in lit tests
3201+
/// with the change this was added with, defer that to a follow up.
3202+
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3203+
scf::ForallOp loop) {
3204+
SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3205+
SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3206+
SmallVector<OpFoldResult> steps = loop.getMixedStep();
3207+
3208+
if (llvm::all_of(
3209+
lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
3210+
llvm::all_of(
3211+
steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
3212+
return loop;
3213+
}
3214+
3215+
Location loc = loop.getLoc();
3216+
SmallVector<OpFoldResult> normalizedUbs =
3217+
normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3218+
SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3219+
rewriter.getIndexAttr(0));
3220+
SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3221+
rewriter.getIndexAttr(1));
3222+
3223+
auto normalizedForallOp = rewriter.create<scf::ForallOp>(
3224+
loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
3225+
loop.getMapping(), [](OpBuilder &, Location, ValueRange) {});
3226+
3227+
auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3228+
OpBuilder::InsertionGuard g(rewriter);
3229+
Block *normalizedLoopBlock = normalizedForallOp.getBody();
3230+
rewriter.setInsertionPointToStart(normalizedLoopBlock);
3231+
3232+
SmallVector<Value> argValues =
3233+
denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3234+
argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3235+
normalizedForallOp.getRegionIterArgs().end());
3236+
Block *origLoopBlock = loop.getBody();
3237+
rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3238+
3239+
rewriter.replaceOp(loop, normalizedForallOp);
3240+
return normalizedForallOp;
3241+
}
3242+
31543243
DiagnosedSilenceableFailure transform::tileToForallOpImpl(
31553244
RewriterBase &rewriter, transform::TransformState &state,
31563245
TransformOpInterface transformOp, Operation *target,
31573246
ArrayRef<OpFoldResult> mixedNumThreads,
31583247
ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3159-
linalg::ForallTilingResult &tilingResult) {
3248+
scf::SCFTilingResult &tilingResult) {
31603249
// Transform all targets one by one.
31613250
auto tileableOp = dyn_cast<TilingInterface>(target);
31623251
if (!tileableOp) {
@@ -3167,20 +3256,35 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
31673256
return diag;
31683257
}
31693258
rewriter.setInsertionPoint(tileableOp);
3170-
FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
3259+
scf::SCFTilingOptions options;
3260+
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
31713261
if (!mixedNumThreads.empty()) {
3172-
maybeTilingResult =
3173-
linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
3262+
options.setNumThreads(mixedNumThreads);
31743263
} else {
3175-
maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
3176-
rewriter, tileableOp, mixedTileSizes, mapping);
3264+
options.setTileSizes(mixedTileSizes);
31773265
}
3266+
if (mapping) {
3267+
options.setMapping(mapping.value().getValue());
3268+
}
3269+
FailureOr<scf::SCFTilingResult> maybeTilingResult =
3270+
scf::tileUsingSCF(rewriter, tileableOp, options);
31783271

31793272
if (failed(maybeTilingResult))
31803273
return transformOp.emitDefaultSilenceableFailure(tileableOp);
3181-
rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
3274+
3275+
rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
31823276

31833277
tilingResult = *maybeTilingResult;
3278+
3279+
if (mixedNumThreads.empty()) {
3280+
auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3281+
OpBuilder::InsertionGuard g(rewriter);
3282+
rewriter.setInsertionPoint(generatedForallOp);
3283+
scf::ForallOp normalizedForallOp =
3284+
normalizeForallLoopOp(rewriter, generatedForallOp);
3285+
tilingResult.loops.front() = normalizedForallOp;
3286+
}
3287+
31843288
return DiagnosedSilenceableFailure::success();
31853289
}
31863290

@@ -3214,14 +3318,14 @@ DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
32143318
return status;
32153319

32163320
for (Operation *target : state.getPayloadOps(getTarget())) {
3217-
linalg::ForallTilingResult tilingResult;
3321+
scf::SCFTilingResult tilingResult;
32183322
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
32193323
rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
32203324
getMapping(), tilingResult);
32213325
if (!diag.succeeded())
32223326
return diag;
3223-
tileOps.push_back(tilingResult.tileOp);
3224-
tiledOps.push_back(tilingResult.tiledOp);
3327+
tileOps.push_back(tilingResult.loops.front());
3328+
tiledOps.append(tilingResult.tiledOps);
32253329
}
32263330

32273331
transformResults.set(cast<OpResult>(getForallOp()), tileOps);
@@ -3699,7 +3803,7 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
36993803

37003804
// OpBuilder only used to compute attributes.
37013805
OpBuilder b(getContext());
3702-
linalg::ForallTilingResult tilingResult;
3806+
scf::SCFTilingResult tilingResult;
37033807
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
37043808
/*rewriter=*/rewriter,
37053809
/*state=*/state,
@@ -3712,8 +3816,9 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
37123816
if (!diag.succeeded())
37133817
return diag;
37143818

3715-
results.push_back(tilingResult.tileOp);
3716-
results.push_back(tilingResult.tiledOp);
3819+
results.push_back(tilingResult.loops.front());
3820+
for (auto op : tilingResult.tiledOps)
3821+
results.push_back(op);
37173822
return DiagnosedSilenceableFailure::success();
37183823
}
37193824

0 commit comments

Comments
 (0)