Skip to content

Commit 7a9fbd6

Browse files
committed
Enhance matcher and QueryOptions documentation
- Enhance docs for matchers and QueryOptions - Fix whitespace and alignment issues - Move matchers to Matchers.h - Change data type from unsigned to signed for arithmetic operations
1 parent 69dc854 commit 7a9fbd6

File tree

13 files changed

+383
-268
lines changed

13 files changed

+383
-268
lines changed

mlir/include/mlir/IR/Matchers.h

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#include "mlir/IR/BuiltinTypes.h"
2020
#include "mlir/IR/OpDefinition.h"
2121
#include "mlir/Interfaces/InferIntRangeInterface.h"
22+
#include "mlir/Query/Matcher/MatchersInternal.h"
23+
#include "mlir/Query/Query.h"
24+
#include "llvm/ADT/SetVector.h"
2225

2326
namespace mlir {
2427

@@ -363,8 +366,267 @@ struct RecursivePatternMatcher {
363366
std::tuple<OperandMatchers...> operandMatchers;
364367
};
365368

369+
/// Fills `backwardSlice` with the computed backward slice (i.e.
370+
/// all the transitive defs of op)
371+
///
372+
/// The implementation traverses the def chains in postorder traversal for
373+
/// efficiency reasons: if an operation is already in `backwardSlice`, no
374+
/// need to traverse its definitions again. Since use-def chains form a DAG,
375+
/// this terminates.
376+
///
377+
/// Upon return to the root call, `backwardSlice` is filled with a
378+
/// postorder list of defs. This happens to be a topological order, from the
379+
/// point of view of the use-def chains.
380+
///
381+
/// Example starting from node 8
382+
/// ============================
383+
///
384+
/// 1 2 3 4
385+
/// |_______| |______|
386+
/// | | |
387+
/// | 5 6
388+
/// |___|_____________|
389+
/// | |
390+
/// 7 8
391+
/// |_______________|
392+
/// |
393+
/// 9
394+
///
395+
/// Assuming all local orders match the numbering order:
396+
/// {1, 2, 5, 3, 4, 6}
397+
///
398+
399+
class BackwardSliceMatcher {
400+
public:
401+
BackwardSliceMatcher(mlir::query::matcher::DynMatcher &&innerMatcher,
402+
int64_t maxDepth)
403+
: innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
404+
405+
bool match(Operation *op, SetVector<Operation *> &backwardSlice,
406+
mlir::query::QueryOptions &options) {
407+
408+
if (innerMatcher.match(op) &&
409+
matches(op, backwardSlice, options, maxDepth)) {
410+
if (!options.inclusive) {
411+
// Don't insert the top level operation, we just queried on it and don't
412+
// want it in the results.
413+
backwardSlice.remove(op);
414+
}
415+
return true;
416+
}
417+
return false;
418+
}
419+
420+
private:
421+
bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
422+
mlir::query::QueryOptions &options, int64_t remainingDepth) {
423+
424+
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
425+
return false;
426+
}
427+
428+
auto processValue = [&](Value value) {
429+
// We need to check the current depth level;
430+
// if we have reached level 0, we stop further traversing
431+
if (remainingDepth == 0) {
432+
return;
433+
}
434+
if (auto *definingOp = value.getDefiningOp()) {
435+
// We omit traversing the same operations
436+
if (backwardSlice.count(definingOp) == 0)
437+
matches(definingOp, backwardSlice, options, remainingDepth - 1);
438+
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
439+
if (options.omitBlockArguments)
440+
return;
441+
Block *block = blockArg.getOwner();
442+
443+
Operation *parentOp = block->getParentOp();
444+
// TODO: determine whether we want to recurse backward into the other
445+
// blocks of parentOp, which are not technically backward unless they
446+
// flow into us. For now, just bail.
447+
if (parentOp && backwardSlice.count(parentOp) == 0) {
448+
if (parentOp->getNumRegions() != 1 &&
449+
parentOp->getRegion(0).getBlocks().size() != 1) {
450+
llvm::errs()
451+
<< "Error: Expected parentOp to have exactly one region and "
452+
<< "exactly one block, but found " << parentOp->getNumRegions()
453+
<< " regions and "
454+
<< (parentOp->getRegion(0).getBlocks().size()) << " blocks.\n";
455+
};
456+
matches(parentOp, backwardSlice, options, remainingDepth - 1);
457+
}
458+
} else {
459+
llvm_unreachable("No definingOp and not a block argument\n");
460+
return;
461+
}
462+
};
463+
464+
if (!options.omitUsesFromAbove) {
465+
llvm::for_each(op->getRegions(), [&](Region &region) {
466+
// Walk this region recursively to collect the regions that descend from
467+
// this op's nested regions (inclusive).
468+
SmallPtrSet<Region *, 4> descendents;
469+
region.walk(
470+
[&](Region *childRegion) { descendents.insert(childRegion); });
471+
region.walk([&](Operation *op) {
472+
for (OpOperand &operand : op->getOpOperands()) {
473+
if (!descendents.contains(operand.get().getParentRegion()))
474+
processValue(operand.get());
475+
}
476+
});
477+
});
478+
}
479+
480+
llvm::for_each(op->getOperands(), processValue);
481+
backwardSlice.insert(op);
482+
return true;
483+
}
484+
485+
private:
486+
// The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
487+
// to determine whether we want to traverse the DAG or not. For example, we
488+
// want to explore the DAG only if the top-level operation name is
489+
// "arith.addf".
490+
mlir::query::matcher::DynMatcher innerMatcher;
491+
492+
// maxDepth specifies the maximum depth that the matcher can traverse in the
493+
// DAG. For example, if maxDepth is 2, the matcher will explore the defining
494+
// operations of the top-level op up to 2 levels.
495+
int64_t maxDepth;
496+
};
497+
498+
/// Fills `forwardSlice` with the computed forward slice (i.e. all
499+
/// the transitive uses of op)
500+
///
501+
///
502+
/// The implementation traverses the use chains in postorder traversal for
503+
/// efficiency reasons: if an operation is already in `forwardSlice`, no
504+
/// need to traverse its uses again. Since use-def chains form a DAG, this
505+
/// terminates.
506+
///
507+
/// Upon return to the root call, `forwardSlice` is filled with a
508+
/// postorder list of uses (i.e. a reverse topological order). To get a proper
509+
/// topological order, we just reverse the order in `forwardSlice` before
510+
/// returning.
511+
///
512+
/// Example starting from node 0
513+
/// ============================
514+
///
515+
/// 0
516+
/// ___________|___________
517+
/// 1 2 3 4
518+
/// |_______| |______|
519+
/// | | |
520+
/// | 5 6
521+
/// |___|_____________|
522+
/// | |
523+
/// 7 8
524+
/// |_______________|
525+
/// |
526+
/// 9
527+
///
528+
/// Assuming all local orders match the numbering order:
529+
/// 1. after getting back to the root getForwardSlice, `forwardSlice` may
530+
/// contain:
531+
/// {9, 7, 8, 5, 1, 2, 6, 3, 4}
532+
/// 2. reversing the result of 1. gives:
533+
/// {4, 3, 6, 2, 1, 5, 8, 7, 9}
534+
///
535+
class ForwardSliceMatcher {
536+
public:
537+
ForwardSliceMatcher(mlir::query::matcher::DynMatcher &&innerMatcher,
538+
int64_t maxDepth)
539+
: innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
540+
541+
bool match(Operation *op, SetVector<Operation *> &forwardSlice,
542+
mlir::query::QueryOptions &options) {
543+
if (innerMatcher.match(op) &&
544+
matches(op, forwardSlice, options, maxDepth)) {
545+
if (!options.inclusive) {
546+
// Don't insert the top level operation, we just queried on it and don't
547+
// want it in the results.
548+
forwardSlice.remove(op);
549+
}
550+
// Reverse to get back the actual topological order.
551+
// std::reverse does not work out of the box on SetVector and I want an
552+
// in-place swap based thing (the real std::reverse, not the LLVM
553+
// adapter).
554+
SmallVector<Operation *, 0> v(forwardSlice.takeVector());
555+
forwardSlice.insert(v.rbegin(), v.rend());
556+
return true;
557+
}
558+
return false;
559+
}
560+
561+
private:
562+
bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
563+
mlir::query::QueryOptions &options, int64_t remainingDepth) {
564+
565+
// We need to check the current depth level;
566+
// if we have reached level 0, we stop further traversing and insert
567+
// the last user in def-use chain
568+
if (remainingDepth == 0) {
569+
forwardSlice.insert(op);
570+
return true;
571+
}
572+
573+
for (Region &region : op->getRegions())
574+
for (Block &block : region)
575+
for (Operation &blockOp : block)
576+
if (forwardSlice.count(&blockOp) == 0)
577+
matches(&blockOp, forwardSlice, options, remainingDepth - 1);
578+
for (Value result : op->getResults()) {
579+
for (Operation *userOp : result.getUsers())
580+
// We omit traversing the same operations
581+
if (forwardSlice.count(userOp) == 0)
582+
matches(userOp, forwardSlice, options, remainingDepth - 1);
583+
}
584+
585+
forwardSlice.insert(op);
586+
return true;
587+
}
588+
589+
private:
590+
// The outer matcher e.g (ForwardSliceMatcher) relies on the innerMatcher to
591+
// determine whether we want to traverse the graph or not. E.g: we want to
592+
// explore the DAG only if the top level operation name is "arith.addf"
593+
mlir::query::matcher::DynMatcher innerMatcher;
594+
595+
// maxDepth specifies the maximum depth that the matcher can traverse the
596+
// graph E.g: if maxDepth is 2, the matcher will explore the user
597+
// operations of the top level op up to 2 levels
598+
int64_t maxDepth;
599+
};
600+
366601
} // namespace detail
367602

