Skip to content

Commit 44cfde5

Browse files
[mlir][PDL] Add support for native constraints with results
This adds support for native PDL (and PDLL) C++ constraints to return results. This is useful for situations where a pattern checks for certain constraints of multiple interdependent attributes and computes a new attribute value based on them. Currently, for such an example it is required to escape to C++ during matching to perform the check and after a successful match again escape to native C++ to perform the computation during the rewriting part of the pattern. With this work we can do the computation in C++ during matching and use the result in the rewriting part of the pattern. Effectively this enables a choice in the trade-off of memory consumption during matching vs recomputation of values. This is an example of a situation where this is useful: We have two operations with certain attributes that have interdependent constraints. For instance attr_foo: one_of [0, 2, 4, 8], attr_bar: one_of [0, 2, 4, 8] and attr_foo == attr_bar. The pattern should only match if all conditions are true. The new operation should be created with a new attribute which is computed from the two matched attributes e.g. attr_baz = attr_foo * attr_bar. For the check we already escape to native C++ and have all values at hand so it makes sense to directly compute the new attribute value as well: ``` Constraint checkAndCompute(attr0: Attr, attr1: Attr) -> Attr; Pattern example with benefit(1) { let foo = op<test.foo>() {attr = attr_foo : Attr}; let bar = op<test.bar>(foo) {attr = attr_bar : Attr}; let attr_baz = checkAndCompute(attr_foo, attr_bar); rewrite bar with { let baz = op<test.baz> {attr=attr_baz}; replace bar with baz; }; } ``` To achieve this the following notable changes were necessary: PDLL - Remove check in PDLL parser that prevented native constraints from returning results PDL - Change PDL definition of pdl.apply_native_constraint to allow variadic results PDL_interp - Change PDL_interp definition of pdl_interp.apply_constraint to allow variadic results PDLToPDLInterp Pass: The input to the pass is an arbitrary number of PDL patterns. The pass collects the predicates that are required to match all of the pdl patterns and establishes an ordering that allows creation of a single efficient matcher function to match all of them. Values that are matched and possibly used in the rewriting part of a pattern are represented as positions. This allows fusion and thus reusing a single position for multiple matching patterns. Accordingly, we introduce ConstraintPosition, which records the type and index of the result of the constraint. The problem is for the corresponding value to be used in the rewriting part of a pattern it has to be an input to the pdl_interp.record_match operation, which is generated early during the pass such that its surrounding block can be referred to by branching operations. In consequence the value has to be materialized after the original pdl.apply_native_constraint has been deleted but before we get the chance to generate the corresponding pdl_interp.apply_constraint operation. We solve this by emitting a placeholder value when a ConstraintPosition is evaluated. These placeholder values (due to fusion there may be multiple for one constraint result) are replaced later when the actual pdl_interp.apply_constraint operation is created. Bytecode generator and interpreter: Constraint functions which return results have a different type compared to existing constraint functions. They have the same type as native rewrite functions and hence are registered as rewrite functions. Co-authored-by: Martin Lücke <[email protected]>
1 parent 27498e9 commit 44cfde5

File tree

17 files changed

+513
-77
lines changed

17 files changed

+513
-77
lines changed

mlir/include/mlir/Dialect/PDL/IR/PDLOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,25 @@ def PDL_ApplyNativeConstraintOp
3535
let description = [{
3636
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
3737
has been registered externally with the consumer of PDL, to a given set of
38-
entities.
38+
entities and optionally return a number of values.
3939

4040
Example:
4141

4242
```mlir
4343
// Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
4444
pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
45+
// Apply constraint `with_result` to `root`. This constraint returns an attribute.
46+
%attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
4547
```
4648
}];
4749

4850
let arguments = (ins StrAttr:$name,
4951
Variadic<PDL_AnyType>:$args,
5052
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
51-
let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
53+
let results = (outs Variadic<PDL_AnyType>:$results);
54+
let assemblyFormat = [{
55+
$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict
56+
}];
5257
let hasVerifier = 1;
5358
}
5459

mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
8888
let description = [{
8989
`pdl_interp.apply_constraint` operations apply a generic constraint, that
9090
has been registered with the interpreter, with a given set of positional
91-
values. On success, this operation branches to the true destination,
91+
values.
92+
The constraint function may return any number of results.
93+
On success, this operation branches to the true destination,
9294
otherwise the false destination is taken. This behavior can be reversed
9395
by setting the attribute `isNegated` to true.
9496

@@ -104,8 +106,10 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
104106
let arguments = (ins StrAttr:$name,
105107
Variadic<PDL_AnyType>:$args,
106108
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
109+
let results = (outs Variadic<PDL_AnyType>:$results);
107110
let assemblyFormat = [{
108-
$name `(` $args `:` type($args) `)` attr-dict `->` successors
111+
$name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict
112+
`->` successors
109113
}];
110114
}
111115

mlir/include/mlir/IR/PDLPatternMatch.h.inc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,19 @@ public:
868868
std::forward<ConstraintFnT>(constraintFn)));
869869
}
870870

871+
/// Register a constraint function that produces results with PDL. A
872+
/// constraint function with results uses the same registry as rewrite
873+
/// functions. It may be specified as follows:
874+
///
875+
/// * `LogicalResult (PatternRewriter &, PDLResultList &,
876+
/// ArrayRef<PDLValue>)`
877+
///
878+
/// The arguments of the constraint function are passed via the low-level
879+
/// PDLValue form, and the results are manually appended to the given result
880+
/// list.
881+
void registerConstraintFunctionWithResults(StringRef name,
882+
PDLRewriteFunction constraintFn);
883+
871884
/// Register a rewrite function with PDL. A rewrite function may be specified
872885
/// in one of two ways:
873886
///

mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,12 @@ struct PatternLowering {
148148
/// A mapping between pattern operations and the corresponding configuration
149149
/// set.
150150
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
151+
152+
/// A mapping from a constraint question and result index that together
153+
/// refer to a value created by a constraint to the temporary placeholder
154+
/// values created for them.
155+
std::multimap<std::pair<ConstraintQuestion *, unsigned>, Value>
156+
constraintResultMap;
151157
};
152158
} // namespace
153159

@@ -364,6 +370,21 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
364370
loc, cast<ArrayAttr>(rawTypeAttr));
365371
break;
366372
}
373+
case Predicates::ConstraintResultPos: {
374+
// The corresponding pdl.ApplyNativeConstraint op has already been deleted
375+
// and the new pdl_interp.ApplyConstraint has not been created yet. To
376+
// enable referring to results created by this operation we build a
377+
// placeholder value that will be replaced when the actual
378+
// pdl_interp.ApplyConstraint operation is created.
379+
auto *constrResPos = cast<ConstraintPosition>(pos);
380+
Value placeholderValue = builder.create<pdl_interp::CreateAttributeOp>(
381+
loc, StringAttr::get(builder.getContext(), "placeholder"));
382+
constraintResultMap.insert(
383+
{{constrResPos->getQuestion(), constrResPos->getIndex()},
384+
placeholderValue});
385+
value = placeholderValue;
386+
break;
387+
}
367388
default:
368389
llvm_unreachable("Generating unknown Position getter");
369390
break;
@@ -447,9 +468,25 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
447468
}
448469
case Predicates::ConstraintQuestion: {
449470
auto *cstQuestion = cast<ConstraintQuestion>(question);
450-
builder.create<pdl_interp::ApplyConstraintOp>(
451-
loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
452-
failure);
471+
auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
472+
loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
473+
cstQuestion->getIsNegated(), success, failure);
474+
// Replace the generated placeholders with the results of the constraint and
475+
// erase them
476+
for (auto result : llvm::enumerate(applyConstraintOp.getResults())) {
477+
std::pair<ConstraintQuestion *, unsigned> substitutionKey = {
478+
cstQuestion, result.index()};
479+
// Check if there are substitutions to perform. If the result is never
480+
// used or multiple calls to the same constraint have been merged,
481+
// no substitutions will have been generated for this specific op.
482+
auto range = constraintResultMap.equal_range(substitutionKey);
483+
std::for_each(range.first, range.second, [&](const auto &elem) {
484+
Value placeholder = elem.second;
485+
placeholder.replaceAllUsesWith(result.value());
486+
placeholder.getDefiningOp()->erase();
487+
});
488+
constraintResultMap.erase(substitutionKey);
489+
}
453490
break;
454491
}
455492
default:

mlir/lib/Conversion/PDLToPDLInterp/Predicate.h

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ enum Kind : unsigned {
4747
OperandPos,
4848
OperandGroupPos,
4949
AttributePos,
50+
ConstraintResultPos,
5051
ResultPos,
5152
ResultGroupPos,
5253
TypePos,
@@ -279,6 +280,28 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
279280
bool isOperandDefiningOp() const;
280281
};
281282

283+
//===----------------------------------------------------------------------===//
284+
// ConstraintPosition
285+
286+
struct ConstraintQuestion;
287+
288+
/// A position describing the result of a native constraint. It saves the
289+
/// corresponding ConstraintQuestion and result index to enable referring
290+
/// back to them
291+
struct ConstraintPosition
292+
: public PredicateBase<ConstraintPosition, Position,
293+
std::pair<ConstraintQuestion *, unsigned>,
294+
Predicates::ConstraintResultPos> {
295+
using PredicateBase::PredicateBase;
296+
297+
/// Returns the ConstraintQuestion to enable keeping track of the native
298+
/// constraint this position stems from.
299+
ConstraintQuestion *getQuestion() const { return key.first; }
300+
301+
// Returns the result index of this position
302+
unsigned getIndex() const { return key.second; }
303+
};
304+
282305
//===----------------------------------------------------------------------===//
283306
// ResultPosition
284307

@@ -447,11 +470,13 @@ struct AttributeQuestion
447470
: public PredicateBase<AttributeQuestion, Qualifier, void,
448471
Predicates::AttributeQuestion> {};
449472

450-
/// Apply a parameterized constraint to multiple position values.
473+
/// Apply a parameterized constraint to multiple position values and possibly
474+
/// produce results.
451475
struct ConstraintQuestion
452-
: public PredicateBase<ConstraintQuestion, Qualifier,
453-
std::tuple<StringRef, ArrayRef<Position *>, bool>,
454-
Predicates::ConstraintQuestion> {
476+
: public PredicateBase<
477+
ConstraintQuestion, Qualifier,
478+
std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
479+
Predicates::ConstraintQuestion> {
455480
using Base::Base;
456481

457482
/// Return the name of the constraint.
@@ -460,15 +485,19 @@ struct ConstraintQuestion
460485
/// Return the arguments of the constraint.
461486
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
462487

488+
/// Return the result types of the constraint.
489+
ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
490+
463491
/// Return the negation status of the constraint.
464-
bool getIsNegated() const { return std::get<2>(key); }
492+
bool getIsNegated() const { return std::get<3>(key); }
465493

466494
/// Construct an instance with the given storage allocator.
467495
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
468496
KeyTy key) {
469497
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
470498
alloc.copyInto(std::get<1>(key)),
471-
std::get<2>(key)});
499+
alloc.copyInto(std::get<2>(key)),
500+
std::get<3>(key)});
472501
}
473502

474503
/// Returns a hash suitable for the given keytype.
@@ -526,6 +555,7 @@ class PredicateUniquer : public StorageUniquer {
526555
// Register the types of Positions with the uniquer.
527556
registerParametricStorageType<AttributePosition>();
528557
registerParametricStorageType<AttributeLiteralPosition>();
558+
registerParametricStorageType<ConstraintPosition>();
529559
registerParametricStorageType<ForEachPosition>();
530560
registerParametricStorageType<OperandPosition>();
531561
registerParametricStorageType<OperandGroupPosition>();
@@ -588,6 +618,12 @@ class PredicateBuilder {
588618
return OperationPosition::get(uniquer, p);
589619
}
590620

621+
// Returns a position for a new value created by a constraint.
622+
ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
623+
unsigned index) {
624+
return ConstraintPosition::get(uniquer, std::make_pair(q, index));
625+
}
626+
591627
/// Returns an attribute position for an attribute of the given operation.
592628
Position *getAttribute(OperationPosition *p, StringRef name) {
593629
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
@@ -673,11 +709,11 @@ class PredicateBuilder {
673709
}
674710

675711
/// Create a predicate that applies a generic constraint.
676-
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
677-
bool isNegated) {
678-
return {
679-
ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
680-
TrueAnswer::get(uniquer)};
712+
Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
713+
ArrayRef<Type> resultTypes, bool isNegated) {
714+
return {ConstraintQuestion::get(
715+
uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
716+
TrueAnswer::get(uniquer)};
681717
}
682718

683719
/// Create a predicate comparing a value with null.

mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/BuiltinOps.h"
1616
#include "mlir/Interfaces/InferTypeOpInterface.h"
1717
#include "llvm/ADT/MapVector.h"
18+
#include "llvm/ADT/SmallPtrSet.h"
1819
#include "llvm/ADT/TypeSwitch.h"
1920
#include "llvm/Support/Debug.h"
2021
#include <queue>
@@ -272,8 +273,17 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
272273
// Push the constraint to the furthest position.
273274
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
274275
comparePosDepth);
275-
PredicateBuilder::Predicate pred =
276-
builder.getConstraint(op.getName(), allPositions, op.getIsNegated());
276+
ResultRange results = op.getResults();
277+
PredicateBuilder::Predicate pred = builder.getConstraint(
278+
op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
279+
op.getIsNegated());
280+
281+
// for each result register a position so it can be used later
282+
for (auto result : llvm::enumerate(results)) {
283+
ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
284+
ConstraintPosition *pos = builder.getConstraintPosition(q, result.index());
285+
inputs[result.value()] = pos;
286+
}
277287
predList.emplace_back(pos, pred);
278288
}
279289

