Skip to content

[mlir][ODS] Add OptionalTypesMatchWith and remove a custom assemblyFormat #68876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def Vector_ReductionOp :
Vector_Op<"reduction", [Pure,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
OptionalTypesMatchWith<"dest and acc have the same type",
"dest", "acc", "::llvm::cast<Type>($_self)">,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
Expand Down Expand Up @@ -263,9 +265,8 @@ def Vector_ReductionOp :
"::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
];

// TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
// operands.
let hasCustomAssemblyFormat = 1;
let assemblyFormat = "$kind `,` $vector (`,` $acc^)? (`fastmath` `` $fastmath^)?"
" attr-dict `:` type($vector) `into` type($dest)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,14 @@ class TypesMatchWith<string summary, string lhsArg, string rhsArg,
string transformer = transform;
}

// The same as TypesMatchWith but if either `lhsArg` or `rhsArg` are optional
// and not present returns success.
class OptionalTypesMatchWith<string summary, string lhsArg, string rhsArg,
string transform, string comparator = "std::equal_to<>()">
: TypesMatchWith<summary, lhsArg, rhsArg, transform,
"!get" # snakeCaseToCamelCase<lhsArg>.ret # "()"
# " || !get" # snakeCaseToCamelCase<rhsArg>.ret # "() || " # comparator>;

// Special variant of `TypesMatchWith` that provides a comparator suitable for
// ranged arguments.
class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
Expand Down
24 changes: 24 additions & 0 deletions mlir/include/mlir/IR/Utils.td
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,28 @@ class CArg<string ty, string value = ""> {
string defaultValue = value;
}

// Helper which makes the first letter of a string uppercase.
// e.g. cat -> Cat
class firstCharToUpper<string str>
{
string ret = !if(!gt(!size(str), 0),
!toupper(!substr(str, 0, 1)) # !substr(str, 1),
"");
}

class _snakeCaseHelper<string str> {
int idx = !find(str, "_");
string ret = !if(!ge(idx, 0),
!substr(str, 0, idx) # firstCharToUpper<!substr(str, !add(idx, 1))>.ret,
str);
}

// Converts a snake_case string to CamelCase.
// TODO: Replace with a !tocamelcase bang operator.
class snakeCaseToCamelCase<string str>
{
string ret = !foldl(firstCharToUpper<str>.ret,
!range(0, !size(str)), acc, idx, _snakeCaseHelper<acc>.ret);
}

#endif // UTILS_TD
41 changes: 0 additions & 41 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,47 +485,6 @@ LogicalResult ReductionOp::verify() {
return success();
}

ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandsInfo;
Type redType;
Type resType;
CombiningKindAttr kindAttr;
arith::FastMathFlagsAttr fastMathAttr;
if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind",
result.attributes) ||
parser.parseComma() || parser.parseOperandList(operandsInfo) ||
(succeeded(parser.parseOptionalKeyword("fastmath")) &&
parser.parseCustomAttributeWithFallback(fastMathAttr, Type{}, "fastmath",
result.attributes)) ||
parser.parseColonType(redType) ||
parser.parseKeywordType("into", resType) ||
(!operandsInfo.empty() &&
parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
(operandsInfo.size() > 1 &&
parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
parser.addTypeToList(resType, result.types))
return failure();
if (operandsInfo.empty() || operandsInfo.size() > 2)
return parser.emitError(parser.getNameLoc(),
"unsupported number of operands");
return success();
}

void ReductionOp::print(OpAsmPrinter &p) {
p << " ";
getKindAttr().print(p);
p << ", " << getVector();
if (getAcc())
p << ", " << getAcc();

if (getFastmathAttr() &&
getFastmathAttr().getValue() != arith::FastMathFlags::none) {
p << ' ' << getFastmathAttrName().getValue();
p.printStrippedAttrOrType(getFastmathAttr());
}
p << " : " << getVector().getType() << " into " << getDest().getType();
}

// MaskableOpInterface methods.

/// Returns the mask type expected by this operation.
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ func.func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 {
// -----

func.func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
// expected-error@+1 {{'vector.reduction' unsupported number of operands}}
// expected-error@+1 {{expected ':'}}
%0 = vector.reduction <add>, %arg0, %arg1, %arg1 : vector<16xf32> into f32
}

Expand Down
23 changes: 23 additions & 0 deletions mlir/test/mlir-tblgen/utils.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: mlir-tblgen -I %S/../../include %s | FileCheck %s

include "mlir/IR/Utils.td"

// CHECK-DAG: string value = "CamelCaseTest"
class already_camel_case {
string value = snakeCaseToCamelCase<"CamelCaseTest">.ret;
}

// CHECK-DAG: string value = "Foo"
class single_word {
string value = snakeCaseToCamelCase<"foo">.ret;
}

// CHECK-DAG: string value = "ThisIsATest"
class snake_case {
string value = snakeCaseToCamelCase<"this_is_a_test">.ret;
}

// CHECK-DAG: string value = "ThisIsATestAgain"
class extra_underscores {
string value = snakeCaseToCamelCase<"__this__is_a_test__again__">.ret;
}