Skip to content

Commit d637929

Browse files
committed
Improve depth limiting approach
1 parent 5f940da commit d637929

File tree

3 files changed

+31
-22
lines changed

3 files changed

+31
-22
lines changed

mlir/include/mlir/Query/Matcher/ExtraMatchers.h

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file provides matchers that depend on Query.
9+
// This file provides matchers for MLIRQuery with more involved pattern-matching
10+
// logic.
1011
//
1112
//===----------------------------------------------------------------------===//
1213

@@ -80,46 +81,42 @@ bool BackwardSliceMatcher<Matcher>::matches(
8081
BackwardSliceOptions &options, int64_t maxDepth) {
8182
backwardSlice.clear();
8283
llvm::DenseMap<Operation *, int64_t> opDepths;
83-
// The starting point is the root op; therefore, we set its depth to 0.
84+
// Initializing the root op with a depth of 0
8485
opDepths[rootOp] = 0;
8586
options.filter = [&](Operation *subOp) {
86-
// If the subOp's depth exceeds maxDepth, we stop further slicing for this
87-
// branch.
88-
if (opDepths[subOp] > maxDepth)
87+
// If the subOp hasn't been recorded in opDepths, it is deeper than
88+
// maxDepth.
89+
if (!opDepths.contains(subOp))
8990
return false;
9091
// Examine subOp's operands to compute depths of their defining operations.
9192
for (auto operand : subOp->getOperands()) {
93+
int64_t newDepth = opDepths[subOp] + 1;
94+
// If the newDepth is greater than maxDepth, further computation can be
95+
// skipped.
96+
if (newDepth > maxDepth)
97+
continue;
98+
9299
if (auto definingOp = operand.getDefiningOp()) {
93-
// Set the defining operation's depth to one level greater than
94-
// subOp's depth.
95-
int64_t newDepth = opDepths[subOp] + 1;
96-
if (!opDepths.contains(definingOp)) {
100+
// Registers the minimum depth
101+
if (!opDepths.contains(definingOp) || newDepth < opDepths[definingOp])
97102
opDepths[definingOp] = newDepth;
98-
} else {
99-
opDepths[definingOp] = std::min(opDepths[definingOp], newDepth);
100-
}
101-
return !(opDepths[subOp] > maxDepth);
102103
} else {
103104
auto blockArgument = cast<BlockArgument>(operand);
104105
Operation *parentOp = blockArgument.getOwner()->getParentOp();
105106
if (!parentOp)
106107
continue;
107-
int64_t newDepth = opDepths[subOp] + 1;
108-
if (!opDepths.contains(parentOp)) {
108+
109+
if (!opDepths.contains(parentOp) || newDepth < opDepths[parentOp])
109110
opDepths[parentOp] = newDepth;
110-
} else {
111-
opDepths[parentOp] = std::min(opDepths[parentOp], newDepth);
112-
}
113-
return !(opDepths[parentOp] > maxDepth);
114111
}
115112
}
116113
return true;
117114
};
118115
getBackwardSlice(rootOp, &backwardSlice, options);
119-
return true;
116+
return backwardSlice.size() >= 1;
120117
}
121118

122-
// Matches transitive defs of a top-level operation up to N levels.
119+
/// Matches transitive defs of a top-level operation up to N levels.
123120
template <typename Matcher>
124121
inline BackwardSliceMatcher<Matcher>
125122
m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
@@ -130,6 +127,15 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
130127
omitUsesFromAbove);
131128
}
132129

130+
/// Matches all transitive defs of a top-level operation up to N levels
131+
template <typename Matcher>
132+
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
133+
int64_t maxDepth) {
134+
assert(maxDepth >= 0 && "maxDepth must be non-negative");
135+
return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth, true,
136+
false, false);
137+
}
138+
133139
} // namespace mlir::query::matcher
134140

135141
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H

mlir/test/mlir-query/complex-test.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-query %s -c "m getDefinitions(hasOpName(\"arith.addf\"),2,true,false,false)" | FileCheck %s
1+
// RUN: mlir-query %s -c "m getAllDefinitions(hasOpName(\"arith.addf\"),2)" | FileCheck %s
22

33
#map = affine_map<(d0, d1) -> (d0, d1)>
44
func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {

mlir/tools/mlir-query/mlir-query.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ int main(int argc, char **argv) {
4343
matcherRegistry.registerMatcher(
4444
"getDefinitions",
4545
query::matcher::m_GetDefinitions<query::matcher::DynMatcher>);
46+
matcherRegistry.registerMatcher(
47+
"getAllDefinitions",
48+
query::matcher::m_GetAllDefinitions<query::matcher::DynMatcher>);
4649
matcherRegistry.registerMatcher("hasOpAttrName",
4750
static_cast<HasOpAttrName *>(m_Attr));
4851
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));

0 commit comments

Comments
 (0)