Skip to content

Commit 76ead96

Browse files
[mlir][TilingInterface] Use LoopLikeOpInterface in tiling using SCF to unify tiling with scf.for and scf.forall. (#77874)
Using `LoopLikeOpInterface` as the basis for the implementation unifies all the tiling logic for both `scf.for` and `scf.forall`. The only difference is the actual loop generation. This is a follow up to #72178 Instead of many entry points for each loop type, the loop type is now passed as part of the options passed to the tiling method. This is a breaking change with the following changes 1) The `scf::tileUsingSCFForOp` is renamed to `scf::tileUsingSCF` 2) The `scf::tileUsingSCFForallOp` is deprecated. The same functionality is obtained by using `scf::tileUsingSCF` and setting the loop type in `scf::SCFTilingOptions` passed into this method to `scf::SCFTilingOptions::LoopType::ForallOp` (using the `setLoopType` method). 3) The `scf::tileConsumerAndFusedProducerGreedilyUsingSCFForOp` is renamed to `scf::tileConsumerAndFuseProducerUsingSCF`. The use of the `controlFn` in `scf::SCFTileAndFuseOptions` allows implementing any strategy with the default callback implemeting the greedy fusion. 4) The `scf::SCFTilingResult` and `scf::SCFTileAndFuseResult` now use `SmallVector<LoopLikeOpInterface>`. 5) To make `scf::ForallOp` implement the parts of `LoopLikeOpInterface` needed, the `getOutputBlockArguments()` method is replaced with `getRegionIterArgs()` These changes now bring the tiling and fusion capabilities using `scf.forall` on par with what was already supported by `scf.for`
1 parent 3c9f34c commit 76ead96

28 files changed

+1236
-549
lines changed

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,7 +1935,7 @@ mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) {
19351935
return {};
19361936
}
19371937

1938-
llvm::MutableArrayRef<mlir::OpOperand>
1938+
std::optional<llvm::MutableArrayRef<mlir::OpOperand>>
19391939
fir::IterWhileOp::getYieldedValuesMutable() {
19401940
auto *term = getRegion().front().getTerminator();
19411941
return getFinalValue() ? term->getOpOperands().drop_front()
@@ -2247,7 +2247,7 @@ mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) {
22472247
return {};
22482248
}
22492249

2250-
llvm::MutableArrayRef<mlir::OpOperand>
2250+
std::optional<llvm::MutableArrayRef<mlir::OpOperand>>
22512251
fir::DoLoopOp::getYieldedValuesMutable() {
22522252
auto *term = getRegion().front().getTerminator();
22532253
return getFinalValue() ? term->getOpOperands().drop_front()

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,11 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
135135

136136
def ForOp : SCF_Op<"for",
137137
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
138-
["getInitsMutable", "getSingleInductionVar", "getSingleLowerBound",
139-
"getSingleStep", "getSingleUpperBound", "getYieldedValuesMutable",
140-
"getLoopResults", "promoteIfSingleIteration",
141-
"replaceWithAdditionalYields"]>,
138+
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
139+
"getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
140+
"getSingleUpperBound", "getYieldedValuesMutable",
141+
"promoteIfSingleIteration", "replaceWithAdditionalYields",
142+
"yieldTiledValuesAndReplace"]>,
142143
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
143144
ConditionallySpeculatable,
144145
DeclareOpInterfaceMethods<RegionBranchOpInterface,
@@ -259,10 +260,6 @@ def ForOp : SCF_Op<"for",
259260

260261
Value getInductionVar() { return getBody()->getArgument(0); }
261262

