Skip to content

Commit 5ecc0ef

Browse files
authored
[mlir] Improve EnumProp, making it take an EnumInfo (#132349)
This commit improves the `EnumProp` class, causing it to wrap around an `EnumInfo` just like` EnumAttr` does. This EnumProp also has logic for converting to/from an integer attribute and for being read and written as bitcode. The following variants of `EnumProp` are provided: - `EnumPropWithAttrForm` - an EnumProp that can be constructed from (and will be converted to, if `storeInCustomAttribute` is true) a custom attribute, like an `EnumAttr`, instead of a plain integer. This is meant for backwards compatibility with code that uses enum attributes. `NamedEnumProp` adds a "`mnemonic` `<` $enum `>`" syntax around the enum, replicating a common pattern seen in MLIR printers and allowing for reduced ambiguity. `NamedEnumPropWithAttrForm` combines both of these extensions. (Sadly, bytecode auto-upgrade is hampered by the lack of the ability to optionally parse an attribute.) Depends on #132148
1 parent e038c54 commit 5ecc0ef

File tree

9 files changed

+332
-93
lines changed

9 files changed

+332
-93
lines changed

mlir/docs/DefiningDialects/Operations.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,6 +1756,23 @@ that it has a value within the valid range of the enum. If their
17561756
wrapper attribute instead of using a bare signless integer attribute
17571757
for storage.
17581758

1759+
### Enum properties
1760+
1761+
Enums can be wrapped in properties so that they can be stored inline.
1762+
This causes a value of the enum's C++ class to become a member of the operation's
1763+
property struct and for the operation's verifier to check that the enum's value
1764+
is a valid value for the enum.
1765+
1766+
The basic wrapper is `EnumProp`, which simply takes an `EnumInfo`.
1767+
1768+
A less ambiguous syntax, namely putting a mnemonic and `<>`s surrounding
1769+
the enum is generated with `NamedEnumProp`, which takes a `*EnumInfo`
1770+
and a mnemonic string, which becomes part of the property's syntax.
1771+
1772+
Both of these `EnumProp` types have a `*EnumPropWithAttrForm`, which allows for
1773+
transparently upgrading from `EnumAttr`s and optionally retaining those
1774+
attributes in the generic form.
1775+
17591776
## Debugging Tips
17601777

17611778
### Run `mlir-tblgen` to see the generated content

mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -485,17 +485,16 @@ def DISubprogramFlags : I32BitEnumAttr<
485485
// IntegerOverflowFlags
486486
//===----------------------------------------------------------------------===//
487487

488-
def IOFnone : I32BitEnumAttrCaseNone<"none">;
489-
def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
490-
def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;
488+
def IOFnone : I32BitEnumCaseNone<"none">;
489+
def IOFnsw : I32BitEnumCaseBit<"nsw", 0>;
490+
def IOFnuw : I32BitEnumCaseBit<"nuw", 1>;
491491

492-
def IntegerOverflowFlags : I32BitEnumAttr<
492+
def IntegerOverflowFlags : I32BitEnum<
493493
"IntegerOverflowFlags",
494494
"LLVM integer overflow flags",
495495
[IOFnone, IOFnsw, IOFnuw]> {
496496
let separator = ", ";
497497
let cppNamespace = "::mlir::LLVM";
498-
let genSpecializedAttr = 0;
499498
let printBitEnumPrimaryGroups = 1;
500499
}
501500

@@ -504,6 +503,11 @@ def LLVM_IntegerOverflowFlagsAttr :
504503
let assemblyFormat = "`<` $value `>`";
505504
}
506505

506+
def LLVM_IntegerOverflowFlagsProp :
507+
NamedEnumPropWithAttrForm<IntegerOverflowFlags, "overflow", LLVM_IntegerOverflowFlagsAttr> {
508+
let defaultValue = enum.cppType # "::" # "none";
509+
}
510+
507511
//===----------------------------------------------------------------------===//
508512
// FastmathFlags
509513
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
6060
list<Trait> traits = []> :
6161
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
6262
!listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
63-
dag iofArg = (ins EnumProp<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags);
63+
dag iofArg = (ins LLVM_IntegerOverflowFlagsProp:$overflowFlags);
6464
let arguments = !con(commonArgs, iofArg);
6565

