Skip to content

Commit 0eac094

Browse files
dbudiichios202
authored andcommitted
Fixed pattern matching in mlir-query test files & removed asserts from slice-matchers
1 parent 6383a12 commit 0eac094

17 files changed

+529
-51
lines changed

mlir/include/mlir/IR/Matchers.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,15 @@ struct NameOpMatcher {
5959
NameOpMatcher(StringRef name) : name(name) {}
6060
bool match(Operation *op) { return op->getName().getStringRef() == name; }
6161

62-
StringRef name;
62+
std::string name;
6363
};
6464

6565
/// The matcher that matches operations that have the specified attribute name.
6666
struct AttrOpMatcher {
6767
AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
6868
bool match(Operation *op) { return op->hasAttr(attrName); }
6969

70-
StringRef attrName;
70+
std::string attrName;
7171
};
7272

7373
/// The matcher that matches operations that have the `ConstantLike` trait, and
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
//===- ExtraMatchers.h - Various common matchers ---------------------*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This file provides extra matchers that are very useful for mlir-query
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_IR_EXTRAMATCHERS_H
15+
#define MLIR_IR_EXTRAMATCHERS_H
16+
17+
#include "MatchFinder.h"
18+
#include "MatchersInternal.h"
19+
#include "mlir/IR/Region.h"
20+
#include "mlir/Query/Query.h"
21+
#include "llvm/Support/raw_ostream.h"
22+
23+
namespace mlir {
24+
25+
namespace query {
26+
27+
namespace extramatcher {
28+
29+
namespace detail {
30+
31+
class BackwardSliceMatcher {
32+
public:
33+
BackwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
34+
: innerMatcher(std::move(innerMatcher)), hops(hops) {}
35+
36+
private:
37+
bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
38+
QueryOptions &options, unsigned tempHops) {
39+
40+
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
41+
return false;
42+
}
43+
44+
auto processValue = [&](Value value) {
45+
if (tempHops == 0) {
46+
return;
47+
}
48+
if (auto *definingOp = value.getDefiningOp()) {
49+
if (backwardSlice.count(definingOp) == 0)
50+
matches(definingOp, backwardSlice, options, tempHops - 1);
51+
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
52+
if (options.omitBlockArguments)
53+
return;
54+
Block *block = blockArg.getOwner();
55+
56+
Operation *parentOp = block->getParentOp();
57+
58+
if (parentOp && backwardSlice.count(parentOp) == 0) {
59+
if (parentOp->getNumRegions() != 1 &&
60+
parentOp->getRegion(0).getBlocks().size() != 1) {
61+
llvm::errs()
62+
<< "Error: Expected parentOp to have exactly one region and "
63+
<< "exactly one block, but found " << parentOp->getNumRegions()
64+
<< " regions and "
65+
<< (parentOp->getRegion(0).getBlocks().size()) << " blocks.\n";
66+
};
67+
matches(parentOp, backwardSlice, options, tempHops - 1);
68+
}
69+
} else {
70+
llvm::errs() << "No definingOp and not a block argument\n";
71+
return;
72+
}
73+
};
74+
75+
if (!options.omitUsesFromAbove) {
76+
llvm::for_each(op->getRegions(), [&](Region &region) {
77+
SmallPtrSet<Region *, 4> descendents;
78+
region.walk(
79+
[&](Region *childRegion) { descendents.insert(childRegion); });
80+
region.walk([&](Operation *op) {
81+
for (OpOperand &operand : op->getOpOperands()) {
82+
if (!descendents.contains(operand.get().getParentRegion()))
83+
processValue(operand.get());
84+
}
85+
});
86+
});
87+
}
88+
89+
llvm::for_each(op->getOperands(), processValue);
90+
backwardSlice.insert(op);
91+
return true;
92+
}
93+
94+
public:
95+
bool match(Operation *op, SetVector<Operation *> &backwardSlice,
96+
QueryOptions &options) {
97+
98+
if (innerMatcher.match(op) && matches(op, backwardSlice, options, hops)) {
99+
if (!options.inclusive) {
100+
backwardSlice.remove(op);
101+
}
102+
return true;
103+
}
104+
return false;
105+
}
106+
107+
private:
108+
matcher::DynMatcher innerMatcher;
109+
unsigned hops;
110+
};
111+
112+
class ForwardSliceMatcher {
113+
public:
114+
ForwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
115+
: innerMatcher(std::move(innerMatcher)), hops(hops) {}
116+
117+
private:
118+
bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
119+
QueryOptions &options, unsigned tempHops) {
120+
121+
if (tempHops == 0) {
122+
forwardSlice.insert(op);
123+
return true;
124+
}
125+
126+
for (Region &region : op->getRegions())
127+
for (Block &block : region)
128+
for (Operation &blockOp : block)
129+
if (forwardSlice.count(&blockOp) == 0)
130+
matches(&blockOp, forwardSlice, options, tempHops - 1);
131+
for (Value result : op->getResults()) {
132+
for (Operation *userOp : result.getUsers())
133+
if (forwardSlice.count(userOp) == 0)
134+
matches(userOp, forwardSlice, options, tempHops - 1);
135+
}
136+
137+
forwardSlice.insert(op);
138+
return true;
139+
}
140+
141+
public:
142+
bool match(Operation *op, SetVector<Operation *> &forwardSlice,
143+
QueryOptions &options) {
144+
if (innerMatcher.match(op) && matches(op, forwardSlice, options, hops)) {
145+
if (!options.inclusive) {
146+
forwardSlice.remove(op);
147+
}
148+
SmallVector<Operation *, 0> v(forwardSlice.takeVector());
149+
forwardSlice.insert(v.rbegin(), v.rend());
150+
return true;
151+
}
152+
return false;
153+
}
154+
155+
private:
156+
matcher::DynMatcher innerMatcher;
157+
unsigned hops;
158+
};
159+
160+
} // namespace detail
161+
162+
inline detail::BackwardSliceMatcher
163+
definedBy(mlir::query::matcher::DynMatcher innerMatcher) {
164+
return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
165+
}
166+
167+
inline detail::BackwardSliceMatcher
168+
getDefinitions(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
169+
return detail::BackwardSliceMatcher(std::move(innerMatcher), hops);
170+
}
171+
172+
inline detail::ForwardSliceMatcher
173+
usedBy(mlir::query::matcher::DynMatcher innerMatcher) {
174+
return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
175+
}
176+
177+
inline detail::ForwardSliceMatcher
178+
getUses(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
179+
return detail::ForwardSliceMatcher(std::move(innerMatcher), hops);
180+
}
181+
182+
} // namespace extramatcher
183+
184+
} // namespace query
185+
186+
} // namespace mlir
187+
188+
#endif // MLIR_IR_EXTRAMATCHERS_H

