12
12
13
13
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
14
14
#define MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
15
+
15
16
#include " mlir/Analysis/SliceAnalysis.h"
16
- #include " mlir/Query/Matcher/MatchersInternal.h"
17
17
18
- // / A matcher encapsulating the initial `getBackwardSlice` method from
19
- // / SliceAnalysis.h
18
+ // / A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
20
19
// / Additionally, it limits the slice computation to a certain depth level using
21
- // / a custom filter
20
+ // / a custom filter.
22
21
// /
23
- // / Example starting from node 9, assuming the matcher
24
- // / computes the slice for the first two depth levels
22
+ // / Example: starting from node 9, assuming the matcher
23
+ // / computes the slice for the first two depth levels:
25
24
// / ============================
26
25
// / 1 2 3 4
27
26
// / |_______| |______|
37
36
// / Assuming all local orders match the numbering order:
38
37
// / {5, 7, 6, 8, 9}
39
38
namespace mlir ::query::matcher {
39
+
40
+ template <typename Matcher>
40
41
class BackwardSliceMatcher {
41
42
public:
42
- explicit BackwardSliceMatcher (query::matcher::DynMatcher &&innerMatcher,
43
- int64_t maxDepth, bool inclusive,
44
- bool omitBlockArguments, bool omitUsesFromAbove)
43
+ BackwardSliceMatcher (Matcher innerMatcher, int64_t maxDepth, bool inclusive,
44
+ bool omitBlockArguments, bool omitUsesFromAbove)
45
45
: innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth),
46
46
inclusive (inclusive), omitBlockArguments(omitBlockArguments),
47
47
omitUsesFromAbove(omitUsesFromAbove) {}
48
- bool match (Operation *op, SetVector<Operation *> &backwardSlice) {
48
+
49
+ bool match (Operation *rootOp, SetVector<Operation *> &backwardSlice) {
49
50
BackwardSliceOptions options;
50
- return (innerMatcher.match (op) &&
51
- matches (op, backwardSlice, options, maxDepth));
51
+ options.inclusive = inclusive;
52
+ options.omitUsesFromAbove = omitUsesFromAbove;
53
+ options.omitBlockArguments = omitBlockArguments;
54
+ return (innerMatcher.match (rootOp) &&
55
+ matches (rootOp, backwardSlice, options, maxDepth));
52
56
}
53
57
54
58
private:
@@ -57,29 +61,75 @@ class BackwardSliceMatcher {
57
61
58
62
private:
59
63
// The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
60
- // to determine whether we want to traverse the DAG or not. For example, we
61
- // want to explore the DAG only if the top-level operation name is
62
- // "arith.addf".
63
- query::matcher::DynMatcher innerMatcher;
64
- // maxDepth specifies the maximum depth that the matcher can traverse in the
65
- // DAG . For example, if maxDepth is 2, the matcher will explore the defining
64
+ // to determine whether we want to traverse the IR or not. For example, we
65
+ // want to explore the IR only if the top-level operation name is
66
+ // ` "arith.addf"` .
67
+ Matcher innerMatcher;
68
+ // ` maxDepth` specifies the maximum depth that the matcher can traverse the
69
+ // IR . For example, if ` maxDepth` is 2, the matcher will explore the defining
66
70
// operations of the top-level op up to 2 levels.
67
71
int64_t maxDepth;
68
-
69
72
bool inclusive;
70
73
bool omitBlockArguments;
71
74
bool omitUsesFromAbove;
72
75
};
73
76
74
- // Matches transitive defs of a top level operation up to N levels
75
- inline BackwardSliceMatcher
76
- m_GetDefinitions (query::matcher::DynMatcher innerMatcher, int64_t maxDepth,
77
- bool inclusive, bool omitBlockArguments,
78
- bool omitUsesFromAbove) {
77
+ template <typename Matcher>
78
+ bool BackwardSliceMatcher<Matcher>::matches(
79
+ Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
80
+ BackwardSliceOptions &options, int64_t maxDepth) {
81
+ backwardSlice.clear ();
82
+ llvm::DenseMap<Operation *, int64_t > opDepths;
83
+ // The starting point is the root op; therefore, we set its depth to 0.
84
+ opDepths[rootOp] = 0 ;
85
+ options.filter = [&](Operation *subOp) {
86
+ // If the subOp's depth exceeds maxDepth, we stop further slicing for this
87
+ // branch.
88
+ if (opDepths[subOp] > maxDepth)
89
+ return false ;
90
+ // Examine subOp's operands to compute depths of their defining operations.
91
+ for (auto operand : subOp->getOperands ()) {
92
+ if (auto definingOp = operand.getDefiningOp ()) {
93
+ // Set the defining operation's depth to one level greater than
94
+ // subOp's depth.
95
+ int64_t newDepth = opDepths[subOp] + 1 ;
96
+ if (!opDepths.contains (definingOp)) {
97
+ opDepths[definingOp] = newDepth;
98
+ } else {
99
+ opDepths[definingOp] = std::min (opDepths[definingOp], newDepth);
100
+ }
101
+ return !(opDepths[subOp] > maxDepth);
102
+ } else {
103
+ auto blockArgument = cast<BlockArgument>(operand);
104
+ Operation *parentOp = blockArgument.getOwner ()->getParentOp ();
105
+ if (!parentOp)
106
+ continue ;
107
+ int64_t newDepth = opDepths[subOp] + 1 ;
108
+ if (!opDepths.contains (parentOp)) {
109
+ opDepths[parentOp] = newDepth;
110
+ } else {
111
+ opDepths[parentOp] = std::min (opDepths[parentOp], newDepth);
112
+ }
113
+ return !(opDepths[parentOp] > maxDepth);
114
+ }
115
+ }
116
+ return true ;
117
+ };
118
+ getBackwardSlice (rootOp, &backwardSlice, options);
119
+ return true ;
120
+ }
121
+
122
+ // Matches transitive defs of a top-level operation up to N levels.
123
+ template <typename Matcher>
124
+ inline BackwardSliceMatcher<Matcher>
125
+ m_GetDefinitions (Matcher innerMatcher, int64_t maxDepth, bool inclusive,
126
+ bool omitBlockArguments, bool omitUsesFromAbove) {
79
127
assert (maxDepth >= 0 && " maxDepth must be non-negative" );
80
- return BackwardSliceMatcher (std::move (innerMatcher), maxDepth, inclusive,
81
- omitBlockArguments, omitUsesFromAbove);
128
+ return BackwardSliceMatcher<Matcher>(std::move (innerMatcher), maxDepth,
129
+ inclusive, omitBlockArguments,
130
+ omitUsesFromAbove);
82
131
}
132
+
83
133
} // namespace mlir::query::matcher
84
134
85
135
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
0 commit comments