262-
Block::BlockArgListType getRegionIterArgs() {
263-
return getBody()->getArguments().drop_front(getNumInductionVars());
264-
}
265-
266263
/// Return the `index`-th region iteration argument.
267264
BlockArgument getRegionIterArg(unsigned index) {
268265
assert(index < getNumRegionIterArgs() &&
@@ -304,8 +301,9 @@ def ForallOp : SCF_Op<"forall", [
304301
AttrSizedOperandSegments,
305302
AutomaticAllocationScope,
306303
DeclareOpInterfaceMethods<LoopLikeOpInterface,
307-
["promoteIfSingleIteration", "getSingleInductionVar",
308-
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
304+
["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar",
305+
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep",
306+
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
309307
RecursiveMemoryEffects,
310308
SingleBlockImplicitTerminator<"scf::InParallelOp">,
311309
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
@@ -585,10 +583,6 @@ def ForallOp : SCF_Op<"forall", [
585583
getNumDynamicControlOperands() + getRank());
586584
}
587585

588-
ArrayRef<BlockArgument> getOutputBlockArguments() {
589-
return getBody()->getArguments().drop_front(getRank());
590-
}
591-
592586
::mlir::ValueRange getInductionVars() {
593587
return getBody()->getArguments().take_front(getRank());
594588
}

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/SCF/IR/SCF.h"
1313
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1414
#include "mlir/IR/PatternMatch.h"
15+
#include "mlir/Interfaces/LoopLikeInterface.h"
1516
#include "mlir/Interfaces/TilingInterface.h"
1617

1718
#include <deque>
@@ -52,6 +53,14 @@ struct SCFTilingOptions {
5253
return *this;
5354
}
5455

56+
/// Specify which loop construct to use for tile and fuse.
57+
enum class LoopType { ForOp, ForallOp };
58+
LoopType loopType = LoopType::ForOp;
59+
SCFTilingOptions &setLoopType(LoopType type) {
60+
loopType = type;
61+
return *this;
62+
}
63+
5564
/// Specify mapping of loops to devices. This is only respected when the loop
5665
/// constructs support such a mapping (like `scf.forall`). Will be ignored
5766
/// when using loop constructs that dont support such a mapping (like
@@ -71,23 +80,17 @@ struct SCFTilingResult {
7180
/// of the last op.
7281
SmallVector<Operation *> tiledOps;
7382
/// The `scf.for` operations that iterate over the tiles.
74-
SmallVector<Operation *> loops;
83+
SmallVector<LoopLikeOpInterface> loops;
7584
/// Values to use as replacements for the untiled op. Is the same size as the
7685
/// number of results of the untiled op.
7786
SmallVector<Value> replacements;
7887
};
7988

8089
/// Method to tile an op that implements the `TilingInterface` using
8190
/// `scf.for` for iterating over the tiles.
82-
FailureOr<SCFTilingResult> tileUsingSCFForOp(RewriterBase &rewriter,
83-
TilingInterface op,
84-
const SCFTilingOptions &options);
85-
86-
/// Method to tile an op that implements the `TilingInterface` using
87-
/// `scf.forall`.
88-
FailureOr<SCFTilingResult>
89-
tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
90-
const SCFTilingOptions &options);
91+
FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
92+
TilingInterface op,
93+
const SCFTilingOptions &options);
9194

9295
/// Options used to control tile + fuse.
9396
struct SCFTileAndFuseOptions {
@@ -135,7 +138,7 @@ struct SCFFuseProducerOfSliceResult {
135138
std::optional<SCFFuseProducerOfSliceResult>
136139
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
137140
tensor::ExtractSliceOp candidateSliceOp,
138-
MutableArrayRef<scf::ForOp> loops);
141+
MutableArrayRef<LoopLikeOpInterface> loops);
139142

140143
/// Reconstruct the fused producer from within the tiled-and-fused code. Based
141144
/// on the slice of the producer computed in place it is possible that within
@@ -187,10 +190,10 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
187190
/// where `%0` had other uses as well. If not reconstructed from within the loop
188191
/// body, uses of `%0` could not be replaced, making it still live and the
189192
/// fusion immaterial.
190-
void yieldReplacementForFusedProducer(
193+
LogicalResult yieldReplacementForFusedProducer(
191194
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
192195
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
193-
MutableArrayRef<scf::ForOp> loops);
196+
MutableArrayRef<LoopLikeOpInterface> loops);
194197

195198
/// Transformation information returned after tile and fuse.
196199
struct SCFTileAndFuseResult {
@@ -201,7 +204,7 @@ struct SCFTileAndFuseResult {
201204
/// generated operation.
202205
llvm::SetVector<Operation *> tiledAndFusedOps;
203206
/// The `scf.for` operations that iterate over the tiles.
204-
SmallVector<Operation *> loops;
207+
SmallVector<LoopLikeOpInterface> loops;
205208
/// The replacement values to use for the tiled and fused operations.
206209
llvm::DenseMap<Value, Value> replacements;
207210
};
@@ -232,9 +235,9 @@ struct SCFTileAndFuseResult {
232235
/// }
233236
/// ```
234237
FailureOr<SCFTileAndFuseResult>
235-
tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
236-
RewriterBase &rewriter, TilingInterface consumer,
237-
const SCFTileAndFuseOptions &options);
238+
tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
239+
TilingInterface consumer,
240+
const SCFTileAndFuseOptions &options);
238241

239242
/// Method to lower an `op` that implements the `TilingInterface` to
240243
/// loops/scalars.
@@ -249,8 +252,8 @@ struct SCFReductionTilingResult {
249252
Operation *mergeOp;
250253
/// Initial op
251254
Operation *initialOp;
252-
/// The `scf.for` operations that iterate over the tiles.
253-
SmallVector<scf::ForOp> loops;
255+
/// The loop operations that iterate over the tiles.
256+
SmallVector<LoopLikeOpInterface> loops;
254257
};
255258

256259
/// Method to tile a reduction and generate a parallel op within a serial loop.

mlir/include/mlir/Interfaces/LoopLikeInterface.td

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,16 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
164164
InterfaceMethod<[{
165165
Return the mutable operand range of values that are yielded to the next
166166
iteration by the loop terminator.
167+
168+
For loop operations that dont yield a value, this should return
169+
std::nullopt.
167170
}],
168-
/*retTy=*/"::llvm::MutableArrayRef<::mlir::OpOperand>",
171+
/*retTy=*/"std::optional<::llvm::MutableArrayRef<::mlir::OpOperand>>",
169172
/*methodName=*/"getYieldedValuesMutable",
170173
/*args=*/(ins),
171174
/*methodBody=*/"",
172175
/*defaultImplementation=*/[{
173-
return {};
176+
return std::nullopt;
174177
}]
175178
>,
176179
InterfaceMethod<[{
@@ -217,7 +220,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
217220
/*defaultImplementation=*/[{
218221
return ::mlir::failure();
219222
}]
220-
>,
223+
>
221224
];
222225

223226
let extraClassDeclaration = [{
@@ -244,16 +247,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
244247
});
245248
}
246249

247-
/// Return the values that are yielded to the next iteration.
250+
/// Return the values that are yielded to the next iteration. If
251+
/// the loop doesnt yield any values return `{}`.
248252
::mlir::ValueRange getYieldedValues() {
249253
auto mutableValues = $_op.getYieldedValuesMutable();
250-
if (mutableValues.empty())
254+
if (!mutableValues || mutableValues->empty())
251255
return {};
252-
Operation *yieldOp = mutableValues.begin()->getOwner();
253-
unsigned firstOperandIndex = mutableValues.begin()->getOperandNumber();
256+
Operation *yieldOp = mutableValues->begin()->getOwner();
257+
unsigned firstOperandIndex = mutableValues->begin()->getOperandNumber();
254258
return OperandRange(
255259
yieldOp->operand_begin() + firstOperandIndex,
256-
yieldOp->operand_begin() + firstOperandIndex + mutableValues.size());
260+
yieldOp->operand_begin() + firstOperandIndex + mutableValues->size());
257261
}
258262

259263
/// Return the "init" operands that are used as initialization values for
@@ -318,14 +322,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
318322

319323
/// Return the yielded value that corresponds to the given region iter_arg.
320324
/// Return "nullptr" if the given block argument is not a region iter_arg
321-
/// of this loop op.
325+
/// of this loop op or if there is no yield corresponding to this `bbArg`.
322326
OpOperand *getTiedLoopYieldedValue(BlockArgument bbArg) {
323327
auto iterArgs = $_op.getRegionIterArgs();
324328
auto it = llvm::find(iterArgs, bbArg);
325329
if (it == iterArgs.end())
326330
return {};
327-
return
328-
&$_op.getYieldedValuesMutable()[std::distance(iterArgs.begin(), it)];
331+
std::optional<llvm::MutableArrayRef<::mlir::OpOperand>> yieldValues =
332+
$_op.getYieldedValuesMutable();
333+
if (!yieldValues)
334+
return {};
335+
return &yieldValues.value()[std::distance(iterArgs.begin(), it)];
329336
}
330337

331338
/// Return the loop result that corresponds to the given init operand.

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2127,7 +2127,8 @@ unsigned AffineForOp::getNumIterOperands() {
21272127
return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
21282128
}
21292129

2130-
MutableArrayRef<OpOperand> AffineForOp::getYieldedValuesMutable() {
2130+
std::optional<MutableArrayRef<OpOperand>>
2131+
AffineForOp::getYieldedValuesMutable() {
21312132
return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
21322133
}
21332134

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,8 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
489489
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
490490
[&](TilingInterface tilingInterfaceOp)
491491
-> FailureOr<scf::SCFTileAndFuseResult> {
492-
return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
493-
rewriter, tilingInterfaceOp, tileAndFuseOptions);
492+
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
493+
tileAndFuseOptions);
494494
});
495495
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
496496
: DiagnosedSilenceableFailure::success();
@@ -588,7 +588,7 @@ static Operation *replaceForAllWithNewSignature(
588588
Operation *firstYieldOp = yieldingOps.front();
589589
rewriter.setInsertionPoint(firstYieldOp);
590590
Value src = tileAndFuseResult.tiledValues[0];
591-
Value dst = newforallOp.getOutputBlockArguments().back();
591+
Value dst = newforallOp.getRegionIterArgs().back();
592592
SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
593593
rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
594594
dst, offsets, sizes, strides);
@@ -2067,7 +2067,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
20672067
});
20682068
SmallVector<int64_t> emptyTileSizes;
20692069
rewriter.setInsertionPoint(target);
2070-
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
2070+
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
20712071
rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
20722072
if (failed(maybeTilingResult))
20732073
return emitDefaultDefiniteFailure(target);
@@ -2651,7 +2651,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
26512651

