Skip to content

Commit dd474be

Browse files
[mlir][PDL] Add support for native constraints with results
This adds support for native PDL (and PDLL) C++ constraints to return results. This is useful for situations where a pattern checks for certain constraints of multiple interdependent attributes and computes a new attribute value based on them. Currently, for such an example it is required to escape to C++ during matching to perform the check and after a successful match again escape to native C++ to perform the computation during the rewriting part of the pattern. With this work we can do the computation in C++ during matching and use the result in the rewriting part of the pattern. Effectively this enables a choice in the trade-off of memory consumption during matching vs recomputation of values. This is an example of a situation where this is useful: We have two operations with certain attributes that have interdependent constraints. For instance attr_foo: one_of [0, 2, 4, 8], attr_bar: one_of [0, 2, 4, 8] and attr_foo == attr_bar. The pattern should only match if all conditions are true. The new operation should be created with a new attribute which is computed from the two matched attributes e.g. attr_baz = attr_foo * attr_bar. For the check we already escape to native C++ and have all values at hand so it makes sense to directly compute the new attribute value as well: ``` Constraint checkAndCompute(attr0: Attr, attr1: Attr) -> Attr; Pattern example with benefit(1) { let foo = op<test.foo>() {attr = attr_foo : Attr}; let bar = op<test.bar>(foo) {attr = attr_bar : Attr}; let attr_baz = checkAndCompute(attr_foo, attr_bar); rewrite bar with { let baz = op<test.baz> {attr=attr_baz}; replace bar with baz; }; } ``` To achieve this the following notable changes were necessary: PDLL - Remove check in PDLL parser that prevented native constraints from returning results PDL - Change PDL definition of pdl.apply_native_constraint to allow variadic results PDL_interp - Change PDL_interp definition of pdl_interp.apply_constraint to allow variadic results PDLToPDLInterp Pass: The input to the pass is an arbitrary number of PDL patterns. The pass collects the predicates that are required to match all of the pdl patterns and establishes an ordering that allows creation of a single efficient matcher function to match all of them. Values that are matched and possibly used in the rewriting part of a pattern are represented as positions. This allows fusion and thus reusing a single position for multiple matching patterns. Accordingly, we introduce ConstraintPosition, which records the type and index of the result of the constraint. The problem is for the corresponding value to be used in the rewriting part of a pattern it has to be an input to the pdl_interp.record_match operation, which is generated early during the pass such that its surrounding block can be referred to by branching operations. In consequence the value has to be materialized after the original pdl.apply_native_constraint has been deleted but before we get the chance to generate the corresponding pdl_interp.apply_constraint operation. We solve this by emitting a placeholder value when a ConstraintPosition is evaluated. These placeholder values (due to fusion there may be multiple for one constraint result) are replaced later when the actual pdl_interp.apply_constraint operation is created. Bytecode generator and interpreter: Constraint functions which return results have a different type compared to existing constraint functions. They have the same type as native rewrite functions and hence are registered as rewrite functions. Co-authored-by: Martin Lücke <[email protected]>
1 parent 27498e9 commit dd474be

File tree

18 files changed

+489
-92
lines changed

18 files changed

+489
-92
lines changed

mlir/include/mlir/Dialect/PDL/IR/PDLOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,25 @@ def PDL_ApplyNativeConstraintOp
3535
let description = [{
3636
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
3737
has been registered externally with the consumer of PDL, to a given set of
38-
entities.
38+
entities and optionally return a number of values.
3939

4040
Example:
4141

4242
```mlir
4343
// Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
4444
pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
45+
// Apply constraint `with_result` to `root`. This constraint returns an attribute.
46+
%attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
4547
```
4648
}];
4749

4850
let arguments = (ins StrAttr:$name,
4951
Variadic<PDL_AnyType>:$args,
5052
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
51-
let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
53+
let results = (outs Variadic<PDL_AnyType>:$results);
54+
let assemblyFormat = [{
55+
$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict
56+
}];
5257
let hasVerifier = 1;
5358
}
5459

mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
8888
let description = [{
8989
`pdl_interp.apply_constraint` operations apply a generic constraint, that
9090
has been registered with the interpreter, with a given set of positional
91-
values. On success, this operation branches to the true destination,
91+
values.
92+
The constraint function may return any number of results.
93+
On success, this operation branches to the true destination,
9294
otherwise the false destination is taken. This behavior can be reversed
9395
by setting the attribute `isNegated` to true.
9496

@@ -104,8 +106,10 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
104106
let arguments = (ins StrAttr:$name,
105107
Variadic<PDL_AnyType>:$args,
106108
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
109+
let results = (outs Variadic<PDL_AnyType>:$results);
107110
let assemblyFormat = [{
108-
$name `(` $args `:` type($args) `)` attr-dict `->` successors
111+
$name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict
112+
`->` successors
109113
}];
110114
}
111115

mlir/include/mlir/IR/PDLPatternMatch.h.inc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,9 @@ protected:
318318
/// A generic PDL pattern constraint function. This function applies a
319319
/// constraint to a given set of opaque PDLValue entities. Returns success if
320320
/// the constraint successfully held, failure otherwise.
321-
using PDLConstraintFunction =
322-
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
321+
using PDLConstraintFunction = std::function<LogicalResult(
322+
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
323+
323324
/// A native PDL rewrite function. This function performs a rewrite on the
324325
/// given set of values. Any results from this rewrite that should be passed
325326
/// back to PDL should be added to the provided result list. This method is only
@@ -726,7 +727,7 @@ std::enable_if_t<
726727
PDLConstraintFunction>
727728
buildConstraintFn(ConstraintFnT &&constraintFn) {
728729
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
729-
PatternRewriter &rewriter,
730+
PatternRewriter &rewriter, PDLResultList &,
730731
ArrayRef<PDLValue> values) -> LogicalResult {
731732
auto argIndices = std::make_index_sequence<
732733
llvm::function_traits<ConstraintFnT>::num_args - 1>();
@@ -842,10 +843,13 @@ public:
842843
/// Register a constraint function with PDL. A constraint function may be
843844
/// specified in one of two ways:
844845
///
845-
/// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
846+
/// * `LogicalResult (PatternRewriter &,
847+
/// PDLResultList &,
848+
/// ArrayRef<PDLValue>)`
846849
///
847850
/// In this overload the arguments of the constraint function are passed via
848-
/// the low-level PDLValue form.
851+
/// the low-level PDLValue form, and the results are manually appended to
852+
/// the given result list.
849853
///
850854
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
851855
///
@@ -960,8 +964,8 @@ public:
960964
}
961965
};
962966
class PDLResultList {};
963-
using PDLConstraintFunction =
964-
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
967+
using PDLConstraintFunction = std::function<LogicalResult(
968+
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
965969
using PDLRewriteFunction = std::function<LogicalResult(
966970
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
967971

mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ struct PatternLowering {
5050

5151
/// Generate interpreter operations for the tree rooted at the given matcher
5252
/// node, in the specified region.
53-
Block *generateMatcher(MatcherNode &node, Region &region);
53+
Block *generateMatcher(MatcherNode &node, Region &region,
54+
Block *block = nullptr);
5455

5556
/// Get or create an access to the provided positional value in the current
5657
/// block. This operation may mutate the provided block pointer if nested
@@ -148,6 +149,12 @@ struct PatternLowering {
148149
/// A mapping between pattern operations and the corresponding configuration
149150
/// set.
150151
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
152+
153+
/// A mapping from a constraint question and result index that together
154+
/// refer to a value created by a constraint to the temporary placeholder
155+
/// values created for them.
156+
std::multimap<std::pair<ConstraintQuestion *, unsigned>, Value>
157+
constraintResultMap;
151158
};
152159
} // namespace
153160

@@ -182,9 +189,12 @@ void PatternLowering::lower(ModuleOp module) {
182189
firstMatcherBlock->erase();
183190
}
184191

185-
Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
192+
Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
193+
Block *block) {
186194
// Push a new scope for the values used by this matcher.
187-
Block *block = &region.emplaceBlock();
195+
if (!block) {
196+
block = &region.emplaceBlock();
197+
}
188198
ValueMapScope scope(values);
189199

190200
// If this is the return node, simply insert the corresponding interpreter
@@ -364,6 +374,19 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
364374
loc, cast<ArrayAttr>(rawTypeAttr));
365375
break;
366376
}
377+
case Predicates::ConstraintResultPos: {
378+
// The corresponding pdl.ApplyNativeConstraint op has already been deleted
379+
// and the new pdl_interp.ApplyConstraint has not been created yet. To
380+
// enable referring to results created by this operation we build a
381+
// placeholder value that will be replaced when the actual
382+
// pdl_interp.ApplyConstraint operation is created.
383+
auto *constrResPos = cast<ConstraintPosition>(pos);
384+
auto i = constraintResultMap.find(
385+
{constrResPos->getQuestion(), constrResPos->getIndex()});
386+
assert(i != constraintResultMap.end());
387+
value = i->second;
388+
break;
389+
}
367390
default:
368391
llvm_unreachable("Generating unknown Position getter");
369392
break;
@@ -390,12 +413,11 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
390413
args.push_back(getValueAt(currentBlock, position));
391414
}
392415

393-
// Generate the matcher in the current (potentially nested) region
394-
// and get the failure successor.
395-
Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
416+
// Generate a new block as success successor and get the failure successor.
417+
Block *success = &region->emplaceBlock();
396418
Block *failure = failureBlockStack.back();
397419

398-
// Finally, create the predicate.
420+
// Create the predicate.
399421
builder.setInsertionPointToEnd(currentBlock);
400422
Predicates::Kind kind = question->getKind();
401423
switch (kind) {
@@ -447,14 +469,26 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
447469
}
448470
case Predicates::ConstraintQuestion: {
449471
auto *cstQuestion = cast<ConstraintQuestion>(question);
450-
builder.create<pdl_interp::ApplyConstraintOp>(
451-
loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
452-
failure);
472+
auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
473+
loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
474+
cstQuestion->getIsNegated(), success, failure);
475+
476+
// Replace the generated placeholders with the results of the constraint and
477+
// erase them
478+
for (auto result : llvm::enumerate(applyConstraintOp.getResults())) {
479+
std::pair<ConstraintQuestion *, unsigned> substitutionKey = {
480+
cstQuestion, result.index()};
481+
constraintResultMap.insert({substitutionKey, result.value()});
482+
}
453483
break;
454484
}
455485
default:
456486
llvm_unreachable("Generating unknown Predicate operation");
457487
}
488+
489+
// Generate the matcher in the current (potentially nested) region.
490+
// This might use the results of the current predicate.
491+
generateMatcher(*boolNode->getSuccessNode(), *region, success);
458492
}
459493

