Skip to content

Commit bd44b77

Browse files
martin-lueckeFerdinand Lemaire
authored and
Ferdinand Lemaire
committed
FXML.1923: PDLL support for native constraints with attribute results (#24)
1 parent 4325a06 commit bd44b77

File tree

15 files changed

+255
-33
lines changed

15 files changed

+255
-33
lines changed

mlir/include/mlir/Dialect/PDL/IR/PDLOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def PDL_ApplyNativeConstraintOp
3535
let description = [{
3636
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
3737
has been registered externally with the consumer of PDL, to a given set of
38-
entities.
38+
entities and optionally return a number of values..
3939

4040
Example:
4141

@@ -46,7 +46,8 @@ def PDL_ApplyNativeConstraintOp
4646
}];
4747

4848
let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args);
49-
let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
49+
let results = (outs Variadic<PDL_AnyType>:$results);
50+
let assemblyFormat = "$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict";
5051
let hasVerifier = 1;
5152
}
5253

mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
8888
let description = [{
8989
`pdl_interp.apply_constraint` operations apply a generic constraint, that
9090
has been registered with the interpreter, with a given set of positional
91-
values. On success, this operation branches to the true destination,
92-
otherwise the false destination is taken.
91+
values. The constraint function may return any number of results.
92+
On success, this operation branches to the true destination, otherwise
93+
the false destination is taken.
9394

9495
Example:
9596

@@ -101,8 +102,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
101102
}];
102103

103104
let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args);
105+
let results = (outs Variadic<PDL_AnyType>:$results);
104106
let assemblyFormat = [{
105-
$name `(` $args `:` type($args) `)` attr-dict `->` successors
107+
$name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict `->` successors
106108
}];
107109
}
108110

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,6 +1525,20 @@ class PDLPatternModule {
15251525
std::forward<ConstraintFnT>(constraintFn)));
15261526
}
15271527

1528+
/// Register a constraint function that produces results with PDL. A
1529+
/// constraint function with results uses the same registry as
1530+
/// rewrite functions. It may be specified as follows:
1531+
///
1532+
/// * `LogicalResult (PatternRewriter &, PDLResultList &,
1533+
/// ArrayRef<PDLValue>)`
1534+
///
1535+
/// In this overload the arguments of the constraint function are passed via
1536+
/// the low-level PDLValue form, and the results are manually appended to
1537+
/// the given result list.
1538+
///
1539+
void registerConstraintFunctionWithResults(StringRef name,
1540+
PDLRewriteFunction constraintFn);
1541+
15281542
/// Register a rewrite function with PDL. A rewrite function may be specified
15291543
/// in one of two ways:
15301544
///

mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ struct PatternLowering {
148148
/// A mapping between pattern operations and the corresponding configuration
149149
/// set.
150150
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
151+
152+
/// A mapping between constraint questions that refer to values created by
153+
/// constraints and the temporary placeholder values created for them.
154+
DenseMap<std::pair<ConstraintQuestion *, unsigned>, Value> substitutions;
151155
};
152156
} // namespace
153157

