Skip to content

[mlir][PDL] Add support for native constraints with results #82760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 1, 2024

Conversation

mgehre-amd
Copy link
Contributor

@mgehre-amd mgehre-amd commented Feb 23, 2024

From https://reviews.llvm.org/D153245

This adds support for native PDL (and PDLL) C++ constraints to return results.

This is useful for situations where a pattern checks for certain constraints of multiple interdependent attributes and computes a new attribute value based on them. Currently, for such an example it is required to escape to C++ during matching to perform the check and after a successful match again escape to native C++ to perform the computation during the rewriting part of the pattern. With this work we can do the computation in C++ during matching and use the result in the rewriting part of the pattern. Effectively this enables a choice in the trade-off of memory consumption during matching vs recomputation of values.

This is an example of a situation where this is useful: We have two operations with certain attributes that have interdependent constraints. For instance attr_foo: one_of [0, 2, 4, 8], attr_bar: one_of [0, 2, 4, 8] and attr_foo == attr_bar. The pattern should only match if all conditions are true. The new operation should be created with a new attribute which is computed from the two matched attributes e.g. attr_baz = attr_foo * attr_bar. For the check we already escape to native C++ and have all values at hand so it makes sense to directly compute the new attribute value as well:

Constraint checkAndCompute(attr0: Attr, attr1: Attr) -> Attr;

Pattern example with benefit(1) {
    let foo = op<test.foo>() {attr = attr_foo : Attr};
    let bar = op<test.bar>(foo) {attr = attr_bar : Attr};
    let attr_baz = checkAndCompute(attr_foo, attr_bar);
    rewrite bar with {
        let baz = op<test.baz> {attr=attr_baz};
        replace bar with baz;
    };
}

To achieve this the following notable changes were necessary:
PDLL:

  • Remove check in PDLL parser that prevented native constraints from returning results

PDL:

  • Change PDL definition of pdl.apply_native_constraint to allow variadic results

PDL_interp:

  • Change PDL_interp definition of pdl_interp.apply_constraint to allow variadic results

PDLToPDLInterp Pass:
The input to the pass is an arbitrary number of PDL patterns. The pass collects the predicates that are required to match all of the pdl patterns and establishes an ordering that allows creation of a single efficient matcher function to match all of them. Values that are matched and possibly used in the rewriting part of a pattern are represented as positions. This allows fusion and thus reusing a single position for multiple matching patterns. Accordingly, we introduce ConstraintPosition, which records the type and index of the result of the constraint. The problem is for the corresponding value to be used in the rewriting part of a pattern it has to be an input to the pdl_interp.record_match operation, which is generated early during the pass such that its surrounding block can be referred to by branching operations. In consequence the value has to be materialized after the original pdl.apply_native_constraint has been deleted but before we get the chance to generate the corresponding pdl_interp.apply_constraint operation. We solve this by emitting a placeholder value when a ConstraintPosition is evaluated. These placeholder values (due to fusion there may be multiple for one constraint result) are replaced later when the actual pdl_interp.apply_constraint operation is created.

Changes since the phabricator review:

  • Addressed all comments
  • In particular, removed registerConstraintFunctionWithResults and instead changed registerConstraintFunction so that contraint functions always have results (empty by default)
  • Thus we don't need to reuse rewriteFunctions to store constraint functions with results anymore, and can instead use constraintFunctions
  • Perform a stable sort of ConstraintQuestion, so that ConstraintQuestion appear before other ConstraintQuestion that use their results.
  • Don't create placeholders for pdl_interp::ApplyConstraintOp. Instead generate the pdl_interp::ApplyConstraintOp before generating the successor block.
  • Fixed a test failure in the pdl python bindings

Original code by @martin-luecke

@mgehre-amd mgehre-amd force-pushed the matthias.pdl_constraint_results branch from ab7d422 to 44cfde5 Compare February 23, 2024 13:07
Copy link