460494
template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>

mlir/lib/Conversion/PDLToPDLInterp/Predicate.h

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ enum Kind : unsigned {
4747
OperandPos,
4848
OperandGroupPos,
4949
AttributePos,
50+
ConstraintResultPos,
5051
ResultPos,
5152
ResultGroupPos,
5253
TypePos,
@@ -279,6 +280,28 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
279280
bool isOperandDefiningOp() const;
280281
};
281282

283+
//===----------------------------------------------------------------------===//
284+
// ConstraintPosition
285+
286+
struct ConstraintQuestion;
287+
288+
/// A position describing the result of a native constraint. It saves the
289+
/// corresponding ConstraintQuestion and result index to enable referring
290+
/// back to them
291+
struct ConstraintPosition
292+
: public PredicateBase<ConstraintPosition, Position,
293+
std::pair<ConstraintQuestion *, unsigned>,
294+
Predicates::ConstraintResultPos> {
295+
using PredicateBase::PredicateBase;
296+
297+
/// Returns the ConstraintQuestion to enable keeping track of the native
298+
/// constraint this position stems from.
299+
ConstraintQuestion *getQuestion() const { return key.first; }
300+
301+
// Returns the result index of this position
302+
unsigned getIndex() const { return key.second; }
303+
};
304+
282305
//===----------------------------------------------------------------------===//
283306
// ResultPosition
284307

@@ -447,11 +470,13 @@ struct AttributeQuestion
447470
: public PredicateBase<AttributeQuestion, Qualifier, void,
448471
Predicates::AttributeQuestion> {};
449472

450-
/// Apply a parameterized constraint to multiple position values.
473+
/// Apply a parameterized constraint to multiple position values and possibly
474+
/// produce results.
451475
struct ConstraintQuestion
452-
: public PredicateBase<ConstraintQuestion, Qualifier,
453-
std::tuple<StringRef, ArrayRef<Position *>, bool>,
454-
Predicates::ConstraintQuestion> {
476+
: public PredicateBase<
477+
ConstraintQuestion, Qualifier,
478+
std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
479+
Predicates::ConstraintQuestion> {
455480
using Base::Base;
456481

457482
/// Return the name of the constraint.
@@ -460,15 +485,19 @@ struct ConstraintQuestion
460485
/// Return the arguments of the constraint.
461486
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
462487

488+
/// Return the result types of the constraint.
489+
ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
490+
463491
/// Return the negation status of the constraint.
464-
bool getIsNegated() const { return std::get<2>(key); }
492+
bool getIsNegated() const { return std::get<3>(key); }
465493

466494
/// Construct an instance with the given storage allocator.
467495
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
468496
KeyTy key) {
469497
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
470498
alloc.copyInto(std::get<1>(key)),
471-
std::get<2>(key)});
499+
alloc.copyInto(std::get<2>(key)),
500+
std::get<3>(key)});
472501
}
473502

474503
/// Returns a hash suitable for the given keytype.
@@ -526,6 +555,7 @@ class PredicateUniquer : public StorageUniquer {
526555
// Register the types of Positions with the uniquer.
527556
registerParametricStorageType<AttributePosition>();
528557
registerParametricStorageType<AttributeLiteralPosition>();
558+
registerParametricStorageType<ConstraintPosition>();
529559
registerParametricStorageType<ForEachPosition>();
530560
registerParametricStorageType<OperandPosition>();
531561
registerParametricStorageType<OperandGroupPosition>();
@@ -588,6 +618,12 @@ class PredicateBuilder {
588618
return OperationPosition::get(uniquer, p);
589619
}
590620

621+
// Returns a position for a new value created by a constraint.
622+
ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
623+
unsigned index) {
624+
return ConstraintPosition::get(uniquer, std::make_pair(q, index));
625+
}
626+
591627
/// Returns an attribute position for an attribute of the given operation.
592628
Position *getAttribute(OperationPosition *p, StringRef name) {
593629
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
@@ -673,11 +709,11 @@ class PredicateBuilder {
673709
}
674710

675711
/// Create a predicate that applies a generic constraint.
676-
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
677-
bool isNegated) {
678-
return {
679-
ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
680-
TrueAnswer::get(uniquer)};
712+
Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
713+
ArrayRef<Type> resultTypes, bool isNegated) {
714+
return {ConstraintQuestion::get(
715+
uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
716+
TrueAnswer::get(uniquer)};
681717
}
682718

683719
/// Create a predicate comparing a value with null.

mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/BuiltinOps.h"
1616
#include "mlir/Interfaces/InferTypeOpInterface.h"
1717
#include "llvm/ADT/MapVector.h"
18+
#include "llvm/ADT/SmallPtrSet.h"
1819
#include "llvm/ADT/TypeSwitch.h"
1920
#include "llvm/Support/Debug.h"
2021
#include <queue>
@@ -272,8 +273,17 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
272273
// Push the constraint to the furthest position.
273274
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
274275
comparePosDepth);
275-
PredicateBuilder::Predicate pred =
276-
builder.getConstraint(op.getName(), allPositions, op.getIsNegated());
276+
ResultRange results = op.getResults();
277+
PredicateBuilder::Predicate pred = builder.getConstraint(
278+
op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
279+
op.getIsNegated());
280+
281+
// For each result register a position so it can be used later
282+
for (auto [i, result] : llvm::enumerate(results)) {
283+
ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
284+
ConstraintPosition *pos = builder.getConstraintPosition(q, i);
285+
inputs[result] = pos;
286+
}
277287
predList.emplace_back(pos, pred);
278288
}
279289

@@ -875,6 +885,27 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
875885
*root = std::make_unique<ExitNode>();
876886
}
877887

888+
/// Sorts the range begin/end with the partial order given by cmp.
889+
template <typename Iterator, typename Compare>
890+
void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
891+
while (begin != end) {
892+
// Cannot compute sortBeforeOthers in the predicate of stable_partition
893+
// because stable_partition will not keep the [begin, end) range intact
894+
// while it runs.
895+
llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
896+
for (auto i = begin; i != end; ++i) {
897+
if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
898+
sortBeforeOthers.insert(*i);
899+
}
900+
901+
auto const next = std::stable_partition(begin, end, [&](auto const &a) {
902+
return sortBeforeOthers.contains(a);
903+
});
904+
assert(next != begin && "not a partial ordering");
905+
begin = next;
906+
}
907+
}
908+
878909
/// Given a module containing PDL pattern operations, generate a matcher tree
879910
/// using the patterns within the given module and return the root matcher node.
880911
std::unique_ptr<MatcherNode>
@@ -955,6 +986,24 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
955986
return *lhs < *rhs;
956987
});
957988

989+
// Mostly keep the now established order, but also ensure that
990+
// ConstraintQuestions come after the results they use.
991+
stableTopologicalSort(ordered.begin(), ordered.end(),
992+
[](OrderedPredicate *a, OrderedPredicate *b) {
993+
auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
994+
auto *cqb = dyn_cast<ConstraintQuestion>(b->question);
995+
if (cqa && cqb) {
996+
// Does any argument of b use a? Then b must be
997+
// sorted after a.
998+
return llvm::any_of(
999+
cqb->getArgs(), [&](Position *p) {
1000+
auto *cp = dyn_cast<ConstraintPosition>(p);
1001+
return cp && cp->getQuestion() == cqa;
1002+
});
1003+
}
1004+
return false;
1005+
});
1006+
9581007
// Build the matchers for each of the pattern predicate lists.
9591008
std::unique_ptr<MatcherNode> root;
9601009
for (OrderedPredicateList &list : lists)

0 commit comments

Comments
 (0)