@@ -364,6 +368,20 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
364368
loc, rawTypeAttr.cast<ArrayAttr>());
365369
break;
366370
}
371+
case Predicates::ConstraintResultPos: {
372+
// At this point in time the corresponding pdl.ApplyNativeConstraint op has
373+
// been deleted and the new pdl_interp.ApplyConstraint has not been created
374+
// yet. To enable use of results created by these operations we build a
375+
// placeholder value that will be replaced when the actual
376+
// pdl_interp.ApplyConstraint operation is created.
377+
auto *constrResPos = cast<ConstraintPosition>(pos);
378+
Value placeholderValue = builder.create<pdl_interp::CreateAttributeOp>(
379+
loc, StringAttr::get(builder.getContext(), "placeholder"));
380+
substitutions[{constrResPos->getQuestion(), constrResPos->getIndex()}] =
381+
placeholderValue;
382+
value = placeholderValue;
383+
break;
384+
}
367385
default:
368386
llvm_unreachable("Generating unknown Position getter");
369387
break;
@@ -447,8 +465,21 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
447465
}
448466
case Predicates::ConstraintQuestion: {
449467
auto *cstQuestion = cast<ConstraintQuestion>(question);
450-
builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(),
451-
args, success, failure);
468+
auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
469+
loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
470+
success, failure);
471+
// Replace the generated placeholders with the results of the constraint and
472+
// erase them
473+
for (auto result : llvm::enumerate(applyConstraintOp.getResults())) {
474+
std::pair<ConstraintQuestion *, unsigned> substitutionKey = {
475+
cstQuestion, result.index()};
476+
// Check if there are substitutions to perform. If the result is never
477+
// used no substitutions will have been generated.
478+
if (substitutions.count(substitutionKey)) {
479+
substitutions[substitutionKey].replaceAllUsesWith(result.value());
480+
substitutions[substitutionKey].getDefiningOp()->erase();
481+
}
482+
}
452483
break;
453484
}
454485
default:

mlir/lib/Conversion/PDLToPDLInterp/Predicate.h

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ enum Kind : unsigned {
4747
OperandPos,
4848
OperandGroupPos,
4949
AttributePos,
50+
ConstraintResultPos,
5051
ResultPos,
5152
ResultGroupPos,
5253
TypePos,
@@ -187,6 +188,25 @@ struct AttributeLiteralPosition
187188
using PredicateBase::PredicateBase;
188189
};
189190

191+
//===----------------------------------------------------------------------===//
192+
// ConstraintPosition
193+
194+
struct ConstraintQuestion;
195+
196+
/// A position describing the result of a native constraint. It saves the
197+
/// corresponding ConstraintQuestion and result index to enable referring
198+
/// back to them
199+
struct ConstraintPosition
200+
: public PredicateBase<ConstraintPosition, Position,
201+
std::pair<ConstraintQuestion *, unsigned>,
202+
Predicates::ConstraintResultPos> {
203+
using PredicateBase::PredicateBase;
204+
205+
ConstraintQuestion *getQuestion() const { return key.first; }
206+
207+
unsigned getIndex() const { return key.second; }
208+
};
209+
190210
//===----------------------------------------------------------------------===//
191211
// ForEachPosition
192212

@@ -447,11 +467,13 @@ struct AttributeQuestion
447467
: public PredicateBase<AttributeQuestion, Qualifier, void,
448468
Predicates::AttributeQuestion> {};
449469

450-
/// Apply a parameterized constraint to multiple position values.
470+
/// Apply a parameterized constraint to multiple position values and possibly
471+
/// produce results.
451472
struct ConstraintQuestion
452-
: public PredicateBase<ConstraintQuestion, Qualifier,
453-
std::tuple<StringRef, ArrayRef<Position *>>,
454-
Predicates::ConstraintQuestion> {
473+
: public PredicateBase<
474+
ConstraintQuestion, Qualifier,
475+
std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>>,
476+
Predicates::ConstraintQuestion> {
455477
using Base::Base;
456478

457479
/// Return the name of the constraint.
@@ -460,11 +482,15 @@ struct ConstraintQuestion
460482
/// Return the arguments of the constraint.
461483
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
462484

485+
/// Return the result types of the constraint.
486+
ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
487+
463488
/// Construct an instance with the given storage allocator.
464489
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
465490
KeyTy key) {
466491
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
467-
alloc.copyInto(std::get<1>(key))});
492+
alloc.copyInto(std::get<1>(key)),
493+
alloc.copyInto(std::get<2>(key))});
468494
}
469495
};
470496

