Skip to content

Commit 69472f0

Browse files
committed
Update grammar
Make BackwardSlice matcher more generic Capture values in tests
1 parent e07e1fe commit 69472f0

File tree

8 files changed

+95
-102
lines changed

8 files changed

+95
-102
lines changed

mlir/include/mlir/Query/Matcher/ExtraMatchers.h

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212

1313
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
1414
#define MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
15+
1516
#include "mlir/Analysis/SliceAnalysis.h"
16-
#include "mlir/Query/Matcher/MatchersInternal.h"
1717

1818
/// A matcher encapsulating the initial `getBackwardSlice` method from
19-
/// SliceAnalysis.h
19+
/// SliceAnalysis.h.
2020
/// Additionally, it limits the slice computation to a certain depth level using
21-
/// a custom filter
21+
/// a custom filter.
2222
///
23-
/// Example starting from node 9, assuming the matcher
24-
/// computes the slice for the first two depth levels
23+
/// Example: starting from node 9, assuming the matcher
24+
/// computes the slice for the first two depth levels:
2525
/// ============================
2626
/// 1 2 3 4
2727
/// |_______| |______|
@@ -37,16 +37,22 @@
3737
/// Assuming all local orders match the numbering order:
3838
/// {5, 7, 6, 8, 9}
3939
namespace mlir::query::matcher {
40+
41+
template <typename Matcher>
4042
class BackwardSliceMatcher {
4143
public:
42-
explicit BackwardSliceMatcher(query::matcher::DynMatcher &&innerMatcher,
43-
int64_t maxDepth, bool inclusive,
44-
bool omitBlockArguments, bool omitUsesFromAbove)
44+
explicit BackwardSliceMatcher(Matcher innerMatcher, int64_t maxDepth,
45+
bool inclusive, bool omitBlockArguments,
46+
bool omitUsesFromAbove)
4547
: innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth),
4648
inclusive(inclusive), omitBlockArguments(omitBlockArguments),
4749
omitUsesFromAbove(omitUsesFromAbove) {}
50+
4851
bool match(Operation *op, SetVector<Operation *> &backwardSlice) {
4952
BackwardSliceOptions options;
53+
options.inclusive = inclusive;
54+
options.omitUsesFromAbove = omitUsesFromAbove;
55+
options.omitBlockArguments = omitBlockArguments;
5056
return (innerMatcher.match(op) &&
5157
matches(op, backwardSlice, options, maxDepth));
5258
}
@@ -59,27 +65,73 @@ class BackwardSliceMatcher {
5965
// The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
6066
// to determine whether we want to traverse the DAG or not. For example, we
6167
// want to explore the DAG only if the top-level operation name is
62-
// "arith.addf".
63-
query::matcher::DynMatcher innerMatcher;
64-
// maxDepth specifies the maximum depth that the matcher can traverse in the
65-
// DAG. For example, if maxDepth is 2, the matcher will explore the defining
68+
// `"arith.addf"`.
69+
Matcher innerMatcher;
70+
// `maxDepth` specifies the maximum depth that the matcher can traverse in the
71+
// DAG. For example, if `maxDepth` is 2, the matcher will explore the defining
6672
// operations of the top-level op up to 2 levels.
6773
int64_t maxDepth;
68-
6974
bool inclusive;
7075
bool omitBlockArguments;
7176
bool omitUsesFromAbove;
7277
};
7378

74-
// Matches transitive defs of a top level operation up to N levels
75-
inline BackwardSliceMatcher
76-
m_GetDefinitions(query::matcher::DynMatcher innerMatcher, int64_t maxDepth,
77-
bool inclusive, bool omitBlockArguments,
78-
bool omitUsesFromAbove) {
79+
template <typename Matcher>
80+
bool BackwardSliceMatcher<Matcher>::matches(
81+
Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
82+
BackwardSliceOptions &options, int64_t maxDepth) {
83+
backwardSlice.clear();
84+
llvm::DenseMap<Operation *, int64_t> opDepths;
85+
// The starting point is the root op; therefore, we set its depth to 0.
86+
opDepths[rootOp] = 0;
87+
options.filter = [&](Operation *subOp) {
88+
// If the subOp's depth exceeds maxDepth, we stop further slicing for this
89+
// branch.
90+
if (opDepths[subOp] > maxDepth)
91+
return false;
92+
// Examine subOp's operands to compute depths of their defining operations.
93+
for (auto operand : subOp->getOperands()) {
94+
if (auto definingOp = operand.getDefiningOp()) {
95+
// Set the defining operation's depth to one level greater than
96+
// subOp's depth.
97+
int64_t newDepth = opDepths[subOp] + 1;
98+
if (!opDepths.contains(definingOp)) {
99+
opDepths[definingOp] = newDepth;
100+
} else {
101+
opDepths[definingOp] = std::min(opDepths[definingOp], newDepth);
102+
}
103+
return !(opDepths[subOp] > maxDepth);
104+
} else {
105+
auto blockArgument = cast<BlockArgument>(operand);
106+
Operation *parentOp = blockArgument.getOwner()->getParentOp();
107+
if (!parentOp)
108+
continue;
109+
int64_t newDepth = opDepths[subOp] + 1;
110+
if (!opDepths.contains(parentOp)) {
111+
opDepths[parentOp] = newDepth;
112+
} else {
113+
opDepths[parentOp] = std::min(opDepths[parentOp], newDepth);
114+
}
115+
return !(opDepths[parentOp] > maxDepth);
116+
}
117+
}
118+
return true;
119+
};
120+
getBackwardSlice(rootOp, &backwardSlice, options);
121+
return true;
122+
}
123+
124+
// Matches transitive defs of a top-level operation up to N levels.
125+
template <typename Matcher>
126+
inline BackwardSliceMatcher<Matcher>
127+
m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
128+
bool omitBlockArguments, bool omitUsesFromAbove) {
79129
assert(maxDepth >= 0 && "maxDepth must be non-negative");
80-
return BackwardSliceMatcher(std::move(innerMatcher), maxDepth, inclusive,
81-
omitBlockArguments, omitUsesFromAbove);
130+
return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth,
131+
inclusive, omitBlockArguments,
132+
omitUsesFromAbove);
82133
}
134+
83135
} // namespace mlir::query::matcher
84136