github-actions bot commented Feb 23, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@mgehre-amd mgehre-amd force-pushed the matthias.pdl_constraint_results branch from 44cfde5 to dd474be Compare February 23, 2024 15:29
@mgehre-amd mgehre-amd force-pushed the matthias.pdl_constraint_results branch from dd474be to 403ecc3 Compare February 23, 2024 22:27
@mgehre-amd mgehre-amd requested a review from ftynse February 23, 2024 22:27
@mgehre-amd mgehre-amd marked this pull request as ready for review February 23, 2024 22:27
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:pdl labels Feb 23, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 23, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-pdl

@llvm/pr-subscribers-mlir-core

Author: Matthias Gehre (mgehre-amd)

Changes

From https://reviews.llvm.org/D153245

This adds support for native PDL (and PDLL) C++ constraints to return results.

This is useful for situations where a pattern checks for certain constraints of multiple interdependent attributes and computes a new attribute value based on them. Currently, for such an example it is required to escape to C++ during matching to perform the check and after a successful match again escape to native C++ to perform the computation during the rewriting part of the pattern. With this work we can do the computation in C++ during matching and use the result in the rewriting part of the pattern. Effectively this enables a choice in the trade-off of memory consumption during matching vs recomputation of values.

This is an example of a situation where this is useful: We have two operations with certain attributes that have interdependent constraints. For instance attr_foo: one_of [0, 2, 4, 8], attr_bar: one_of [0, 2, 4, 8] and attr_foo == attr_bar. The pattern should only match if all conditions are true. The new operation should be created with a new attribute which is computed from the two matched attributes e.g. attr_baz = attr_foo * attr_bar. For the check we already escape to native C++ and have all values at hand so it makes sense to directly compute the new attribute value as well:

Constraint checkAndCompute(attr0: Attr, attr1: Attr) -&gt; Attr;

Pattern example with benefit(1) {
    let foo = op&lt;test.foo&gt;() {attr = attr_foo : Attr};
    let bar = op&lt;test.bar&gt;(foo) {attr = attr_bar : Attr};
    let attr_baz = checkAndCompute(attr_foo, attr_bar);
    rewrite bar with {
        let baz = op&lt;test.baz&gt; {attr=attr_baz};
        replace bar with baz;
    };
}

To achieve this the following notable changes were necessary:
PDLL:

  • Remove check in PDLL parser that prevented native constraints from returning results

PDL:

  • Change PDL definition of pdl.apply_native_constraint to allow variadic results

PDL_interp:

  • Change PDL_interp definition of pdl_interp.apply_constraint to allow variadic results

PDLToPDLInterp Pass:
The input to the pass is an arbitrary number of PDL patterns. The pass collects the predicates that are required to match all of the pdl patterns and establishes an ordering that allows creation of a single efficient matcher function to match all of them. Values that are matched and possibly used in the rewriting part of a pattern are represented as positions. This allows fusion and thus reusing a single position for multiple matching patterns. Accordingly, we introduce ConstraintPosition, which records the type and index of the result of the constraint. The problem is for the corresponding value to be used in the rewriting part of a pattern it has to be an input to the pdl_interp.record_match operation, which is generated early during the pass such that its surrounding block can be referred to by branching operations. In consequence the value has to be materialized after the original pdl.apply_native_constraint has been deleted but before we get the chance to generate the corresponding pdl_interp.apply_constraint operation. We solve this by emitting a placeholder value when a ConstraintPosition is evaluated. These placeholder values (due to fusion there may be multiple for one constraint result) are replaced later when the actual pdl_interp.apply_constraint operation is created.

Changes since the phabricator review:

  • Addressed all comments
  • In particular, removed registerConstraintFunctionWithResults and instead changed registerConstraintFunction so that contraint functions always have results (empty by default)
  • Thus we don't need to reuse rewriteFunctions to store constraint functions with results anymore, and can instead use constraintFunctions
  • Perform a stable sort of ConstraintQuestion, so that ConstraintQuestion appear before other ConstraintQuestion that use their results.
  • Don't create placeholders for pdl_interp::ApplyConstraintOp. Instead generate the pdl_interp::ApplyConstraintOp before generating the successor block.
  • Fixed a test failure in the pdl python bindings

Original code by @martin-luecke


Patch is 41.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82760.diff