@@ -517,6 +543,7 @@ class PredicateUniquer : public StorageUniquer {
517543
// Register the types of Positions with the uniquer.
518544
registerParametricStorageType<AttributePosition>();
519545
registerParametricStorageType<AttributeLiteralPosition>();
546+
registerParametricStorageType<ConstraintPosition>();
520547
registerParametricStorageType<ForEachPosition>();
521548
registerParametricStorageType<OperandPosition>();
522549
registerParametricStorageType<OperandGroupPosition>();
@@ -579,6 +606,12 @@ class PredicateBuilder {
579606
return OperationPosition::get(uniquer, p);
580607
}
581608

609+
// Returns a position for a new value created by a constraint.
610+
ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
611+
unsigned index) {
612+
return ConstraintPosition::get(uniquer, std::make_pair(q, index));
613+
}
614+
582615
/// Returns an attribute position for an attribute of the given operation.
583616
Position *getAttribute(OperationPosition *p, StringRef name) {
584617
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
@@ -664,8 +697,10 @@ class PredicateBuilder {
664697
}
665698

666699
/// Create a predicate that applies a generic constraint.
667-
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos) {
668-
return {ConstraintQuestion::get(uniquer, std::make_tuple(name, pos)),
700+
Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
701+
ArrayRef<Type> resultTypes) {
702+
return {ConstraintQuestion::get(uniquer,
703+
std::make_tuple(name, args, resultTypes)),
669704
TrueAnswer::get(uniquer)};
670705
}
671706

mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,16 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
272272
// Push the constraint to the furthest position.
273273
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
274274
comparePosDepth);
275-
PredicateBuilder::Predicate pred =
276-
builder.getConstraint(op.getName(), allPositions);
275+
ResultRange results = op.getResults();
276+
PredicateBuilder::Predicate pred = builder.getConstraint(
277+
op.getName(), allPositions, SmallVector<Type>(results.getTypes()));
278+
279+
// for each result register a position so it can be used later
280+
for (auto result : llvm::enumerate(results)) {
281+
ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
282+
ConstraintPosition *pos = builder.getConstraintPosition(q, result.index());
283+
inputs[result.value()] = pos;
284+
}
277285
predList.emplace_back(pos, pred);
278286
}
279287

mlir/lib/IR/PatternMatch.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,15 @@ void PDLPatternModule::registerConstraintFunction(
204204
constraintFunctions.try_emplace(name, std::move(constraintFn));
205205
}
206206

207+
void PDLPatternModule::registerConstraintFunctionWithResults(
208+
StringRef name, PDLRewriteFunction constraintFn) {
209+
// TODO: Is it possible to diagnose when `name` is already registered to
210+
// a function that is not equivalent to `rewriteFn`?
211+
// Allow existing mappings in the case multiple patterns depend on the same
212+
// rewrite.
213+
registerRewriteFunction(name, std::move(constraintFn));
214+
}
215+
207216
void PDLPatternModule::registerRewriteFunction(StringRef name,
208217
PDLRewriteFunction rewriteFn) {
209218
// TODO: Is it possible to diagnose when `name` is already registered to

mlir/lib/Rewrite/ByteCode.cpp

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -769,10 +769,28 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
769769

770770
void Generator::generate(pdl_interp::ApplyConstraintOp op,
771771
ByteCodeWriter &writer) {
772-
assert(constraintToMemIndex.count(op.getName()) &&
773-
"expected index for constraint function");
774-
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
772+
/// Constraints that should return a value have to be registered as rewrites
773+
/// If the constraint and rewrite of similar name are registered the
774+
/// constraint fun takes precedence
775+
ResultRange results = op.getResults();
776+
if (results.size() == 0 && constraintToMemIndex.count(op.getName()) != 0) {
777+
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
778+
} else if (results.size() > 0 &&
779+
externalRewriterToMemIndex.count(op.getName()) != 0) {
780+
writer.append(OpCode::ApplyConstraint,
781+
externalRewriterToMemIndex[op.getName()]);
782+
} else {
783+
assert(true && "expected index for constraint function, make sure it is "
784+
"registered properly. Note that native constraints with "
785+
"results have to be registered using "
786+
"PDLPatternModule::registerConstraintFunctionWithResults.");
787+
}
775788
writer.appendPDLValueList(op.getArgs());
789+
writer.append(ByteCodeField(results.size()));
790+
for (Value result : results) {
791+
// TODO: Handle result ranges
792+
writer.append(result);
793+
}
776794
writer.append(op.getSuccessors());
777795
}
778796
void Generator::generate(pdl_interp::ApplyRewriteOp op,
@@ -1406,7 +1424,7 @@ class ByteCodeRewriteResultList : public PDLResultList {
14061424

14071425
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
14081426
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1409-
const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1427+
ByteCodeField fun_idx = read();
14101428
SmallVector<PDLValue, 16> args;
14111429
readList<PDLValue>(args);
14121430

@@ -1415,8 +1433,26 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
14151433
llvm::interleaveComma(args, llvm::dbgs());
14161434
});
14171435

1418-
// Invoke the constraint and jump to the proper destination.
1419-
selectJump(succeeded(constraintFn(rewriter, args)));
1436+
ByteCodeField numResults = read();
1437+
if (numResults == 0) {
1438+
const PDLConstraintFunction &constraintFn = constraintFunctions[fun_idx];
1439+
LogicalResult rewriteResult = constraintFn(rewriter, args);
1440+
// Depending on the constraint jump to the proper destination.
1441+
selectJump(succeeded(rewriteResult));
1442+
} else {
1443+
const PDLRewriteFunction &constraintFn = rewriteFunctions[fun_idx];
1444+
ByteCodeRewriteResultList results(numResults);
1445+
LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1446+
assert(results.getResults().size() == numResults &&
1447+
"native PDL rewrite function returned unexpected number of results");
1448+
1449+
for (PDLValue &result : results.getResults()) {
1450+
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
1451+
memory[read()] = result.getAsOpaquePointer();
1452+
}
1453+
// Depending on the constraint jump to the proper destination.
1454+
selectJump(succeeded(rewriteResult));
1455+
}
14201456
}
14211457

