Skip to content

Commit cad295c

Browse files
committed
Compute shortest depth in backwardSlice method
Relocate backwardSlice matcher to Query specific headers Remove unncecessary code
1 parent 374a5be commit cad295c

15 files changed

+493
-58
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
//===- ExtraMatchers.h - Various common matchers --------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file provides matchers that depend on Query.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
14+
#define MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
15+
#include "mlir/Analysis/SliceAnalysis.h"
16+
#include "mlir/Query/Matcher/MatchersInternal.h"
17+
18+
/// A matcher encapsulating the initial `getBackwardSlice` method from
19+
/// SliceAnalysis.h
20+
/// Additionally, it limits the slice computation to a certain depth level using
21+
/// a custom filter
22+
///
23+
/// Example starting from node 9, assuming the matcher
24+
/// computes the slice for the first two depth levels
25+
/// ============================
26+
/// 1 2 3 4
27+
/// |_______| |______|
28+
/// | | |
29+
/// | 5 6
30+
/// |___|_____________|
31+
/// | |
32+
/// 7 8
33+
/// |_______________|
34+
/// |
35+
/// 9
36+
///
37+
/// Assuming all local orders match the numbering order:
38+
/// {5, 7, 6, 8, 9}
39+
namespace mlir::query::matcher {
40+
class BackwardSliceMatcher {
41+
public:
42+
explicit BackwardSliceMatcher(query::matcher::DynMatcher &&innerMatcher,
43+
int64_t maxDepth, bool inclusive,
44+
bool omitBlockArguments, bool omitUsesFromAbove)
45+
: innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth),
46+
inclusive(inclusive), omitBlockArguments(omitBlockArguments),
47+
omitUsesFromAbove(omitUsesFromAbove) {}
48+
bool match(Operation *op, SetVector<Operation *> &backwardSlice) {
49+
BackwardSliceOptions options;
50+
return (innerMatcher.match(op) &&
51+
matches(op, backwardSlice, options, maxDepth));
52+
}
53+
54+
private:
55+
bool matches(Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
56+
BackwardSliceOptions &options, int64_t maxDepth);
57+
58+
private:
59+
// The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
60+
// to determine whether we want to traverse the DAG or not. For example, we
61+
// want to explore the DAG only if the top-level operation name is
62+
// "arith.addf".
63+
query::matcher::DynMatcher innerMatcher;
64+
// maxDepth specifies the maximum depth that the matcher can traverse in the
65+
// DAG. For example, if maxDepth is 2, the matcher will explore the defining
66+
// operations of the top-level op up to 2 levels.
67+
int64_t maxDepth;
68+
69+
bool inclusive;
70+
bool omitBlockArguments;
71+
bool omitUsesFromAbove;
72+
};
73+
74+
// Matches transitive defs of a top level operation up to N levels
75+
inline BackwardSliceMatcher
76+
m_GetDefinitions(query::matcher::DynMatcher innerMatcher, int64_t maxDepth,
77+
bool inclusive, bool omitBlockArguments,
78+
bool omitUsesFromAbove) {
79+
assert(maxDepth >= 0 && "maxDepth must be non-negative");
80+
return BackwardSliceMatcher(std::move(innerMatcher), maxDepth, inclusive,
81+
omitBlockArguments, omitUsesFromAbove);
82+
}
83+
} // namespace mlir::query::matcher
84+
85+
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H

mlir/include/mlir/Query/Matcher/Marshallers.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,36 @@ struct ArgTypeTraits<llvm::StringRef> {
5050
}
5151
};
5252