85137
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H

mlir/include/mlir/Query/Matcher/MatchFinder.h

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,35 @@
2121

2222
namespace mlir::query::matcher {
2323

24-
/// A class that provides utilities to find operations in a DAG
24+
/// A class that provides utilities to find operations in the IR.
2525
class MatchFinder {
2626

2727
public:
28-
/// A subclass which preserves the matching information
28+
/// A subclass which preserves the matching information. Each instance
29+
/// contains the `rootOp` along with the matching environment.
2930
struct MatchResult {
3031
MatchResult() = default;
3132
MatchResult(Operation *rootOp, std::vector<Operation *> matchedOps);
3233

33-
/// Contains the root operation of the matching environment
3434
Operation *rootOp = nullptr;
35-
/// Contains the matching enviroment. This allows the user to easily
36-
/// extract the matched operations
35+
/// Contains the matching environment.
3736
std::vector<Operation *> matchedOps;
3837
};
39-
/// Traverses the DAG and collects the "rootOp" + "matching enviroment" for
40-
/// a given Matcher
38+
39+
/// Traverses the IR and returns a vector of `MatchResult` for each match of
40+
/// the `matcher`.
4141
std::vector<MatchResult> collectMatches(Operation *root,
4242
DynMatcher matcher) const;
43-
/// Prints the matched operation
43+
44+
/// Prints the matched operation.
4445
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const;
45-
/// Labels the matched operation with the given binding (e.g., "root") and
46-
/// prints it
46+
47+
/// Labels the matched operation with the given binding (e.g., `"root"`) and
48+
/// prints it.
4749
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
4850
const std::string &binding) const;
49-
/// Flattens a vector of MatchResults into a vector of operations
51+
52+
/// Flattens a vector of `MatchResult` into a vector of operations.
5053
std::vector<Operation *>
5154
flattenMatchedOps(std::vector<MatchResult> &matches) const;
5255
};

mlir/lib/Query/Matcher/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
add_mlir_library(MLIRQueryMatcher
22
MatchFinder.cpp
3-
ExtraMatchers.cpp
43
Parser.cpp
54
RegistryManager.cpp
65
VariantValue.cpp

mlir/lib/Query/Matcher/ExtraMatchers.cpp

Lines changed: 0 additions & 66 deletions
This file was deleted.

mlir/lib/Query/Matcher/Parser.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
// provided to the parser.
1717
//
1818
// The grammar for the supported expressions is as follows:
19-
// <Expression> := <StringLiteral> | <MatcherExpression>
19+
// <Expression> := <Literal> | <MatcherExpression>
20+
// <Literal> := <StringLiteral> | <NumericLiteral> | <BooleanLiteral>
2021
// <StringLiteral> := "quoted string"
22+
// <BooleanLiteral> := "true" | "false"
23+
// <NumericLiteral> := [0-9]+
2124
// <MatcherExpression> := <MatcherName>(<ArgumentList>)
2225
// <MatcherName> := [a-zA-Z]+
2326
// <ArgumentList> := <Expression> | <Expression>,<ArgumentList>

mlir/lib/Query/QueryParser.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ QueryRef QueryParser::doParse() {
166166

167167
case ParsedQueryKind::Quit:
168168
return endQuery(new QuitQuery);
169+
169170
case ParsedQueryKind::Match: {
170171
if (completionPos) {
171172
return completeMatcherExpression();

mlir/test/mlir-query/complex-test.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>)
2626

2727
// CHECK: Match #2:
2828

29-
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %0 {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
29+
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
3030
// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
3131
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
3232
// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32

mlir/tools/mlir-query/mlir-query.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ int main(int argc, char **argv) {
4040
query::matcher::Registry matcherRegistry;
4141

4242
// Matchers registered in alphabetical order for consistency:
43-
matcherRegistry.registerMatcher("getDefinitions",
44-
query::matcher::m_GetDefinitions);
43+
matcherRegistry.registerMatcher(
44+
"getDefinitions",
45+
query::matcher::m_GetDefinitions<query::matcher::DynMatcher>);
4546
matcherRegistry.registerMatcher("hasOpAttrName",
4647
static_cast<HasOpAttrName *>(m_Attr));
4748
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));

0 commit comments

Comments
 (0)