26522652
tilingOptions.setInterchange(getInterchange());
26532653
FailureOr<scf::SCFTilingResult> maybeTilingResult =
2654-
tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions);
2654+
tileUsingSCF(rewriter, tilingInterface, tilingOptions);
26552655
if (failed(maybeTilingResult))
26562656
return DiagnosedSilenceableFailure::definiteFailure();
26572657

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
358358

359359
// 3. Clone the tileable op and update its destination operands to use the
360360
// output bbArgs of the ForallOp.
361-
ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
361+
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
362362
Operation *tiledOp = nullptr;
363363
SmallVector<Value> tiledValues;
364364
{
@@ -695,7 +695,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
695695
// 4. Clone the tileable op and update its destination operands to use the
696696
// output bbArgs of the ForallOp.
697697
SmallVector<Value> tilingResults;
698-
ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
698+
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
699699
{
700700
// 4.a. RAII guard, inserting within forallOp, before terminator.
701701
OpBuilder::InsertionGuard g(b);

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
523523

524524
SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
525525

526+
Block::BlockArgListType ForOp::getRegionIterArgs() {
527+
return getBody()->getArguments().drop_front(getNumInductionVars());
528+
}
529+
526530
MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
527531
return getInitArgsMutable();
528532
}
@@ -618,6 +622,14 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
618622
return success();
619623
}
620624

625+
Block::BlockArgListType ForallOp::getRegionIterArgs() {
626+
return getBody()->getArguments().drop_front(getRank());
627+
}
628+
629+
MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
630+
return getOutputsMutable();
631+
}
632+
621633
/// Promotes the loop body of a scf::ForallOp to its containing block.
622634
void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
623635
OpBuilder::InsertionGuard g(rewriter);
@@ -1092,7 +1104,7 @@ std::optional<APInt> ForOp::getConstantStep() {
10921104
return {};
10931105
}
10941106

1095-
MutableArrayRef<OpOperand> ForOp::getYieldedValuesMutable() {
1107+
std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
10961108
return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
10971109
}
10981110

@@ -1351,11 +1363,6 @@ void ForallOp::build(
13511363
return;
13521364
}
13531365
bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1354-
#ifndef NDEBUG
1355-
auto terminator = llvm::dyn_cast<InParallelOp>(bodyBlock.getTerminator());
1356-
assert(terminator &&
1357-
"expected bodyBuilderFn to create InParallelOp terminator");
1358-
#endif // NDEBUG
13591366
}
13601367

13611368
// Builder that takes loop bounds.
@@ -1626,9 +1633,8 @@ struct FoldTensorCastOfOutputIntoForallOp
16261633
// mapped to the tensor.cast old-typed results of the output bbArgs. The
16271634
// destination have to be updated to point to the output bbArgs directly.
16281635
auto terminator = newForallOp.getTerminator();
1629-
for (auto [yieldingOp, outputBlockArg] :
1630-
llvm::zip(terminator.getYieldingOps(),
1631-
newForallOp.getOutputBlockArguments())) {
1636+
for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1637+
terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
16321638
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
16331639
insertSliceOp.getDestMutable().assign(outputBlockArg);
16341640
}
@@ -3108,7 +3114,7 @@ YieldOp WhileOp::getYieldOp() {
31083114
return cast<YieldOp>(getAfterBody()->getTerminator());
31093115
}
31103116

3111-
MutableArrayRef<OpOperand> WhileOp::getYieldedValuesMutable() {
3117+
std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
31123118
return getYieldOp().getResultsMutable();
31133119
}
31143120

0 commit comments

Comments
 (0)