53+
template <>
54+
struct ArgTypeTraits<int64_t> {
55+
static bool hasCorrectType(const VariantValue &value) {
56+
return value.isSigned();
57+
}
58+
59+
static unsigned get(const VariantValue &value) { return value.getSigned(); }
60+
61+
static ArgKind getKind() { return ArgKind::Signed; }
62+
63+
static std::optional<std::string> getBestGuess(const VariantValue &) {
64+
return std::nullopt;
65+
}
66+
};
67+
68+
template <>
69+
struct ArgTypeTraits<bool> {
70+
static bool hasCorrectType(const VariantValue &value) {
71+
return value.isBoolean();
72+
}
73+
74+
static unsigned get(const VariantValue &value) { return value.getBoolean(); }
75+
76+
static ArgKind getKind() { return ArgKind::Boolean; }
77+
78+
static std::optional<std::string> getBestGuess(const VariantValue &) {
79+
return std::nullopt;
80+
}
81+
};
82+
5383
template <>
5484
struct ArgTypeTraits<DynMatcher> {
5585

mlir/include/mlir/Query/Matcher/MatchFinder.h

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,48 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This file contains the MatchFinder class, which is used to find operations
10-
// that match a given matcher.
10+
// that match a given matcher and print them.
1111
//
1212
//===----------------------------------------------------------------------===//
1313

1414
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
1515
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
1616

1717
#include "MatchersInternal.h"
18+
#include "mlir/Query/Query.h"
19+
#include "mlir/Query/QuerySession.h"
20+
#include "llvm/ADT/SetVector.h"
1821

1922
namespace mlir::query::matcher {
2023

21-
// MatchFinder is used to find all operations that match a given matcher.
24+
/// A class that provides utilities to find operations in a DAG
2225
class MatchFinder {
26+
2327
public:
24-
// Returns all operations that match the given matcher.
25-
static std::vector<Operation *> getMatches(Operation *root,
26-
DynMatcher matcher) {
27-
std::vector<Operation *> matches;
28-
29-
// Simple match finding with walk.
30-
root->walk([&](Operation *subOp) {
31-
if (matcher.match(subOp))
32-
matches.push_back(subOp);
33-
});
34-
35-
return matches;
36-
}
28+
/// A subclass which preserves the matching information
29+
struct MatchResult {
30+
MatchResult() = default;
31+
MatchResult(Operation *rootOp, std::vector<Operation *> matchedOps);
32+
33+
/// Contains the root operation of the matching environment
34+
Operation *rootOp = nullptr;
35+
/// Contains the matching enviroment. This allows the user to easily
36+
/// extract the matched operations
37+
std::vector<Operation *> matchedOps;
38+
};
39+
/// Traverses the DAG and collects the "rootOp" + "matching enviroment" for
40+
/// a given Matcher
41+
std::vector<MatchResult> collectMatches(Operation *root,
42+
DynMatcher matcher) const;
43+
/// Prints the matched operation
44+
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const;
45+
/// Labels the matched operation with the given binding (e.g., "root") and
46+
/// prints it
47+
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
48+
const std::string &binding) const;
49+
/// Flattens a vector of MatchResults into a vector of operations
50+
std::vector<Operation *>
51+
flattenMatchedOps(std::vector<MatchResult> &matches) const;
3752
};
3853

3954
} // namespace mlir::query::matcher

mlir/include/mlir/Query/Matcher/MatchersInternal.h

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
//
99
// Implements the base layer of the matcher framework.
1010
//
11-
// Matchers are methods that return a Matcher which provides a method
12-
// match(Operation *op)
11+
// Matchers are methods that return a Matcher which provides a method one of the
12+
// following methods: match(Operation *op), match(Operation *op,
13+
// SetVector<Operation *> &matchedOps)
1314
//
1415
// The matcher functions are defined in include/mlir/IR/Matchers.h.
1516
// This file contains the wrapper classes needed to construct matchers for
@@ -25,13 +26,39 @@
2526

2627
namespace mlir::query::matcher {
2728

29+
// Defaults to false if T has no match() method with the signature:
30+
// match(Operation* op).
31+
template <typename T, typename = void>
32+
struct has_simple_match : std::false_type {};
33+
34+
// Specialized type trait that evaluates to true if T has a match() method
35+
// with the signature: match(Operation* op).
36+
template <typename T>
37+
struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
38+
std::declval<Operation *>()))>>
39+
: std::true_type {};
40+
41+
// Defaults to false if T has no match() method with the signature:
42+
// match(Operation* op, SetVector<Operation*>&).
43+
template <typename T, typename = void>
44+
struct has_bound_match : std::false_type {};
45+
46+
// Specialized type trait that evaluates to true if T has a match() method
47+
// with the signature: match(Operation* op, SetVector<Operation*>&).
48+
template <typename T>
49+
struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
50+
std::declval<Operation *>(),
51+
std::declval<SetVector<Operation *> &>()))>>
52+
: std::true_type {};
53+
2854
// Generic interface for matchers on an MLIR operation.
2955
class MatcherInterface
3056
: public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
3157
public:
3258
virtual ~MatcherInterface() = default;
3359

