Skip to content

[mlir][ODS] Add OptionalTypesMatchWith and remove a custom assemblyFormat #68876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 19, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Oct 12, 2023

This is just a slight specialization of TypesMatchWith that returns success if an optional parameter is missing.

There may be other places this could help e.g.:

// TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could
// then be removed from assemblyFormat.

...but I'm leaving those to avoid some churn.

This constraint will be handy for us in some later patches, it's a formalization of a short circuiting trick with the comparator of the TypesMatchWith constraint (devised for #69195).

TypesMatchWith<
  "padding type matches element type of result (if present)",
  "result", "padding",
  "::llvm::cast<VectorType>($_self).getElementType()",
  // This returns true if no padding is present, or it's present with a type that matches the element type of `result`.
  "!getPadding() || std::equal_to<>()">

This is a little non-obvious, so after this patch you can instead do:

OptionalTypesMatchWith<
  "padding type matches element type of result (if present)",
  "result", "padding",
  "::llvm::cast<VectorType>($_self).getElementType()">

@llvmbot
Copy link
Member

llvmbot commented Oct 12, 2023

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-ods

Author: Benjamin Maxwell (MacDue)

Changes

This is just a slight specialization of TypesMatchWith that returns success if an optional parameter is missing.

There may be other places this could help e.g.:
https://github.com/llvm/llvm-project/blob/eb21049b4b904b072679ece60e73c6b0dc0d1ebf/mlir/include/mlir/Dialect/X86Vector/X86Vector.td#L58C5-L58C5 ...but I'm leaving those to avoid some churn.

(This constraint will be handy for us in some later patches)


Full diff: https://github.com/llvm/llvm-project/pull/68876.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+4-3)
  • (modified) mlir/include/mlir/IR/OpBase.td (+15)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (-41)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (-7)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+1-1)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2df2fe4c5ce8e9c..a22a082fb60ffb4 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -215,6 +215,8 @@ def Vector_ReductionOp :
   Vector_Op<"reduction", [Pure,
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
+     OptionalTypesMatchWith<"dest and acc have the same type",
+                            "dest", "acc", "::llvm::cast<Type>($_self)">,
      DeclareOpInterfaceMethods<ArithFastMathInterface>,
      DeclareOpInterfaceMethods<MaskableOpInterface>,
      DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
@@ -263,9 +265,8 @@ def Vector_ReductionOp :
                          "::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
   ];
 
-  // TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
-  // operands.
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = "$kind `,` $vector (`,` $acc^)? (`fastmath` $fastmath^)?"
+                       " attr-dict `:` type($vector) `into` type($dest)";
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 236dd74839dfb04..61babc93d49875b 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -568,6 +568,21 @@ class TypesMatchWith<string summary, string lhsArg, string rhsArg,
   string transformer = transform;
 }
 
+// Helper which makes the first letter of a string uppercase.
+// e.g. cat -> Cat
+class first_char_to_upper<string str>
+{
+  string ret = !toupper(!substr(str, 0, 1)) # !substr(str, 1);
+}
+
+// The same as TypesMatchWith but if either `lhsArg` or `rhsArg` are optional
+// and not present returns success.
+class OptionalTypesMatchWith<string summary, string lhsArg, string rhsArg,
+                     string transform, string comparator = "std::equal_to<>()">
+  : TypesMatchWith<summary, lhsArg, rhsArg, transform,
+     "!get" # first_char_to_upper<lhsArg>.ret # "()"
+     # " || !get" # first_char_to_upper<rhsArg>.ret # "() || " # comparator>;
+
 // Special variant of `TypesMatchWith` that provides a comparator suitable for
 // ranged arguments.
 class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 044b6cc07d3d629..b63018dbd5d6aaa 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -485,47 +485,6 @@ LogicalResult ReductionOp::verify() {
   return success();
 }
 
-ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
-  SmallVector<OpAsmParser::UnresolvedOperand, 2> operandsInfo;
-  Type redType;
-  Type resType;
-  CombiningKindAttr kindAttr;
-  arith::FastMathFlagsAttr fastMathAttr;
-  if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind",
-                                              result.attributes) ||
-      parser.parseComma() || parser.parseOperandList(operandsInfo) ||
-      (succeeded(parser.parseOptionalKeyword("fastmath")) &&
-       parser.parseCustomAttributeWithFallback(fastMathAttr, Type{}, "fastmath",
-                                               result.attributes)) ||
-      parser.parseColonType(redType) ||
-      parser.parseKeywordType("into", resType) ||
-      (!operandsInfo.empty() &&
-       parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
-      (operandsInfo.size() > 1 &&
-       parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
-      parser.addTypeToList(resType, result.types))
-    return failure();
-  if (operandsInfo.empty() || operandsInfo.size() > 2)
-    return parser.emitError(parser.getNameLoc(),
-                            "unsupported number of operands");
-  return success();
-}
-
-void ReductionOp::print(OpAsmPrinter &p) {
-  p << " ";
-  getKindAttr().print(p);
-  p << ", " << getVector();
-  if (getAcc())
-    p << ", " << getAcc();
-
-  if (getFastmathAttr() &&
-      getFastmathAttr().getValue() != arith::FastMathFlags::none) {
-    p << ' ' << getFastmathAttrName().getValue();
-    p.printStrippedAttrOrType(getFastmathAttr());
-  }
-  p << " : " << getVector().getType() << " into " << getDest().getType();
-}
-
 // MaskableOpInterface methods.
 
 /// Returns the mask type expected by this operation.
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5967a8d69bbfcc0..ce8b56a5d57a2b6 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1168,13 +1168,6 @@ func.func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 {
 
 // -----
 
-func.func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
-  // expected-error@+1 {{'vector.reduction' unsupported number of operands}}
-  %0 = vector.reduction <add>, %arg0, %arg1, %arg1 : vector<16xf32> into f32
-}
-
-// -----
-
 func.func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
   // expected-error@+1 {{'vector.reduction' op unsupported reduction rank: 2}}
   %0 = vector.reduction <add>, %arg0 : vector<4x16xf32> into f32
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 6cfddac94efd850..fbbb61959d12666 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1042,7 +1042,7 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
 
 // CHECK-LABEL:   func.func @fastmath(
 func.func @fastmath(%x: vector<42xf32>) -> f32 {
-  // CHECK: vector.reduction <minf>, %{{.*}} fastmath<reassoc,nnan,ninf>
+  // CHECK: vector.reduction <minf>, %{{.*}} fastmath <reassoc,nnan,ninf>
   %min = vector.reduction <minf>, %x fastmath<reassoc,nnan,ninf> : vector<42xf32> into f32
   return %min: f32
 }

@MacDue MacDue force-pushed the yeet_handrolled_parser_and_printer branch from 4925ff1 to f597f93 Compare October 12, 2023 14:35
…Format

This is just a slight specialization of `TypesMatchWith` that returns
success if an optional parameter is missing.

There may be other places this could help e.g.:
https://github.com/llvm/llvm-project/blob/eb21049b4b904b072679ece60e73c6b0dc0d1ebf/mlir/include/mlir/Dialect/X86Vector/X86Vector.td#L58C5-L58C5
But I'm leaving those to avoid some churn.

(This constraint will be handy for us in some later patches)
A !tocamelcase bang would nice in future :)
@MacDue MacDue force-pushed the yeet_handrolled_parser_and_printer branch from f597f93 to 9669a02 Compare October 12, 2023 16:04
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor comment but otherwise LGTM, this is really handy, cheers!

@banach-space
Copy link
Contributor

Nice!

This is just a slight specialization of TypesMatchWith that returns success if an optional parameter is missing.

There may be other places this could help e.g.: https://github.com/llvm/llvm-project/blob/eb21049b4b904b072679ece60e73c6b0dc0d1ebf/mlir/include/mlir/Dialect/X86Vector/X86Vector.td#L58C5-L58C5 ...but I'm leaving those to avoid some churn.

(This constraint will be handy for us in some later patches)

  • Line numbers in your link are incorrect ("L58C5-L58C5")
  • "(This constraint will be handy for us in some later patches)" - for completeness, are you able to make this more specific or perhaps even refer to some existing PR?

@MacDue
Copy link
Member Author

MacDue commented Oct 17, 2023

Nice!

This is just a slight specialization of TypesMatchWith that returns success if an optional parameter is missing.
There may be other places this could help e.g.: https://github.com/llvm/llvm-project/blob/eb21049b4b904b072679ece60e73c6b0dc0d1ebf/mlir/include/mlir/Dialect/X86Vector/X86Vector.td#L58C5-L58C5 ...but I'm leaving those to avoid some churn.
(This constraint will be handy for us in some later patches)

  • Line numbers in your link are incorrect ("L58C5-L58C5")
  • "(This constraint will be handy for us in some later patches)" - for completeness, are you able to make this more specific or perhaps even refer to some existing PR?

This is a formalization of the comparator trick I came up with for Cullen in #69195:

  TypesMatchWith<
    "padding type matches element type of result (if present)",
    "result", "padding",
    "::llvm::cast<VectorType>($_self).getElementType()",
    "!getPadding() || std::equal_to<>()"
  >

Giving it a proper name seemed like it'd help make it more discoverable.

@MacDue
Copy link
Member Author

MacDue commented Oct 17, 2023

I've now updated the summary (and fixed the line numbers, thanks!)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants