Skip to content

Commit 03d1c99

Browse files
authored
[mlir][ODS] Add OptionalTypesMatchWith and remove a custom assemblyFormat (#68876)
This is just a slight specialization of `TypesMatchWith` that returns success if an optional parameter is missing. There may be other places this could help e.g.: https://github.com/llvm/llvm-project/blob/eb21049b4b904b072679ece60e73c6b0dc0d1ebf/mlir/include/mlir/Dialect/X86Vector/X86Vector.td#L58-L59 ...but I'm leaving those to avoid some churn. This constraint will be handy for us in some later patches, it's a formalization of a short circuiting trick with the `comparator` of the `TypesMatchWith` constraint (devised for #69195). ``` TypesMatchWith< "padding type matches element type of result (if present)", "result", "padding", "::llvm::cast<VectorType>($_self).getElementType()", // This returns true if no padding is present, or it's present with a type that matches the element type of `result`. "!getPadding() || std::equal_to<>()"> ``` This is a little non-obvious, so after this patch you can instead do: ``` OptionalTypesMatchWith< "padding type matches element type of result (if present)", "result", "padding", "::llvm::cast<VectorType>($_self).getElementType()"> ```
1 parent e880e8a commit 03d1c99

File tree

6 files changed

+60
-45
lines changed

6 files changed

+60
-45
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ def Vector_ReductionOp :
215215
Vector_Op<"reduction", [Pure,
216216
PredOpTrait<"source operand and result have same element type",
217217
TCresVTEtIsSameAsOpBase<0, 0>>,
218+
OptionalTypesMatchWith<"dest and acc have the same type",
219+
"dest", "acc", "::llvm::cast<Type>($_self)">,
218220
DeclareOpInterfaceMethods<ArithFastMathInterface>,
219221
DeclareOpInterfaceMethods<MaskableOpInterface>,
220222
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
@@ -263,9 +265,8 @@ def Vector_ReductionOp :
263265
"::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
264266
];
265267

266-
// TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
267-
// operands.
268-
let hasCustomAssemblyFormat = 1;
268+
let assemblyFormat = "$kind `,` $vector (`,` $acc^)? (`fastmath` `` $fastmath^)?"
269+
" attr-dict `:` type($vector) `into` type($dest)";
269270
let hasCanonicalizer = 1;
270271
let hasVerifier = 1;
271272
}

mlir/include/mlir/IR/OpBase.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,14 @@ class TypesMatchWith<string summary, string lhsArg, string rhsArg,
568568
string transformer = transform;
569569
}
570570

571+
// The same as TypesMatchWith but if either `lhsArg` or `rhsArg` are optional
572+
// and not present returns success.
573+
class OptionalTypesMatchWith<string summary, string lhsArg, string rhsArg,
574+
string transform, string comparator = "std::equal_to<>()">
575+
: TypesMatchWith<summary, lhsArg, rhsArg, transform,
576+
"!get" # snakeCaseToCamelCase<lhsArg>.ret # "()"
577+
# " || !get" # snakeCaseToCamelCase<rhsArg>.ret # "() || " # comparator>;
578+
571579
// Special variant of `TypesMatchWith` that provides a comparator suitable for
572580
// ranged arguments.
573581
class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,

mlir/include/mlir/IR/Utils.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,28 @@ class CArg<string ty, string value = ""> {
6666
string defaultValue = value;
6767
}
6868

69+
// Helper which makes the first letter of a string uppercase.
70+
// e.g. cat -> Cat
71+
class firstCharToUpper<string str>
72+
{
73+
string ret = !if(!gt(!size(str), 0),
74+
!toupper(!substr(str, 0, 1)) # !substr(str, 1),
75+
"");
76+
}
77+
78+
class _snakeCaseHelper<string str> {
79+
int idx = !find(str, "_");
80+
string ret = !if(!ge(idx, 0),
81+
!substr(str, 0, idx) # firstCharToUpper<!substr(str, !add(idx, 1))>.ret,
82+
str);
83+
}
84+
85+
// Converts a snake_case string to CamelCase.
86+
// TODO: Replace with a !tocamelcase bang operator.
87+
class snakeCaseToCamelCase<string str>
88+
{
89+
string ret = !foldl(firstCharToUpper<str>.ret,
90+
!range(0, !size(str)), acc, idx, _snakeCaseHelper<acc>.ret);
91+
}
92+
6993
#endif // UTILS_TD

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -524,47 +524,6 @@ LogicalResult ReductionOp::verify() {
524524
return success();
525525
}
526526

527-
ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
528-
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandsInfo;
529-
Type redType;
530-
Type resType;
531-
CombiningKindAttr kindAttr;
532-
arith::FastMathFlagsAttr fastMathAttr;
533-
if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind",
534-
result.attributes) ||
535-
parser.parseComma() || parser.parseOperandList(operandsInfo) ||
536-
(succeeded(parser.parseOptionalKeyword("fastmath")) &&
537-
parser.parseCustomAttributeWithFallback(fastMathAttr, Type{}, "fastmath",
538-
result.attributes)) ||
539-
parser.parseColonType(redType) ||
540-
parser.parseKeywordType("into", resType) ||
541-
(!operandsInfo.empty() &&
542-
parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
543-
(operandsInfo.size() > 1 &&
544-
parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
545-
parser.addTypeToList(resType, result.types))
546-
return failure();
547-
if (operandsInfo.empty() || operandsInfo.size() > 2)
548-
return parser.emitError(parser.getNameLoc(),
549-
"unsupported number of operands");
550-
return success();
551-
}
552-
553-
void ReductionOp::print(OpAsmPrinter &p) {
554-
p << " ";
555-
getKindAttr().print(p);
556-
p << ", " << getVector();
557-
if (getAcc())
558-
p << ", " << getAcc();
559-
560-
if (getFastmathAttr() &&
561-
getFastmathAttr().getValue() != arith::FastMathFlags::none) {
562-
p << ' ' << getFastmathAttrName().getValue();
563-
p.printStrippedAttrOrType(getFastmathAttr());
564-
}
565-
p << " : " << getVector().getType() << " into " << getDest().getType();
566-
}
567-
568527
// MaskableOpInterface methods.
569528

570529
/// Returns the mask type expected by this operation.

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1169,7 +1169,7 @@ func.func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 {
11691169
// -----
11701170

11711171
func.func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
1172-
// expected-error@+1 {{'vector.reduction' unsupported number of operands}}
1172+
// expected-error@+1 {{expected ':'}}
11731173
%0 = vector.reduction <add>, %arg0, %arg1, %arg1 : vector<16xf32> into f32
11741174
}
11751175

mlir/test/mlir-tblgen/utils.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-tblgen -I %S/../../include %s | FileCheck %s
2+
3+
include "mlir/IR/Utils.td"
4+
5+
// CHECK-DAG: string value = "CamelCaseTest"
6+
class already_camel_case {
7+
string value = snakeCaseToCamelCase<"CamelCaseTest">.ret;
8+
}
9+
10+
// CHECK-DAG: string value = "Foo"
11+
class single_word {
12+
string value = snakeCaseToCamelCase<"foo">.ret;
13+
}
14+
15+
// CHECK-DAG: string value = "ThisIsATest"
16+
class snake_case {
17+
string value = snakeCaseToCamelCase<"this_is_a_test">.ret;
18+
}
19+
20+
// CHECK-DAG: string value = "ThisIsATestAgain"
21+
class extra_underscores {
22+
string value = snakeCaseToCamelCase<"__this__is_a_test__again__">.ret;
23+
}

0 commit comments

Comments
 (0)