6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
- // This file provides matchers that depend on Query.
9
+ // This file provides matchers for MLIRQuery with more involved pattern-matching
10
+ // logic.
10
11
//
11
12
// ===----------------------------------------------------------------------===//
12
13
@@ -80,46 +81,42 @@ bool BackwardSliceMatcher<Matcher>::matches(
80
81
BackwardSliceOptions &options, int64_t maxDepth) {
81
82
backwardSlice.clear ();
82
83
llvm::DenseMap<Operation *, int64_t > opDepths;
83
- // The starting point is the root op; therefore, we set its depth to 0.
84
+ // Initializing the root op with a depth of 0
84
85
opDepths[rootOp] = 0 ;
85
86
options.filter = [&](Operation *subOp) {
86
- // If the subOp's depth exceeds maxDepth, we stop further slicing for this
87
- // branch .
88
- if (opDepths[ subOp] > maxDepth )
87
+ // If the subOp hasn't been recorded in opDepths, it is deeper than
88
+ // maxDepth .
89
+ if (! opDepths. contains ( subOp) )
89
90
return false ;
90
91
// Examine subOp's operands to compute depths of their defining operations.
91
92
for (auto operand : subOp->getOperands ()) {
93
+ int64_t newDepth = opDepths[subOp] + 1 ;
94
+ // If the newDepth is greater than maxDepth, further computation can be
95
+ // skipped.
96
+ if (newDepth > maxDepth)
97
+ continue ;
98
+
92
99
if (auto definingOp = operand.getDefiningOp ()) {
93
- // Set the defining operation's depth to one level greater than
94
- // subOp's depth.
95
- int64_t newDepth = opDepths[subOp] + 1 ;
96
- if (!opDepths.contains (definingOp)) {
100
+ // Registers the minimum depth
101
+ if (!opDepths.contains (definingOp) || newDepth < opDepths[definingOp])
97
102
opDepths[definingOp] = newDepth;
98
- } else {
99
- opDepths[definingOp] = std::min (opDepths[definingOp], newDepth);
100
- }
101
- return !(opDepths[subOp] > maxDepth);
102
103
} else {
103
104
auto blockArgument = cast<BlockArgument>(operand);
104
105
Operation *parentOp = blockArgument.getOwner ()->getParentOp ();
105
106
if (!parentOp)
106
107
continue ;
107
- int64_t newDepth = opDepths[subOp] + 1 ;
108
- if (!opDepths.contains (parentOp)) {
108
+
109
+ if (!opDepths.contains (parentOp) || newDepth < opDepths[parentOp])
109
110
opDepths[parentOp] = newDepth;
110
- } else {
111
- opDepths[parentOp] = std::min (opDepths[parentOp], newDepth);
112
- }
113
- return !(opDepths[parentOp] > maxDepth);
114
111
}
115
112
}
116
113
return true ;
117
114
};
118
115
getBackwardSlice (rootOp, &backwardSlice, options);
119
- return true ;
116
+ return backwardSlice. size () >= 1 ;
120
117
}
121
118
122
- // Matches transitive defs of a top-level operation up to N levels.
119
+ // / Matches transitive defs of a top-level operation up to N levels.
123
120
template <typename Matcher>
124
121
inline BackwardSliceMatcher<Matcher>
125
122
m_GetDefinitions (Matcher innerMatcher, int64_t maxDepth, bool inclusive,
@@ -130,6 +127,15 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
130
127
omitUsesFromAbove);
131
128
}
132
129
130
+ // / Matches all transitive defs of a top-level operation up to N levels
131
+ template <typename Matcher>
132
+ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions (Matcher innerMatcher,
133
+ int64_t maxDepth) {
134
+ assert (maxDepth >= 0 && " maxDepth must be non-negative" );
135
+ return BackwardSliceMatcher<Matcher>(std::move (innerMatcher), maxDepth, true ,
136
+ false , false );
137
+ }
138
+
133
139
} // namespace mlir::query::matcher
134
140
135
141
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
0 commit comments