18 Files Affected:

  • (modified) mlir/include/mlir/Dialect/PDL/IR/PDLOps.td (+7-2)
  • (modified) mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td (+6-2)
  • (modified) mlir/include/mlir/IR/PDLPatternMatch.h.inc (+11-7)
  • (modified) mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp (+31-10)
  • (modified) mlir/lib/Conversion/PDLToPDLInterp/Predicate.h (+47-11)
  • (modified) mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp (+51-2)
  • (modified) mlir/lib/Dialect/PDL/IR/PDL.cpp (+6)
  • (modified) mlir/lib/Rewrite/ByteCode.cpp (+99-45)
  • (modified) mlir/lib/Tools/PDLL/Parser/Parser.cpp (-6)
  • (modified) mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir (+51)
  • (added) mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir (+21)
  • (modified) mlir/test/Dialect/PDL/ops.mlir (+18)
  • (modified) mlir/test/Rewrite/pdl-bytecode.mlir (+68)
  • (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp (+1-1)
  • (modified) mlir/test/lib/Rewrite/TestPDLByteCode.cpp (+50)
  • (modified) mlir/test/mlir-pdll/Parser/constraint-failure.pdll (-5)
  • (modified) mlir/test/mlir-pdll/Parser/constraint.pdll (+8)
  • (modified) mlir/test/python/dialects/pdl_ops.py (+1-1)
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index 4e9ebccba77d88..1e108c3d8ac77a 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -35,20 +35,25 @@ def PDL_ApplyNativeConstraintOp
   let description = [{
     `pdl.apply_native_constraint` operations apply a native C++ constraint, that
     has been registered externally with the consumer of PDL, to a given set of
-    entities.
+    entities and optionally return a number of values.
 
     Example:
 
     ```mlir
     // Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
     pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
+    // Apply constraint `with_result` to `root`. This constraint returns an attribute.
+    %attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
     ```
   }];
 
   let arguments = (ins StrAttr:$name, 
                        Variadic<PDL_AnyType>:$args, 
                        DefaultValuedAttr<BoolAttr, "false">:$isNegated);
-  let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
+  let results = (outs Variadic<PDL_AnyType>:$results);
+  let assemblyFormat = [{
+    $name `(` $args `:` type($args) `)` (`:`  type($results)^ )? attr-dict
+  }];
   let hasVerifier = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index 48f625bd1fa3fd..901acc0e6733bb 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -88,7 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
   let description = [{
     `pdl_interp.apply_constraint` operations apply a generic constraint, that
     has been registered with the interpreter, with a given set of positional
-    values. On success, this operation branches to the true destination,
+    values.
+    The constraint function may return any number of results.
+    On success, this operation branches to the true destination,
     otherwise the false destination is taken. This behavior can be reversed
     by setting the attribute `isNegated` to true.
 
@@ -104,8 +106,10 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
   let arguments = (ins StrAttr:$name, 
                        Variadic<PDL_AnyType>:$args,
                        DefaultValuedAttr<BoolAttr, "false">:$isNegated);
+  let results = (outs Variadic<PDL_AnyType>:$results);
   let assemblyFormat = [{
-    $name `(` $args `:` type($args) `)` attr-dict `->` successors
+    $name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict 
+    `->` successors
   }];
 }
 
diff --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
index a215da8cb6431d..66286ed7a4c898 100644
--- a/mlir/include/mlir/IR/PDLPatternMatch.h.inc
+++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
@@ -318,8 +318,9 @@ protected:
 /// A generic PDL pattern constraint function. This function applies a
 /// constraint to a given set of opaque PDLValue entities. Returns success if
 /// the constraint successfully held, failure otherwise.
-using PDLConstraintFunction =
-    std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
+using PDLConstraintFunction = std::function<LogicalResult(
+    PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
+
 /// A native PDL rewrite function. This function performs a rewrite on the
 /// given set of values. Any results from this rewrite that should be passed
 /// back to PDL should be added to the provided result list. This method is only
@@ -726,7 +727,7 @@ std::enable_if_t<
     PDLConstraintFunction>
 buildConstraintFn(ConstraintFnT &&constraintFn) {
   return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
-             PatternRewriter &rewriter,
+             PatternRewriter &rewriter, PDLResultList &,
              ArrayRef<PDLValue> values) -> LogicalResult {
     auto argIndices = std::make_index_sequence<
         llvm::function_traits<ConstraintFnT>::num_args - 1>();
@@ -842,10 +843,13 @@ public:
   /// Register a constraint function with PDL. A constraint function may be
   /// specified in one of two ways:
   ///
-  ///   * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
+  ///   * `LogicalResult (PatternRewriter &,
+  ///                     PDLResultList &,
+  ///                     ArrayRef<PDLValue>)`
   ///
   ///   In this overload the arguments of the constraint function are passed via
-  ///   the low-level PDLValue form.
+  ///   the low-level PDLValue form, and the results are manually appended to
+  ///   the given result list.
   ///
   ///   * `LogicalResult (PatternRewriter &, ValueTs... values)`
   ///
@@ -960,8 +964,8 @@ public:
   }
 };
 class PDLResultList {};
-using PDLConstraintFunction =
-    std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
+using PDLConstraintFunction = std::function<LogicalResult(
+    PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
 using PDLRewriteFunction = std::function<LogicalResult(
     PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
 
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index e911631a4bc52a..b00cd0dee3ae80 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -50,7 +50,8 @@ struct PatternLowering {
 
   /// Generate interpreter operations for the tree rooted at the given matcher
   /// node, in the specified region.
-  Block *generateMatcher(MatcherNode &node, Region &region);
+  Block *generateMatcher(MatcherNode &node, Region &region,
+                         Block *block = nullptr);
 
   /// Get or create an access to the provided positional value in the current
   /// block. This operation may mutate the provided block pointer if nested
@@ -148,6 +149,10 @@ struct PatternLowering {
   /// A mapping between pattern operations and the corresponding configuration
   /// set.
   DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
+
+  /// A mapping from a constraint question to the ApplyConstraintOp
+  /// that implements it.
+  DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
 };
 } // namespace
 
@@ -182,9 +187,11 @@ void PatternLowering::lower(ModuleOp module) {
   firstMatcherBlock->erase();
 }
 
-Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
+Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
+                                        Block *block) {
   // Push a new scope for the values used by this matcher.
-  Block *block = &region.emplaceBlock();
+  if (!block)
+    block = &region.emplaceBlock();
   ValueMapScope scope(values);
 
   // If this is the return node, simply insert the corresponding interpreter
@@ -364,6 +371,15 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
           loc, cast<ArrayAttr>(rawTypeAttr));
     break;
   }
+  case Predicates::ConstraintResultPos: {
+    // Due to the order of traversal, the ApplyConstraintOp has already been
+    // created and we can find it in constraintOpMap.
+    auto *constrResPos = cast<ConstraintPosition>(pos);
+    auto i = constraintOpMap.find(constrResPos->getQuestion());
+    assert(i != constraintOpMap.end());
+    value = i->second->getResult(constrResPos->getIndex());
+    break;
+  }
   default:
     llvm_unreachable("Generating unknown Position getter");
     break;
@@ -390,12 +406,11 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
       args.push_back(getValueAt(currentBlock, position));
   }
 
-  // Generate the matcher in the current (potentially nested) region
-  // and get the failure successor.
-  Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
+  // Generate a new block as success successor and get the failure successor.
+  Block *success = &region->emplaceBlock();
   Block *failure = failureBlockStack.back();
 
-  // Finally, create the predicate.
+  // Create the predicate.
   builder.setInsertionPointToEnd(currentBlock);
   Predicates::Kind kind = question->getKind();
   switch (kind) {
@@ -447,14 +462,20 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
   }
   case Predicates::ConstraintQuestion: {
     auto *cstQuestion = cast<ConstraintQuestion>(question);
-    builder.create<pdl_interp::ApplyConstraintOp>(
-        loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
-        failure);
+    auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
+        loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
+        cstQuestion->getIsNegated(), success, failure);
+
+    constraintOpMap.insert({cstQuestion, applyConstraintOp});
     break;
   }
   default:
     llvm_unreachable("Generating unknown Predicate operation");
   }