mlir/include/mlir/Query/Matcher/Marshallers.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,21 @@ struct ArgTypeTraits<llvm::StringRef> {
5050
}
5151
};
5252

53+
template <>
54+
struct ArgTypeTraits<unsigned> {
55+
static bool hasCorrectType(const VariantValue &value) {
56+
return value.isUnsigned();
57+
}
58+
59+
static unsigned get(const VariantValue &value) { return value.getUnsigned(); }
60+
61+
static ArgKind getKind() { return ArgKind::Unsigned; }
62+
63+
static std::optional<std::string> getBestGuess(const VariantValue &) {
64+
return std::nullopt;
65+
}
66+
};
67+
5368
template <>
5469
struct ArgTypeTraits<DynMatcher> {
5570

mlir/include/mlir/Query/Matcher/MatchFinder.h

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,60 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This file contains the MatchFinder class, which is used to find operations
10-
// that match a given matcher.
10+
// that match a given matcher and print them.
1111
//
1212
//===----------------------------------------------------------------------===//
1313

1414
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
1515
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
1616

1717
#include "MatchersInternal.h"
18+
#include "mlir/Query/QuerySession.h"
19+
#include "llvm/ADT/SetVector.h"
20+
#include "llvm/Support/SourceMgr.h"
21+
#include "llvm/Support/raw_ostream.h"
1822

1923
namespace mlir::query::matcher {
2024

21-
// MatchFinder is used to find all operations that match a given matcher.
2225
class MatchFinder {
23-
public:
24-
// Returns all operations that match the given matcher.
25-
static std::vector<Operation *> getMatches(Operation *root,
26-
DynMatcher matcher) {
27-
std::vector<Operation *> matches;
26+
private:
27+
static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
28+
mlir::Operation *op, const std::string &binding) {
29+
auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
30+
auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
31+
qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
32+
qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
33+
"\"" + binding + "\" binds here");
34+
};
2835

29-
// Simple match finding with walk.
36+
public:
37+
static std::vector<Operation *>
38+
getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
39+
llvm::raw_ostream &os, QuerySession &qs) {
40+
unsigned matchCount = 0;
41+
std::vector<Operation *> matchedOps;
42+
SetVector<Operation *> tempStorage;
43+
os << "\n";
3044
root->walk([&](Operation *subOp) {
31-
if (matcher.match(subOp))
32-
matches.push_back(subOp);
45+
if (matcher.match(subOp)) {
46+
matchedOps.push_back(subOp);
47+
os << "Match #" << ++matchCount << ":\n\n";
48+
printMatch(os, qs, subOp, "root");
49+
} else {
50+
SmallVector<Operation *> printingOps;
51+
if (matcher.match(subOp, tempStorage, options)) {
52+
os << "Match #" << ++matchCount << ":\n\n";
53+
SmallVector<Operation *> printingOps(tempStorage.takeVector());
54+
for (auto op : printingOps) {
55+
printMatch(os, qs, op, "root");
56+
matchedOps.push_back(op);
57+
}
58+
printingOps.clear();
59+
}
60+
}
3361
});
34-
35-
return matches;
62+
os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
63+
return matchedOps;
3664
}
3765
};
3866

0 commit comments

Comments
 (0)