3460
virtual bool match(Operation *op) = 0;
61+
virtual bool match(Operation *op, SetVector<Operation *> &matchedOps) = 0;
3562
};
3663

3764
// MatcherFnImpl takes a matcher function object and implements
@@ -40,14 +67,25 @@ template <typename MatcherFn>
4067
class MatcherFnImpl : public MatcherInterface {
4168
public:
4269
MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
43-
bool match(Operation *op) override { return matcherFn.match(op); }
70+
71+
bool match(Operation *op) override {
72+
if constexpr (has_simple_match<MatcherFn>::value)
73+
return matcherFn.match(op);
74+
return false;
75+
}
76+
77+
bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
78+
if constexpr (has_bound_match<MatcherFn>::value)
79+
return matcherFn.match(op, matchedOps);
80+
return false;
81+
}
4482

4583
private:
4684
MatcherFn matcherFn;
4785
};
4886

49-
// Matcher wraps a MatcherInterface implementation and provides a match()
50-
// method that redirects calls to the underlying implementation.
87+
// Matcher wraps a MatcherInterface implementation and provides match()
88+
// methods that redirect calls to the underlying implementation.
5189
class DynMatcher {
5290
public:
5391
// Takes ownership of the provided implementation pointer.
@@ -62,12 +100,13 @@ class DynMatcher {
62100
}
63101

64102
bool match(Operation *op) const { return implementation->match(op); }
103+
bool match(Operation *op, SetVector<Operation *> &matchedOps) const {
104+
return implementation->match(op, matchedOps);
105+
}
65106

66-
void setFunctionName(StringRef name) { functionName = name.str(); };
67-
68-
bool hasFunctionName() const { return !functionName.empty(); };
69-
70-
StringRef getFunctionName() const { return functionName; };
107+
void setFunctionName(StringRef name) { functionName = name.str(); }
108+
bool hasFunctionName() const { return !functionName.empty(); }
109+
StringRef getFunctionName() const { return functionName; }
71110

72111
private:
73112
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;

mlir/include/mlir/Query/Matcher/VariantValue.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
namespace mlir::query::matcher {
2222

2323
// All types that VariantValue can contain.
24-
enum class ArgKind { Matcher, String };
24+
enum class ArgKind { Boolean, Matcher, Signed, String };
2525

2626
// A variant matcher object to abstract simple and complex matchers into a
2727
// single object type.
@@ -81,6 +81,8 @@ class VariantValue {
8181
// Specific constructors for each supported type.
8282
VariantValue(const llvm::StringRef string);
8383
VariantValue(const VariantMatcher &matcher);
84+
VariantValue(int64_t signedValue);
85+
VariantValue(bool setBoolean);
8486

8587
// String value functions.
8688
bool isString() const;
@@ -92,21 +94,36 @@ class VariantValue {
9294
const VariantMatcher &getMatcher() const;
9395
void setMatcher(const VariantMatcher &matcher);
9496

97+
// Signed value functions.
98+
bool isSigned() const;
99+
int64_t getSigned() const;
100+
void setSigned(int64_t signedValue);
101+
102+
// Boolean value functions.
103+
bool isBoolean() const;
104+
bool getBoolean() const;
105+
void setBoolean(bool booleanValue);
95106
// String representation of the type of the value.
96107
std::string getTypeAsString() const;
108+
explicit operator bool() const { return hasValue(); }
109+
bool hasValue() const { return type != ValueType::Nothing; }
97110

98111
private:
99112
void reset();
100113

101114
// All supported value types.
102115
enum class ValueType {
116+
Boolean,
117+
Matcher,
103118
Nothing,
119+
Signed,
104120
String,
105-
Matcher,
106121
};
107122

108123
// All supported value types.
109124
union AllValues {
125+
bool Boolean;
126+
int64_t Signed;
110127
llvm::StringRef *String;
111128
VariantMatcher *Matcher;
112129
};

mlir/lib/Query/Matcher/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
add_mlir_library(MLIRQueryMatcher
2+
MatchFinder.cpp
3+
ExtraMatchers.cpp
24
Parser.cpp
35
RegistryManager.cpp
46
VariantValue.cpp

0 commit comments

Comments
 (0)