Skip to content

Commit c51a7b9

Browse files
committed
Implement nested slicing matcher & enhance MatchFinder class
- nested slicing matcher - enhance MatchFinder class - rename getSlice static method to avoid collision with SliceAnalysis::getSlice
1 parent 410c5c9 commit c51a7b9

File tree

15 files changed

+208
-336
lines changed

15 files changed

+208
-336
lines changed

mlir/include/mlir/IR/Matchers.h

Lines changed: 14 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
#include "mlir/Interfaces/InferIntRangeInterface.h"
2222
#include "mlir/Query/Matcher/MatchersInternal.h"
2323
#include "mlir/Query/Query.h"
24-
#include "llvm/ADT/SetVector.h"
25-
2624
namespace mlir {
2725

2826
namespace detail {
@@ -366,21 +364,14 @@ struct RecursivePatternMatcher {
366364
std::tuple<OperandMatchers...> operandMatchers;
367365
};
368366

369-
/// Fills `backwardSlice` with the computed backward slice (i.e.
370-
/// all the transitive defs of op)
371-
///
372-
/// The implementation traverses the def chains in postorder traversal for
373-
/// efficiency reasons: if an operation is already in `backwardSlice`, no
374-
/// need to traverse its definitions again. Since use-def chains form a DAG,
375-
/// this terminates.
376-
///
377-
/// Upon return to the root call, `backwardSlice` is filled with a
378-
/// postorder list of defs. This happens to be a topological order, from the
379-
/// point of view of the use-def chains.
367+
/// A matcher encapsulating the initial `getBackwardSlice` method from
368+
/// SliceAnalysis.h
369+
/// Additionally, it limits the slice computation to a certain depth level using
370+
/// a custom filter
380371
///
381-
/// Example starting from node 8
372+
/// Example starting from node 9, assuming the matcher
373+
/// computes the slice for the first two depth levels
382374
/// ============================
383-
///
384375
/// 1 2 3 4
385376
/// |_______| |______|
386377
/// | | |
@@ -393,240 +384,52 @@ struct RecursivePatternMatcher {
393384
/// 9
394385
///
395386
/// Assuming all local orders match the numbering order:
396-
/// {1, 2, 5, 3, 4, 6}
397-
///
398-
387+
/// {5, 7, 6, 8, 9}
399388
class BackwardSliceMatcher {
400389
public:
401-
BackwardSliceMatcher(mlir::query::matcher::DynMatcher &&innerMatcher,
390+
BackwardSliceMatcher(query::matcher::DynMatcher &&innerMatcher,
402391
int64_t maxDepth)
403392
: innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
404-
405393
bool match(Operation *op, SetVector<Operation *> &backwardSlice,
406-
mlir::query::QueryOptions &options) {
394+
query::QueryOptions &options) {
407395

408396
if (innerMatcher.match(op) &&
409397
matches(op, backwardSlice, options, maxDepth)) {
410-
if (!options.inclusive) {
411-
// Don't insert the top level operation, we just queried on it and don't
412-
// want it in the results.
413-
backwardSlice.remove(op);
414-
}
415398
return true;
416399
}
417400
return false;
418401
}
419402

420403
private:
421-
bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
422-
mlir::query::QueryOptions &options, int64_t remainingDepth) {
423-
424-
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
425-
return false;
426-
}
427-
428-
auto processValue = [&](Value value) {
429-
// We need to check the current depth level;
430-
// if we have reached level 0, we stop further traversing
431-
if (remainingDepth == 0) {
432-
return;
433-
}
434-
if (auto *definingOp = value.getDefiningOp()) {
435-
// We omit traversing the same operations
436-
if (backwardSlice.count(definingOp) == 0)
437-
matches(definingOp, backwardSlice, options, remainingDepth - 1);
438-
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
439-
if (options.omitBlockArguments)
440-
return;
441-
Block *block = blockArg.getOwner();
442-
443-
Operation *parentOp = block->getParentOp();
444-
// TODO: determine whether we want to recurse backward into the other
445-
// blocks of parentOp, which are not technically backward unless they
446-
// flow into us. For now, just bail.
447-
if (parentOp && backwardSlice.count(parentOp) == 0) {
448-
if (parentOp->getNumRegions() != 1 &&
449-
parentOp->getRegion(0).getBlocks().size() != 1) {
450-
llvm::errs()
451-
<< "Error: Expected parentOp to have exactly one region and "
452-
<< "exactly one block, but found " << parentOp->getNumRegions()
453-
<< " regions and "
454-
<< (parentOp->getRegion(0).getBlocks().size()) << " blocks.\n";
455-
};
456-
matches(parentOp, backwardSlice, options, remainingDepth - 1);
457-
}
458-
} else {
459-
llvm_unreachable("No definingOp and not a block argument\n");
460-
return;
461-
}
462-
};
463-
464-
if (!options.omitUsesFromAbove) {
465-
llvm::for_each(op->getRegions(), [&](Region &region) {
466-
// Walk this region recursively to collect the regions that descend from
467-
// this op's nested regions (inclusive).
468-
SmallPtrSet<Region *, 4> descendents;
469-
region.walk(
470-
[&](Region *childRegion) { descendents.insert(childRegion); });
471-
region.walk([&](Operation *op) {
472-
for (OpOperand &operand : op->getOpOperands()) {
473-
if (!descendents.contains(operand.get().getParentRegion()))
474-
processValue(operand.get());
475-
}
476-
});
477-
});
478-
}
479-
480-
llvm::for_each(op->getOperands(), processValue);
481-
backwardSlice.insert(op);
482-
return true;
483-
}
404+
bool matches(Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
405+
query::QueryOptions &options, int64_t maxDepth);
484406

485407
private:
486408
// The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
487409
// to determine whether we want to traverse the DAG or not. For example, we
488410
// want to explore the DAG only if the top-level operation name is
489411
// "arith.addf".
490-
mlir::query::matcher::DynMatcher innerMatcher;
491-
412+
query::matcher::DynMatcher innerMatcher;
492413
// maxDepth specifies the maximum depth that the matcher can traverse in the
493414
// DAG. For example, if maxDepth is 2, the matcher will explore the defining
494415
// operations of the top-level op up to 2 levels.
495416
int64_t maxDepth;
496417
};
497-
498-
/// Fills `forwardSlice` with the computed forward slice (i.e. all
499-
/// the transitive uses of op)
500-
///
501-
///
502-
/// The implementation traverses the use chains in postorder traversal for
503-
/// efficiency reasons: if an operation is already in `forwardSlice`, no
504-
/// need to traverse its uses again. Since use-def chains form a DAG, this
505-
/// terminates.
506-
///
507-
/// Upon return to the root call, `forwardSlice` is filled with a
508-
/// postorder list of uses (i.e. a reverse topological order). To get a proper
509-
/// topological order, we just reverse the order in `forwardSlice` before
510-
/// returning.
511-
///
512-
/// Example starting from node 0
513-
/// ============================
514-
///
515-
/// 0
516-
/// ___________|___________
517-
/// 1 2 3 4
518-
/// |_______| |______|
519-
/// | | |
520-
/// | 5 6
521-
/// |___|_____________|
522-
/// | |
523-
/// 7 8
524-
/// |_______________|
525-
/// |
526-
/// 9
527-
///
528-
/// Assuming all local orders match the numbering order:
529-
/// 1. after getting back to the root getForwardSlice, `forwardSlice` may
530-
/// contain:
531-
/// {9, 7, 8, 5, 1, 2, 6, 3, 4}
532-
/// 2. reversing the result of 1. gives:
533-
/// {4, 3, 6, 2, 1, 5, 8, 7, 9}
534-
///
535-
class ForwardSliceMatcher {
536-
public:
537-
ForwardSliceMatcher(mlir::query::matcher::DynMatcher &&innerMatcher,
538-
int64_t maxDepth)
539-
: innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
540-
541-
bool match(Operation *op, SetVector<Operation *> &forwardSlice,
542-
mlir::query::QueryOptions &options) {
543-
if (innerMatcher.match(op) &&
544-
matches(op, forwardSlice, options, maxDepth)) {
545-
if (!options.inclusive) {
546-
// Don't insert the top level operation, we just queried on it and don't
547-
// want it in the results.
548-
forwardSlice.remove(op);
549-
}
550-
// Reverse to get back the actual topological order.
551-
// std::reverse does not work out of the box on SetVector and I want an
552-
// in-place swap based thing (the real std::reverse, not the LLVM
553-
// adapter).
554-
SmallVector<Operation *, 0> v(forwardSlice.takeVector());
555-
forwardSlice.insert(v.rbegin(), v.rend());
556-
return true;
557-
}
558-
return false;
559-
}
560-
561-
private:
562-
bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
563-
mlir::query::QueryOptions &options, int64_t remainingDepth) {
564-
565-
// We need to check the current depth level;
566-
// if we have reached level 0, we stop further traversing and insert
567-
// the last user in def-use chain
568-
if (remainingDepth == 0) {
569-
forwardSlice.insert(op);
570-
return true;
571-
}
572-
573-
for (Region &region : op->getRegions())
574-
for (Block &block : region)
575-
for (Operation &blockOp : block)
576-
if (forwardSlice.count(&blockOp) == 0)
577-
matches(&blockOp, forwardSlice, options, remainingDepth - 1);
578-
for (Value result : op->getResults()) {
579-
for (Operation *userOp : result.getUsers())
580-
// We omit traversing the same operations
581-
if (forwardSlice.count(userOp) == 0)
582-
matches(userOp, forwardSlice, options, remainingDepth - 1);
583-
}
584-
585-
forwardSlice.insert(op);
586-
return true;
587-
}
588-
589-
private:
590-
// The outer matcher e.g (ForwardSliceMatcher) relies on the innerMatcher to
591-
// determine whether we want to traverse the graph or not. E.g: we want to
592-
// explore the DAG only if the top level operation name is "arith.addf"
593-
mlir::query::matcher::DynMatcher innerMatcher;
594-
595-
// maxDepth specifies the maximum depth that the matcher can traverse the
596-
// graph E.g: if maxDepth is 2, the matcher will explore the user
597-
// operations of the top level op up to 2 levels
598-
int64_t maxDepth;
599-
};
600-
601418
} // namespace detail
602419