6666
string mlirBuilder = [{
@@ -69,7 +69,7 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
6969
$res = op;
7070
}];
7171
let assemblyFormat = [{
72-
$lhs `,` $rhs `` custom<OverflowFlags>($overflowFlags) attr-dict `:` type($res)
72+
$lhs `,` $rhs ($overflowFlags^)? attr-dict `:` type($res)
7373
}];
7474
string llvmBuilder =
7575
"$res = builder.Create" # instName #
@@ -563,10 +563,10 @@ class LLVM_CastOpWithOverflowFlag<string mnemonic, string instName, Type type,
563563
Type resultType, list<Trait> traits = []> :
564564
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)>,
565565
LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());"> {
566-
let arguments = (ins type:$arg, EnumProp<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags);
566+
let arguments = (ins type:$arg, LLVM_IntegerOverflowFlagsProp:$overflowFlags);
567567
let results = (outs resultType:$res);
568568
let builders = [LLVM_OneResultOpBuilder];
569-
let assemblyFormat = "$arg `` custom<OverflowFlags>($overflowFlags) attr-dict `:` type($arg) `to` type($res)";
569+
let assemblyFormat = "$arg ($overflowFlags^)? attr-dict `:` type($arg) `to` type($res)";
570570
string llvmInstName = instName;
571571
string mlirBuilder = [{
572572
auto op = $_builder.create<$_qualCppClassName>(

mlir/include/mlir/IR/EnumAttr.td

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define ENUMATTR_TD
1111

1212
include "mlir/IR/AttrTypeBase.td"
13+
include "mlir/IR/Properties.td"
1314

1415
//===----------------------------------------------------------------------===//
1516
// Enum attribute kinds
@@ -552,6 +553,141 @@ class EnumAttr<Dialect dialect, EnumInfo enumInfo, string name = "",
552553
let assemblyFormat = "$value";
553554
}
554555

556+
// A property wrapping by a C++ enum. This class will automatically create bytecode
557+
// serialization logic for the given enum, as well as arranging for parser and
558+
// printer calls.
559+
class EnumProp<EnumInfo enumInfo> : Property<enumInfo.cppType, enumInfo.summary> {
560+
EnumInfo enum = enumInfo;
561+
562+
let description = enum.description;
563+
let predicate = !if(
564+
!isa<BitEnumBase>(enum),
565+
CPred<"(static_cast<" # enum.underlyingType # ">($_self) & ~" # !cast<BitEnumBase>(enum).validBits # ") == 0">,
566+
Or<!foreach(case, enum.enumerants, CPred<"$_self == " # enum.cppType # "::" # case.symbol>)>);
567+
568+
let convertFromAttribute = [{
569+
auto intAttr = ::mlir::dyn_cast_if_present<::mlir::IntegerAttr>($_attr);
570+
if (!intAttr) {
571+
return $_diag() << "expected IntegerAttr storage for }] #
572+
enum.cppType # [{";
573+
}
574+
$_storage = static_cast<}] # enum.cppType # [{>(intAttr.getValue().getZExtValue());
575+
return ::mlir::success();
576+
}];
577+
578+
let convertToAttribute = [{
579+
return ::mlir::IntegerAttr::get(::mlir::IntegerType::get($_ctxt, }] # enum.bitwidth
580+
# [{), static_cast<}] # enum.underlyingType #[{>($_storage));
581+
}];
582+
583+
let writeToMlirBytecode = [{
584+
$_writer.writeVarInt(static_cast<uint64_t>($_storage));
585+
}];
586+
587+
let readFromMlirBytecode = [{
588+
uint64_t rawValue;
589+
if (::mlir::failed($_reader.readVarInt(rawValue)))
590+
return ::mlir::failure();
591+
if (rawValue > std::numeric_limits<}] # enum.underlyingType # [{>::max())
592+
return ::mlir::failure();
593+
$_storage = static_cast<}] # enum.cppType # [{>(rawValue);
594+
}];
595+
596+
let optionalParser = [{
597+
auto value = ::mlir::FieldParser<std::optional<}] # enum.cppType # [{>>::parse($_parser);
598+
if (::mlir::failed(value))
599+
return ::mlir::failure();
600+
if (!(value->has_value()))
601+
return std::nullopt;
602+
$_storage = std::move(**value);
603+
}];
604+
}
605+
606+
// Enum property that can have been (or, if `storeInCustomAttribute` is true, will also
607+
// be stored as) an attribute, in addition to being stored as an integer attribute.
608+
class EnumPropWithAttrForm<EnumInfo enumInfo, Attr attributeForm>
609+
: EnumProp<enumInfo> {
610+
Attr attrForm = attributeForm;
611+
bit storeInCustomAttribute = 0;
612+
613+
let convertFromAttribute = [{
614+
auto customAttr = ::mlir::dyn_cast_if_present<}]
615+
# attrForm.storageType # [{>($_attr);
616+
if (customAttr) {
617+
$_storage = customAttr.getValue();
618+
return ::mlir::success();
619+
}
620+
auto intAttr = ::mlir::dyn_cast_if_present<::mlir::IntegerAttr>($_attr);
621+
if (!intAttr) {
622+
return $_diag() << "expected }] # attrForm.storageType
623+
# [{ or IntegerAttr storage for }] # enum.cppType # [{";
624+
}
625+
$_storage = static_cast<}] # enum.cppType # [{>(intAttr.getValue().getZExtValue());
626+
return ::mlir::success();
627+
}];
628+
629+
let convertToAttribute = !if(storeInCustomAttribute, [{
630+
return }] # attrForm.storageType # [{::get($_ctxt, $_storage);
631+
}], [{
632+
return ::mlir::IntegerAttr::get(::mlir::IntegerType::get($_ctxt, }] # enumInfo.bitwidth
633+
# [{), static_cast<}] # enum.underlyingType #[{>($_storage));
634+
}]);
635+
}
636+
637+
class _namedEnumPropFields<string cppType, string mnemonic> {
638+
code parser = [{
639+
if ($_parser.parseKeyword("}] # mnemonic # [{")
640+
|| $_parser.parseLess()) {
641+
return ::mlir::failure();
642+
}
643+
auto parseRes = ::mlir::FieldParser<}] # cppType # [{>::parse($_parser);
644+
if (::mlir::failed(parseRes) ||
645+
::mlir::failed($_parser.parseGreater())) {
646+
return ::mlir::failure();
647+
}
648+
$_storage = *parseRes;
649+
}];
650+
651+
code optionalParser = [{
652+
if ($_parser.parseOptionalKeyword("}] # mnemonic # [{")) {
653+
return std::nullopt;
654+
}
655+
if ($_parser.parseLess()) {
656+
return ::mlir::failure();
657+
}
658+
auto parseRes = ::mlir::FieldParser<}] # cppType # [{>::parse($_parser);
659+
if (::mlir::failed(parseRes) ||
660+
::mlir::failed($_parser.parseGreater())) {
661+
return ::mlir::failure();
662+
}
663+
$_storage = *parseRes;
664+
}];
665+
666+
code printer = [{
667+
$_printer << "}] # mnemonic # [{<" << $_storage << ">";
668+
}];
669+
}
670+
671+
// An EnumProp which, when printed, is surrounded by mnemonic<>.
672+
// For example, if the enum can be a, b, or c, and the mnemonic is foo,
673+
// the format of this property will be "foo<a>", "foo<b>", or "foo<c>".
674+
class NamedEnumProp<EnumInfo enumInfo, string name>
675+
: EnumProp<enumInfo> {
676+
string mnemonic = name;
677+
let parser = _namedEnumPropFields<enum.cppType, mnemonic>.parser;
678+
let optionalParser = _namedEnumPropFields<enum.cppType, mnemonic>.optionalParser;
679+
let printer = _namedEnumPropFields<enum.cppType, mnemonic>.printer;
680+
}
681+
682+
// A `NamedEnumProp` with an attribute form as in `EnumPropWithAttrForm`.
683+
class NamedEnumPropWithAttrForm<EnumInfo enumInfo, string name, Attr attributeForm>
684+
: EnumPropWithAttrForm<enumInfo, attributeForm> {
685+
string mnemonic = name;
686+
let parser = _namedEnumPropFields<enum.cppType, mnemonic>.parser;
687+
let optionalParser = _namedEnumPropFields<enum.cppType, mnemonic>.optionalParser;
688+
let printer = _namedEnumPropFields<enumInfo.cppType, mnemonic>.printer;
689+
}
690+
555691
class _symbolToValue<EnumInfo enumInfo, string case> {
556692
defvar cases =
557693
!filter(iter, enumInfo.enumerants, !eq(iter.str, case));

mlir/include/mlir/IR/Properties.td

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -239,25 +239,6 @@ def I64Prop : IntProp<"int64_t">;
239239
def I32Property : IntProp<"int32_t">, Deprecated<"moved to shorter name I32Prop">;
240240
def I64Property : IntProp<"int64_t">, Deprecated<"moved to shorter name I64Prop">;
241241

242-
class EnumProp<string storageTypeParam, string desc = "", string default = ""> :
243-
Property<storageTypeParam, desc> {
244-
// TODO: implement predicate for enum validity.
245-
let writeToMlirBytecode = [{
246-
$_writer.writeVarInt(static_cast<uint64_t>($_storage));
247-
}];
248-
let readFromMlirBytecode = [{
249-
uint64_t val;
250-
if (failed($_reader.readVarInt(val)))
251-
return ::mlir::failure();
252-
$_storage = static_cast<}] # storageTypeParam # [{>(val);
253-
}];
254-
let defaultValue = default;
255-
}
256-
257-
class EnumProperty<string storageTypeParam, string desc = "", string default = ""> :
258-
EnumProp<storageTypeParam, desc, default>,
259-
Deprecated<"moved to shorter name EnumProp">;
260-
261242
// Note: only a class so we can deprecate the old name
262243
class _cls_StringProp : Property<"std::string", "string"> {
263244
let interfaceType = "::llvm::StringRef";

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -49,71 +49,6 @@ using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
4949

5050
#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
5151

52-
//===----------------------------------------------------------------------===//
53-
// Property Helpers
54-
//===----------------------------------------------------------------------===//
55-
56-
//===----------------------------------------------------------------------===//
57-
// IntegerOverflowFlags
58-
//===----------------------------------------------------------------------===//
59-
60-
namespace mlir {
61-
static Attribute convertToAttribute(MLIRContext *ctx,
62-
IntegerOverflowFlags flags) {
63-
return IntegerOverflowFlagsAttr::get(ctx, flags);
64-
}
65-
66-
static LogicalResult
67-
convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr,
68-
function_ref<InFlightDiagnostic()> emitError) {
69-
auto flagsAttr = dyn_cast<IntegerOverflowFlagsAttr>(attr);
70-
if (!flagsAttr) {
71-
return emitError() << "expected 'overflowFlags' attribute to be an "
72-
"IntegerOverflowFlagsAttr, but got "
73-
<< attr;
74-
}
75-
flags = flagsAttr.getValue();
76-
return success();
77-
}
78-
} // namespace mlir
79-
80-
static ParseResult parseOverflowFlags(AsmParser &p,
81-
IntegerOverflowFlags &flags) {
82-
if (failed(p.parseOptionalKeyword("overflow"))) {
83-
flags = IntegerOverflowFlags::none;
84-
return success();
85-
}
86-
if (p.parseLess())
87-
return failure();
88-
do {
89-
StringRef kw;
90-
SMLoc loc = p.getCurrentLocation();
91-
if (p.parseKeyword(&kw))
92-
return failure();
93-
std::optional<IntegerOverflowFlags> flag =
94-
symbolizeIntegerOverflowFlags(kw);
95-
if (!flag)
96-
return p.emitError(loc,
97-
"invalid overflow flag: expected nsw, nuw, or none");
98-
flags = flags | *flag;
99-
} while (succeeded(p.parseOptionalComma()));
100-
return p.parseGreater();
101-
}
102-
103-
static void printOverflowFlags(AsmPrinter &p, Operation *op,
104-
IntegerOverflowFlags flags) {
105-
if (flags == IntegerOverflowFlags::none)
106-
return;
107-
p << " overflow<";
108-
SmallVector<StringRef, 2> strs;
109-
if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw))
110-
strs.push_back("nsw");
111-
if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw))
112-
strs.push_back("nuw");
113-
llvm::interleaveComma(strs, p);
114-
p << ">";
115-
}
116-
11752
//===----------------------------------------------------------------------===//
11853
// Attribute Helpers
11954
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)