|
19 | 19 | #include "mlir/IR/BuiltinTypes.h"
|
20 | 20 | #include "mlir/IR/OpDefinition.h"
|
21 | 21 | #include "mlir/Interfaces/InferIntRangeInterface.h"
|
| 22 | +#include "mlir/Query/Matcher/MatchersInternal.h" |
| 23 | +#include "mlir/Query/Query.h" |
| 24 | +#include "llvm/ADT/SetVector.h" |
22 | 25 |
|
23 | 26 | namespace mlir {
|
24 | 27 |
|
@@ -363,8 +366,267 @@ struct RecursivePatternMatcher {
|
363 | 366 | std::tuple<OperandMatchers...> operandMatchers;
|
364 | 367 | };
|
365 | 368 |
|
| 369 | +/// Fills `backwardSlice` with the computed backward slice (i.e. |
| 370 | +/// all the transitive defs of op) |
| 371 | +/// |
| 372 | +/// The implementation traverses the def chains in postorder traversal for |
| 373 | +/// efficiency reasons: if an operation is already in `backwardSlice`, no |
| 374 | +/// need to traverse its definitions again. Since use-def chains form a DAG, |
| 375 | +/// this terminates. |
| 376 | +/// |
| 377 | +/// Upon return to the root call, `backwardSlice` is filled with a |
| 378 | +/// postorder list of defs. This happens to be a topological order, from the |
| 379 | +/// point of view of the use-def chains. |
| 380 | +/// |
| 381 | +/// Example starting from node 8 |
| 382 | +/// ============================ |
| 383 | +/// |
| 384 | +/// 1 2 3 4 |
| 385 | +/// |_______| |______| |
| 386 | +/// | | | |
| 387 | +/// | 5 6 |
| 388 | +/// |___|_____________| |
| 389 | +/// | | |
| 390 | +/// 7 8 |
| 391 | +/// |_______________| |
| 392 | +/// | |
| 393 | +/// 9 |
| 394 | +/// |
| 395 | +/// Assuming all local orders match the numbering order: |
| 396 | +/// {1, 2, 5, 3, 4, 6} |
| 397 | +/// |
| 398 | + |
| 399 | +class BackwardSliceMatcher { |
| 400 | +public: |
| 401 | + BackwardSliceMatcher(mlir::query::matcher::DynMatcher &&innerMatcher, |
| 402 | + int64_t maxDepth) |
| 403 | + : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {} |
| 404 | + |
| 405 | + bool match(Operation *op, SetVector<Operation *> &backwardSlice, |
| 406 | + mlir::query::QueryOptions &options) { |
| 407 | + |
| 408 | + if (innerMatcher.match(op) && |
| 409 | + matches(op, backwardSlice, options, maxDepth)) { |
| 410 | + if (!options.inclusive) { |
| 411 | + // Don't insert the top level operation, we just queried on it and don't |
| 412 | + // want it in the results. |
| 413 | + backwardSlice.remove(op); |
| 414 | + } |
| 415 | + return true; |
| 416 | + } |
| 417 | + return false; |
| 418 | + } |
| 419 | + |
| 420 | +private: |
| 421 | + bool matches(Operation *op, SetVector<Operation *> &backwardSlice, |
| 422 | + mlir::query::QueryOptions &options, int64_t remainingDepth) { |
| 423 | + |
| 424 | + if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) { |
| 425 | + return false; |
| 426 | + } |
| 427 | + |
| 428 | + auto processValue = [&](Value value) { |
| 429 | + // We need to check the current depth level; |
| 430 | + // if we have reached level 0, we stop further traversing |
| 431 | + if (remainingDepth == 0) { |
| 432 | + return; |
| 433 | + } |
| 434 | + if (auto *definingOp = value.getDefiningOp()) { |
| 435 | + // We omit traversing the same operations |
| 436 | + if (backwardSlice.count(definingOp) == 0) |
| 437 | + matches(definingOp, backwardSlice, options, remainingDepth - 1); |
| 438 | + } else if (auto blockArg = dyn_cast<BlockArgument>(value)) { |
| 439 | + if (options.omitBlockArguments) |
| 440 | + return; |
| 441 | + Block *block = blockArg.getOwner(); |
| 442 | + |
| 443 | + Operation *parentOp = block->getParentOp(); |
| 444 | + // TODO: determine whether we want to recurse backward into the other |
| 445 | + // blocks of parentOp, which are not technically backward unless they |
| 446 | + // flow into us. For now, just bail. |
| 447 | + if (parentOp && backwardSlice.count(parentOp) == 0) { |
| 448 | + if (parentOp->getNumRegions() != 1 && |
| 449 | + parentOp->getRegion(0).getBlocks().size() != 1) { |
| 450 | + llvm::errs() |
| 451 | + << "Error: Expected parentOp to have exactly one region and " |
| 452 | + << "exactly one block, but found " << parentOp->getNumRegions() |
| 453 | + << " regions and " |
| 454 | + << (parentOp->getRegion(0).getBlocks().size()) << " blocks.\n"; |
| 455 | + }; |
| 456 | + matches(parentOp, backwardSlice, options, remainingDepth - 1); |
| 457 | + } |
| 458 | + } else { |
| 459 | + llvm_unreachable("No definingOp and not a block argument\n"); |
| 460 | + return; |
| 461 | + } |
| 462 | + }; |
| 463 | + |
| 464 | + if (!options.omitUsesFromAbove) { |
| 465 | + llvm::for_each(op->getRegions(), [&](Region ®ion) { |
| 466 | + // Walk this region recursively to collect the regions that descend from |
| 467 | + // this op's nested regions (inclusive). |
| 468 | + SmallPtrSet<Region *, 4> descendents; |
| 469 | + region.walk( |
| 470 | + [&](Region *childRegion) { descendents.insert(childRegion); }); |
| 471 | + region.walk([&](Operation *op) { |
| 472 | + for (OpOperand &operand : op->getOpOperands()) { |
| 473 | + if (!descendents.contains(operand.get().getParentRegion())) |
| 474 | + processValue(operand.get()); |
| 475 | + } |
| 476 | + }); |
| 477 | + }); |
| 478 | + } |
| 479 | + |
| 480 | + llvm::for_each(op->getOperands(), processValue); |
| 481 | + backwardSlice.insert(op); |
| 482 | + return true; |
| 483 | + } |
| 484 | + |
| 485 | +private: |
| 486 | + // The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher |
| 487 | + // to determine whether we want to traverse the DAG or not. For example, we |
| 488 | + // want to explore the DAG only if the top-level operation name is |
| 489 | + // "arith.addf". |
| 490 | + mlir::query::matcher::DynMatcher innerMatcher; |
| 491 | + |
| 492 | + // maxDepth specifies the maximum depth that the matcher can traverse in the |
| 493 | + // DAG. For example, if maxDepth is 2, the matcher will explore the defining |
| 494 | + // operations of the top-level op up to 2 levels. |
| 495 | + int64_t maxDepth; |
| 496 | +}; |
| 497 | + |
| 498 | +/// Fills `forwardSlice` with the computed forward slice (i.e. all |
| 499 | +/// the transitive uses of op) |
| 500 | +/// |
| 501 | +/// |
| 502 | +/// The implementation traverses the use chains in postorder traversal for |
| 503 | +/// efficiency reasons: if an operation is already in `forwardSlice`, no |
| 504 | +/// need to traverse its uses again. Since use-def chains form a DAG, this |
| 505 | +/// terminates. |
| 506 | +/// |
| 507 | +/// Upon return to the root call, `forwardSlice` is filled with a |
| 508 | +/// postorder list of uses (i.e. a reverse topological order). To get a proper |
| 509 | +/// topological order, we just reverse the order in `forwardSlice` before |
| 510 | +/// returning. |
| 511 | +/// |
| 512 | +/// Example starting from node 0 |
| 513 | +/// ============================ |
| 514 | +/// |
| 515 | +/// 0 |
| 516 | +/// ___________|___________ |
| 517 | +/// 1 2 3 4 |
| 518 | +/// |_______| |______| |
| 519 | +/// | | | |
| 520 | +/// | 5 6 |
| 521 | +/// |___|_____________| |
| 522 | +/// | | |
| 523 | +/// 7 8 |
| 524 | +/// |_______________| |
| 525 | +/// | |
| 526 | +/// 9 |
| 527 | +/// |
| 528 | +/// Assuming all local orders match the numbering order: |
| 529 | +/// 1. after getting back to the root getForwardSlice, `forwardSlice` may |
| 530 | +/// contain: |
| 531 | +/// {9, 7, 8, 5, 1, 2, 6, 3, 4} |
| 532 | +/// 2. reversing the result of 1. gives: |
| 533 | +/// {4, 3, 6, 2, 1, 5, 8, 7, 9} |
| 534 | +/// |
| 535 | +class ForwardSliceMatcher { |
| 536 | +public: |
| 537 | + ForwardSliceMatcher(mlir::query::matcher::DynMatcher &&innerMatcher, |
| 538 | + int64_t maxDepth) |
| 539 | + : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {} |
| 540 | + |
| 541 | + bool match(Operation *op, SetVector<Operation *> &forwardSlice, |
| 542 | + mlir::query::QueryOptions &options) { |
| 543 | + if (innerMatcher.match(op) && |
| 544 | + matches(op, forwardSlice, options, maxDepth)) { |
| 545 | + if (!options.inclusive) { |
| 546 | + // Don't insert the top level operation, we just queried on it and don't |
| 547 | + // want it in the results. |
| 548 | + forwardSlice.remove(op); |
| 549 | + } |
| 550 | + // Reverse to get back the actual topological order. |
| 551 | + // std::reverse does not work out of the box on SetVector and I want an |
| 552 | + // in-place swap based thing (the real std::reverse, not the LLVM |
| 553 | + // adapter). |
| 554 | + SmallVector<Operation *, 0> v(forwardSlice.takeVector()); |
| 555 | + forwardSlice.insert(v.rbegin(), v.rend()); |
| 556 | + return true; |
| 557 | + } |
| 558 | + return false; |
| 559 | + } |
| 560 | + |
| 561 | +private: |
| 562 | + bool matches(Operation *op, SetVector<Operation *> &forwardSlice, |
| 563 | + mlir::query::QueryOptions &options, int64_t remainingDepth) { |
| 564 | + |
| 565 | + // We need to check the current depth level; |
| 566 | + // if we have reached level 0, we stop further traversing and insert |
| 567 | + // the last user in def-use chain |
| 568 | + if (remainingDepth == 0) { |
| 569 | + forwardSlice.insert(op); |
| 570 | + return true; |
| 571 | + } |
| 572 | + |
| 573 | + for (Region ®ion : op->getRegions()) |
| 574 | + for (Block &block : region) |
| 575 | + for (Operation &blockOp : block) |
| 576 | + if (forwardSlice.count(&blockOp) == 0) |
| 577 | + matches(&blockOp, forwardSlice, options, remainingDepth - 1); |
| 578 | + for (Value result : op->getResults()) { |
| 579 | + for (Operation *userOp : result.getUsers()) |
| 580 | + // We omit traversing the same operations |
| 581 | + if (forwardSlice.count(userOp) == 0) |
| 582 | + matches(userOp, forwardSlice, options, remainingDepth - 1); |
| 583 | + } |
| 584 | + |
| 585 | + forwardSlice.insert(op); |
| 586 | + return true; |
| 587 | + } |
| 588 | + |
| 589 | +private: |
| 590 | + // The outer matcher e.g (ForwardSliceMatcher) relies on the innerMatcher to |
| 591 | + // determine whether we want to traverse the graph or not. E.g: we want to |
| 592 | + // explore the DAG only if the top level operation name is "arith.addf" |
| 593 | + mlir::query::matcher::DynMatcher innerMatcher; |
| 594 | + |
| 595 | + // maxDepth specifies the maximum depth that the matcher can traverse the |
| 596 | + // graph E.g: if maxDepth is 2, the matcher will explore the user |
| 597 | + // operations of the top level op up to 2 levels |
| 598 | + int64_t maxDepth; |
| 599 | +}; |
| 600 | + |
366 | 601 | } // namespace detail
|
367 | 602 |
|
| 603 | +// Matches transitive defs of a top level operation up to 1 level |
| 604 | +inline detail::BackwardSliceMatcher |
| 605 | +m_DefinedBy(mlir::query::matcher::DynMatcher innerMatcher) { |
| 606 | + return detail::BackwardSliceMatcher(std::move(innerMatcher), 1); |
| 607 | +} |
| 608 | + |
| 609 | +// Matches transitive defs of a top level operation up to N levels |
| 610 | +inline detail::BackwardSliceMatcher |
| 611 | +m_GetDefinitions(mlir::query::matcher::DynMatcher innerMatcher, |
| 612 | + int64_t maxDepth) { |
| 613 | + assert(maxDepth >= 0 && "maxDepth must be non-negative"); |
| 614 | + return detail::BackwardSliceMatcher(std::move(innerMatcher), maxDepth); |
| 615 | +} |
| 616 | + |
| 617 | +// Matches uses of a top level operation up to 1 level |
| 618 | +inline detail::ForwardSliceMatcher |
| 619 | +m_UsedBy(mlir::query::matcher::DynMatcher innerMatcher) { |
| 620 | + return detail::ForwardSliceMatcher(std::move(innerMatcher), 1); |
| 621 | +} |
| 622 | + |
| 623 | +// Matches uses of a top level operation up to N levels |
| 624 | +inline detail::ForwardSliceMatcher |
| 625 | +m_GetUses(mlir::query::matcher::DynMatcher innerMatcher, int64_t maxDepth) { |
| 626 | + assert(maxDepth >= 0 && "maxDepth must be non-negative"); |
| 627 | + return detail::ForwardSliceMatcher(std::move(innerMatcher), maxDepth); |
| 628 | +} |
| 629 | + |
368 | 630 | /// Matches a constant foldable operation.
|
369 | 631 | inline detail::constant_op_matcher m_Constant() {
|
370 | 632 | return detail::constant_op_matcher();
|
|
0 commit comments