603420
// Matches transitive defs of a top level operation up to 1 level
604421
inline detail::BackwardSliceMatcher
605-
m_DefinedBy(mlir::query::matcher::DynMatcher innerMatcher) {
422+
m_DefinedBy(query::matcher::DynMatcher innerMatcher) {
606423
return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
607424
}
608425

609426
// Matches transitive defs of a top level operation up to N levels
610427
inline detail::BackwardSliceMatcher
611-
m_GetDefinitions(mlir::query::matcher::DynMatcher innerMatcher,
612-
int64_t maxDepth) {
428+
m_GetDefinitions(query::matcher::DynMatcher innerMatcher, int64_t maxDepth) {
613429
assert(maxDepth >= 0 && "maxDepth must be non-negative");
614430
return detail::BackwardSliceMatcher(std::move(innerMatcher), maxDepth);
615431
}
616432

617-
// Matches uses of a top level operation up to 1 level
618-
inline detail::ForwardSliceMatcher
619-
m_UsedBy(mlir::query::matcher::DynMatcher innerMatcher) {
620-
return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
621-
}
622-
623-
// Matches uses of a top level operation up to N levels
624-
inline detail::ForwardSliceMatcher
625-
m_GetUses(mlir::query::matcher::DynMatcher innerMatcher, int64_t maxDepth) {
626-
assert(maxDepth >= 0 && "maxDepth must be non-negative");
627-
return detail::ForwardSliceMatcher(std::move(innerMatcher), maxDepth);
628-
}
629-
630433
/// Matches a constant foldable operation.
631434
inline detail::constant_op_matcher m_Constant() {
632435
return detail::constant_op_matcher();

mlir/include/mlir/Query/Matcher/MatchFinder.h

Lines changed: 26 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -15,86 +15,41 @@
1515
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
1616

1717
#include "MatchersInternal.h"
18+
#include "mlir/Query/Query.h"
1819
#include "mlir/Query/QuerySession.h"
19-
#include "llvm/ADT/SetVector.h"
20-
#include "llvm/Support/SourceMgr.h"
21-
#include "llvm/Support/raw_ostream.h"
2220

2321
namespace mlir::query::matcher {
2422

23+
/// A class that provides utilities to find operations in a DAG
2524
class MatchFinder {
2625

2726
public:
28-
//
29-
// getMatches walks the IR and prints operations as soon as it matches them
30-
// if a matcher is to be further extracted into the function, then it does not
31-
// print operations
32-
//
33-
static std::vector<Operation *>
34-
getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
35-
llvm::raw_ostream &os, QuerySession &qs) {
36-
int matchCount = 0;
37-
bool printMatchingOps = true;
38-
// If matcher is to be extracted to a function, we don't want to print
39-
// matching ops to sdout
40-
if (matcher.hasFunctionName()) {
41-
printMatchingOps = false;
42-
}
43-
std::vector<Operation *> matchedOps;
44-
SetVector<Operation *> tempStorage;
45-
os << "\n";
46-
root->walk([&](Operation *subOp) {
47-
if (matcher.match(subOp)) {
48-
matchedOps.push_back(subOp);
49-
if (printMatchingOps) {
50-
os << "Match #" << ++matchCount << ":\n\n";
51-
printMatch(os, qs, subOp, "root");
52-
}
53-
} else {
54-
SmallVector<Operation *> printingOps;
55-
if (matcher.match(subOp, tempStorage, options)) {
56-
if (printMatchingOps) {
57-
os << "Match #" << ++matchCount << ":\n\n";
58-
}
59-
SmallVector<Operation *> printingOps(tempStorage.takeVector());
60-
for (auto op : printingOps) {
61-
if (printMatchingOps) {
62-
printMatch(os, qs, op, "root");
63-
}
64-
matchedOps.push_back(op);
65-
}
66-
printingOps.clear();
67-
}
68-
}
69-
});
70-
if (printMatchingOps) {
71-
os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
72-
}
73-
return matchedOps;
74-
}
75-
76-
private:
77-
// Overloaded version that doesn't print the binding
78-
static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
79-
mlir::Operation *op) {
80-
auto fileLoc = op->getLoc()->dyn_cast<FileLineColLoc>();
81-
SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn(
82-
qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
27+
/// A subclass which preserves the matching information
28+
struct MatchResult {
29+
MatchResult() = default;
30+
MatchResult(Operation *rootOp, std::vector<Operation *> matchedOps);
8331

84-
llvm::SMDiagnostic diag =
85-
qs.getSourceManager().GetMessage(smloc, llvm::SourceMgr::DK_Note,
32+
/// Contains the root operation of the matching environment
33+
Operation *rootOp = nullptr;
8634

87-
"");
88-
diag.print("", os, true, false, true);
89-
}
90-
static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
91-
mlir::Operation *op, const std::string &binding) {
92-
auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
93-
auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
94-
qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
95-
qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
96-
"\"" + binding + "\" binds here");
97-
}
35+
/// Contains the matching enviroment. This allows the user to easily extract
36+
/// the matched operations
37+
std::vector<Operation *> matchedOps;
38+
};
39+
/// Traverses the DAG and collects the "rootOp" + "matching enviroment" for a
40+
/// given Matcher
41+
std::vector<MatchResult>
42+
collectMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
43+
llvm::raw_ostream &os, QuerySession &qs) const;
44+
/// Prints the matched operation
45+
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const;
46+
/// Labels the matched operation with the given binding (e.g., "root") and
47+
/// prints it
48+
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
49+
const std::string &binding) const;
50+
/// Flattens a vector of MatchResults into a vector of operations
51+
std::vector<Operation *>
52+
flattenMatchedOps(std::vector<MatchResult> &matches) const;
9853
};
9954

10055
} // namespace mlir::query::matcher

0 commit comments

Comments
 (0)