603+
// Matches transitive defs of a top level operation up to 1 level
604+
inline detail::BackwardSliceMatcher
605+
m_DefinedBy(mlir::query::matcher::DynMatcher innerMatcher) {
606+
return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
607+
}
608+
609+
// Matches transitive defs of a top level operation up to N levels
610+
inline detail::BackwardSliceMatcher
611+
m_GetDefinitions(mlir::query::matcher::DynMatcher innerMatcher,
612+
int64_t maxDepth) {
613+
assert(maxDepth >= 0 && "maxDepth must be non-negative");
614+
return detail::BackwardSliceMatcher(std::move(innerMatcher), maxDepth);
615+
}
616+
617+
// Matches uses of a top level operation up to 1 level
618+
inline detail::ForwardSliceMatcher
619+
m_UsedBy(mlir::query::matcher::DynMatcher innerMatcher) {
620+
return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
621+
}
622+
623+
// Matches uses of a top level operation up to N levels
624+
inline detail::ForwardSliceMatcher
625+
m_GetUses(mlir::query::matcher::DynMatcher innerMatcher, int64_t maxDepth) {
626+
assert(maxDepth >= 0 && "maxDepth must be non-negative");
627+
return detail::ForwardSliceMatcher(std::move(innerMatcher), maxDepth);
628+
}
629+
368630
/// Matches a constant foldable operation.
369631
inline detail::constant_op_matcher m_Constant() {
370632
return detail::constant_op_matcher();

0 commit comments

Comments
 (0)