Skip to content

[mlir][TilingInterface] Use LoopLikeOpInterface in tiling using SCF to unify tiling with scf.for and scf.forall. #77874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1935,7 +1935,7 @@ mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) {
return {};
}

llvm::MutableArrayRef<mlir::OpOperand>
std::optional<llvm::MutableArrayRef<mlir::OpOperand>>
fir::IterWhileOp::getYieldedValuesMutable() {
auto *term = getRegion().front().getTerminator();
return getFinalValue() ? term->getOpOperands().drop_front()
Expand Down Expand Up @@ -2247,7 +2247,7 @@ mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) {
return {};
}

llvm::MutableArrayRef<mlir::OpOperand>
std::optional<llvm::MutableArrayRef<mlir::OpOperand>>
fir::DoLoopOp::getYieldedValuesMutable() {
auto *term = getRegion().front().getTerminator();
return getFinalValue() ? term->getOpOperands().drop_front()
Expand Down
22 changes: 8 additions & 14 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,11 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [

def ForOp : SCF_Op<"for",
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getSingleInductionVar", "getSingleLowerBound",
"getSingleStep", "getSingleUpperBound", "getYieldedValuesMutable",
"getLoopResults", "promoteIfSingleIteration",
"replaceWithAdditionalYields"]>,
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
"getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
"getSingleUpperBound", "getYieldedValuesMutable",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
"yieldTiledValuesAndReplace"]>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop yieldTiledValuesAndReplace

AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
Expand Down Expand Up @@ -259,10 +260,6 @@ def ForOp : SCF_Op<"for",

Value getInductionVar() { return getBody()->getArgument(0); }

Block::BlockArgListType getRegionIterArgs() {
return getBody()->getArguments().drop_front(getNumInductionVars());
}

/// Return the `index`-th region iteration argument.
BlockArgument getRegionIterArg(unsigned index) {
assert(index < getNumRegionIterArgs() &&
Expand Down Expand Up @@ -304,8 +301,9 @@ def ForallOp : SCF_Op<"forall", [
AttrSizedOperandSegments,
AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["promoteIfSingleIteration", "getSingleInductionVar",
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar",
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep",
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop yieldTiledValuesAndReplace

RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"scf::InParallelOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
Expand Down Expand Up @@ -585,10 +583,6 @@ def ForallOp : SCF_Op<"forall", [
getNumDynamicControlOperands() + getRank());
}

ArrayRef<BlockArgument> getOutputBlockArguments() {
return getBody()->getArguments().drop_front(getRank());
}

::mlir::ValueRange getInductionVars() {
return getBody()->getArguments().take_front(getRank());
}
Expand Down
41 changes: 22 additions & 19 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/TilingInterface.h"

#include <deque>
Expand Down Expand Up @@ -52,6 +53,14 @@ struct SCFTilingOptions {
return *this;
}

/// Specify which loop construct to use for tile and fuse.
enum class LoopType { ForOp, ForallOp };
LoopType loopType = LoopType::ForOp;
SCFTilingOptions &setLoopType(LoopType type) {
loopType = type;
return *this;
}

/// Specify mapping of loops to devices. This is only respected when the loop
/// constructs support such a mapping (like `scf.forall`). Will be ignored
/// when using loop constructs that dont support such a mapping (like
Expand All @@ -71,23 +80,17 @@ struct SCFTilingResult {
/// of the last op.
SmallVector<Operation *> tiledOps;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<Operation *> loops;
SmallVector<LoopLikeOpInterface> loops;
/// Values to use as replacements for the untiled op. Is the same size as the
/// number of results of the untiled op.
SmallVector<Value> replacements;
};

/// Method to tile an op that implements the `TilingInterface` using
/// `scf.for` for iterating over the tiles.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is outdated: This function may geneate scf.for or scf.forall.

FailureOr<SCFTilingResult> tileUsingSCFForOp(RewriterBase &rewriter,
TilingInterface op,
const SCFTilingOptions &options);

/// Method to tile an op that implements the `TilingInterface` using
/// `scf.forall`.
FailureOr<SCFTilingResult>
tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
const SCFTilingOptions &options);
FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
TilingInterface op,
const SCFTilingOptions &options);

/// Options used to control tile + fuse.
struct SCFTileAndFuseOptions {
Expand Down Expand Up @@ -135,7 +138,7 @@ struct SCFFuseProducerOfSliceResult {
std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
tensor::ExtractSliceOp candidateSliceOp,
MutableArrayRef<scf::ForOp> loops);
MutableArrayRef<LoopLikeOpInterface> loops);

/// Reconstruct the fused producer from within the tiled-and-fused code. Based
/// on the slice of the producer computed in place it is possible that within
Expand Down Expand Up @@ -187,10 +190,10 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
/// where `%0` had other uses as well. If not reconstructed from within the loop
/// body, uses of `%0` could not be replaced, making it still live and the
/// fusion immaterial.
void yieldReplacementForFusedProducer(
LogicalResult yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<scf::ForOp> loops);
MutableArrayRef<LoopLikeOpInterface> loops);