@@ -875,6 +885,27 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
875885
*root = std::make_unique<ExitNode>();
876886
}
877887

888+
/// Sorts the range begin/end with the partial order given by cmp.
889+
template <typename Iterator, typename Compare>
890+
void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
891+
while (begin != end) {
892+
// Cannot compute sortBeforeOthers in the predicate of stable_partition
893+
// because stable_partition will not keep the [begin, end) range intact
894+
// while it runs.
895+
llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
896+
for (auto i = begin; i != end; ++i) {
897+
if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
898+
sortBeforeOthers.insert(*i);
899+
}
900+
901+
auto const next = std::stable_partition(begin, end, [&](auto const &a) {
902+
return sortBeforeOthers.contains(a);
903+
});
904+
assert(next != begin && "not a partial ordering");
905+
begin = next;
906+
}
907+
}
908+
878909
/// Given a module containing PDL pattern operations, generate a matcher tree
879910
/// using the patterns within the given module and return the root matcher node.
880911
std::unique_ptr<MatcherNode>
@@ -955,6 +986,24 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
955986
return *lhs < *rhs;
956987
});
957988

989+
// Mostly keep the now established order, but also ensure that
990+
// ConstraintQuestions come after the results they use.
991+
stableTopologicalSort(ordered.begin(), ordered.end(),
992+
[](OrderedPredicate *a, OrderedPredicate *b) {
993+
auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
994+
auto *cqb = dyn_cast<ConstraintQuestion>(b->question);
995+
if (cqa && cqb) {
996+
// Does any argument of b use a? Then b must be
997+
// sorted after a.
998+
return llvm::any_of(
999+
cqb->getArgs(), [&](Position *p) {
1000+
auto *cp = dyn_cast<ConstraintPosition>(p);
1001+
return cp && cp->getQuestion() == cqa;
1002+
});
1003+
}
1004+
return false;
1005+
});
1006+
9581007
// Build the matchers for each of the pattern predicate lists.
9591008
std::unique_ptr<MatcherNode> root;
9601009
for (OrderedPredicateList &list : lists)

mlir/lib/Dialect/PDL/IR/PDL.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
9494
LogicalResult ApplyNativeConstraintOp::verify() {
9595
if (getNumOperands() == 0)
9696
return emitOpError("expected at least one argument");
97+
if (llvm::any_of(getResults(), [](OpResult result) {
98+
return result.getType().isa<OperationType>();
99+
})) {
100+
return emitOpError(
101+
"returning an operation from a constraint is not supported");
102+
}
97103
return success();
98104
}
99105

mlir/lib/IR/PDL/PDLPatternMatch.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,15 @@ void PDLPatternModule::registerConstraintFunction(
123123
constraintFunctions.try_emplace(name, std::move(constraintFn));
124124
}
125125

126+
void PDLPatternModule::registerConstraintFunctionWithResults(
127+
StringRef name, PDLRewriteFunction constraintFn) {
128+
// TODO: Is it possible to diagnose when `name` is already registered to
129+
// a function that is not equivalent to `rewriteFn`?
130+
// Allow existing mappings in the case multiple patterns depend on the same
131+
// rewrite.
132+
registerRewriteFunction(name, std::move(constraintFn));
133+
}
134+
126135
void PDLPatternModule::registerRewriteFunction(StringRef name,
127136
PDLRewriteFunction rewriteFn) {
128137
// TODO: Is it possible to diagnose when `name` is already registered to

0 commit comments

Comments
 (0)