+
+  // Generate the matcher in the current (potentially nested) region.
+  // This might use the results of the current predicate.
+  generateMatcher(*boolNode->getSuccessNode(), *region, success);
 }
 
 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
index 2c9b63f86d6efa..5ad2c477573a5b 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
@@ -47,6 +47,7 @@ enum Kind : unsigned {
   OperandPos,
   OperandGroupPos,
   AttributePos,
+  ConstraintResultPos,
   ResultPos,
   ResultGroupPos,
   TypePos,
@@ -279,6 +280,28 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
   bool isOperandDefiningOp() const;
 };
 
+//===----------------------------------------------------------------------===//
+// ConstraintPosition
+
+struct ConstraintQuestion;
+
+/// A position describing the result of a native constraint. It saves the
+/// corresponding ConstraintQuestion and result index to enable referring
+/// back to them
+struct ConstraintPosition
+    : public PredicateBase<ConstraintPosition, Position,
+                           std::pair<ConstraintQuestion *, unsigned>,
+                           Predicates::ConstraintResultPos> {
+  using PredicateBase::PredicateBase;
+
+  /// Returns the ConstraintQuestion to enable keeping track of the native
+  /// constraint this position stems from.
+  ConstraintQuestion *getQuestion() const { return key.first; }
+
+  // Returns the result index of this position
+  unsigned getIndex() const { return key.second; }
+};
+
 //===----------------------------------------------------------------------===//
 // ResultPosition
 
