Skip to content

Commit 8ec28af

Browse files
committed
Reapply "[mlir][PDL] Add support for native constraints with results (#82760)"
with a small stack-use-after-scope fix in getConstraintPredicates() This reverts commit c80e6ed.
1 parent da591d3 commit 8ec28af

File tree

18 files changed

+557
-98
lines changed

18 files changed

+557
-98
lines changed

mlir/include/mlir/Dialect/PDL/IR/PDLOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,25 @@ def PDL_ApplyNativeConstraintOp
3535
let description = [{
3636
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
3737
has been registered externally with the consumer of PDL, to a given set of
38-
entities.
38+
entities and optionally return a number of values.
3939

4040
Example:
4141

4242
```mlir
4343
// Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
4444
pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
45+
// Apply constraint `with_result` to `root`. This constraint returns an attribute.
46+
%attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
4547
```
4648
}];
4749

4850
let arguments = (ins StrAttr:$name,
4951
Variadic<PDL_AnyType>:$args,
5052
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
51-
let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
53+
let results = (outs Variadic<PDL_AnyType>:$results);
54+
let assemblyFormat = [{
55+
$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict
56+
}];
5257
let hasVerifier = 1;
5358
}
5459

mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
8888
let description = [{
8989
`pdl_interp.apply_constraint` operations apply a generic constraint, that
9090
has been registered with the interpreter, with a given set of positional
91-
values. On success, this operation branches to the true destination,
91+
values.
92+
The constraint function may return any number of results.
93+
On success, this operation branches to the true destination,
9294
otherwise the false destination is taken. This behavior can be reversed
9395
by setting the attribute `isNegated` to true.
9496

@@ -104,8 +106,10 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
104106
let arguments = (ins StrAttr:$name,
105107
Variadic<PDL_AnyType>:$args,
106108
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
109+
let results = (outs Variadic<PDL_AnyType>:$results);
107110
let assemblyFormat = [{
108-
$name `(` $args `:` type($args) `)` attr-dict `->` successors
111+
$name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict
112+
`->` successors
109113
}];
110114
}
111115

mlir/include/mlir/IR/PDLPatternMatch.h.inc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,9 @@ protected:
318318
/// A generic PDL pattern constraint function. This function applies a
319319
/// constraint to a given set of opaque PDLValue entities. Returns success if
320320
/// the constraint successfully held, failure otherwise.
321-
using PDLConstraintFunction =
322-
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
321+
using PDLConstraintFunction = std::function<LogicalResult(
322+
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
323+
323324
/// A native PDL rewrite function. This function performs a rewrite on the
324325
/// given set of values. Any results from this rewrite that should be passed
325326
/// back to PDL should be added to the provided result list. This method is only
@@ -726,7 +727,7 @@ std::enable_if_t<
726727
PDLConstraintFunction>
727728
buildConstraintFn(ConstraintFnT &&constraintFn) {
728729
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
729-
PatternRewriter &rewriter,
730+
PatternRewriter &rewriter, PDLResultList &,
730731
ArrayRef<PDLValue> values) -> LogicalResult {
731732
auto argIndices = std::make_index_sequence<
732733
llvm::function_traits<ConstraintFnT>::num_args - 1>();
@@ -842,10 +843,13 @@ public:
842843
/// Register a constraint function with PDL. A constraint function may be
843844
/// specified in one of two ways:
844845
///
845-
/// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
846+
/// * `LogicalResult (PatternRewriter &,
847+
/// PDLResultList &,
848+
/// ArrayRef<PDLValue>)`
846849
///
847850
/// In this overload the arguments of the constraint function are passed via
848-
/// the low-level PDLValue form.
851+
/// the low-level PDLValue form, and the results are manually appended to
852+
/// the given result list.
849853
///
850854
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
851855
///
@@ -960,8 +964,8 @@ public:
960964
}
961965
};
962966
class PDLResultList {};
963-
using PDLConstraintFunction =
964-
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
967+
using PDLConstraintFunction = std::function<LogicalResult(
968+
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
965969
using PDLRewriteFunction = std::function<LogicalResult(
966970
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
967971

mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ struct PatternLowering {
5050

5151
/// Generate interpreter operations for the tree rooted at the given matcher
5252
/// node, in the specified region.
53-
Block *generateMatcher(MatcherNode &node, Region &region);
53+
Block *generateMatcher(MatcherNode &node, Region &region,
54+
Block *block = nullptr);
5455

5556
/// Get or create an access to the provided positional value in the current
5657
/// block. This operation may mutate the provided block pointer if nested
@@ -148,6 +149,10 @@ struct PatternLowering {
148149
/// A mapping between pattern operations and the corresponding configuration
149150
/// set.
150151
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
152+
153+
/// A mapping from a constraint question to the ApplyConstraintOp
154+
/// that implements it.
155+
DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
151156
};
152157
} // namespace
153158

@@ -182,9 +187,11 @@ void PatternLowering::lower(ModuleOp module) {
182187
firstMatcherBlock->erase();
183188
}
184189

185-
Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
190+
Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
191+
Block *block) {
186192
// Push a new scope for the values used by this matcher.
187-
Block *block = &region.emplaceBlock();
193+
if (!block)
194+
block = &region.emplaceBlock();
188195
ValueMapScope scope(values);
189196

190197
// If this is the return node, simply insert the corresponding interpreter
@@ -364,6 +371,15 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
364371
loc, cast<ArrayAttr>(rawTypeAttr));
365372
break;
366373
}
374+
case Predicates::ConstraintResultPos: {
375+
// Due to the order of traversal, the ApplyConstraintOp has already been
376+
// created and we can find it in constraintOpMap.
377+
auto *constrResPos = cast<ConstraintPosition>(pos);
378+
auto i = constraintOpMap.find(constrResPos->getQuestion());
379+
assert(i != constraintOpMap.end());
380+
value = i->second->getResult(constrResPos->getIndex());
381+
break;
382+
}
367383
default:
368384
llvm_unreachable("Generating unknown Position getter");
369385
break;
@@ -390,12 +406,11 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
390406
args.push_back(getValueAt(currentBlock, position));
391407
}
392408

393-
// Generate the matcher in the current (potentially nested) region
394-
// and get the failure successor.
395-
Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
409+
// Generate a new block as success successor and get the failure successor.
410+
Block *success = &region->emplaceBlock();
396411
Block *failure = failureBlockStack.back();
397412

398-
// Finally, create the predicate.
413+
// Create the predicate.
399414
builder.setInsertionPointToEnd(currentBlock);
400415
Predicates::Kind kind = question->getKind();
401416
switch (kind) {
@@ -447,14 +462,20 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
447462
}
448463
case Predicates::ConstraintQuestion: {
449464
auto *cstQuestion = cast<ConstraintQuestion>(question);
450-
builder.create<pdl_interp::ApplyConstraintOp>(
451-
loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
452-
failure);
465+
auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
466+
loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
467+
cstQuestion->getIsNegated(), success, failure);
468+
469+
constraintOpMap.insert({cstQuestion, applyConstraintOp});
453470
break;
454471
}
455472
default:
456473
llvm_unreachable("Generating unknown Predicate operation");
457474
}
475+
476+
// Generate the matcher in the current (potentially nested) region.
477+
// This might use the results of the current predicate.
478+
generateMatcher(*boolNode->getSuccessNode(), *region, success);
458479
}
459480

