Skip to content

Commit ee09087

Browse files
committed
Change the printing/parsing behavior for Attributes used in declarative assembly format
The new form of printing attribute in the declarative assembly is eliding the `#dialect.mnemonic` prefix to only keep the `<....>` part. Differential Revision: https://reviews.llvm.org/D113873
1 parent 63cd184 commit ee09087

32 files changed

+574
-170
lines changed

mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def ArmSVE_Dialect : Dialect {
3131
vector operations, including a scalable vector type and intrinsics for
3232
some Arm SVE instructions.
3333
}];
34+
let useDefaultTypePrinterParser = 1;
3435
}
3536

3637
//===----------------------------------------------------------------------===//
@@ -66,20 +67,6 @@ def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
6667
"Type":$elementType
6768
);
6869

69-
let printer = [{
70-
$_printer << "<";
71-
for (int64_t dim : getShape())
72-
$_printer << dim << 'x';
73-
$_printer << getElementType() << '>';
74-
}];
75-
76-
let parser = [{
77-
VectorType vector;
78-
if ($_parser.parseType(vector))
79-
return Type();
80-
return get($_ctxt, vector.getShape(), vector.getElementType());
81-
}];
82-
8370
let extraClassDeclaration = [{
8471
bool hasStaticShape() const {
8572
return llvm::none_of(getShape(), ShapedType::isDynamic);

mlir/include/mlir/IR/DialectImplementation.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,19 @@ struct FieldParser<
6464
AttributeT>> {
6565
static FailureOr<AttributeT> parse(AsmParser &parser) {
6666
AttributeT value;
67-
if (parser.parseAttribute(value))
67+
if (parser.parseCustomAttributeWithFallback(value))
6868
return failure();
6969
return value;
7070
}
7171
};
7272

73-
/// Parse a type.
73+
/// Parse an attribute.
7474
template <typename TypeT>
7575
struct FieldParser<
7676
TypeT, std::enable_if_t<std::is_base_of<Type, TypeT>::value, TypeT>> {
7777
static FailureOr<TypeT> parse(AsmParser &parser) {
7878
TypeT value;
79-
if (parser.parseType(value))
79+
if (parser.parseCustomTypeWithFallback(value))
8080
return failure();
8181
return value;
8282
}

mlir/include/mlir/IR/OpBase.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2984,6 +2984,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
29842984
string baseCppClass = "::mlir::Type">
29852985
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
29862986
AttrOrTypeDef<"Type", name, traits, baseCppClass> {
2987+
// Make it possible to use such type as parameters for other types.
2988+
string cppType = dialect.cppNamespace # "::" # cppClassName;
2989+
29872990
// A constant builder provided when the type has no parameters.
29882991
let builderCall = !if(!empty(parameters),
29892992
"$_builder.getType<" # dialect.cppNamespace #

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 169 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,36 @@ class AsmPrinter {
5050
virtual void printType(Type type);
5151
virtual void printAttribute(Attribute attr);
5252

53+
/// Trait to check if `AttrType` provides a `print` method.
54+
template <typename AttrOrType>
55+
using has_print_method =
56+
decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
57+
template <typename AttrOrType>
58+
using detect_has_print_method =
59+
llvm::is_detected<has_print_method, AttrOrType>;
60+
61+
/// Print the provided attribute in the context of an operation custom
62+
/// printer/parser: this will invoke directly the print method on the
63+
/// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
64+
template <typename AttrOrType,
65+
std::enable_if_t<detect_has_print_method<AttrOrType>::value>
66+
*sfinae = nullptr>
67+
void printStrippedAttrOrType(AttrOrType attrOrType) {
68+
if (succeeded(printAlias(attrOrType)))
69+
return;
70+
attrOrType.print(*this);
71+
}
72+
73+
/// SFINAE for printing the provided attribute in the context of an operation
74+
/// custom printer in the case where the attribute does not define a print
75+
/// method.
76+
template <typename AttrOrType,
77+
std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
78+
*sfinae = nullptr>
79+
void printStrippedAttrOrType(AttrOrType attrOrType) {
80+
*this << attrOrType;
81+
}
82+
5383
/// Print the given attribute without its type. The corresponding parser must
5484
/// provide a valid type for the attribute.
5585
virtual void printAttributeWithoutType(Attribute attr);
@@ -102,6 +132,14 @@ class AsmPrinter {
102132
AsmPrinter(const AsmPrinter &) = delete;
103133
void operator=(const AsmPrinter &) = delete;
104134

135+
/// Print the alias for the given attribute, return failure if no alias could
136+
/// be printed.
137+
virtual LogicalResult printAlias(Attribute attr);
138+
139+
/// Print the alias for the given type, return failure if no alias could
140+
/// be printed.
141+
virtual LogicalResult printAlias(Type type);
142+
105143
/// The internal implementation of the printer.
106144
Impl *impl;
107145
};
@@ -608,6 +646,13 @@ class AsmParser {
608646
/// Parse an arbitrary attribute of a given type and return it in result.
609647
virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
610648

649+
/// Parse a custom attribute with the provided callback, unless the next
650+
/// token is `#`, in which case the generic parser is invoked.
651+
virtual ParseResult parseCustomAttributeWithFallback(
652+
Attribute &result, Type type,
653+
function_ref<ParseResult(Attribute &result, Type type)>
654+
parseAttribute) = 0;
655+
611656
/// Parse an attribute of a specific kind and type.
612657
template <typename AttrType>
613658
ParseResult parseAttribute(AttrType &result, Type type = {}) {
@@ -639,9 +684,9 @@ class AsmParser {
639684
return parseAttribute(result, Type(), attrName, attrs);
640685
}
641686

642-
/// Parse an arbitrary attribute of a given type and return it in result. This
643-
/// also adds the attribute to the specified attribute list with the specified
644-
/// name.
687+
/// Parse an arbitrary attribute of a given type and populate it in `result`.
688+
/// This also adds the attribute to the specified attribute list with the
689+
/// specified name.
645690
template <typename AttrType>
646691
ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
647692
NamedAttrList &attrs) {
@@ -661,6 +706,82 @@ class AsmParser {
661706
return success();
662707
}
663708

709+
/// Trait to check if `AttrType` provides a `parse` method.
710+
template <typename AttrType>
711+
using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
712+
std::declval<Type>()));
713+
template <typename AttrType>
714+
using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
715+
716+
/// Parse a custom attribute of a given type unless the next token is `#`, in
717+
/// which case the generic parser is invoked. The parsed attribute is
718+
/// populated in `result` and also added to the specified attribute list with
719+
/// the specified name.
720+
template <typename AttrType>
721+
std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
722+
parseCustomAttributeWithFallback(AttrType &result, Type type,
723+
StringRef attrName, NamedAttrList &attrs) {
724+
llvm::SMLoc loc = getCurrentLocation();
725+
726+
// Parse any kind of attribute.
727+
Attribute attr;
728+
if (parseCustomAttributeWithFallback(
729+
attr, type, [&](Attribute &result, Type type) -> ParseResult {
730+
result = AttrType::parse(*this, type);
731+
if (!result)
732+
return failure();
733+
return success();
734+
}))
735+
return failure();
736+
737+
// Check for the right kind of attribute.
738+
result = attr.dyn_cast<AttrType>();
739+
if (!result)
740+
return emitError(loc, "invalid kind of attribute specified");
741+
742+
attrs.append(attrName, result);
743+
return success();
744+
}
745+
746+
/// SFINAE parsing method for Attribute that don't implement a parse method.
747+
template <typename AttrType>
748+
std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
749+
parseCustomAttributeWithFallback(AttrType &result, Type type,
750+
StringRef attrName, NamedAttrList &attrs) {
751+
return parseAttribute(result, type, attrName, attrs);
752+
}
753+
754+
/// Parse a custom attribute of a given type unless the next token is `#`, in
755+
/// which case the generic parser is invoked. The parsed attribute is
756+
/// populated in `result`.
757+
template <typename AttrType>
758+
std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
759+
parseCustomAttributeWithFallback(AttrType &result) {
760+
llvm::SMLoc loc = getCurrentLocation();
761+
762+
// Parse any kind of attribute.
763+
Attribute attr;
764+
if (parseCustomAttributeWithFallback(
765+
attr, {}, [&](Attribute &result, Type type) -> ParseResult {
766+
result = AttrType::parse(*this, type);
767+
return success(!!result);
768+
}))
769+
return failure();
770+
771+
// Check for the right kind of attribute.
772+
result = attr.dyn_cast<AttrType>();
773+
if (!result)
774+
return emitError(loc, "invalid kind of attribute specified");
775+
return success();
776+
}
777+
778+
/// SFINAE parsing method for Attribute that don't implement a parse method.
779+
template <typename AttrType>
780+
std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
781+
parseCustomAttributeWithFallback(AttrType &result) {
782+
return parseAttribute(result);
783+
}
784+
664785
/// Parse an arbitrary optional attribute of a given type and return it in
665786
/// result.
666787
virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
@@ -740,6 +861,11 @@ class AsmParser {
740861
/// Parse a type.
741862
virtual ParseResult parseType(Type &result) = 0;
742863

864+
/// Parse a custom type with the provided callback, unless the next
865+
/// token is `#`, in which case the generic parser is invoked.
866+
virtual ParseResult parseCustomTypeWithFallback(
867+
Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
868+
743869
/// Parse an optional type.
744870
virtual OptionalParseResult parseOptionalType(Type &result) = 0;
745871

@@ -753,14 +879,52 @@ class AsmParser {
753879
if (parseType(type))
754880
return failure();
755881

756-
// Check for the right kind of attribute.
882+
// Check for the right kind of type.
757883
result = type.dyn_cast<TypeT>();
758884
if (!result)
759885
return emitError(loc, "invalid kind of type specified");
760886

761887
return success();
762888
}
763889

890+
/// Trait to check if `TypeT` provides a `parse` method.
891+
template <typename TypeT>
892+
using type_has_parse_method =
893+
decltype(TypeT::parse(std::declval<AsmParser &>()));
894+
template <typename TypeT>
895+
using detect_type_has_parse_method =
896+
llvm::is_detected<type_has_parse_method, TypeT>;
897+
898+
/// Parse a custom Type of a given type unless the next token is `#`, in
899+
/// which case the generic parser is invoked. The parsed Type is
900+
/// populated in `result`.
901+
template <typename TypeT>
902+
std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
903+
parseCustomTypeWithFallback(TypeT &result) {
904+
llvm::SMLoc loc = getCurrentLocation();
905+
906+
// Parse any kind of Type.
907+
Type type;
908+
if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
909+
result = TypeT::parse(*this);
910+
return success(!!result);
911+
}))
912+
return failure();
913+
914+
// Check for the right kind of Type.
915+
result = type.dyn_cast<TypeT>();
916+
if (!result)
917+
return emitError(loc, "invalid kind of Type specified");
918+
return success();
919+
}
920+
921+
/// SFINAE parsing method for Type that don't implement a parse method.
922+
template <typename TypeT>
923+
std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
924+
parseCustomTypeWithFallback(TypeT &result) {
925+
return parseType(result);
926+
}
927+
764928
/// Parse a type list.
765929
ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
766930
do {
@@ -792,7 +956,7 @@ class AsmParser {
792956
if (parseColonType(type))
793957
return failure();
794958

795-
// Check for the right kind of attribute.
959+
// Check for the right kind of type.
796960
result = type.dyn_cast<TypeType>();
797961
if (!result)
798962
return emitError(loc, "invalid kind of type specified");

mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,21 +53,21 @@ void ArmSVEDialect::initialize() {
5353
// ScalableVectorType
5454
//===----------------------------------------------------------------------===//
5555

56-
Type ArmSVEDialect::parseType(DialectAsmParser &parser) const {
57-
llvm::SMLoc typeLoc = parser.getCurrentLocation();
58-
{
59-
Type genType;
60-
auto parseResult = generatedTypeParser(parser, "vector", genType);
61-
if (parseResult.hasValue())
62-
return genType;
63-
}
64-
parser.emitError(typeLoc, "unknown type in ArmSVE dialect");
65-
return Type();
56+
void ScalableVectorType::print(AsmPrinter &printer) const {
57+
printer << "<";
58+
for (int64_t dim : getShape())
59+
printer << dim << 'x';
60+
printer << getElementType() << '>';
6661
}
6762

68-
void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
69-
if (failed(generatedTypePrinter(type, os)))
70-
llvm_unreachable("unexpected 'arm_sve' type kind");
63+
Type ScalableVectorType::parse(AsmParser &parser) {
64+
SmallVector<int64_t> dims;
65+
Type eltType;
66+
if (parser.parseLess() ||
67+
parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
68+
parser.parseType(eltType) || parser.parseGreater())
69+
return {};
70+
return ScalableVectorType::get(eltType.getContext(), dims, eltType);
7171
}
7272

7373
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/VectorOps.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ static constexpr const CombiningKind combiningKindsList[] = {
170170
};
171171

172172
void CombiningKindAttr::print(AsmPrinter &printer) const {
173-
printer << "kind<";
173+
printer << "<";
174174
auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
175175
return bitEnumContains(this->getKind(), kind);
176176
});
@@ -215,10 +215,12 @@ Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
215215

216216
void VectorDialect::printAttribute(Attribute attr,
217217
DialectAsmPrinter &os) const {
218-
if (auto ck = attr.dyn_cast<CombiningKindAttr>())
218+
if (auto ck = attr.dyn_cast<CombiningKindAttr>()) {
219+
os << "kind";
219220
ck.print(os);
220-
else
221-
llvm_unreachable("Unknown attribute type");
221+
return;
222+
}
223+
llvm_unreachable("Unknown attribute type");
222224
}
223225

224226
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)