Skip to content

Commit be1f9bf

Browse files
committed
Update grammar
Make BackwardSlice matcher more generic Capture values in tests
1 parent f1c1658 commit be1f9bf

File tree

8 files changed

+99
-108
lines changed

8 files changed

+99
-108
lines changed

mlir/include/mlir/Query/Matcher/ExtraMatchers.h

Lines changed: 76 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@
1212

1313
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
1414
#define MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
15+
1516
#include "mlir/Analysis/SliceAnalysis.h"
16-
#include "mlir/Query/Matcher/MatchersInternal.h"
1717

18-
/// A matcher encapsulating the initial `getBackwardSlice` method from
19-
/// SliceAnalysis.h
18+
/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
2019
/// Additionally, it limits the slice computation to a certain depth level using
21-
/// a custom filter
20+
/// a custom filter.
2221
///
23-
/// Example starting from node 9, assuming the matcher
24-
/// computes the slice for the first two depth levels
22+
/// Example: starting from node 9, assuming the matcher
23+
/// computes the slice for the first two depth levels:
2524
/// ============================
2625
/// 1 2 3 4
2726
/// |_______| |______|
@@ -37,18 +36,23 @@
3736
/// Assuming all local orders match the numbering order:
3837
/// {5, 7, 6, 8, 9}
3938
namespace mlir::query::matcher {
39+
40+
template <typename Matcher>
4041
class BackwardSliceMatcher {
4142
public:
42-
explicit BackwardSliceMatcher(query::matcher::DynMatcher &&innerMatcher,
43-
int64_t maxDepth, bool inclusive,
44-
bool omitBlockArguments, bool omitUsesFromAbove)
43+
BackwardSliceMatcher(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
44+
bool omitBlockArguments, bool omitUsesFromAbove)
4545
: innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth),
4646
inclusive(inclusive), omitBlockArguments(omitBlockArguments),
4747
omitUsesFromAbove(omitUsesFromAbove) {}
48-
bool match(Operation *op, SetVector<Operation *> &backwardSlice) {
48+
49+
bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
4950
BackwardSliceOptions options;
50-
return (innerMatcher.match(op) &&
51-
matches(op, backwardSlice, options, maxDepth));
51+
options.inclusive = inclusive;
52+
options.omitUsesFromAbove = omitUsesFromAbove;
53+
options.omitBlockArguments = omitBlockArguments;
54+
return (innerMatcher.match(rootOp) &&
55+
matches(rootOp, backwardSlice, options, maxDepth));
5256
}
5357

5458
private:
@@ -57,29 +61,75 @@ class BackwardSliceMatcher {
5761

5862
private:
5963
// The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
60-
// to determine whether we want to traverse the DAG or not. For example, we
61-
// want to explore the DAG only if the top-level operation name is
62-
// "arith.addf".
63-
query::matcher::DynMatcher innerMatcher;
64-
// maxDepth specifies the maximum depth that the matcher can traverse in the
65-
// DAG. For example, if maxDepth is 2, the matcher will explore the defining
64+
// to determine whether we want to traverse the IR or not. For example, we
65+
// want to explore the IR only if the top-level operation name is
66+
// `"arith.addf"`.
67+
Matcher innerMatcher;
68+
// `maxDepth` specifies the maximum depth that the matcher can traverse the
69+
// IR. For example, if `maxDepth` is 2, the matcher will explore the defining
6670
// operations of the top-level op up to 2 levels.
6771
int64_t maxDepth;
68-
6972
bool inclusive;
7073
bool omitBlockArguments;
7174
bool omitUsesFromAbove;
7275
};
7376

74-
// Matches transitive defs of a top level operation up to N levels
75-
inline BackwardSliceMatcher
76-
m_GetDefinitions(query::matcher::DynMatcher innerMatcher, int64_t maxDepth,
77-
bool inclusive, bool omitBlockArguments,
78-
bool omitUsesFromAbove) {
77+
template <typename Matcher>
78+
bool BackwardSliceMatcher<Matcher>::matches(
79+
Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
80+
BackwardSliceOptions &options, int64_t maxDepth) {
81+
backwardSlice.clear();
82+
llvm::DenseMap<Operation *, int64_t> opDepths;
83+
// The starting point is the root op; therefore, we set its depth to 0.
84+
opDepths[rootOp] = 0;
85+
options.filter = [&](Operation *subOp) {
86+
// If the subOp's depth exceeds maxDepth, we stop further slicing for this
87+
// branch.
88+
if (opDepths[subOp] > maxDepth)
89+
return false;
90+
// Examine subOp's operands to compute depths of their defining operations.
91+
for (auto operand : subOp->getOperands()) {
92+
if (auto definingOp = operand.getDefiningOp()) {
93+
// Set the defining operation's depth to one level greater than
94+
// subOp's depth.
95+
int64_t newDepth = opDepths[subOp] + 1;
96+
if (!opDepths.contains(definingOp)) {
97+
opDepths[definingOp] = newDepth;
98+
} else {
99+
opDepths[definingOp] = std::min(opDepths[definingOp], newDepth);
100+
}
101+
return !(opDepths[subOp] > maxDepth);
102+
} else {
103+
auto blockArgument = cast<BlockArgument>(operand);
104+
Operation *parentOp = blockArgument.getOwner()->getParentOp();
105+
if (!parentOp)
106+
continue;
107+
int64_t newDepth = opDepths[subOp] + 1;
108+
if (!opDepths.contains(parentOp)) {
109+
opDepths[parentOp] = newDepth;
110+
} else {
111+
opDepths[parentOp] = std::min(opDepths[parentOp], newDepth);
112+
}
113+
return !(opDepths[parentOp] > maxDepth);
114+
}
115+
}
116+
return true;
117+
};
118+
getBackwardSlice(rootOp, &backwardSlice, options);
119+
return true;
120+
}
121+
122+
// Matches transitive defs of a top-level operation up to N levels.
123+
template <typename Matcher>
124+
inline BackwardSliceMatcher<Matcher>
125+
m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
126+
bool omitBlockArguments, bool omitUsesFromAbove) {
79127
assert(maxDepth >= 0 && "maxDepth must be non-negative");
80-
return BackwardSliceMatcher(std::move(innerMatcher), maxDepth, inclusive,
81-
omitBlockArguments, omitUsesFromAbove);
128+
return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth,
129+
inclusive, omitBlockArguments,
130+
omitUsesFromAbove);
82131
}
132+
83133
} // namespace mlir::query::matcher
84134