460481
template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>

mlir/lib/Conversion/PDLToPDLInterp/Predicate.h

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ enum Kind : unsigned {
4747
OperandPos,
4848
OperandGroupPos,
4949
AttributePos,
50+
ConstraintResultPos,
5051
ResultPos,
5152
ResultGroupPos,
5253
TypePos,
@@ -279,6 +280,28 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
279280
bool isOperandDefiningOp() const;
280281
};
281282

283+
//===----------------------------------------------------------------------===//
284+
// ConstraintPosition
285+
286+
struct ConstraintQuestion;
287+
288+
/// A position describing the result of a native constraint. It saves the
289+
/// corresponding ConstraintQuestion and result index to enable referring
290+
/// back to them
291+
struct ConstraintPosition
292+
: public PredicateBase<ConstraintPosition, Position,
293+
std::pair<ConstraintQuestion *, unsigned>,
294+
Predicates::ConstraintResultPos> {
295+
using PredicateBase::PredicateBase;
296+
297+
/// Returns the ConstraintQuestion to enable keeping track of the native
298+
/// constraint this position stems from.
299+
ConstraintQuestion *getQuestion() const { return key.first; }
300+
301+
// Returns the result index of this position
302+
unsigned getIndex() const { return key.second; }
303+
};
304+
282305
//===----------------------------------------------------------------------===//
283306
// ResultPosition
284307

