@@ -319,8 +319,7 @@ class RandomizedWorklist : public Worklist {
319
319
// / This abstract class manages the worklist and contains helper methods for
320
320
// / rewriting ops on the worklist. Derived classes specify how ops are added
321
321
// / to the worklist in the beginning.
322
- class GreedyPatternRewriteDriver : public PatternRewriter ,
323
- public RewriterBase::Listener {
322
+ class GreedyPatternRewriteDriver : public RewriterBase ::Listener {
324
323
protected:
325
324
explicit GreedyPatternRewriteDriver (MLIRContext *ctx,
326
325
const FrozenRewritePatternSet &patterns,
@@ -339,7 +338,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
339
338
// / Notify the driver that the specified operation was inserted. Update the
340
339
// / worklist as needed: The operation is enqueued depending on scope and
341
340
// / strict mode.
342
- void notifyOperationInserted (Operation *op, InsertPoint previous) override ;
341
+ void notifyOperationInserted (Operation *op,
342
+ OpBuilder::InsertPoint previous) override ;
343
343
344
344
// / Notify the driver that the specified operation was removed. Update the
345
345
// / worklist as needed: The operation and its children are removed from the
@@ -354,6 +354,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
354
354
// / reached. Return `true` if any IR was changed.
355
355
bool processWorklist ();
356
356
357
+ // / The pattern rewriter that is used for making IR modifications and is
358
+ // / passed to rewrite patterns.
359
+ PatternRewriter rewriter;
360
+
357
361
// / The worklist for this transformation keeps track of the operations that
358
362
// / need to be (re)visited.
359
363
#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
@@ -407,7 +411,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
407
411
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver (
408
412
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
409
413
const GreedyRewriteConfig &config)
410
- : PatternRewriter (ctx), config(config), matcher(patterns)
414
+ : rewriter (ctx), config(config), matcher(patterns)
411
415
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
412
416
// clang-format off
413
417
, expensiveChecks(
@@ -423,9 +427,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
423
427
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
424
428
// Send IR notifications to the debug handler. This handler will then forward
425
429
// all notifications to this GreedyPatternRewriteDriver.
426
- setListener (&expensiveChecks);
430
+ rewriter. setListener (&expensiveChecks);
427
431
#else
428
- setListener (this );
432
+ rewriter. setListener (this );
429
433
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
430
434
}
431
435
@@ -473,7 +477,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
473
477
474
478
// If the operation is trivially dead - remove it.
475
479
if (isOpTriviallyDead (op)) {
476
- eraseOp (op);
480
+ rewriter. eraseOp (op);
477
481
changed = true ;
478
482
479
483
LLVM_DEBUG (logResultWithLine (" success" , " operation is trivially dead" ));
@@ -505,8 +509,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
505
509
// Op results can be replaced with `foldResults`.
506
510
assert (foldResults.size () == op->getNumResults () &&
507
511
" folder produced incorrect number of results" );
508
- OpBuilder::InsertionGuard g (* this );
509
- setInsertionPoint (op);
512
+ OpBuilder::InsertionGuard g (rewriter );
513
+ rewriter. setInsertionPoint (op);
510
514
SmallVector<Value> replacements;
511
515
bool materializationSucceeded = true ;
512
516
for (auto [ofr, resultType] :
@@ -519,7 +523,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
519
523
}
520
524
// Materialize Attributes as SSA values.
521
525
Operation *constOp = op->getDialect ()->materializeConstant (
522
- * this , ofr.get <Attribute>(), resultType, op->getLoc ());
526
+ rewriter , ofr.get <Attribute>(), resultType, op->getLoc ());
523
527
524
528
if (!constOp) {
525
529
// If materialization fails, cleanup any operations generated for
@@ -532,7 +536,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
532
536
replacementOps.insert (replacement.getDefiningOp ());
533
537
}
534
538
for (Operation *op : replacementOps) {
535
- eraseOp (op);
539
+ rewriter. eraseOp (op);
536
540
}
537
541
538
542
materializationSucceeded = false ;
@@ -547,7 +551,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
547
551
}
548
552
549
553
if (materializationSucceeded) {
550
- replaceOp (op, replacements);
554
+ rewriter. replaceOp (op, replacements);
551
555
changed = true ;
552
556
LLVM_DEBUG (logSuccessfulFolding (dumpRootOp));
553
557
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -608,7 +612,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
608
612
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
609
613
610
614
LogicalResult matchResult =
611
- matcher.matchAndRewrite (op, * this , canApply, onFailure, onSuccess);
615
+ matcher.matchAndRewrite (op, rewriter , canApply, onFailure, onSuccess);
612
616
613
617
if (succeeded (matchResult)) {
614
618
LLVM_DEBUG (logResultWithLine (" success" , " pattern matched" ));
@@ -664,8 +668,8 @@ void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
664
668
config.listener ->notifyBlockErased (block);
665
669
}
666
670
667
- void GreedyPatternRewriteDriver::notifyOperationInserted (Operation *op,
668
- InsertPoint previous) {
671
+ void GreedyPatternRewriteDriver::notifyOperationInserted (
672
+ Operation *op, OpBuilder:: InsertPoint previous) {
669
673
LLVM_DEBUG ({
670
674
logger.startLine () << " ** Insert : '" << op->getName () << " '(" << op
671
675
<< " )\n " ;
@@ -822,7 +826,7 @@ class GreedyPatternRewriteIteration
822
826
LogicalResult RegionPatternRewriteDriver::simplify (bool *changed) && {
823
827
bool continueRewrites = false ;
824
828
int64_t iteration = 0 ;
825
- MLIRContext *ctx = getContext ();
829
+ MLIRContext *ctx = rewriter. getContext ();
826
830
do {
827
831
// Check if the iteration limit was reached.
828
832
if (++iteration > config.maxIterations &&
@@ -834,7 +838,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
834
838
835
839
// `OperationFolder` CSE's constant ops (and may move them into parents
836
840
// regions to enable more aggressive CSE'ing).
837
- OperationFolder folder (getContext () , this );
841
+ OperationFolder folder (ctx , this );
838
842
auto insertKnownConstant = [&](Operation *op) {
839
843
// Check for existing constants when populating the worklist. This avoids
840
844
// accidentally reversing the constant order during processing.
@@ -872,7 +876,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
872
876
// After applying patterns, make sure that the CFG of each of the
873
877
// regions is kept up to date.
874
878
if (config.enableRegionSimplification )
875
- continueRewrites |= succeeded (simplifyRegions (* this , region));
879
+ continueRewrites |= succeeded (simplifyRegions (rewriter , region));
876
880
},
877
881
{®ion}, iteration);
878
882
} while (continueRewrites);
0 commit comments