/// Transformation information returned after tile and fuse.
struct SCFTileAndFuseResult {
Expand All @@ -201,7 +204,7 @@ struct SCFTileAndFuseResult {
/// generated operation.
llvm::SetVector<Operation *> tiledAndFusedOps;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<Operation *> loops;
SmallVector<LoopLikeOpInterface> loops;
/// The replacement values to use for the tiled and fused operations.
llvm::DenseMap<Value, Value> replacements;
};
Expand Down Expand Up @@ -232,9 +235,9 @@ struct SCFTileAndFuseResult {
/// }
/// ```
FailureOr<SCFTileAndFuseResult>
tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
RewriterBase &rewriter, TilingInterface consumer,
const SCFTileAndFuseOptions &options);
tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
TilingInterface consumer,
const SCFTileAndFuseOptions &options);

/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
Expand All @@ -249,8 +252,8 @@ struct SCFReductionTilingResult {
Operation *mergeOp;
/// Initial op
Operation *initialOp;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<scf::ForOp> loops;
/// The loop operations that iterate over the tiles.
SmallVector<LoopLikeOpInterface> loops;
};

/// Method to tile a reduction and generate a parallel op within a serial loop.
Expand Down
29 changes: 18 additions & 11 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,16 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
InterfaceMethod<[{
Return the mutable operand range of values that are yielded to the next
iteration by the loop terminator.

For loop operations that dont yield a value, this should return
std::nullopt.
}],
/*retTy=*/"::llvm::MutableArrayRef<::mlir::OpOperand>",
/*retTy=*/"std::optional<::llvm::MutableArrayRef<::mlir::OpOperand>>",
/*methodName=*/"getYieldedValuesMutable",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return {};
return std::nullopt;
}]
>,
InterfaceMethod<[{
Expand Down Expand Up @@ -217,7 +220,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*defaultImplementation=*/[{
return ::mlir::failure();
}]
>,
>
];

let extraClassDeclaration = [{
Expand All @@ -244,16 +247,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
});
}

/// Return the values that are yielded to the next iteration.
/// Return the values that are yielded to the next iteration. If
/// the loop doesnt yield any values return `{}`.
::mlir::ValueRange getYieldedValues() {
auto mutableValues = $_op.getYieldedValuesMutable();
if (mutableValues.empty())
if (!mutableValues || mutableValues->empty())
return {};
Operation *yieldOp = mutableValues.begin()->getOwner();
unsigned firstOperandIndex = mutableValues.begin()->getOperandNumber();
Operation *yieldOp = mutableValues->begin()->getOwner();
unsigned firstOperandIndex = mutableValues->begin()->getOperandNumber();
return OperandRange(
yieldOp->operand_begin() + firstOperandIndex,
yieldOp->operand_begin() + firstOperandIndex + mutableValues.size());
yieldOp->operand_begin() + firstOperandIndex + mutableValues->size());
}

/// Return the "init" operands that are used as initialization values for
Expand Down Expand Up @@ -318,14 +322,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {

/// Return the yielded value that corresponds to the given region iter_arg.
/// Return "nullptr" if the given block argument is not a region iter_arg
/// of this loop op.
/// of this loop op or if there is no yield corresponding to this `bbArg`.
OpOperand *getTiedLoopYieldedValue(BlockArgument bbArg) {
auto iterArgs = $_op.getRegionIterArgs();
auto it = llvm::find(iterArgs, bbArg);
if (it == iterArgs.end())
return {};
return
&$_op.getYieldedValuesMutable()[std::distance(iterArgs.begin(), it)];
std::optional<llvm::MutableArrayRef<::mlir::OpOperand>> yieldValues =
$_op.getYieldedValuesMutable();
if (!yieldValues)
return {};
return &yieldValues.value()[std::distance(iterArgs.begin(), it)];
}

/// Return the loop result that corresponds to the given init operand.
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2127,7 +2127,8 @@ unsigned AffineForOp::getNumIterOperands() {
return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
}

MutableArrayRef<OpOperand> AffineForOp::getYieldedValuesMutable() {
std::optional<MutableArrayRef<OpOperand>>
AffineForOp::getYieldedValuesMutable() {
return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
}

Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,8 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
[&](TilingInterface tilingInterfaceOp)
-> FailureOr<scf::SCFTileAndFuseResult> {
return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
rewriter, tilingInterfaceOp, tileAndFuseOptions);
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is the 'UsingSCF' part of the name still load-bearing now that we don't distinguish between scf.for and scf.forall anymore ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, good question. This is in the SCF dialect... So I wanted to leave the "SCF" part. I dont have a strong preference, but this is living in SCF dialect.

tileAndFuseOptions);
});
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
Expand Down Expand Up @@ -588,7 +588,7 @@ static Operation *replaceForAllWithNewSignature(
Operation *firstYieldOp = yieldingOps.front();
rewriter.setInsertionPoint(firstYieldOp);
Value src = tileAndFuseResult.tiledValues[0];
Value dst = newforallOp.getOutputBlockArguments().back();
Value dst = newforallOp.getRegionIterArgs().back();
SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
dst, offsets, sizes, strides);
Expand Down Expand Up @@ -2067,7 +2067,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
});
SmallVector<int64_t> emptyTileSizes;
rewriter.setInsertionPoint(target);
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
if (failed(maybeTilingResult))
return emitDefaultDefiniteFailure(target);
Expand Down Expand Up @@ -2651,7 +2651,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,

tilingOptions.setInterchange(getInterchange());
FailureOr<scf::SCFTilingResult> maybeTilingResult =
tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions);
tileUsingSCF(rewriter, tilingInterface, tilingOptions);
if (failed(maybeTilingResult))
return DiagnosedSilenceableFailure::definiteFailure();

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(

// 3. Clone the tileable op and update its destination operands to use the
// output bbArgs of the ForallOp.
ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
Operation *tiledOp = nullptr;
SmallVector<Value> tiledValues;
{
Expand Down Expand Up @@ -695,7 +695,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
// 4. Clone the tileable op and update its destination operands to use the
// output bbArgs of the ForallOp.
SmallVector<Value> tilingResults;
ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
{
// 4.a. RAII guard, inserting within forallOp, before terminator.
OpBuilder::InsertionGuard g(b);
Expand Down
26 changes: 16 additions & 10 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {

SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }

Block::BlockArgListType ForOp::getRegionIterArgs() {
return getBody()->getArguments().drop_front(getNumInductionVars());
}

MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
return getInitArgsMutable();
}
Expand Down Expand Up @@ -618,6 +622,14 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
return success();
}

Block::BlockArgListType ForallOp::getRegionIterArgs() {
return getBody()->getArguments().drop_front(getRank());
}

MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
Copy link
Member

@matthias-springer matthias-springer Jan 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are missing a getYieldedValuesMutable implementation for ForallOp.

This is what the interface documentation says:

    Loop-carried variables can be exposed through this interface. There are
    3 components to a loop-carried variable.
    - The "region iter_arg" is the block argument of the entry block that
      represents the loop-carried variable in each iteration.
    - The "init value" is an operand of the loop op that serves as the initial
      region iter_arg value for the first iteration (if any).
    - The "yielded" value is the value that is forwarded from one iteration to
      serve as the region iter_arg of the next iteration.

    If one of the respective interface methods is implemented, so must the other
    two. The interface verifier ensures that the number of types of the region
    iter_args, init values and yielded values match.

This is checked by the op verifier, so you should be seeing verification failures.

The problem is that this op does not have yielded values.

We could have getYieldedValuesMutable return a std::optional<::llvm::MutableArrayRef<::mlir::OpOperand>> to account for the fact that some loops do not have yielding semantics. (Same as we do for getLoopResults.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, scf.forall should not implement getRegionIterArgs at all. The region iter_args in the loop like op interface are loop-carried variables, but the scf.forall does not have any loop-carried variables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are missing a getYieldedValuesMutable implementation for ForallOp.

This is what the interface documentation says:

    Loop-carried variables can be exposed through this interface. There are
    3 components to a loop-carried variable.
    - The "region iter_arg" is the block argument of the entry block that
      represents the loop-carried variable in each iteration.
    - The "init value" is an operand of the loop op that serves as the initial
      region iter_arg value for the first iteration (if any).
    - The "yielded" value is the value that is forwarded from one iteration to
      serve as the region iter_arg of the next iteration.

    If one of the respective interface methods is implemented, so must the other
    two. The interface verifier ensures that the number of types of the region
    iter_args, init values and yielded values match.

Obviously that definition does not work for scf.forall. Might need to update this since the scf.forall does not yield any value. The region iter arg and init value can stay though.

This is checked by the op verifier, so you should be seeing verification failures.

I did. I fixed the verifier also, but I think as you note below, getYieldedValuesMutable should return an std::optional<..>

On second thought, scf.forall should not implement getRegionIterArgs at all. The region iter_args in the loop like op interface are loop-carried variables, but the scf.forall does not have any loop-carried variables.

We probably dont need that.. the region iter args can still be tied to init. If there is no yield value, then the verifier can handle it appropriately.

return getOutputsMutable();
}

/// Promotes the loop body of a scf::ForallOp to its containing block.
void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
OpBuilder::InsertionGuard g(rewriter);
Expand Down Expand Up @@ -1092,7 +1104,7 @@ std::optional<APInt> ForOp::getConstantStep() {
return {};
}

MutableArrayRef<OpOperand> ForOp::getYieldedValuesMutable() {
std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
}

Expand Down Expand Up @@ -1351,11 +1363,6 @@ void ForallOp::build(
return;
}
bodyBuilderFn(b, result.location, bodyBlock.getArguments());
#ifndef NDEBUG
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this dropped on purpose?

auto terminator = llvm::dyn_cast<InParallelOp>(bodyBlock.getTerminator());
assert(terminator &&
"expected bodyBuilderFn to create InParallelOp terminator");
#endif // NDEBUG
}

// Builder that takes loop bounds.
Expand Down Expand Up @@ -1626,9 +1633,8 @@ struct FoldTensorCastOfOutputIntoForallOp
// mapped to the tensor.cast old-typed results of the output bbArgs. The
// destination have to be updated to point to the output bbArgs directly.
auto terminator = newForallOp.getTerminator();
for (auto [yieldingOp, outputBlockArg] :
llvm::zip(terminator.getYieldingOps(),
newForallOp.getOutputBlockArguments())) {
for (auto [yieldingOp, outputBlockArg] : llvm::zip(
terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
insertSliceOp.getDestMutable().assign(outputBlockArg);
}
Expand Down Expand Up @@ -3108,7 +3114,7 @@ YieldOp WhileOp::getYieldOp() {
return cast<YieldOp>(getAfterBody()->getTerminator());
}

MutableArrayRef<OpOperand> WhileOp::getYieldedValuesMutable() {
std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
return getYieldOp().getResultsMutable();
}

Expand Down
Loading