Skip to content

Commit 5fcf907

Browse files
[mlir][IR] Rename "update root" to "modify op" in rewriter API (#78260)
This commit renames 4 pattern rewriter API functions: * `updateRootInPlace` -> `modifyOpInPlace` * `startRootUpdate` -> `startOpModification` * `finalizeRootUpdate` -> `finalizeOpModification` * `cancelRootUpdate` -> `cancelOpModification` The term "root" is a misnomer. The root is the op that a rewrite pattern matches against (https://mlir.llvm.org/docs/PatternRewriter/#root-operation-name-optional). A rewriter must be notified of all in-place op modifications, not just in-place modifications of the root (https://mlir.llvm.org/docs/PatternRewriter/#pattern-rewriter). The old function names were confusing and have contributed to various broken rewrite patterns. Note: The new function names use the term "modify" instead of "update" for consistency with the `RewriterBase::Listener` terminology (`notifyOperationModified`).
1 parent 57b50ef commit 5fcf907

File tree

78 files changed

+246
-246
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+246
-246
lines changed

flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -215,14 +215,14 @@ class BoxedProcedurePass
215215
rewriter.replaceOpWithNewOp<ConvertOp>(
216216
addr, typeConverter.convertType(addr.getType()), addr.getVal());
217217
} else if (typeConverter.needsConversion(resTy)) {
218-
rewriter.startRootUpdate(op);
218+
rewriter.startOpModification(op);
219219
op->getResult(0).setType(typeConverter.convertType(resTy));
220-
rewriter.finalizeRootUpdate(op);
220+
rewriter.finalizeOpModification(op);
221221
}
222222
} else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
223223
mlir::FunctionType ty = func.getFunctionType();
224224
if (typeConverter.needsConversion(ty)) {
225-
rewriter.startRootUpdate(func);
225+
rewriter.startOpModification(func);
226226
auto toTy =
227227
typeConverter.convertType(ty).cast<mlir::FunctionType>();
228228
if (!func.empty())
@@ -235,7 +235,7 @@ class BoxedProcedurePass
235235
block.eraseArgument(i + 1);
236236
}
237237
func.setType(toTy);
238-
rewriter.finalizeRootUpdate(func);
238+
rewriter.finalizeOpModification(func);
239239
}
240240
} else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
241241
// Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
@@ -273,10 +273,10 @@ class BoxedProcedurePass
273273
} else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
274274
auto ty = global.getType();
275275
if (typeConverter.needsConversion(ty)) {
276-
rewriter.startRootUpdate(global);
276+
rewriter.startOpModification(global);
277277
auto toTy = typeConverter.convertType(ty);
278278
global.setType(toTy);
279-
rewriter.finalizeRootUpdate(global);
279+
rewriter.finalizeOpModification(global);
280280
}
281281
} else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
282282
auto ty = mem.getType();
@@ -339,17 +339,17 @@ class BoxedProcedurePass
339339
mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
340340
}
341341
} else if (op->getDialect() == firDialect) {
342-
rewriter.startRootUpdate(op);
342+
rewriter.startOpModification(op);
343343
for (auto i : llvm::enumerate(op->getResultTypes()))
344344
if (typeConverter.needsConversion(i.value())) {
345345
auto toTy = typeConverter.convertType(i.value());
346346
op->getResult(i.index()).setType(toTy);
347347
}
348-
rewriter.finalizeRootUpdate(op);
348+
rewriter.finalizeOpModification(op);
349349
}
350350
// Ensure block arguments are updated if needed.
351351
if (op->getNumRegions() != 0) {
352-
rewriter.startRootUpdate(op);
352+
rewriter.startOpModification(op);
353353
for (mlir::Region &region : op->getRegions())
354354
for (mlir::Block &block : region.getBlocks())
355355
for (mlir::BlockArgument blockArg : block.getArguments())
@@ -358,7 +358,7 @@ class BoxedProcedurePass
358358
typeConverter.convertType(blockArg.getType());
359359
blockArg.setType(toTy);
360360
}
361-
rewriter.finalizeRootUpdate(op);
361+
rewriter.finalizeOpModification(op);
362362
}
363363
});
364364
}

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3763,13 +3763,13 @@ class RenameMSVCLibmCallees
37633763
mlir::LogicalResult
37643764
matchAndRewrite(mlir::LLVM::CallOp op,
37653765
mlir::PatternRewriter &rewriter) const override {
3766-
rewriter.startRootUpdate(op);
3766+
rewriter.startOpModification(op);
37673767
auto callee = op.getCallee();
37683768
if (callee)
37693769
if (callee->equals("hypotf"))
37703770
op.setCalleeAttr(mlir::SymbolRefAttr::get(op.getContext(), "_hypotf"));
37713771

3772-
rewriter.finalizeRootUpdate(op);
3772+
rewriter.finalizeOpModification(op);
37733773
return mlir::success();
37743774
}
37753775
};
@@ -3782,10 +3782,10 @@ class RenameMSVCLibmFuncs
37823782
mlir::LogicalResult
37833783
matchAndRewrite(mlir::LLVM::LLVMFuncOp op,
37843784
mlir::PatternRewriter &rewriter) const override {
3785-
rewriter.startRootUpdate(op);
3785+
rewriter.startOpModification(op);
37863786
if (op.getSymName().equals("hypotf"))
37873787
op.setSymNameAttr(rewriter.getStringAttr("_hypotf"));
3788-
rewriter.finalizeRootUpdate(op);
3788+
rewriter.finalizeOpModification(op);
37893789
return mlir::success();
37903790
}
37913791
};

flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,9 @@ struct AssignOpConversion : public mlir::OpConversionPattern<hlfir::AssignOp> {
256256
llvm::SmallVector<mlir::Value> newOperands;
257257
for (mlir::Value operand : adaptor.getOperands())
258258
newOperands.push_back(getBufferizedExprStorage(operand));
259-
rewriter.startRootUpdate(assign);
259+
rewriter.startOpModification(assign);
260260
assign->setOperands(newOperands);
261-
rewriter.finalizeRootUpdate(assign);
261+
rewriter.finalizeOpModification(assign);
262262
return mlir::success();
263263
}
264264
};
@@ -834,9 +834,9 @@ struct ElementalOpConversion
834834
// Explicitly delete the body of the elemental to get rid
835835
// of any users of hlfir.expr values inside the body as early
836836
// as possible.
837-
rewriter.startRootUpdate(elemental);
837+
rewriter.startOpModification(elemental);
838838
rewriter.eraseBlock(elemental.getBody());
839-
rewriter.finalizeRootUpdate(elemental);
839+
rewriter.finalizeOpModification(elemental);
840840
rewriter.replaceOp(elemental, bufferizedExpr);
841841
return mlir::success();
842842
}

flang/lib/Optimizer/Transforms/AffineDemotion.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> {
114114
op.getValue());
115115
return success();
116116
}
117-
rewriter.startRootUpdate(op->getParentOp());
117+
rewriter.startOpModification(op->getParentOp());
118118
op.getResult().replaceAllUsesWith(op.getValue());
119-
rewriter.finalizeRootUpdate(op->getParentOp());
119+
rewriter.finalizeOpModification(op->getParentOp());
120120
rewriter.eraseOp(op);
121121
}
122122
return success();

flang/lib/Optimizer/Transforms/AffinePromotion.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -464,15 +464,15 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
464464
auto affineFor = loopAndIndex.first;
465465
auto inductionVar = loopAndIndex.second;
466466

467-
rewriter.startRootUpdate(affineFor.getOperation());
467+
rewriter.startOpModification(affineFor.getOperation());
468468
affineFor.getBody()->getOperations().splice(
469469
std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
470470
std::prev(loopOps.end()));
471-
rewriter.finalizeRootUpdate(affineFor.getOperation());
471+
rewriter.finalizeOpModification(affineFor.getOperation());
472472

473-
rewriter.startRootUpdate(loop.getOperation());
473+
rewriter.startOpModification(loop.getOperation());
474474
loop.getInductionVar().replaceAllUsesWith(inductionVar);
475-
rewriter.finalizeRootUpdate(loop.getOperation());
475+
rewriter.finalizeOpModification(loop.getOperation());
476476

477477
rewriteMemoryOps(affineFor.getBody(), rewriter);
478478

@@ -561,7 +561,7 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
561561
auto affineIf = rewriter.create<affine::AffineIfOp>(
562562
op.getLoc(), affineCondition.getIntegerSet(),
563563
affineCondition.getAffineArgs(), !op.getElseRegion().empty());
564-
rewriter.startRootUpdate(affineIf);
564+
rewriter.startOpModification(affineIf);
565565
affineIf.getThenBlock()->getOperations().splice(
566566
std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
567567
std::prev(ifOps.end()));
@@ -571,7 +571,7 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
571571
std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
572572
std::prev(otherOps.end()));
573573
}
574-
rewriter.finalizeRootUpdate(affineIf);
574+
rewriter.finalizeOpModification(affineIf);
575575
rewriteMemoryOps(affineIf.getBody(), rewriter);
576576

577577
LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";

flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ struct MangleNameOnFuncOp : public mlir::OpRewritePattern<mlir::func::FuncOp> {
7676
matchAndRewrite(mlir::func::FuncOp op,
7777
mlir::PatternRewriter &rewriter) const override {
7878
mlir::LogicalResult ret = success();
79-
rewriter.startRootUpdate(op);
79+
rewriter.startOpModification(op);
8080
llvm::StringRef oldName = op.getSymName();
8181
auto result = fir::NameUniquer::deconstruct(oldName);
8282
if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
@@ -95,7 +95,7 @@ struct MangleNameOnFuncOp : public mlir::OpRewritePattern<mlir::func::FuncOp> {
9595
}
9696

9797
updateEarlyOutliningParentName(op, appendUnderscore);
98-
rewriter.finalizeRootUpdate(op);
98+
rewriter.finalizeOpModification(op);
9999
return ret;
100100
}
101101

@@ -114,15 +114,15 @@ struct MangleNameForCommonBlock : public mlir::OpRewritePattern<fir::GlobalOp> {
114114
mlir::LogicalResult
115115
matchAndRewrite(fir::GlobalOp op,
116116
mlir::PatternRewriter &rewriter) const override {
117-
rewriter.startRootUpdate(op);
117+
rewriter.startOpModification(op);
118118
auto result = fir::NameUniquer::deconstruct(
119119
op.getSymref().getRootReference().getValue());
120120
if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
121121
auto newName = mangleExternalName(result, appendUnderscore);
122122
op.setSymrefAttr(mlir::SymbolRefAttr::get(op.getContext(), newName));
123123
SymbolTable::setSymbolName(op, newName);
124124
}
125-
rewriter.finalizeRootUpdate(op);
125+
rewriter.finalizeOpModification(op);
126126
return success();
127127
}
128128

mlir/docs/PatternRewriter.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,15 +213,15 @@ user is determined by the specific pattern driver.
213213
This method replaces an operation's results with a set of provided values, and
214214
erases the operation.
215215
216-
* Update an Operation in-place : `(start|cancel|finalize)RootUpdate`
216+
* Update an Operation in-place : `(start|cancel|finalize)OpModification`
217217
218218
This is a collection of methods that provide a transaction-like API for updating
219219
the attributes, location, operands, or successors of an operation in-place
220220
within a pattern. An in-place update transaction is started with
221-
`startRootUpdate`, and may either be canceled or finalized with
222-
`cancelRootUpdate` and `finalizeRootUpdate` respectively. A convenience wrapper,
223-
`updateRootInPlace`, is provided that wraps a `start` and `finalize` around a
224-
callback.
221+
`startOpModification`, and may either be canceled or finalized with
222+
`cancelOpModification` and `finalizeOpModification` respectively. A convenience
223+
wrapper, `modifyOpInPlace`, is provided that wraps a `start` and `finalize`
224+
around a callback.
225225
226226
* OpBuilder API
227227

mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class StandaloneSwitchBarFooRewriter : public OpRewritePattern<func::FuncOp> {
2424
LogicalResult matchAndRewrite(func::FuncOp op,
2525
PatternRewriter &rewriter) const final {
2626
if (op.getSymName() == "bar") {
27-
rewriter.updateRootInPlace(op, [&op]() { op.setSymName("foo"); });
27+
rewriter.modifyOpInPlace(op, [&op]() { op.setSymName("foo"); });
2828
return success();
2929
}
3030
return failure();

mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
260260
ConversionPatternRewriter &rewriter) const final {
261261
// We don't lower "toy.print" in this pass, but we need to update its
262262
// operands.
263-
rewriter.updateRootInPlace(op,
264-
[&] { op->setOperands(adaptor.getOperands()); });
263+
rewriter.modifyOpInPlace(op,
264+
[&] { op->setOperands(adaptor.getOperands()); });
265265
return success();
266266
}
267267
};

mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
260260
ConversionPatternRewriter &rewriter) const final {
261261
// We don't lower "toy.print" in this pass, but we need to update its
262262
// operands.
263-
rewriter.updateRootInPlace(op,
264-
[&] { op->setOperands(adaptor.getOperands()); });
263+
rewriter.modifyOpInPlace(op,
264+
[&] { op->setOperands(adaptor.getOperands()); });
265265
return success();
266266
}
267267
};

mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
260260
ConversionPatternRewriter &rewriter) const final {
261261
// We don't lower "toy.print" in this pass, but we need to update its
262262
// operands.
263-
rewriter.updateRootInPlace(op,
264-
[&] { op->setOperands(adaptor.getOperands()); });
263+
rewriter.modifyOpInPlace(op,
264+
[&] { op->setOperands(adaptor.getOperands()); });
265265
return success();
266266
}
267267
};

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -585,28 +585,30 @@ class RewriterBase : public OpBuilder {
585585

586586
/// This method is used to notify the rewriter that an in-place operation
587587
/// modification is about to happen. A call to this function *must* be
588-
/// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
589-
/// This is a minor efficiency win (it avoids creating a new operation and
590-
/// removing the old one) but also often allows simpler code in the client.
591-
virtual void startRootUpdate(Operation *op) {}
592-
593-
/// This method is used to signal the end of a root update on the given
594-
/// operation. This can only be called on operations that were provided to a
595-
/// call to `startRootUpdate`.
596-
virtual void finalizeRootUpdate(Operation *op);
597-
598-
/// This method cancels a pending root update. This can only be called on
599-
/// operations that were provided to a call to `startRootUpdate`.
600-
virtual void cancelRootUpdate(Operation *op) {}
601-
602-
/// This method is a utility wrapper around a root update of an operation. It
603-
/// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
604-
/// callable.
588+
/// followed by a call to either `finalizeOpModification` or
589+
/// `cancelOpModification`. This is a minor efficiency win (it avoids creating
590+
/// a new operation and removing the old one) but also often allows simpler
591+
/// code in the client.
592+
virtual void startOpModification(Operation *op) {}
593+
594+
/// This method is used to signal the end of an in-place modification of the
595+
/// given operation. This can only be called on operations that were provided
596+
/// to a call to `startOpModification`.
597+
virtual void finalizeOpModification(Operation *op);
598+
599+
/// This method cancels a pending in-place modification. This can only be
600+
/// called on operations that were provided to a call to
601+
/// `startOpModification`.
602+
virtual void cancelOpModification(Operation *op) {}
603+
604+
/// This method is a utility wrapper around an in-place modification of an
605+
/// operation. It wraps calls to `startOpModification` and
606+
/// `finalizeOpModification` around the given callable.
605607
template <typename CallableT>
606-
void updateRootInPlace(Operation *root, CallableT &&callable) {
607-
startRootUpdate(root);
608+
void modifyOpInPlace(Operation *root, CallableT &&callable) {
609+
startOpModification(root);
608610
callable();
609-
finalizeRootUpdate(root);
611+
finalizeOpModification(root);
610612
}
611613

612614
/// Find uses of `from` and replace them with `to`. It also marks every
@@ -619,7 +621,7 @@ class RewriterBase : public OpBuilder {
619621
void replaceAllUsesWith(IRObjectWithUseList<OperandType> *from, ValueT &&to) {
620622
for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) {
621623
Operation *op = operand.getOwner();
622-
updateRootInPlace(op, [&]() { operand.set(to); });
624+
modifyOpInPlace(op, [&]() { operand.set(to); });
623625
}
624626
}
625627
void replaceAllUsesWith(ValueRange from, ValueRange to) {

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -739,17 +739,17 @@ class ConversionPatternRewriter final : public PatternRewriter,
739739
/// PatternRewriter hook for inserting a new operation.
740740
void notifyOperationInserted(Operation *op) override;
741741

742-
/// PatternRewriter hook for updating the root operation in-place.
743-
/// Note: These methods only track updates to the top-level operation itself,
742+
/// PatternRewriter hook for updating the given operation in-place.
743+
/// Note: These methods only track updates to the given operation itself,
744744
/// and not nested regions. Updates to regions will still require notification
745745
/// through other more specific hooks above.
746-
void startRootUpdate(Operation *op) override;
746+
void startOpModification(Operation *op) override;
747747

748-
/// PatternRewriter hook for updating the root operation in-place.
749-
void finalizeRootUpdate(Operation *op) override;
748+
/// PatternRewriter hook for updating the given operation in-place.
749+
void finalizeOpModification(Operation *op) override;
750750

751-
/// PatternRewriter hook for updating the root operation in-place.
752-
void cancelRootUpdate(Operation *op) override;
751+
/// PatternRewriter hook for updating the given operation in-place.
752+
void cancelOpModification(Operation *op) override;
753753

754754
/// PatternRewriter hook for notifying match failure reasons.
755755
LogicalResult

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
255255
// Step 2. Assign the op a real tile ID.
256256
// For simplicity, we always use tile 0 (which always exists).
257257
auto zeroTileId = rewriter.getI32IntegerAttr(0);
258-
rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
258+
rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
259259

260260
VectorType tileVectorType = tileOp.getTileType();
261261
auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -918,8 +918,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
918918
for (auto stream : streams)
919919
streamDestroyCallBuilder.create(loc, rewriter, {stream});
920920

921-
rewriter.updateRootInPlace(yieldOp,
922-
[&] { yieldOp->setOperands(newOperands); });
921+
rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
923922
return success();
924923
}
925924

mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,13 @@ class ExpandIfCondition : public OpRewritePattern<OpTy> {
4343
if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) {
4444
auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
4545
op.getIfCond(), false);
46-
rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
46+
rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
4747
auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener());
4848
thenBodyBuilder.clone(*op.getOperation());
4949
rewriter.eraseOp(op);
5050
} else {
5151
if (constAttr.getInt())
52-
rewriter.updateRootInPlace(op,
53-
[&]() { op.getIfCondMutable().erase(0); });
52+
rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
5453
else
5554
rewriter.eraseOp(op);
5655
}

0 commit comments

Comments
 (0)