Skip to content

Commit 391d9a7

Browse files
[mlir][TilingInterface] Use LoopLikeOpInterface in tiling using SCF to unify tiling with scf.for and scf.forall.
Using `LoopLikeOpInterface` as the basis for the implementation unifies all the tiling logic for both `scf.for` and `scf.forall`. The only difference is the actual loop generation. Instead of many entry points for each loop type, the loop type is now passed as part of the options passed to the tiling method. This is a breaking change with the following changes 1) The `scf::tileUsingSCFForOp` is renamed to `scf::tileUsingSCF` 2) The `scf::tileUsingSCFForallOp` is deprecated. The same functionality is obtained by using `scf::tileUsingSCF` and setting the loop type in `scf::SCFTilingOptions` passed into this method to `scf::SCFTilingOptions::LoopType::ForallOp` (using the `setLoopType` method). 3) The `scf::tileConsumerAndFusedProducerGreedilyUsingSCFForOp` is renamed to `scf::tileConsumerAndFuseProducerUsingSCF`. The use of the `controlFn` in `scf::SCFTileAndFuseOptions` allows implementing any strategy with the default callback implemeting the greedy fusion. 4) The `scf::SCFTilingResult` and `scf::SCFTileAndFuseResult` now use `SmallVector<LoopLikeOpInterface>`. 5) To make `scf::ForallOp` implement the parts of `LoopLikeOpInterface` needed, the `getOutputBlockArguments()` method is replaced with `getRegionIterArgs()` This change also introduces a new interface method for `LoopLikeOpInterface`, that allows loop constructs to handle tiled yields. These changes now bring the tiling and fusion capabilities using `scf.forall` on par with what was already supported
1 parent aa2a96a commit 391d9a7

25 files changed

+1176
-534
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
135135

136136
def ForOp : SCF_Op<"for",
137137
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
138-
["getInitsMutable", "getSingleInductionVar", "getSingleLowerBound",
139-
"getSingleStep", "getSingleUpperBound", "getYieldedValuesMutable",
140-
"getLoopResults", "promoteIfSingleIteration",
141-
"replaceWithAdditionalYields"]>,
138+
["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar",
139+
"getSingleLowerBound", "getSingleStep", "getSingleUpperBound",
140+
"getYieldedValuesMutable", "getLoopResults", "promoteIfSingleIteration",
141+
"replaceWithAdditionalYields", "yieldTiledValuesAndReplace"]>,
142142
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
143143
ConditionallySpeculatable,
144144
DeclareOpInterfaceMethods<RegionBranchOpInterface,
@@ -259,10 +259,6 @@ def ForOp : SCF_Op<"for",
259259

260260
Value getInductionVar() { return getBody()->getArgument(0); }
261261