85135
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H

mlir/include/mlir/Query/Matcher/MatchFinder.h

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,35 @@
2121

2222
namespace mlir::query::matcher {
2323

24-
/// A class that provides utilities to find operations in a DAG
24+
/// A class that provides utilities to find operations in the IR.
2525
class MatchFinder {
2626

2727
public:
28-
/// A subclass which preserves the matching information
28+
/// A subclass which preserves the matching information. Each instance
29+
/// contains the `rootOp` along with the matching environment.
2930
struct MatchResult {
3031
MatchResult() = default;
3132
MatchResult(Operation *rootOp, std::vector<Operation *> matchedOps);
3233

33-
/// Contains the root operation of the matching environment
3434
Operation *rootOp = nullptr;
35-
/// Contains the matching enviroment. This allows the user to easily
36-
/// extract the matched operations
35+
/// Contains the matching environment.
3736
std::vector<Operation *> matchedOps;
3837
};
39-
/// Traverses the DAG and collects the "rootOp" + "matching enviroment" for
40-
/// a given Matcher
38+
39+
/// Traverses the IR and returns a vector of `MatchResult` for each match of
40+
/// the `matcher`.
4141
std::vector<MatchResult> collectMatches(Operation *root,
4242
DynMatcher matcher) const;
43-
/// Prints the matched operation
43+
44+
/// Prints the matched operation.
4445
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const;
45-
/// Labels the matched operation with the given binding (e.g., "root") and
46-
/// prints it
46+
47+
/// Labels the matched operation with the given binding (e.g., `"root"`) and
48+
/// prints it.
4749
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
4850
const std::string &binding) const;
49-
/// Flattens a vector of MatchResults into a vector of operations
51+
52+
/// Flattens a vector of `MatchResult` into a vector of operations.
5053
std::vector<Operation *>
5154
flattenMatchedOps(std::vector<MatchResult> &matches) const;
5255
};

mlir/lib/Query/Matcher/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
add_mlir_library(MLIRQueryMatcher
22
MatchFinder.cpp
3-
ExtraMatchers.cpp
43
Parser.cpp
54
RegistryManager.cpp
65
VariantValue.cpp

mlir/lib/Query/Matcher/ExtraMatchers.cpp

Lines changed: 0 additions & 66 deletions
This file was deleted.

mlir/lib/Query/Matcher/Parser.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
// provided to the parser.
1717
//
1818
// The grammar for the supported expressions is as follows:
19-
// <Expression> := <StringLiteral> | <MatcherExpression>
19+
// <Expression> := <Literal> | <MatcherExpression>
20+
// <Literal> := <StringLiteral> | <NumericLiteral> | <BooleanLiteral>
2021
// <StringLiteral> := "quoted string"
22+
// <BooleanLiteral> := "true" | "false"
23+
// <NumericLiteral> := [0-9]+
2124
// <MatcherExpression> := <MatcherName>(<ArgumentList>)
2225
// <MatcherName> := [a-zA-Z]+
2326
// <ArgumentList> := <Expression> | <Expression>,<ArgumentList>

mlir/lib/Query/QueryParser.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ QueryRef QueryParser::doParse() {
166166

167167
case ParsedQueryKind::Quit:
168168
return endQuery(new QuitQuery);
169+
169170
case ParsedQueryKind::Match: {
170171
if (completionPos) {
171172
return completeMatcherExpression();

mlir/test/mlir-query/complex-test.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>)
2626

2727
// CHECK: Match #2:
2828

29-
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %0 {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
29+
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
3030
// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
3131
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
3232
// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32

mlir/tools/mlir-query/mlir-query.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ int main(int argc, char **argv) {
4040
query::matcher::Registry matcherRegistry;
4141

4242
// Matchers registered in alphabetical order for consistency:
43-
matcherRegistry.registerMatcher("getDefinitions",
44-
query::matcher::m_GetDefinitions);
43+
matcherRegistry.registerMatcher(
44+
"getDefinitions",
45+
query::matcher::m_GetDefinitions<query::matcher::DynMatcher>);
4546
matcherRegistry.registerMatcher("hasOpAttrName",
4647
static_cast<HasOpAttrName *>(m_Attr));
4748
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));

0 commit comments

Comments
 (0)