12
12
13
13
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
14
14
#define MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
15
+
15
16
#include " mlir/Analysis/SliceAnalysis.h"
16
- #include " mlir/Query/Matcher/MatchersInternal.h"
17
17
18
18
// / A matcher encapsulating the initial `getBackwardSlice` method from
19
- // / SliceAnalysis.h
19
+ // / SliceAnalysis.h.
20
20
// / Additionally, it limits the slice computation to a certain depth level using
21
- // / a custom filter
21
+ // / a custom filter.
22
22
// /
23
- // / Example starting from node 9, assuming the matcher
24
- // / computes the slice for the first two depth levels
23
+ // / Example: starting from node 9, assuming the matcher
24
+ // / computes the slice for the first two depth levels:
25
25
// / ============================
26
26
// / 1 2 3 4
27
27
// / |_______| |______|
37
37
// / Assuming all local orders match the numbering order:
38
38
// / {5, 7, 6, 8, 9}
39
39
namespace mlir ::query::matcher {
40
+
41
+ template <typename Matcher>
40
42
class BackwardSliceMatcher {
41
43
public:
42
- explicit BackwardSliceMatcher (query::matcher::DynMatcher && innerMatcher,
43
- int64_t maxDepth , bool inclusive ,
44
- bool omitBlockArguments, bool omitUsesFromAbove)
44
+ explicit BackwardSliceMatcher (Matcher innerMatcher, int64_t maxDepth ,
45
+ bool inclusive , bool omitBlockArguments ,
46
+ bool omitUsesFromAbove)
45
47
: innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth),
46
48
inclusive(inclusive), omitBlockArguments(omitBlockArguments),
47
49
omitUsesFromAbove(omitUsesFromAbove) {}
50
+
48
51
bool match (Operation *op, SetVector<Operation *> &backwardSlice) {
49
52
BackwardSliceOptions options;
53
+ options.inclusive = inclusive;
54
+ options.omitUsesFromAbove = omitUsesFromAbove;
55
+ options.omitBlockArguments = omitBlockArguments;
50
56
return (innerMatcher.match (op) &&
51
57
matches (op, backwardSlice, options, maxDepth));
52
58
}
@@ -59,27 +65,73 @@ class BackwardSliceMatcher {
59
65
// The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
60
66
// to determine whether we want to traverse the DAG or not. For example, we
61
67
// want to explore the DAG only if the top-level operation name is
62
- // "arith.addf".
63
- query::matcher::DynMatcher innerMatcher;
64
- // maxDepth specifies the maximum depth that the matcher can traverse in the
65
- // DAG. For example, if maxDepth is 2, the matcher will explore the defining
68
+ // ` "arith.addf"` .
69
+ Matcher innerMatcher;
70
+ // ` maxDepth` specifies the maximum depth that the matcher can traverse in the
71
+ // DAG. For example, if ` maxDepth` is 2, the matcher will explore the defining
66
72
// operations of the top-level op up to 2 levels.
67
73
int64_t maxDepth;
68
-
69
74
bool inclusive;
70
75
bool omitBlockArguments;
71
76
bool omitUsesFromAbove;
72
77
};
73
78
74
- // Matches transitive defs of a top level operation up to N levels
75
- inline BackwardSliceMatcher
76
- m_GetDefinitions (query::matcher::DynMatcher innerMatcher, int64_t maxDepth,
77
- bool inclusive, bool omitBlockArguments,
78
- bool omitUsesFromAbove) {
79
+ template <typename Matcher>
80
+ bool BackwardSliceMatcher<Matcher>::matches(
81
+ Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
82
+ BackwardSliceOptions &options, int64_t maxDepth) {
83
+ backwardSlice.clear ();
84
+ llvm::DenseMap<Operation *, int64_t > opDepths;
85
+ // The starting point is the root op; therefore, we set its depth to 0.
86
+ opDepths[rootOp] = 0 ;
87
+ options.filter = [&](Operation *subOp) {
88
+ // If the subOp's depth exceeds maxDepth, we stop further slicing for this
89
+ // branch.
90
+ if (opDepths[subOp] > maxDepth)
91
+ return false ;
92
+ // Examine subOp's operands to compute depths of their defining operations.
93
+ for (auto operand : subOp->getOperands ()) {
94
+ if (auto definingOp = operand.getDefiningOp ()) {
95
+ // Set the defining operation's depth to one level greater than
96
+ // subOp's depth.
97
+ int64_t newDepth = opDepths[subOp] + 1 ;
98
+ if (!opDepths.contains (definingOp)) {
99
+ opDepths[definingOp] = newDepth;
100
+ } else {
101
+ opDepths[definingOp] = std::min (opDepths[definingOp], newDepth);
102
+ }
103
+ return !(opDepths[subOp] > maxDepth);
104
+ } else {
105
+ auto blockArgument = cast<BlockArgument>(operand);
106
+ Operation *parentOp = blockArgument.getOwner ()->getParentOp ();
107
+ if (!parentOp)
108
+ continue ;
109
+ int64_t newDepth = opDepths[subOp] + 1 ;
110
+ if (!opDepths.contains (parentOp)) {
111
+ opDepths[parentOp] = newDepth;
112
+ } else {
113
+ opDepths[parentOp] = std::min (opDepths[parentOp], newDepth);
114
+ }
115
+ return !(opDepths[parentOp] > maxDepth);
116
+ }
117
+ }
118
+ return true ;
119
+ };
120
+ getBackwardSlice (rootOp, &backwardSlice, options);
121
+ return true ;
122
+ }
123
+
124
+ // Matches transitive defs of a top-level operation up to N levels.
125
+ template <typename Matcher>
126
+ inline BackwardSliceMatcher<Matcher>
127
+ m_GetDefinitions (Matcher innerMatcher, int64_t maxDepth, bool inclusive,
128
+ bool omitBlockArguments, bool omitUsesFromAbove) {
79
129
assert (maxDepth >= 0 && " maxDepth must be non-negative" );
80
- return BackwardSliceMatcher (std::move (innerMatcher), maxDepth, inclusive,
81
- omitBlockArguments, omitUsesFromAbove);
130
+ return BackwardSliceMatcher<Matcher>(std::move (innerMatcher), maxDepth,
131
+ inclusive, omitBlockArguments,
132
+ omitUsesFromAbove);
82
133
}
134
+
83
135
} // namespace mlir::query::matcher
84
136
85
137
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
0 commit comments