262-
Block::BlockArgListType getRegionIterArgs() {
263-
return getBody()->getArguments().drop_front(getNumInductionVars());
264-
}
265-
266262
/// Return the `index`-th region iteration argument.
267263
BlockArgument getRegionIterArg(unsigned index) {
268264
assert(index < getNumRegionIterArgs() &&
@@ -304,8 +300,9 @@ def ForallOp : SCF_Op<"forall", [
304300
AttrSizedOperandSegments,
305301
AutomaticAllocationScope,
306302
DeclareOpInterfaceMethods<LoopLikeOpInterface,
307-
["promoteIfSingleIteration", "getSingleInductionVar",
308-
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
303+
["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar",
304+
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep",
305+
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
309306
RecursiveMemoryEffects,
310307
SingleBlockImplicitTerminator<"scf::InParallelOp">,
311308
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
@@ -585,10 +582,6 @@ def ForallOp : SCF_Op<"forall", [
585582
getNumDynamicControlOperands() + getRank());
586583
}
587584

588-
ArrayRef<BlockArgument> getOutputBlockArguments() {
589-
return getBody()->getArguments().drop_front(getRank());
590-
}
591-
592585
::mlir::ValueRange getInductionVars() {
593586
return getBody()->getArguments().take_front(getRank());
594587
}

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/SCF/IR/SCF.h"
1313
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1414
#include "mlir/IR/PatternMatch.h"
15+
#include "mlir/Interfaces/LoopLikeInterface.h"
1516
#include "mlir/Interfaces/TilingInterface.h"
1617

1718
#include <deque>
@@ -52,6 +53,14 @@ struct SCFTilingOptions {
5253
return *this;
5354
}
5455

56+
/// Specify which loop construct to use for tile and fuse.
57+
enum class LoopType { ForOp, ForallOp };
58+
LoopType loopType = LoopType::ForOp;
59+
SCFTilingOptions &setLoopType(LoopType type) {
60+
loopType = type;
61+
return *this;
62+
}
63+
5564
/// Specify mapping of loops to devices. This is only respected when the loop
5665
/// constructs support such a mapping (like `scf.forall`). Will be ignored
5766
/// when using loop constructs that dont support such a mapping (like
@@ -71,23 +80,17 @@ struct SCFTilingResult {
7180
/// of the last op.
7281
SmallVector<Operation *> tiledOps;
7382
/// The `scf.for` operations that iterate over the tiles.
74-
SmallVector<Operation *> loops;
83+
SmallVector<LoopLikeOpInterface> loops;
7584
/// Values to use as replacements for the untiled op. Is the same size as the
7685
/// number of results of the untiled op.
7786
SmallVector<Value> replacements;
7887
};
7988

8089
/// Method to tile an op that implements the `TilingInterface` using
8190
/// `scf.for` for iterating over the tiles.
82-
FailureOr<SCFTilingResult> tileUsingSCFForOp(RewriterBase &rewriter,
83-
TilingInterface op,
84-
const SCFTilingOptions &options);
85-
86-
/// Method to tile an op that implements the `TilingInterface` using
87-
/// `scf.forall`.
88-
FailureOr<SCFTilingResult>
89-
tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
90-
const SCFTilingOptions &options);
91+
FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
92+
TilingInterface op,
93+
const SCFTilingOptions &options);
9194

9295
/// Options used to control tile + fuse.
9396
struct SCFTileAndFuseOptions {
@@ -135,7 +138,7 @@ struct SCFFuseProducerOfSliceResult {
135138
std::optional<SCFFuseProducerOfSliceResult>
136139
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
137140
tensor::ExtractSliceOp candidateSliceOp,
138-
MutableArrayRef<scf::ForOp> loops);
141+
MutableArrayRef<LoopLikeOpInterface> loops);
139142

140143
/// Reconstruct the fused producer from within the tiled-and-fused code. Based
141144
/// on the slice of the producer computed in place it is possible that within
@@ -187,10 +190,10 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
187190
/// where `%0` had other uses as well. If not reconstructed from within the loop
188191
/// body, uses of `%0` could not be replaced, making it still live and the
189192
/// fusion immaterial.
190-
void yieldReplacementForFusedProducer(
193+
LogicalResult yieldReplacementForFusedProducer(
191194
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
192195
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
193-
MutableArrayRef<scf::ForOp> loops);
196+
MutableArrayRef<LoopLikeOpInterface> loops);
194197

195198
/// Transformation information returned after tile and fuse.
196199
struct SCFTileAndFuseResult {
@@ -201,7 +204,7 @@ struct SCFTileAndFuseResult {
201204
/// generated operation.
202205
llvm::SetVector<Operation *> tiledAndFusedOps;
203206
/// The `scf.for` operations that iterate over the tiles.
204-
SmallVector<Operation *> loops;
207+
SmallVector<LoopLikeOpInterface> loops;
205208
/// The replacement values to use for the tiled and fused operations.
206209
llvm::DenseMap<Value, Value> replacements;
207210
};
@@ -232,9 +235,9 @@ struct SCFTileAndFuseResult {
232235
/// }
233236
/// ```
234237
FailureOr<SCFTileAndFuseResult>
235-
tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
236-
RewriterBase &rewriter, TilingInterface consumer,
237-
const SCFTileAndFuseOptions &options);
238+
tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
239+
TilingInterface consumer,
240+
const SCFTileAndFuseOptions &options);
238241

239242
/// Method to lower an `op` that implements the `TilingInterface` to
240243
/// loops/scalars.
@@ -249,8 +252,8 @@ struct SCFReductionTilingResult {
249252
Operation *mergeOp;
250253
/// Initial op
251254
Operation *initialOp;
252-
/// The `scf.for` operations that iterate over the tiles.
253-
SmallVector<scf::ForOp> loops;
255+
/// The loop operations that iterate over the tiles.
256+
SmallVector<LoopLikeOpInterface> loops;
254257
};
255258

256259
/// Method to tile a reduction and generate a parallel op within a serial loop.

mlir/include/mlir/Interfaces/LoopLikeInterface.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,28 @@ class RewriterBase;
2525
using NewYieldValuesFn = std::function<SmallVector<Value>(
2626
OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>;
2727

28+
/// A function that allows returning additional yielded values during
29+
/// `yieldTiledValuesAndReplace`.
30+
/// - `ivs` induction variable for the loop.
31+
/// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
32+
/// - `tiledValues` the tiled values to return. Must be of same size as
33+
/// `newbbArgs`, each element of this array is inserted into the corresponding
34+
/// element in `newbbArgs`.
35+
/// - `resultOffsets` is of the same size as `tiledValues` and represents
36+
/// the offsets to use when inserting corresponding element from `tiledValues`
37+
/// into the element from `newBbArgs`.
38+
/// - `resultSizes` is of the same size as `tiledValues` and represents
39+
/// the size of the corresponding element from `tiledValues` inserted into
40+
/// the element from `newBbArgs`.
41+
/// - `resultStrides` is of the same size as `tiledValues` and represents
42+
/// the strides to use when inserting corresponding element from `tiledValues`
43+
/// into the element from `newBbArgs`.
44+
using YieldTiledValuesFn = std::function<LogicalResult(
45+
RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
46+
SmallVector<Value> &tiledValues,
47+
SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
48+
SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
49+
2850
namespace detail {
2951
/// Verify invariants of the LoopLikeOpInterface.
3052
LogicalResult verifyLoopLikeOpInterface(Operation *op);

mlir/include/mlir/Interfaces/LoopLikeInterface.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,19 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
218218
return ::mlir::failure();
219219
}]
220220
>,
221+
InterfaceMethod<[{
222+
TODO
223+
}],
224+
/*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",
225+
/*methodName=*/"yieldTiledValuesAndReplace",
226+
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
227+
"::mlir::ValueRange":$newInitOperands,
228+
"const ::mlir::YieldTiledValuesFn &":$yieldTiledValuesFn),
229+
/*methodBody=*/"",
230+
/*defaultImplementation=*/[{
231+
return ::mlir::failure();
232+
}]
233+
>,
221234
];
222235

223236
let extraClassDeclaration = [{

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,8 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
485485
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
486486
[&](TilingInterface tilingInterfaceOp)
487487
-> FailureOr<scf::SCFTileAndFuseResult> {
488-
return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
489-
rewriter, tilingInterfaceOp, tileAndFuseOptions);
488+
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
489+
tileAndFuseOptions);
490490
});
491491
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
492492
: DiagnosedSilenceableFailure::success();
@@ -584,7 +584,7 @@ static Operation *replaceForAllWithNewSignature(
584584
Operation *firstYieldOp = yieldingOps.front();
585585
rewriter.setInsertionPoint(firstYieldOp);
586586
Value src = tileAndFuseResult.tiledValues[0];
587-
Value dst = newforallOp.getOutputBlockArguments().back();
587+
Value dst = newforallOp.getRegionIterArgs().back();
588588
SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
589589
rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
590590
dst, offsets, sizes, strides);
@@ -2063,7 +2063,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
20632063
});
20642064
SmallVector<int64_t> emptyTileSizes;
20652065
rewriter.setInsertionPoint(target);
2066-
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
2066+
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
20672067
rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
20682068
if (failed(maybeTilingResult))
20692069
return emitDefaultDefiniteFailure(target);
@@ -2647,7 +2647,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
26472647

26482648
tilingOptions.setInterchange(getInterchange());
26492649
FailureOr<scf::SCFTilingResult> maybeTilingResult =
2650-
tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions);
2650+
tileUsingSCF(rewriter, tilingInterface, tilingOptions);
26512651
if (failed(maybeTilingResult))
26522652
return DiagnosedSilenceableFailure::definiteFailure();
26532653

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
358358

359359
// 3. Clone the tileable op and update its destination operands to use the
360360
// output bbArgs of the ForallOp.
361-
ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
361+
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
362362
Operation *tiledOp = nullptr;
363363
SmallVector<Value> tiledValues;
364364
{
@@ -695,7 +695,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
695695
// 4. Clone the tileable op and update its destination operands to use the
696696
// output bbArgs of the ForallOp.
697697
SmallVector<Value> tilingResults;
698-
ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
698+
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
699699
{
700700
// 4.a. RAII guard, inserting within forallOp, before terminator.
701701
OpBuilder::InsertionGuard g(b);

0 commit comments

Comments
 (0)