-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][TilingInterface] Use LoopLikeOpInterface
in tiling using SCF to unify tiling with scf.for
and scf.forall
.
#77874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
baaf648
2a55c1b
8b53308
edd6d90
e59bc5c
660da82
e4aa097
a54da70
67bc728
0eb979f
413192c
1634eff
46f591f
2a2583c
5d2f0e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,10 +135,11 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ | |
|
||
def ForOp : SCF_Op<"for", | ||
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface, | ||
["getInitsMutable", "getSingleInductionVar", "getSingleLowerBound", | ||
"getSingleStep", "getSingleUpperBound", "getYieldedValuesMutable", | ||
"getLoopResults", "promoteIfSingleIteration", | ||
"replaceWithAdditionalYields"]>, | ||
["getInitsMutable", "getLoopResults", "getRegionIterArgs", | ||
"getSingleInductionVar", "getSingleLowerBound", "getSingleStep", | ||
"getSingleUpperBound", "getYieldedValuesMutable", | ||
"promoteIfSingleIteration", "replaceWithAdditionalYields", | ||
"yieldTiledValuesAndReplace"]>, | ||
AllTypesMatch<["lowerBound", "upperBound", "step"]>, | ||
ConditionallySpeculatable, | ||
DeclareOpInterfaceMethods<RegionBranchOpInterface, | ||
|
@@ -259,10 +260,6 @@ def ForOp : SCF_Op<"for", | |
|
||
Value getInductionVar() { return getBody()->getArgument(0); } | ||
|
||
Block::BlockArgListType getRegionIterArgs() { | ||
return getBody()->getArguments().drop_front(getNumInductionVars()); | ||
} | ||
|
||
/// Return the `index`-th region iteration argument. | ||
BlockArgument getRegionIterArg(unsigned index) { | ||
assert(index < getNumRegionIterArgs() && | ||
|
@@ -304,8 +301,9 @@ def ForallOp : SCF_Op<"forall", [ | |
AttrSizedOperandSegments, | ||
AutomaticAllocationScope, | ||
DeclareOpInterfaceMethods<LoopLikeOpInterface, | ||
["promoteIfSingleIteration", "getSingleInductionVar", | ||
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>, | ||
["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar", | ||
MaheshRavishankar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep", | ||
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. drop |
||
RecursiveMemoryEffects, | ||
SingleBlockImplicitTerminator<"scf::InParallelOp">, | ||
DeclareOpInterfaceMethods<RegionBranchOpInterface>, | ||
|
@@ -585,10 +583,6 @@ def ForallOp : SCF_Op<"forall", [ | |
getNumDynamicControlOperands() + getRank()); | ||
} | ||
|
||
ArrayRef<BlockArgument> getOutputBlockArguments() { | ||
return getBody()->getArguments().drop_front(getRank()); | ||
} | ||
|
||
::mlir::ValueRange getInductionVars() { | ||
return getBody()->getArguments().take_front(getRank()); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
#include "mlir/Dialect/SCF/IR/SCF.h" | ||
#include "mlir/Dialect/Tensor/Transforms/Transforms.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Interfaces/LoopLikeInterface.h" | ||
#include "mlir/Interfaces/TilingInterface.h" | ||
|
||
#include <deque> | ||
|
@@ -52,6 +53,14 @@ struct SCFTilingOptions { | |
return *this; | ||
} | ||
|
||
/// Specify which loop construct to use for tile and fuse. | ||
enum class LoopType { ForOp, ForallOp }; | ||
LoopType loopType = LoopType::ForOp; | ||
SCFTilingOptions &setLoopType(LoopType type) { | ||
loopType = type; | ||
return *this; | ||
} | ||
|
||
/// Specify mapping of loops to devices. This is only respected when the loop | ||
/// constructs support such a mapping (like `scf.forall`). Will be ignored | ||
/// when using loop constructs that dont support such a mapping (like | ||
|
@@ -71,23 +80,17 @@ struct SCFTilingResult { | |
/// of the last op. | ||
SmallVector<Operation *> tiledOps; | ||
/// The `scf.for` operations that iterate over the tiles. | ||
SmallVector<Operation *> loops; | ||
SmallVector<LoopLikeOpInterface> loops; | ||
/// Values to use as replacements for the untiled op. Is the same size as the | ||
/// number of results of the untiled op. | ||
SmallVector<Value> replacements; | ||
}; | ||
|
||
/// Method to tile an op that implements the `TilingInterface` using | ||
/// `scf.for` for iterating over the tiles. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment is outdated: This function may geneate |
||
FailureOr<SCFTilingResult> tileUsingSCFForOp(RewriterBase &rewriter, | ||
TilingInterface op, | ||
const SCFTilingOptions &options); | ||
|
||
/// Method to tile an op that implements the `TilingInterface` using | ||
/// `scf.forall`. | ||
FailureOr<SCFTilingResult> | ||
tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op, | ||
const SCFTilingOptions &options); | ||
FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter, | ||
TilingInterface op, | ||
const SCFTilingOptions &options); | ||
|
||
/// Options used to control tile + fuse. | ||
struct SCFTileAndFuseOptions { | ||
|
@@ -135,7 +138,7 @@ struct SCFFuseProducerOfSliceResult { | |
std::optional<SCFFuseProducerOfSliceResult> | ||
tileAndFuseProducerOfSlice(RewriterBase &rewriter, | ||
tensor::ExtractSliceOp candidateSliceOp, | ||
MutableArrayRef<scf::ForOp> loops); | ||
MutableArrayRef<LoopLikeOpInterface> loops); | ||
|
||
/// Reconstruct the fused producer from within the tiled-and-fused code. Based | ||
/// on the slice of the producer computed in place it is possible that within | ||
|
@@ -187,10 +190,10 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter, | |
/// where `%0` had other uses as well. If not reconstructed from within the loop | ||
/// body, uses of `%0` could not be replaced, making it still live and the | ||
/// fusion immaterial. | ||
void yieldReplacementForFusedProducer( | ||
LogicalResult yieldReplacementForFusedProducer( | ||
MaheshRavishankar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, | ||
scf::SCFFuseProducerOfSliceResult fusedProducerInfo, | ||
MutableArrayRef<scf::ForOp> loops); | ||
MutableArrayRef<LoopLikeOpInterface> loops); | ||
|
||
/// Transformation information returned after tile and fuse. | ||
struct SCFTileAndFuseResult { | ||
|
@@ -201,7 +204,7 @@ struct SCFTileAndFuseResult { | |
/// generated operation. | ||
llvm::SetVector<Operation *> tiledAndFusedOps; | ||
/// The `scf.for` operations that iterate over the tiles. | ||
SmallVector<Operation *> loops; | ||
SmallVector<LoopLikeOpInterface> loops; | ||
/// The replacement values to use for the tiled and fused operations. | ||
llvm::DenseMap<Value, Value> replacements; | ||
}; | ||
|
@@ -232,9 +235,9 @@ struct SCFTileAndFuseResult { | |
/// } | ||
/// ``` | ||
FailureOr<SCFTileAndFuseResult> | ||
tileConsumerAndFuseProducerGreedilyUsingSCFForOp( | ||
RewriterBase &rewriter, TilingInterface consumer, | ||
const SCFTileAndFuseOptions &options); | ||
tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, | ||
TilingInterface consumer, | ||
const SCFTileAndFuseOptions &options); | ||
|
||
/// Method to lower an `op` that implements the `TilingInterface` to | ||
/// loops/scalars. | ||
|
@@ -249,8 +252,8 @@ struct SCFReductionTilingResult { | |
Operation *mergeOp; | ||
/// Initial op | ||
Operation *initialOp; | ||
/// The `scf.for` operations that iterate over the tiles. | ||
SmallVector<scf::ForOp> loops; | ||
/// The loop operations that iterate over the tiles. | ||
SmallVector<LoopLikeOpInterface> loops; | ||
}; | ||
|
||
/// Method to tile a reduction and generate a parallel op within a serial loop. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -489,8 +489,8 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, | |
tileSizes.size() - llvm::count(tileSizes, 0), transformResults, | ||
[&](TilingInterface tilingInterfaceOp) | ||
-> FailureOr<scf::SCFTileAndFuseResult> { | ||
return tileConsumerAndFuseProducerGreedilyUsingSCFForOp( | ||
rewriter, tilingInterfaceOp, tileAndFuseOptions); | ||
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: is the 'UsingSCF' part of the name still load-bearing now that we don't distinguish between There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, good question. This is in the SCF dialect... So I wanted to leave the "SCF" part. I dont have a strong preference, but this is living in SCF dialect. |
||
tileAndFuseOptions); | ||
}); | ||
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() | ||
: DiagnosedSilenceableFailure::success(); | ||
|
@@ -588,7 +588,7 @@ static Operation *replaceForAllWithNewSignature( | |
Operation *firstYieldOp = yieldingOps.front(); | ||
rewriter.setInsertionPoint(firstYieldOp); | ||
Value src = tileAndFuseResult.tiledValues[0]; | ||
Value dst = newforallOp.getOutputBlockArguments().back(); | ||
Value dst = newforallOp.getRegionIterArgs().back(); | ||
SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1)); | ||
rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src, | ||
dst, offsets, sizes, strides); | ||
|
@@ -2067,7 +2067,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter, | |
}); | ||
SmallVector<int64_t> emptyTileSizes; | ||
rewriter.setInsertionPoint(target); | ||
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp( | ||
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF( | ||
rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions); | ||
if (failed(maybeTilingResult)) | ||
return emitDefaultDefiniteFailure(target); | ||
|
@@ -2651,7 +2651,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter, | |
|
||
tilingOptions.setInterchange(getInterchange()); | ||
FailureOr<scf::SCFTilingResult> maybeTilingResult = | ||
tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions); | ||
tileUsingSCF(rewriter, tilingInterface, tilingOptions); | ||
if (failed(maybeTilingResult)) | ||
return DiagnosedSilenceableFailure::definiteFailure(); | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -523,6 +523,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { | |
|
||
SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; } | ||
|
||
Block::BlockArgListType ForOp::getRegionIterArgs() { | ||
return getBody()->getArguments().drop_front(getNumInductionVars()); | ||
} | ||
|
||
MutableArrayRef<OpOperand> ForOp::getInitsMutable() { | ||
return getInitArgsMutable(); | ||
} | ||
|
@@ -618,6 +622,14 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { | |
return success(); | ||
} | ||
|
||
Block::BlockArgListType ForallOp::getRegionIterArgs() { | ||
return getBody()->getArguments().drop_front(getRank()); | ||
} | ||
|
||
MutableArrayRef<OpOperand> ForallOp::getInitsMutable() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you are missing a This is what the interface documentation says:
This is checked by the op verifier, so you should be seeing verification failures. The problem is that this op does not have yielded values. We could have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On second thought, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Obviously that definition does not work for
I did. I fixed the verifier also, but I think as you note below,
We probably dont need that.. the region iter args can still be tied to init. If there is no yield value, then the verifier can handle it appropriately. |
||
return getOutputsMutable(); | ||
} | ||
|
||
/// Promotes the loop body of a scf::ForallOp to its containing block. | ||
void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { | ||
OpBuilder::InsertionGuard g(rewriter); | ||
|
@@ -1092,7 +1104,7 @@ std::optional<APInt> ForOp::getConstantStep() { | |
return {}; | ||
} | ||
|
||
MutableArrayRef<OpOperand> ForOp::getYieldedValuesMutable() { | ||
std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() { | ||
return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable(); | ||
} | ||
|
||
|
@@ -1351,11 +1363,6 @@ void ForallOp::build( | |
return; | ||
} | ||
bodyBuilderFn(b, result.location, bodyBlock.getArguments()); | ||
#ifndef NDEBUG | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was this dropped on purpose? |
||
auto terminator = llvm::dyn_cast<InParallelOp>(bodyBlock.getTerminator()); | ||
assert(terminator && | ||
"expected bodyBuilderFn to create InParallelOp terminator"); | ||
#endif // NDEBUG | ||
} | ||
|
||
// Builder that takes loop bounds. | ||
|
@@ -1626,9 +1633,8 @@ struct FoldTensorCastOfOutputIntoForallOp | |
// mapped to the tensor.cast old-typed results of the output bbArgs. The | ||
// destination have to be updated to point to the output bbArgs directly. | ||
auto terminator = newForallOp.getTerminator(); | ||
for (auto [yieldingOp, outputBlockArg] : | ||
llvm::zip(terminator.getYieldingOps(), | ||
newForallOp.getOutputBlockArguments())) { | ||
for (auto [yieldingOp, outputBlockArg] : llvm::zip( | ||
MaheshRavishankar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) { | ||
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp); | ||
insertSliceOp.getDestMutable().assign(outputBlockArg); | ||
} | ||
|
@@ -3108,7 +3114,7 @@ YieldOp WhileOp::getYieldOp() { | |
return cast<YieldOp>(getAfterBody()->getTerminator()); | ||
} | ||
|
||
MutableArrayRef<OpOperand> WhileOp::getYieldedValuesMutable() { | ||
std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() { | ||
return getYieldOp().getResultsMutable(); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
drop
yieldTiledValuesAndReplace