@@ -447,11 +470,13 @@ struct AttributeQuestion
447470
: public PredicateBase<AttributeQuestion, Qualifier, void,
448471
Predicates::AttributeQuestion> {};
449472

450-
/// Apply a parameterized constraint to multiple position values.
473+
/// Apply a parameterized constraint to multiple position values and possibly
474+
/// produce results.
451475
struct ConstraintQuestion
452-
: public PredicateBase<ConstraintQuestion, Qualifier,
453-
std::tuple<StringRef, ArrayRef<Position *>, bool>,
454-
Predicates::ConstraintQuestion> {
476+
: public PredicateBase<
477+
ConstraintQuestion, Qualifier,
478+
std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
479+
Predicates::ConstraintQuestion> {
455480
using Base::Base;
456481

457482
/// Return the name of the constraint.
@@ -460,15 +485,19 @@ struct ConstraintQuestion
460485
/// Return the arguments of the constraint.
461486
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
462487

488+
/// Return the result types of the constraint.
489+
ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
490+
463491
/// Return the negation status of the constraint.
464-
bool getIsNegated() const { return std::get<2>(key); }
492+
bool getIsNegated() const { return std::get<3>(key); }
465493

466494
/// Construct an instance with the given storage allocator.
467495
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
468496
KeyTy key) {
469497
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
470498
alloc.copyInto(std::get<1>(key)),
471-
std::get<2>(key)});
499+
alloc.copyInto(std::get<2>(key)),
500+
std::get<3>(key)});
472501
}
473502

474503
/// Returns a hash suitable for the given keytype.
@@ -526,6 +555,7 @@ class PredicateUniquer : public StorageUniquer {
526555
// Register the types of Positions with the uniquer.
527556
registerParametricStorageType<AttributePosition>();
528557
registerParametricStorageType<AttributeLiteralPosition>();
558+
registerParametricStorageType<ConstraintPosition>();
529559
registerParametricStorageType<ForEachPosition>();
530560
registerParametricStorageType<OperandPosition>();
531561
registerParametricStorageType<OperandGroupPosition>();
@@ -588,6 +618,12 @@ class PredicateBuilder {
588618
return OperationPosition::get(uniquer, p);
589619
}
590620

621+
// Returns a position for a new value created by a constraint.
622+
ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
623+
unsigned index) {
624+
return ConstraintPosition::get(uniquer, std::make_pair(q, index));
625+
}
626+
591627
/// Returns an attribute position for an attribute of the given operation.
592628
Position *getAttribute(OperationPosition *p, StringRef name) {
593629
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
@@ -673,11 +709,11 @@ class PredicateBuilder {
673709
}
674710

675711
/// Create a predicate that applies a generic constraint.
676-
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
677-
bool isNegated) {
678-
return {
679-
ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
680-
TrueAnswer::get(uniquer)};
712+
Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
713+
ArrayRef<Type> resultTypes, bool isNegated) {
714+
return {ConstraintQuestion::get(
715+
uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
716+
TrueAnswer::get(uniquer)};
681717
}
682718

683719
/// Create a predicate comparing a value with null.

0 commit comments

Comments
 (0)