-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][ODS] Add OptionalTypesMatchWith
and remove a custom assemblyFormat
#68876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-ods Author: Benjamin Maxwell (MacDue) ChangesThis is just a slight specialization of There may be other places this could help e.g.: (This constraint will be handy for us in some later patches) Full diff: https://github.com/llvm/llvm-project/pull/68876.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2df2fe4c5ce8e9c..a22a082fb60ffb4 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -215,6 +215,8 @@ def Vector_ReductionOp :
Vector_Op<"reduction", [Pure,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
+ OptionalTypesMatchWith<"dest and acc have the same type",
+ "dest", "acc", "::llvm::cast<Type>($_self)">,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
@@ -263,9 +265,8 @@ def Vector_ReductionOp :
"::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
];
- // TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
- // operands.
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = "$kind `,` $vector (`,` $acc^)? (`fastmath` $fastmath^)?"
+ " attr-dict `:` type($vector) `into` type($dest)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 236dd74839dfb04..61babc93d49875b 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -568,6 +568,21 @@ class TypesMatchWith<string summary, string lhsArg, string rhsArg,
string transformer = transform;
}
+// Helper which makes the first letter of a string uppercase.
+// e.g. cat -> Cat
+class first_char_to_upper<string str>
+{
+ string ret = !toupper(!substr(str, 0, 1)) # !substr(str, 1);
+}
+
+// The same as TypesMatchWith but if either `lhsArg` or `rhsArg` are optional
+// and not present returns success.
+class OptionalTypesMatchWith<string summary, string lhsArg, string rhsArg,
+ string transform, string comparator = "std::equal_to<>()">
+ : TypesMatchWith<summary, lhsArg, rhsArg, transform,
+ "!get" # first_char_to_upper<lhsArg>.ret # "()"
+ # " || !get" # first_char_to_upper<rhsArg>.ret # "() || " # comparator>;
+
// Special variant of `TypesMatchWith` that provides a comparator suitable for
// ranged arguments.
class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 044b6cc07d3d629..b63018dbd5d6aaa 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -485,47 +485,6 @@ LogicalResult ReductionOp::verify() {
return success();
}
-ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand, 2> operandsInfo;
- Type redType;
- Type resType;
- CombiningKindAttr kindAttr;
- arith::FastMathFlagsAttr fastMathAttr;
- if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind",
- result.attributes) ||
- parser.parseComma() || parser.parseOperandList(operandsInfo) ||
- (succeeded(parser.parseOptionalKeyword("fastmath")) &&
- parser.parseCustomAttributeWithFallback(fastMathAttr, Type{}, "fastmath",
- result.attributes)) ||
- parser.parseColonType(redType) ||
- parser.parseKeywordType("into", resType) ||
- (!operandsInfo.empty() &&
- parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
- (operandsInfo.size() > 1 &&
- parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
- parser.addTypeToList(resType, result.types))
- return failure();
- if (operandsInfo.empty() || operandsInfo.size() > 2)
- return parser.emitError(parser.getNameLoc(),
- "unsupported number of operands");
- return success();
-}
-
-void ReductionOp::print(OpAsmPrinter &p) {
- p << " ";
- getKindAttr().print(p);
- p << ", " << getVector();
- if (getAcc())
- p << ", " << getAcc();
-
- if (getFastmathAttr() &&
- getFastmathAttr().getValue() != arith::FastMathFlags::none) {
- p << ' ' << getFastmathAttrName().getValue();
- p.printStrippedAttrOrType(getFastmathAttr());
- }
- p << " : " << getVector().getType() << " into " << getDest().getType();
-}
-
// MaskableOpInterface methods.
/// Returns the mask type expected by this operation.
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5967a8d69bbfcc0..ce8b56a5d57a2b6 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1168,13 +1168,6 @@ func.func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 {
// -----
-func.func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
- // expected-error@+1 {{'vector.reduction' unsupported number of operands}}
- %0 = vector.reduction <add>, %arg0, %arg1, %arg1 : vector<16xf32> into f32
-}
-
-// -----
-
func.func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
// expected-error@+1 {{'vector.reduction' op unsupported reduction rank: 2}}
%0 = vector.reduction <add>, %arg0 : vector<4x16xf32> into f32
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 6cfddac94efd850..fbbb61959d12666 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1042,7 +1042,7 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
// CHECK-LABEL: func.func @fastmath(
func.func @fastmath(%x: vector<42xf32>) -> f32 {
- // CHECK: vector.reduction <minf>, %{{.*}} fastmath<reassoc,nnan,ninf>
+ // CHECK: vector.reduction <minf>, %{{.*}} fastmath <reassoc,nnan,ninf>
%min = vector.reduction <minf>, %x fastmath<reassoc,nnan,ninf> : vector<42xf32> into f32
return %min: f32
}
|
4925ff1
to
f597f93
Compare
…Format This is just a slight specialization of `TypesMatchWith` that returns success if an optional parameter is missing. There may be other places this could help e.g.: https://github.com/llvm/llvm-project/blob/eb21049b4b904b072679ece60e73c6b0dc0d1ebf/mlir/include/mlir/Dialect/X86Vector/X86Vector.td#L58C5-L58C5 But I'm leaving those to avoid some churn. (This constraint will be handy for us in some later patches)
A !tocamelcase bang would nice in future :)
f597f93
to
9669a02
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One minor comment but otherwise LGTM, this is really handy, cheers!
Nice!
|
This is a formalization of the comparator trick I came up with for Cullen in #69195:
Giving it a proper name seemed like it'd help make it more discoverable. |
I've now updated the summary (and fixed the line numbers, thanks!) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
This is just a slight specialization of
TypesMatchWith
that returns success if an optional parameter is missing.There may be other places this could help e.g.:
llvm-project/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
Lines 58 to 59 in eb21049
...but I'm leaving those to avoid some churn.
This constraint will be handy for us in some later patches, it's a formalization of a short circuiting trick with the
comparator
of theTypesMatchWith
constraint (devised for #69195).This is a little non-obvious, so after this patch you can instead do: