Skip to content

[mlir][IR] Change MutableArrayRange to enumerate OpOperand & #66622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,10 @@ def ForOp : SCF_Op<"for",
"expected an index less than the number of region iter args");
return getBody()->getArguments().drop_front(getNumInductionVars())[index];
}
MutableArrayRef<OpOperand> getIterOpOperands() {
return
getOperation()->getOpOperands().drop_front(getNumControlOperands());
}

void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
void setStep(Value step) { getOperation()->setOperand(2, step); }
void setIterArg(unsigned iterArgNum, Value iterArgValue) {
getOperation()->setOperand(iterArgNum + getNumControlOperands(), iterArgValue);
}

/// Number of induction variables, always 1 for scf::ForOp.
unsigned getNumInductionVars() { return 1; }
Expand Down
10 changes: 3 additions & 7 deletions mlir/include/mlir/IR/ValueRange.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,9 @@ class MutableOperandRange {
/// Returns the OpOperand at the given index.
OpOperand &operator[](unsigned index) const;

OperandRange::iterator begin() const {
return static_cast<OperandRange>(*this).begin();
}

OperandRange::iterator end() const {
return static_cast<OperandRange>(*this).end();
}
/// Iterators enumerate OpOperands.
MutableArrayRef<OpOperand>::iterator begin() const;
MutableArrayRef<OpOperand>::iterator end() const;

private:
/// Update the length of this range to the one provided.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {

static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }

static bool isMemrefOperand(OpOperand &operand) {
return isMemref(operand.get());
}

//===----------------------------------------------------------------------===//
// Backedges analysis
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -937,7 +941,7 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {

// Add an additional operand for every MemRef for the ownership indicator.
if (!funcWithoutDynamicOwnership) {
unsigned numMemRefs = llvm::count_if(operands, isMemref);
unsigned numMemRefs = llvm::count_if(operands, isMemrefOperand);
SmallVector<Value> newOperands{OperandRange(operands)};
auto ownershipValues =
deallocOp.getUpdatedConditions().take_front(numMemRefs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ struct CondBranchOpInterface
mapping[retained] = ownership;
}
SmallVector<Value> replacements, ownerships;
for (Value operand : destOperands) {
replacements.push_back(operand);
if (isMemref(operand)) {
assert(mapping.contains(operand) &&
for (OpOperand &operand : destOperands) {
replacements.push_back(operand.get());
if (isMemref(operand.get())) {
assert(mapping.contains(operand.get()) &&
"Should be contained at this point");
ownerships.push_back(mapping[operand]);
ownerships.push_back(mapping[operand.get()]);
}
}
replacements.append(ownerships);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
assert(operand.get().getType() != replacement.getType() &&
"Expected a different type");
SmallVector<Value> newIterOperands;
for (OpOperand &opOperand : forOp.getIterOpOperands()) {
for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
newIterOperands.push_back(replacement);
continue;
Expand Down Expand Up @@ -1015,7 +1015,7 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {

LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
OpOperand &iterOpOperand = std::get<0>(it);
auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
if (!incomingCast ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
/// Helper function for loop bufferization. Return the bufferized values of the
/// given OpOperands. If an operand is not a tensor, return the original value.
static FailureOr<SmallVector<Value>>
getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
getBuffers(RewriterBase &rewriter, MutableOperandRange operands,
const BufferizationOptions &options) {
SmallVector<Value> result;
for (OpOperand &opOperand : operands) {
Expand Down Expand Up @@ -606,7 +606,7 @@ struct ForOpInterface

// The new memref init_args of the loop.
FailureOr<SmallVector<Value>> maybeInitArgs =
getBuffers(rewriter, forOp.getIterOpOperands(), options);
getBuffers(rewriter, forOp.getInitArgsMutable(), options);
if (failed(maybeInitArgs))
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;
Expand Down Expand Up @@ -825,7 +825,7 @@ struct WhileOpInterface

// The new memref init_args of the loop.
FailureOr<SmallVector<Value>> maybeInitArgs =
getBuffers(rewriter, whileOp->getOpOperands(), options);
getBuffers(rewriter, whileOp.getInitsMutable(), options);
if (failed(maybeInitArgs))
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;
Expand Down
16 changes: 7 additions & 9 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
MutableArrayRef<scf::ForOp> loops) {
// 1. Get the producer of the source (potentially walking through
// `iter_args` of nested `scf.for`)
auto [fusableProducer, destinationIterArg] =
auto [fusableProducer, destinationInitArg] =
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
loops);
if (!fusableProducer)
Expand Down Expand Up @@ -575,17 +575,15 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
// TODO: This can be modeled better if the `DestinationStyleOpInterface`.
// Update to use that when it does become available.
scf::ForOp outerMostLoop = loops.front();
std::optional<unsigned> iterArgNumber;
if (destinationIterArg) {
iterArgNumber =
outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
}
if (iterArgNumber) {
if (destinationInitArg &&
(*destinationInitArg)->getOwner() == outerMostLoop) {
std::optional<unsigned> iterArgNumber =
outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg);
int64_t resultNumber = fusableProducer.getResultNumber();
if (auto dstOp =
dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
outerMostLoop.setIterArg(iterArgNumber.value(),
dstOp.getTiedOpOperand(fusableProducer)->get());
(*destinationInitArg)
->set(dstOp.getTiedOpOperand(fusableProducer)->get());
}
for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/IR/OperationSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,14 @@ OpOperand &MutableOperandRange::operator[](unsigned index) const {
return owner->getOpOperand(start + index);
}

MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
return owner->getOpOperands().slice(start, length).begin();
}

MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
return owner->getOpOperands().slice(start, length).end();
}

//===----------------------------------------------------------------------===//
// MutableOperandRangeRange

Expand Down
28 changes: 20 additions & 8 deletions mlir/lib/Transforms/Utils/CFGToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ getMutableSuccessorOperands(Block *block, unsigned successorIndex) {
return succOps.getMutableForwardedOperands();
}

/// Return the operand range used to transfer operands from `block` to its
/// successor with the given index.
static OperandRange getSuccessorOperands(Block *block,
unsigned successorIndex) {
return getMutableSuccessorOperands(block, successorIndex);
}

/// Appends all the block arguments from `other` to the block arguments of
/// `block`, copying their types and locations.
static void addBlockArgumentsFromOther(Block *block, Block *other) {
Expand Down Expand Up @@ -175,8 +182,14 @@ class Edge {

/// Returns the arguments of this edge that are passed to the block arguments
/// of the successor.
MutableOperandRange getSuccessorOperands() const {
return getMutableSuccessorOperands(fromBlock, successorIndex);
MutableOperandRange getMutableSuccessorOperands() const {
return ::getMutableSuccessorOperands(fromBlock, successorIndex);
}

/// Returns the arguments of this edge that are passed to the block arguments
/// of the successor.
OperandRange getSuccessorOperands() const {
return ::getSuccessorOperands(fromBlock, successorIndex);
}
};

Expand Down Expand Up @@ -262,7 +275,7 @@ class EdgeMultiplexer {
assert(result != blockArgMapping.end() &&
"Edge was not originally passed to `create` method.");

MutableOperandRange successorOperands = edge.getSuccessorOperands();
MutableOperandRange successorOperands = edge.getMutableSuccessorOperands();

// Extra arguments are always appended at the end of the block arguments.
unsigned extraArgsBeginIndex =
Expand Down Expand Up @@ -666,7 +679,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
// invalidated when mutating the operands through a different
// `MutableOperandRange` of the same operation.
SmallVector<Value> loopHeaderSuccessorOperands =
llvm::to_vector(getMutableSuccessorOperands(latch, loopHeaderIndex));
llvm::to_vector(getSuccessorOperands(latch, loopHeaderIndex));

// Add all values used in the next iteration to the exit block. Replace
// any uses that are outside the loop with the newly created exit block.
Expand Down Expand Up @@ -742,7 +755,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,

loopHeaderSuccessorOperands.push_back(argument);
for (Edge edge : successorEdges(latch))
edge.getSuccessorOperands().append(argument);
edge.getMutableSuccessorOperands().append(argument);
}

use.set(blockArgument);
Expand Down Expand Up @@ -939,9 +952,8 @@ static FailureOr<SmallVector<Block *>> transformToStructuredCFBranches(
if (regionEntry->getNumSuccessors() == 1) {
// Single successor we can just splice together.
Block *successor = regionEntry->getSuccessor(0);
for (auto &&[oldValue, newValue] :
llvm::zip(successor->getArguments(),
getMutableSuccessorOperands(regionEntry, 0)))
for (auto &&[oldValue, newValue] : llvm::zip(
successor->getArguments(), getSuccessorOperands(regionEntry, 0)))
oldValue.replaceAllUsesWith(newValue);
regionEntry->getTerminator()->erase();

Expand Down