@@ -447,11 +470,13 @@ struct AttributeQuestion
     : public PredicateBase<AttributeQuestion, Qualifier, void,
                            Predicates::AttributeQuestion> {};
 
-/// Apply a parameterized constraint to multiple position values.
+/// Apply a parameterized constraint to multiple position values and possibly
+/// produce results.
 struct ConstraintQuestion
-    : public PredicateBase<ConstraintQuestion, Qualifier,
-                           std::tuple<StringRef, ArrayRef<Position *>, bool>,
-                           Predicates::ConstraintQuestion> {
+    : public PredicateBase<
+          ConstraintQuestion, Qualifier,
+          std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
+          Predicates::ConstraintQuestion> {
   using Base::Base;
 
   /// Return the name of the constraint.
@@ -460,15 +485,19 @@ struct ConstraintQuestion
   /// Return the arguments of the constraint.
   ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
 
+  /// Return the result types of the constraint.
+  ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
+
   /// Return the negation status of the constraint.
-  bool getIsNegated() const { return std::get<2>(key); }
+  bool getIsNegated() const { return std::get<3>(key); }
 
   /// Construct an instance with the given storage allocator.
   static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
                                        KeyTy key) {
     return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
                                         alloc.copyInto(std::get<1>(key)),
-                                        std::get<2>(key)});
+                                        alloc.copyInto(std::get<2>(key)),
+                                        std::get<3>(key)});
   }
 
   /// Returns a hash suitable for the given keytype.
@@ -526,6 +555,7 @@ class PredicateUniquer : public StorageUniquer {
     // Register the types of Positions with the uniquer.
     registerParametricStorageType<AttributePosition>();
     registerParametricStorageType<AttributeLiteralPosition>();
+    registerParametricStorageType<ConstraintPosition>();
     registerParametricStorageType<ForEachPosition>();
     registerParametricStorageType<OperandPosition>();
     registerParametricStorageType<OperandGroupPosition>();
@@ -588,6 +618,12 @@ class PredicateBuilder {
     return OperationPosition::get(uniquer, p);
   }
 
+  // Returns a position for a new value created by a constraint.
+  ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
+                                            unsigned index) {
+    return ConstraintPosition::get(uniquer, std::make_pair(q, index));
+  }
+
   /// Returns an attribute position for an attribute of the given operation.
   Position *getAttribute(OperationPosition *p, StringRef name) {
     return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
@@ -673,11 +709,11 @@ class PredicateBuilder {
   }
 
   /// Create a predicate that applies a generic constraint.
-  Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
-                          bool isNegated) {
-    return {
-        ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
-        TrueAnswer::get(uniquer)};
+  Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
+                          ArrayRef<Type> resultTypes, bool isNegated) {
+    return {ConstraintQuestion::get(
+                uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
+            TrueAnswer::get(uniquer)};
   }
 
   /// Create a predicate comparing a value with null.
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index a9c3b0a71ef0d7..1fd637641e9296 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <queue>
@@ -272,8 +273,17 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
   // Push the constraint to the furthest position.
   Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
                                     comparePosDepth);
