21
21
#include " mlir/Interfaces/InferIntRangeInterface.h"
22
22
#include " mlir/Query/Matcher/MatchersInternal.h"
23
23
#include " mlir/Query/Query.h"
24
- #include " llvm/ADT/SetVector.h"
25
-
26
24
namespace mlir {
27
25
28
26
namespace detail {
@@ -366,21 +364,14 @@ struct RecursivePatternMatcher {
366
364
std::tuple<OperandMatchers...> operandMatchers;
367
365
};
368
366
369
- // / Fills `backwardSlice` with the computed backward slice (i.e.
370
- // / all the transitive defs of op)
371
- // /
372
- // / The implementation traverses the def chains in postorder traversal for
373
- // / efficiency reasons: if an operation is already in `backwardSlice`, no
374
- // / need to traverse its definitions again. Since use-def chains form a DAG,
375
- // / this terminates.
376
- // /
377
- // / Upon return to the root call, `backwardSlice` is filled with a
378
- // / postorder list of defs. This happens to be a topological order, from the
379
- // / point of view of the use-def chains.
367
+ // / A matcher encapsulating the initial `getBackwardSlice` method from
368
+ // / SliceAnalysis.h
369
+ // / Additionally, it limits the slice computation to a certain depth level using
370
+ // / a custom filter
380
371
// /
381
- // / Example starting from node 8
372
+ // / Example starting from node 9, assuming the matcher
373
+ // / computes the slice for the first two depth levels
382
374
// / ============================
383
- // /
384
375
// / 1 2 3 4
385
376
// / |_______| |______|
386
377
// / | | |
@@ -393,240 +384,52 @@ struct RecursivePatternMatcher {
393
384
// / 9
394
385
// /
395
386
// / Assuming all local orders match the numbering order:
396
- // / {1, 2, 5, 3, 4, 6}
397
- // /
398
-
387
+ // / {5, 7, 6, 8, 9}
399
388
class BackwardSliceMatcher {
400
389
public:
401
- BackwardSliceMatcher (mlir:: query::matcher::DynMatcher &&innerMatcher,
390
+ BackwardSliceMatcher (query::matcher::DynMatcher &&innerMatcher,
402
391
int64_t maxDepth)
403
392
: innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
404
-
405
393
bool match (Operation *op, SetVector<Operation *> &backwardSlice,
406
- mlir:: query::QueryOptions &options) {
394
+ query::QueryOptions &options) {
407
395
408
396
if (innerMatcher.match (op) &&
409
397
matches (op, backwardSlice, options, maxDepth)) {
410
- if (!options.inclusive ) {
411
- // Don't insert the top level operation, we just queried on it and don't
412
- // want it in the results.
413
- backwardSlice.remove (op);
414
- }
415
398
return true ;
416
399
}
417
400
return false ;
418
401
}
419
402
420
403
private:
421
- bool matches (Operation *op, SetVector<Operation *> &backwardSlice,
422
- mlir::query::QueryOptions &options, int64_t remainingDepth) {
423
-
424
- if (op->hasTrait <OpTrait::IsIsolatedFromAbove>()) {
425
- return false ;
426
- }
427
-
428
- auto processValue = [&](Value value) {
429
- // We need to check the current depth level;
430
- // if we have reached level 0, we stop further traversing
431
- if (remainingDepth == 0 ) {
432
- return ;
433
- }
434
- if (auto *definingOp = value.getDefiningOp ()) {
435
- // We omit traversing the same operations
436
- if (backwardSlice.count (definingOp) == 0 )
437
- matches (definingOp, backwardSlice, options, remainingDepth - 1 );
438
- } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
439
- if (options.omitBlockArguments )
440
- return ;
441
- Block *block = blockArg.getOwner ();
442
-
443
- Operation *parentOp = block->getParentOp ();
444
- // TODO: determine whether we want to recurse backward into the other
445
- // blocks of parentOp, which are not technically backward unless they
446
- // flow into us. For now, just bail.
447
- if (parentOp && backwardSlice.count (parentOp) == 0 ) {
448
- if (parentOp->getNumRegions () != 1 &&
449
- parentOp->getRegion (0 ).getBlocks ().size () != 1 ) {
450
- llvm::errs ()
451
- << " Error: Expected parentOp to have exactly one region and "
452
- << " exactly one block, but found " << parentOp->getNumRegions ()
453
- << " regions and "
454
- << (parentOp->getRegion (0 ).getBlocks ().size ()) << " blocks.\n " ;
455
- };
456
- matches (parentOp, backwardSlice, options, remainingDepth - 1 );
457
- }
458
- } else {
459
- llvm_unreachable (" No definingOp and not a block argument\n " );
460
- return ;
461
- }
462
- };
463
-
464
- if (!options.omitUsesFromAbove ) {
465
- llvm::for_each (op->getRegions (), [&](Region ®ion) {
466
- // Walk this region recursively to collect the regions that descend from
467
- // this op's nested regions (inclusive).
468
- SmallPtrSet<Region *, 4 > descendents;
469
- region.walk (
470
- [&](Region *childRegion) { descendents.insert (childRegion); });
471
- region.walk ([&](Operation *op) {
472
- for (OpOperand &operand : op->getOpOperands ()) {
473
- if (!descendents.contains (operand.get ().getParentRegion ()))
474
- processValue (operand.get ());
475
- }
476
- });
477
- });
478
- }
479
-
480
- llvm::for_each (op->getOperands (), processValue);
481
- backwardSlice.insert (op);
482
- return true ;
483
- }
404
+ bool matches (Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
405
+ query::QueryOptions &options, int64_t maxDepth);
484
406
485
407
private:
486
408
// The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
487
409
// to determine whether we want to traverse the DAG or not. For example, we
488
410
// want to explore the DAG only if the top-level operation name is
489
411
// "arith.addf".
490
- mlir::query::matcher::DynMatcher innerMatcher;
491
-
412
+ query::matcher::DynMatcher innerMatcher;
492
413
// maxDepth specifies the maximum depth that the matcher can traverse in the
493
414
// DAG. For example, if maxDepth is 2, the matcher will explore the defining
494
415
// operations of the top-level op up to 2 levels.
495
416
int64_t maxDepth;
496
417
};
497
-
498
- // / Fills `forwardSlice` with the computed forward slice (i.e. all
499
- // / the transitive uses of op)
500
- // /
501
- // /
502
- // / The implementation traverses the use chains in postorder traversal for
503
- // / efficiency reasons: if an operation is already in `forwardSlice`, no
504
- // / need to traverse its uses again. Since use-def chains form a DAG, this
505
- // / terminates.
506
- // /
507
- // / Upon return to the root call, `forwardSlice` is filled with a
508
- // / postorder list of uses (i.e. a reverse topological order). To get a proper
509
- // / topological order, we just reverse the order in `forwardSlice` before
510
- // / returning.
511
- // /
512
- // / Example starting from node 0
513
- // / ============================
514
- // /
515
- // / 0
516
- // / ___________|___________
517
- // / 1 2 3 4
518
- // / |_______| |______|
519
- // / | | |
520
- // / | 5 6
521
- // / |___|_____________|
522
- // / | |
523
- // / 7 8
524
- // / |_______________|
525
- // / |
526
- // / 9
527
- // /
528
- // / Assuming all local orders match the numbering order:
529
- // / 1. after getting back to the root getForwardSlice, `forwardSlice` may
530
- // / contain:
531
- // / {9, 7, 8, 5, 1, 2, 6, 3, 4}
532
- // / 2. reversing the result of 1. gives:
533
- // / {4, 3, 6, 2, 1, 5, 8, 7, 9}
534
- // /
535
- class ForwardSliceMatcher {
536
- public:
537
- ForwardSliceMatcher (mlir::query::matcher::DynMatcher &&innerMatcher,
538
- int64_t maxDepth)
539
- : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
540
-
541
- bool match (Operation *op, SetVector<Operation *> &forwardSlice,
542
- mlir::query::QueryOptions &options) {
543
- if (innerMatcher.match (op) &&
544
- matches (op, forwardSlice, options, maxDepth)) {
545
- if (!options.inclusive ) {
546
- // Don't insert the top level operation, we just queried on it and don't
547
- // want it in the results.
548
- forwardSlice.remove (op);
549
- }
550
- // Reverse to get back the actual topological order.
551
- // std::reverse does not work out of the box on SetVector and I want an
552
- // in-place swap based thing (the real std::reverse, not the LLVM
553
- // adapter).
554
- SmallVector<Operation *, 0 > v (forwardSlice.takeVector ());
555
- forwardSlice.insert (v.rbegin (), v.rend ());
556
- return true ;
557
- }
558
- return false ;
559
- }
560
-
561
- private:
562
- bool matches (Operation *op, SetVector<Operation *> &forwardSlice,
563
- mlir::query::QueryOptions &options, int64_t remainingDepth) {
564
-
565
- // We need to check the current depth level;
566
- // if we have reached level 0, we stop further traversing and insert
567
- // the last user in def-use chain
568
- if (remainingDepth == 0 ) {
569
- forwardSlice.insert (op);
570
- return true ;
571
- }
572
-
573
- for (Region ®ion : op->getRegions ())
574
- for (Block &block : region)
575
- for (Operation &blockOp : block)
576
- if (forwardSlice.count (&blockOp) == 0 )
577
- matches (&blockOp, forwardSlice, options, remainingDepth - 1 );
578
- for (Value result : op->getResults ()) {
579
- for (Operation *userOp : result.getUsers ())
580
- // We omit traversing the same operations
581
- if (forwardSlice.count (userOp) == 0 )
582
- matches (userOp, forwardSlice, options, remainingDepth - 1 );
583
- }
584
-
585
- forwardSlice.insert (op);
586
- return true ;
587
- }
588
-
589
- private:
590
- // The outer matcher e.g (ForwardSliceMatcher) relies on the innerMatcher to
591
- // determine whether we want to traverse the graph or not. E.g: we want to
592
- // explore the DAG only if the top level operation name is "arith.addf"
593
- mlir::query::matcher::DynMatcher innerMatcher;
594
-
595
- // maxDepth specifies the maximum depth that the matcher can traverse the
596
- // graph E.g: if maxDepth is 2, the matcher will explore the user
597
- // operations of the top level op up to 2 levels
598
- int64_t maxDepth;
599
- };
600
-
601
418
} // namespace detail
602
419
603
420
// Matches transitive defs of a top level operation up to 1 level
604
421
inline detail::BackwardSliceMatcher
605
- m_DefinedBy (mlir:: query::matcher::DynMatcher innerMatcher) {
422
+ m_DefinedBy (query::matcher::DynMatcher innerMatcher) {
606
423
return detail::BackwardSliceMatcher (std::move (innerMatcher), 1 );
607
424
}
608
425
609
426
// Matches transitive defs of a top level operation up to N levels
610
427
inline detail::BackwardSliceMatcher
611
- m_GetDefinitions (mlir::query::matcher::DynMatcher innerMatcher,
612
- int64_t maxDepth) {
428
+ m_GetDefinitions (query::matcher::DynMatcher innerMatcher, int64_t maxDepth) {
613
429
assert (maxDepth >= 0 && " maxDepth must be non-negative" );
614
430
return detail::BackwardSliceMatcher (std::move (innerMatcher), maxDepth);
615
431
}
616
432
617
- // Matches uses of a top level operation up to 1 level
618
- inline detail::ForwardSliceMatcher
619
- m_UsedBy (mlir::query::matcher::DynMatcher innerMatcher) {
620
- return detail::ForwardSliceMatcher (std::move (innerMatcher), 1 );
621
- }
622
-
623
- // Matches uses of a top level operation up to N levels
624
- inline detail::ForwardSliceMatcher
625
- m_GetUses (mlir::query::matcher::DynMatcher innerMatcher, int64_t maxDepth) {
626
- assert (maxDepth >= 0 && " maxDepth must be non-negative" );
627
- return detail::ForwardSliceMatcher (std::move (innerMatcher), maxDepth);
628
- }
629
-
630
433
// / Matches a constant foldable operation.
631
434
inline detail::constant_op_matcher m_Constant () {
632
435
return detail::constant_op_matcher ();
0 commit comments