Skip to content

Commit 6b3e000

Browse files
[mlir][Transforms][NFC] GreedyPatternRewriteDriver: Use composition instead of inheritance (#92785)
This commit simplifies the design of the `GreedyPatternRewriterDriver` class. This class used to inherit from both `PatternRewriter` and `RewriterBase::Listener` and then attached itself as a listener. In the new design, the class has a `PatternRewriter` field instead of inheriting from `PatternRewriter`, which is generally perferred in object-oriented programming. --------- Co-authored-by: Markus Böck <[email protected]>
1 parent 7f5d1f1 commit 6b3e000

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,7 @@ class IRRewriter : public RewriterBase {
784784
/// place.
785785
class PatternRewriter : public RewriterBase {
786786
public:
787+
explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
787788
using RewriterBase::RewriterBase;
788789

789790
/// A hook used to indicate if the pattern rewriter can recover from failure

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,7 @@ class RandomizedWorklist : public Worklist {
319319
/// This abstract class manages the worklist and contains helper methods for
320320
/// rewriting ops on the worklist. Derived classes specify how ops are added
321321
/// to the worklist in the beginning.
322-
class GreedyPatternRewriteDriver : public PatternRewriter,
323-
public RewriterBase::Listener {
322+
class GreedyPatternRewriteDriver : public RewriterBase::Listener {
324323
protected:
325324
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
326325
const FrozenRewritePatternSet &patterns,
@@ -339,7 +338,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
339338
/// Notify the driver that the specified operation was inserted. Update the
340339
/// worklist as needed: The operation is enqueued depending on scope and
341340
/// strict mode.
342-
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
341+
void notifyOperationInserted(Operation *op,
342+
OpBuilder::InsertPoint previous) override;
343343

344344
/// Notify the driver that the specified operation was removed. Update the
345345
/// worklist as needed: The operation and its children are removed from the
@@ -354,6 +354,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
354354
/// reached. Return `true` if any IR was changed.
355355
bool processWorklist();
356356

357+
/// The pattern rewriter that is used for making IR modifications and is
358+
/// passed to rewrite patterns.
359+
PatternRewriter rewriter;
360+
357361
/// The worklist for this transformation keeps track of the operations that
358362
/// need to be (re)visited.
359363
#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
@@ -407,7 +411,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
407411
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
408412
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
409413
const GreedyRewriteConfig &config)
410-
: PatternRewriter(ctx), config(config), matcher(patterns)
414+
: rewriter(ctx), config(config), matcher(patterns)
411415
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
412416
// clang-format off
413417
, expensiveChecks(
@@ -423,9 +427,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
423427
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
424428
// Send IR notifications to the debug handler. This handler will then forward
425429
// all notifications to this GreedyPatternRewriteDriver.
426-
setListener(&expensiveChecks);
430+
rewriter.setListener(&expensiveChecks);
427431
#else
428-
setListener(this);
432+
rewriter.setListener(this);
429433
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
430434
}
431435

@@ -473,7 +477,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
473477

474478
// If the operation is trivially dead - remove it.
475479
if (isOpTriviallyDead(op)) {
476-
eraseOp(op);
480+
rewriter.eraseOp(op);
477481
changed = true;
478482

479483
LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
@@ -505,8 +509,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
505509
// Op results can be replaced with `foldResults`.
506510
assert(foldResults.size() == op->getNumResults() &&
507511
"folder produced incorrect number of results");
508-
OpBuilder::InsertionGuard g(*this);
509-
setInsertionPoint(op);
512+
OpBuilder::InsertionGuard g(rewriter);
513+
rewriter.setInsertionPoint(op);
510514
SmallVector<Value> replacements;
511515
bool materializationSucceeded = true;
512516
for (auto [ofr, resultType] :
@@ -519,7 +523,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
519523
}
520524
// Materialize Attributes as SSA values.
521525
Operation *constOp = op->getDialect()->materializeConstant(
522-
*this, ofr.get<Attribute>(), resultType, op->getLoc());
526+
rewriter, ofr.get<Attribute>(), resultType, op->getLoc());
523527

524528
if (!constOp) {
525529
// If materialization fails, cleanup any operations generated for
@@ -532,7 +536,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
532536
replacementOps.insert(replacement.getDefiningOp());
533537
}
534538
for (Operation *op : replacementOps) {
535-
eraseOp(op);
539+
rewriter.eraseOp(op);
536540
}
537541

538542
materializationSucceeded = false;
@@ -547,7 +551,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
547551
}
548552

549553
if (materializationSucceeded) {
550-
replaceOp(op, replacements);
554+
rewriter.replaceOp(op, replacements);
551555
changed = true;
552556
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
553557
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -608,7 +612,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
608612
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
609613

610614
LogicalResult matchResult =
611-
matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
615+
matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
612616

613617
if (succeeded(matchResult)) {
614618
LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
@@ -664,8 +668,8 @@ void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
664668
config.listener->notifyBlockErased(block);
665669
}
666670

667-
void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
668-
InsertPoint previous) {
671+
void GreedyPatternRewriteDriver::notifyOperationInserted(
672+
Operation *op, OpBuilder::InsertPoint previous) {
669673
LLVM_DEBUG({
670674
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
671675
<< ")\n";
@@ -822,7 +826,7 @@ class GreedyPatternRewriteIteration
822826
LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
823827
bool continueRewrites = false;
824828
int64_t iteration = 0;
825-
MLIRContext *ctx = getContext();
829+
MLIRContext *ctx = rewriter.getContext();
826830
do {
827831
// Check if the iteration limit was reached.
828832
if (++iteration > config.maxIterations &&
@@ -834,7 +838,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
834838

835839
// `OperationFolder` CSE's constant ops (and may move them into parents
836840
// regions to enable more aggressive CSE'ing).
837-
OperationFolder folder(getContext(), this);
841+
OperationFolder folder(ctx, this);
838842
auto insertKnownConstant = [&](Operation *op) {
839843
// Check for existing constants when populating the worklist. This avoids
840844
// accidentally reversing the constant order during processing.
@@ -872,7 +876,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
872876
// After applying patterns, make sure that the CFG of each of the
873877
// regions is kept up to date.
874878
if (config.enableRegionSimplification)
875-
continueRewrites |= succeeded(simplifyRegions(*this, region));
879+
continueRewrites |= succeeded(simplifyRegions(rewriter, region));
876880
},
877881
{&region}, iteration);
878882
} while (continueRewrites);

0 commit comments

Comments
 (0)