14221458
LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {

mlir/lib/Tools/PDLL/Parser/Parser.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,12 +1359,6 @@ FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
13591359
if (failed(parseToken(Token::semicolon,
13601360
"expected `;` after native declaration")))
13611361
return failure();
1362-
// TODO: PDL should be able to support constraint results in certain
1363-
// situations, we should revise this.
1364-
if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1365-
return emitError(
1366-
"native Constraints currently do not support returning results");
1367-
}
13681362
return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
13691363
}
13701364

mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,34 @@ module @constraints {
7979

8080
// -----
8181

82+
// CHECK-LABEL: module @constraint_with_result
83+
module @constraint_with_result {
84+
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
85+
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
86+
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
87+
pdl.pattern : benefit(1) {
88+
%root = operation
89+
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
90+
rewrite %root with "rewriter"(%attr : !pdl.attribute)
91+
}
92+
}
93+
94+
// -----
95+
96+
// CHECK-LABEL: module @constraint_with_unused_result
97+
module @constraint_with_unused_result {
98+
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
99+
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
100+
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation)
101+
pdl.pattern : benefit(1) {
102+
%root = operation
103+
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
104+
rewrite %root with "rewriter"
105+
}
106+
}
107+
108+
// -----
109+
82110
// CHECK-LABEL: module @inputs
83111
module @inputs {
84112
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)

mlir/test/Dialect/PDL/ops.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,24 @@ pdl.pattern @apply_rewrite_with_no_results : benefit(1) {
134134

135135
// -----
136136

137+
pdl.pattern @apply_constraint_with_no_results : benefit(1) {
138+
%root = operation
139+
apply_native_constraint "NativeConstraint"(%root : !pdl.operation)
140+
rewrite %root with "rewriter"
141+
}
142+
143+
// -----
144+
145+
pdl.pattern @apply_constraint_with_results : benefit(1) {
146+
%root = operation
147+
%attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute
148+
rewrite %root {
149+
apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute)
150+
}
151+
}
152+
153+
// -----
154+
137155
pdl.pattern @attribute_with_dict : benefit(1) {
138156
%root = operation
139157
rewrite %root {

0 commit comments

Comments
 (0)