-  PredicateBuilder::Predicate pred =
-      builder.getConstraint(op.getName(), allPositions, op.getIsNegated());
+  ResultRange results = op.getResults();
+  PredicateBuilder::Predicate pred = builder.getConstraint(
+      op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
+      op.getIsNegated());
+
+  // For each result register a position so it can be used later
+  for (auto [i, result] : llvm::enumerate(results)) {
+    ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
+    ConstraintPosition *pos = builder.getConstraintPosition(q, i);
+    inputs[result] = pos;
+  }
   predList.emplace_back(pos, pred);
 }
 
@@ -875,6 +885,27 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
   *root = std::make_unique<ExitNode>();
 }
 
+/// Sorts the range begin/end with the partial order given by cmp.
+template <typename Iterator, typename Compare>
+static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
+  while (begin != end) {
+    // Cannot compute sortBeforeOthers in the predicate of stable_partition
+    // because stable_partition will not keep the [begin, end) range intact
+    // while it runs.
+    llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
+    for (auto i = begin; i != end; ++i) {
+      if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
+        sortBeforeOthers.insert(*i);
+    }
+
+    auto const next = std::stable_partition(begin, end, [&](auto const &a) {
+      return sortBeforeOthers.contains(a);
+    });
+    assert(next != begin && "not a partial ordering");
+    begin = next;
+  }
+}
+
 /// Given a module containing PDL pattern operations, generate a matcher tree
 /// using the patterns within the given module and return the root matcher node.
 std::unique_ptr<MatcherNode>
@@ -955,6 +986,24 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
     return *lhs < *rhs;
   });
 
+  // Mostly keep the now established order, but also ensure that
+  // ConstraintQuestions come after the results they use.
+  stableTopologicalSort(ordered.begin(), ordered.end(),
+                        [](OrderedPredicate *a, OrderedPredicate *b) {
+                          auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
+                          auto *cqb = dyn_cast<ConstraintQuestion>(b->question);
+                          if (cqa && cqb) {
+                            // Does any argument of b use a? Then b must be
+                            // sorted after a.
+                            return llvm::any_of(
+                                cqb->getArgs(), [&](Position *p) {
+                                  auto *cp = dyn_cast<ConstraintPosition>(p);
+                                  return cp && cp->getQuestion() == cqa;
+                                });
+                          }
+                          return false;
+                        });
+
   // Build the matchers for each of the pattern predicate lists.
   std::unique_ptr<MatcherNode> root;
   for (OrderedPredicateList &list : lists)
diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index d5f34679f06c60..428b19f4c74af8 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -94,6 +94,12 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
 LogicalResult ApplyNativeConstraintOp::verify() {
   if (getNumOperands() == 0)
     return emitOpError("expected at least one argument");
+  if (llvm::any_of(getResults(), [](OpResult result) {
+        return isa<OperationType>(result.getType());
+      })) {
+    return emitOpError(
+        "returning an operation from a constraint is not supported");
+  }
   return success();
 }
 
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 6e6992dcdeea78..559ce7f466a83d 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -769,11 +769,25 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
 
 void Generator::generate(pdl_interp::ApplyConstraintOp op,
                          ByteCodeWriter &writer) {
-  assert(constraintToMemIndex.count(op.getName()) &&
-         "expected index for constraint function");
+  // Constraints that should return a value have to be registered as rewrites.
+  // If a constraint and a rewrite of similar name are registered the
+  // constraint takes precedence
   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
   writer.appendPDLValueList(op.getArgs());
   writer.append(ByteCodeField(op.getIsNegated()));
+  ResultRange results = op.getResults();
+  writer.append(ByteCodeField(results.size()));
+  for (Value result : results) {
+    // We record the expected kind of the result, so that we can provide extra
+    // verification of the native rewrite function and handle the failure case
+    // of constraints accordingly.
+  ...
[truncated]

@mgehre-amd mgehre-amd force-pushed the matthias.pdl_constraint_results branch from 403ecc3 to 54a3c99 Compare February 27, 2024 21:55
Copy link
Contributor

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for bringing this over the line!

LGTM with the recent changes

This adds support for native PDL (and PDLL) C++ constraints to return results.

This is useful for situations where a pattern checks for certain constraints of multiple interdependent attributes and computes a new attribute value based on them. Currently, for such an example it is required to escape to C++ during matching to perform the check and after a successful match again escape to native C++ to perform the computation during the rewriting part of the pattern.
With this work we can do the computation in C++ during matching and use the result in the rewriting part of the pattern. Effectively this enables a choice in the trade-off of memory consumption during matching vs recomputation of values.

This is an example of a situation where this is useful: We have two operations with certain attributes that have interdependent constraints. For instance attr_foo: one_of [0, 2, 4, 8], attr_bar: one_of [0, 2, 4, 8] and attr_foo == attr_bar. The pattern should only match if all conditions are true. The new operation should be created with a new attribute which is computed from the two matched attributes e.g. attr_baz = attr_foo * attr_bar. For the check we already escape to native C++ and have all values at hand so it makes sense to directly compute the new attribute value as well:

```
Constraint checkAndCompute(attr0: Attr, attr1: Attr) -> Attr;

Pattern example with benefit(1) {
    let foo = op<test.foo>() {attr = attr_foo : Attr};
    let bar = op<test.bar>(foo) {attr = attr_bar : Attr};
    let attr_baz = checkAndCompute(attr_foo, attr_bar);
    rewrite bar with {
        let baz = op<test.baz> {attr=attr_baz};
        replace bar with baz;
    };
}
```
To achieve this the following notable changes were necessary:
PDLL
- Remove check in PDLL parser that prevented native constraints from returning results
PDL
- Change PDL definition of pdl.apply_native_constraint to allow variadic results
PDL_interp
- Change PDL_interp definition of pdl_interp.apply_constraint to allow variadic results

PDLToPDLInterp Pass:
The input to the pass is an arbitrary number of PDL patterns. The pass collects the predicates
that are required to match all of the pdl patterns and establishes an ordering that allows
creation of a single efficient matcher function to match all of them. Values that are matched and
possibly used in the rewriting part of a pattern are represented as positions. This allows fusion
and thus reusing a single position for multiple matching patterns.
Accordingly, we introduce ConstraintPosition, which records the type and index of the result of the constraint.
The problem is for the corresponding value to be used in the rewriting part of a pattern it has
to be an input to the pdl_interp.record_match operation, which is generated early during the pass
such that its surrounding block can be referred to by branching operations. In consequence the value has
to be materialized after the original pdl.apply_native_constraint has been deleted but before
we get the chance to generate the corresponding pdl_interp.apply_constraint operation.
We solve this by emitting a placeholder value when a ConstraintPosition is evaluated.
These placeholder values (due to fusion there may be multiple for one constraint result) are
replaced later when the actual pdl_interp.apply_constraint operation is created.

Bytecode generator and interpreter:
Constraint functions which return results have a different type compared to existing constraint functions.
They have the same type as native rewrite functions and hence are registered as rewrite functions.

Co-authored-by: martin-luecke <[email protected]>
@mgehre-amd mgehre-amd force-pushed the matthias.pdl_constraint_results branch from 54a3c99 to b7bd3e9 Compare February 28, 2024 12:03
Copy link
Contributor

@Mogball Mogball left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

@mgehre-amd mgehre-amd merged commit dca32a3 into llvm:main Mar 1, 2024
@mgehre-amd mgehre-amd deleted the matthias.pdl_constraint_results branch March 1, 2024 06:29
mgehre-amd added a commit that referenced this pull request Mar 2, 2024
…82760)"

with a small stack-use-after-scope fix in getConstraintPredicates()

This reverts commit c80e6ed.
@mmilanifard
Copy link
Contributor

mmilanifard commented Mar 2, 2024

Some builds broken again due to unused var here:

error: unused variable 'constraintResults'
1447 |   ArrayRef<PDLValue> constraintResults = results.getResults();

Buildbot passed though. This issue might not show up always as constraintResults is used conditionally (IIUC).

Please revert to fix breakage.

@mgehre-amd
Copy link
Contributor Author

mgehre-amd commented Mar 2, 2024

Some builds broken again due to unused var here:

error: unused variable 'constraintResults'
1447 |   ArrayRef<PDLValue> constraintResults = results.getResults();

Buildbot passed though. This issue might not show up always as constraintResults is used conditionally (IIUC).

Please revert to fix breakage.

Thanks for letting me know, I pushed a fix (0ec318e).

mgehre-amd added a commit that referenced this pull request Mar 2, 2024
ge28boj pushed a commit to Xilinx/llvm-project that referenced this pull request Jun 3, 2024
From https://reviews.llvm.org/D153245

This adds support for native PDL (and PDLL) C++ constraints to return
results.

This is useful for situations where a pattern checks for certain
constraints of multiple interdependent attributes and computes a new
attribute value based on them. Currently, for such an example it is
required to escape to C++ during matching to perform the check and after
a successful match again escape to native C++ to perform the computation
during the rewriting part of the pattern. With this work we can do the
computation in C++ during matching and use the result in the rewriting
part of the pattern. Effectively this enables a choice in the trade-off
of memory consumption during matching vs recomputation of values.

This is an example of a situation where this is useful: We have two
operations with certain attributes that have interdependent constraints.
For instance `attr_foo: one_of [0, 2, 4, 8], attr_bar: one_of [0, 2, 4,
8]` and `attr_foo == attr_bar`. The pattern should only match if all
conditions are true. The new operation should be created with a new
attribute which is computed from the two matched attributes e.g.
`attr_baz = attr_foo * attr_bar`. For the check we already escape to
native C++ and have all values at hand so it makes sense to directly
compute the new attribute value as well:

```
Constraint checkAndCompute(attr0: Attr, attr1: Attr) -> Attr;

Pattern example with benefit(1) {
    let foo = op<test.foo>() {attr = attr_foo : Attr};
    let bar = op<test.bar>(foo) {attr = attr_bar : Attr};
    let attr_baz = checkAndCompute(attr_foo, attr_bar);
    rewrite bar with {
        let baz = op<test.baz> {attr=attr_baz};
        replace bar with baz;
    };
}
```
To achieve this the following notable changes were necessary:
PDLL:
- Remove check in PDLL parser that prevented native constraints from
returning results

PDL:
- Change PDL definition of pdl.apply_native_constraint to allow variadic
results

PDL_interp:
- Change PDL_interp definition of pdl_interp.apply_constraint to allow
variadic results

PDLToPDLInterp Pass:
The input to the pass is an arbitrary number of PDL patterns. The pass
collects the predicates that are required to match all of the pdl
patterns and establishes an ordering that allows creation of a single
efficient matcher function to match all of them. Values that are matched
and possibly used in the rewriting part of a pattern are represented as
positions. This allows fusion and thus reusing a single position for
multiple matching patterns. Accordingly, we introduce
ConstraintPosition, which records the type and index of the result of
the constraint. The problem is for the corresponding value to be used in
the rewriting part of a pattern it has to be an input to the
pdl_interp.record_match operation, which is generated early during the
pass such that its surrounding block can be referred to by branching
operations. In consequence the value has to be materialized after the
original pdl.apply_native_constraint has been deleted but before we get
the chance to generate the corresponding pdl_interp.apply_constraint
operation. We solve this by emitting a placeholder value when a
ConstraintPosition is evaluated. These placeholder values (due to fusion
there may be multiple for one constraint result) are replaced later when
the actual pdl_interp.apply_constraint operation is created.

Changes since the phabricator review:
- Addressed all comments
- In particular, removed registerConstraintFunctionWithResults and
instead changed registerConstraintFunction so that contraint functions
always have results (empty by default)
- Thus we don't need to reuse `rewriteFunctions` to store constraint
functions with results anymore, and can instead use
`constraintFunctions`
- Perform a stable sort of ConstraintQuestion, so that
ConstraintQuestion appear before other ConstraintQuestion that use their
results.
- Don't create placeholders for pdl_interp::ApplyConstraintOp. Instead
generate the `pdl_interp::ApplyConstraintOp` before generating the
successor block.
- Fixed a test failure in the pdl python bindings

Original code by @martin-luecke

Co-authored-by: martin-luecke <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:pdl mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants