Skip to content

[mlir][Linalg] Deprecate linalg::tileToForallOp and linalg::tileToForallOpUsingTileSizes #91878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class GenericOp;
class LinalgOp;
} // namespace linalg

namespace scf {
struct SCFTilingResult;
} // namespace scf

namespace tensor {
class InsertSliceOp;
class PackOp;
Expand Down Expand Up @@ -60,7 +64,7 @@ tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
ArrayRef<OpFoldResult> mixedNumThreads,
ArrayRef<OpFoldResult> mixedTileSizes,
std::optional<ArrayAttr> mapping,
linalg::ForallTilingResult &tilingResult);
scf::SCFTilingResult &tilingResult);

} // namespace transform
} // namespace mlir
Expand Down
33 changes: 6 additions & 27 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -866,29 +866,6 @@ FailureOr<ContinuousTileSizeSpecification>
computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
unsigned dimension, OpFoldResult targetSize,
bool emitAssertions);
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
/// tiling by `numThreads`.
/// If non-empty, the `mapping` is added as an attribute to the
/// resulting `scf.forall`.
/// Zero tile sizes indicate that the dimension is not tiled, and can be
/// thought of as tiling by the full size of data. It is the user's
/// responsibility to ensure that `numThreads` is a valid tiling specification
/// (i.e. that only tiles parallel dimensions, e.g. in the Linalg case).
struct ForallTilingResult {
Operation *tileOp;
Operation *tiledOp;
};
FailureOr<ForallTilingResult> tileToForallOp(RewriterBase &builder,
TilingInterface op,
ArrayRef<OpFoldResult> numThreads,
std::optional<ArrayAttr> mapping);

/// Same as `tileToForallOp`, but calculate the number of threads
/// required using the given tileSizes.
FailureOr<ForallTilingResult>
tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
ArrayRef<OpFoldResult> tileSizes,
std::optional<ArrayAttr> mapping);

/// Transformation information returned after reduction tiling.
struct ForallReductionTilingResult {
Expand Down Expand Up @@ -1750,10 +1727,12 @@ void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);

/// Adds patterns that reduce the rank of named contraction ops that have
/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a
/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
/// unit dimensions in the operand(s) by converting to a sequence of
/// `collapse_shape`,
/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For
/// example a `linalg.batch_matmul` with unit batch size will convert to
/// `linalg.matmul` and a `linalg.matvec` with with unit spatial dim in lhs will
/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);

} // namespace linalg
Expand Down
35 changes: 28 additions & 7 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ using SCFTileSizeComputationFunction =

/// Options to use to control tiling.
struct SCFTilingOptions {
/// Computation function that returns the tile sizes for each operation.
/// Delayed construction of constant tile sizes should occur to interoperate
/// with folding.
/// Computation function that returns the tile sizes to use for each loop.
/// Returning a tile size of zero implies no tiling for that loop. If the
/// size of the returned vector is smaller than the number of loops, the inner
/// loops are not tiled. If the size of the returned vector is larger, then
/// the vector is truncated to number of loops.
SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;

SCFTilingOptions &
Expand All @@ -45,7 +47,27 @@ struct SCFTilingOptions {
/// Convenience function to set the `tileSizeComputationFunction` to a
/// function that computes tile sizes at the point they are needed. Allows
/// proper interaction with folding.
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);

/// Computation function that returns the number of threads to use for
/// each loop. Returning a num threads of zero implies no tiling for that
/// loop. If the size of the returned vector is smaller than the number of
/// loops, the inner loops are not tiled. If the size of the returned vector
/// is larger, then the vector is truncated to number of loops. Note: This
/// option is only supported with loopType set to `LoopType::ForallOp`. If the
/// tile size function is not specified while the num threads computation is,
/// then the tile size is determined automatically to map at most one tile per
/// thread.
SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr;

SCFTilingOptions &
setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
numThreadsComputationFunction = std::move(fun);
return *this;
}
/// Convenience function to set the `numThreadsComputationFunction` to a
/// function that computes num threads at the point they are needed.
SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);

/// The interchange vector to reorder the tiled loops.
SmallVector<int64_t> interchangeVector = {};
Expand All @@ -67,9 +89,8 @@ struct SCFTilingOptions {
/// when using loop constructs that dont support such a mapping (like
/// `scf.for`)
SmallVector<Attribute> mappingVector = {};
SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
mappingVector = llvm::map_to_vector(
mapping, [](auto attr) -> Attribute { return attr; });
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
mappingVector = llvm::to_vector(mapping);
return *this;
}
};
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,14 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
RewriterBase &rewriter);

/// Normalize an `scf.forall` operation. Returns `failure()`if normalization
/// fails.
// On `success()` returns the
/// newly created operation with all uses of the original operation replaced
/// with results of the new operation.
FailureOr<scf::ForallOp> normalizeForallOp(RewriterBase &rewriter,
scf::ForallOp forallOp);

} // namespace mlir

#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
131 changes: 118 additions & 13 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
Expand Down Expand Up @@ -3151,12 +3152,100 @@ void transform::TileUsingForallOp::build(OpBuilder &builder,
/*mapping=*/mapping);
}

/// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
/// normalized upper bound.
static SmallVector<OpFoldResult>
normalizeUpperBounds(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> steps) {
AffineExpr s0, s1, s2;
bindSymbols(rewriter.getContext(), s0, s1, s2);
AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
SmallVector<OpFoldResult> normalizedUbs;
for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
OpFoldResult normalizedUb = affine::makeComposedFoldedAffineApply(
rewriter, loc, normalizedUbExpr, {lb, ub, step});
normalizedUbs.push_back(normalizedUb);
}
return normalizedUbs;
}

/// When a loop is normalized, the uses of the induction variable within the
/// loop need to replaced with `original_lb + old_iv * original_step`.
static SmallVector<Value> denormalizeIndVar(RewriterBase &rewriter,
Location loc, ValueRange ivs,
ArrayRef<OpFoldResult> lbs,
ArrayRef<OpFoldResult> steps) {
AffineExpr s0, s1;
AffineExpr d0;
bindSymbols(rewriter.getContext(), s0, s1);
bindDims(rewriter.getContext(), d0);
AffineExpr denormExpr = s0 + d0 * s1;
SmallVector<Value> denormalizedIvs;

for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
OpFoldResult denormValue = affine::makeComposedFoldedAffineApply(
rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
denormalizedIvs.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
}
return denormalizedIvs;
}

/// Given a `scf.forall` loop return a loop op with the loop bounds
/// normalized.
/// TODO: Replace this with a general utility to normalize `scf.forall`.
/// At the time of writing, this wasnt done since adding this to `scf`
/// dialect would disallow using of `affine.apply` operations due
/// to cyclic dependencies. To avoid churn in lit tests
/// with the change this was added with, defer that to a follow up.
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
scf::ForallOp loop) {
SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
SmallVector<OpFoldResult> steps = loop.getMixedStep();

if (llvm::all_of(
lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
llvm::all_of(
steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
return loop;
}

Location loc = loop.getLoc();
SmallVector<OpFoldResult> normalizedUbs =
normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
rewriter.getIndexAttr(1));

auto normalizedForallOp = rewriter.create<scf::ForallOp>(
loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
loop.getMapping(), [](OpBuilder &, Location, ValueRange) {});

auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
OpBuilder::InsertionGuard g(rewriter);
Block *normalizedLoopBlock = normalizedForallOp.getBody();
rewriter.setInsertionPointToStart(normalizedLoopBlock);

SmallVector<Value> argValues =
denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
normalizedForallOp.getRegionIterArgs().end());
Block *origLoopBlock = loop.getBody();
rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);

rewriter.replaceOp(loop, normalizedForallOp);
return normalizedForallOp;
}

DiagnosedSilenceableFailure transform::tileToForallOpImpl(
RewriterBase &rewriter, transform::TransformState &state,
TransformOpInterface transformOp, Operation *target,
ArrayRef<OpFoldResult> mixedNumThreads,
ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
linalg::ForallTilingResult &tilingResult) {
scf::SCFTilingResult &tilingResult) {
// Transform all targets one by one.
auto tileableOp = dyn_cast<TilingInterface>(target);
if (!tileableOp) {
Expand All @@ -3167,20 +3256,35 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
return diag;
}
rewriter.setInsertionPoint(tileableOp);
FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
scf::SCFTilingOptions options;
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
if (!mixedNumThreads.empty()) {
maybeTilingResult =
linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
options.setNumThreads(mixedNumThreads);
} else {
maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
rewriter, tileableOp, mixedTileSizes, mapping);
options.setTileSizes(mixedTileSizes);
}
if (mapping) {
options.setMapping(mapping.value().getValue());
}
FailureOr<scf::SCFTilingResult> maybeTilingResult =
scf::tileUsingSCF(rewriter, tileableOp, options);

if (failed(maybeTilingResult))
return transformOp.emitDefaultSilenceableFailure(tileableOp);
rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());

rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);

tilingResult = *maybeTilingResult;

if (mixedNumThreads.empty()) {
auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(generatedForallOp);
scf::ForallOp normalizedForallOp =
normalizeForallLoopOp(rewriter, generatedForallOp);
tilingResult.loops.front() = normalizedForallOp;
}

return DiagnosedSilenceableFailure::success();
}

Expand Down Expand Up @@ -3214,14 +3318,14 @@ DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
return status;

for (Operation *target : state.getPayloadOps(getTarget())) {
linalg::ForallTilingResult tilingResult;
scf::SCFTilingResult tilingResult;
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
getMapping(), tilingResult);
if (!diag.succeeded())
return diag;
tileOps.push_back(tilingResult.tileOp);
tiledOps.push_back(tilingResult.tiledOp);
tileOps.push_back(tilingResult.loops.front());
tiledOps.append(tilingResult.tiledOps);
}

transformResults.set(cast<OpResult>(getForallOp()), tileOps);
Expand Down Expand Up @@ -3699,7 +3803,7 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(

// OpBuilder only used to compute attributes.
OpBuilder b(getContext());
linalg::ForallTilingResult tilingResult;
scf::SCFTilingResult tilingResult;
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
/*rewriter=*/rewriter,
/*state=*/state,
Expand All @@ -3712,8 +3816,9 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
if (!diag.succeeded())
return diag;

results.push_back(tilingResult.tileOp);
results.push_back(tilingResult.tiledOp);
results.push_back(tilingResult.loops.front());
for (auto op : tilingResult.tiledOps)
results.push_back(op);
return DiagnosedSilenceableFailure::success();
}

Expand Down
Loading