Skip to content

Commit bccd37f

Browse files
authored
[NFC][MLIR][TableGen] Eliminate llvm:: for commonly used types (#110841)
Eliminate `llvm::` namespace qualifier for commonly used types in MLIR TableGen backends to reduce code clutter.
1 parent 906ffc4 commit bccd37f

17 files changed

+448
-473
lines changed

mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
using namespace mlir;
2424
using namespace mlir::tblgen;
25+
using llvm::Record;
26+
using llvm::RecordKeeper;
2527

2628
//===----------------------------------------------------------------------===//
2729
// Utility Functions
@@ -30,14 +32,14 @@ using namespace mlir::tblgen;
3032
/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
3133
/// specified and can only find one dialect's defs, use that.
3234
static void collectAllDefs(StringRef selectedDialect,
33-
ArrayRef<const llvm::Record *> records,
35+
ArrayRef<const Record *> records,
3436
SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
3537
// Nothing to do if no defs were found.
3638
if (records.empty())
3739
return;
3840

3941
auto defs = llvm::map_range(
40-
records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); });
42+
records, [&](const Record *rec) { return AttrOrTypeDef(rec); });
4143
if (selectedDialect.empty()) {
4244
// If a dialect was not specified, ensure that all found defs belong to the
4345
// same dialect.
@@ -690,15 +692,14 @@ class DefGenerator {
690692
bool emitDefs(StringRef selectedDialect);
691693

692694
protected:
693-
DefGenerator(ArrayRef<const llvm::Record *> defs, raw_ostream &os,
695+
DefGenerator(ArrayRef<const Record *> defs, raw_ostream &os,
694696
StringRef defType, StringRef valueType, bool isAttrGenerator)
695697
: defRecords(defs), os(os), defType(defType), valueType(valueType),
696698
isAttrGenerator(isAttrGenerator) {
697699
// Sort by occurrence in file.
698-
llvm::sort(defRecords,
699-
[](const llvm::Record *lhs, const llvm::Record *rhs) {
700-
return lhs->getID() < rhs->getID();
701-
});
700+
llvm::sort(defRecords, [](const Record *lhs, const Record *rhs) {
701+
return lhs->getID() < rhs->getID();
702+
});
702703
}
703704

704705
/// Emit the list of def type names.
@@ -707,7 +708,7 @@ class DefGenerator {
707708
void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
708709

709710
/// The set of def records to emit.
710-
std::vector<const llvm::Record *> defRecords;
711+
std::vector<const Record *> defRecords;
711712
/// The attribute or type class to emit.
712713
/// The stream to emit to.
713714
raw_ostream &os;
@@ -722,13 +723,13 @@ class DefGenerator {
722723

723724
/// A specialized generator for AttrDefs.
724725
struct AttrDefGenerator : public DefGenerator {
725-
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
726+
AttrDefGenerator(const RecordKeeper &records, raw_ostream &os)
726727
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
727728
"Attr", "Attribute", /*isAttrGenerator=*/true) {}
728729
};
729730
/// A specialized generator for TypeDefs.
730731
struct TypeDefGenerator : public DefGenerator {
731-
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
732+
TypeDefGenerator(const RecordKeeper &records, raw_ostream &os)
732733
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
733734
"Type", "Type", /*isAttrGenerator=*/false) {}
734735
};
@@ -1030,9 +1031,9 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
10301031

10311032
/// Find all type constraints for which a C++ function should be generated.
10321033
static std::vector<Constraint>
1033-
getAllTypeConstraints(const llvm::RecordKeeper &records) {
1034+
getAllTypeConstraints(const RecordKeeper &records) {
10341035
std::vector<Constraint> result;
1035-
for (const llvm::Record *def :
1036+
for (const Record *def :
10361037
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
10371038
// Ignore constraints defined outside of the top-level file.
10381039
if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
@@ -1047,7 +1048,7 @@ getAllTypeConstraints(const llvm::RecordKeeper &records) {
10471048
return result;
10481049
}
10491050

1050-
static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
1051+
static void emitTypeConstraintDecls(const RecordKeeper &records,
10511052
raw_ostream &os) {
10521053
static const char *const typeConstraintDecl = R"(
10531054
bool {0}(::mlir::Type type);
@@ -1057,7 +1058,7 @@ bool {0}(::mlir::Type type);
10571058
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
10581059
}
10591060

1060-
static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
1061+
static void emitTypeConstraintDefs(const RecordKeeper &records,
10611062
raw_ostream &os) {
10621063
static const char *const typeConstraintDef = R"(
10631064
bool {0}(::mlir::Type type) {
@@ -1088,13 +1089,13 @@ static llvm::cl::opt<std::string>
10881089

10891090
static mlir::GenRegistration
10901091
genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
1091-
[](const llvm::RecordKeeper &records, raw_ostream &os) {
1092+
[](const RecordKeeper &records, raw_ostream &os) {
10921093
AttrDefGenerator generator(records, os);
10931094
return generator.emitDefs(attrDialect);
10941095
});
10951096
static mlir::GenRegistration
10961097
genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
1097-
[](const llvm::RecordKeeper &records, raw_ostream &os) {
1098+
[](const RecordKeeper &records, raw_ostream &os) {
10981099
AttrDefGenerator generator(records, os);
10991100
return generator.emitDecls(attrDialect);
11001101
});
@@ -1110,28 +1111,28 @@ static llvm::cl::opt<std::string>
11101111

11111112
static mlir::GenRegistration
11121113
genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
1113-
[](const llvm::RecordKeeper &records, raw_ostream &os) {
1114+
[](const RecordKeeper &records, raw_ostream &os) {
11141115
TypeDefGenerator generator(records, os);
11151116
return generator.emitDefs(typeDialect);
11161117
});
11171118
static mlir::GenRegistration
11181119
genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
1119-
[](const llvm::RecordKeeper &records, raw_ostream &os) {
1120+
[](const RecordKeeper &records, raw_ostream &os) {
11201121
TypeDefGenerator generator(records, os);
11211122
return generator.emitDecls(typeDialect);
11221123
});
11231124

11241125
static mlir::GenRegistration
11251126
genTypeConstrDefs("gen-type-constraint-defs",
11261127
"Generate type constraint definitions",
1127-
[](const llvm::RecordKeeper &records, raw_ostream &os) {
1128+
[](const RecordKeeper &records, raw_ostream &os) {
11281129
emitTypeConstraintDefs(records, os);
11291130
return false;
11301131
});
11311132
static mlir::GenRegistration
11321133
genTypeConstrDecls("gen-type-constraint-decls",
11331134
"Generate type constraint declarations",
1134-
[](const llvm::RecordKeeper &records, raw_ostream &os) {
1135+
[](const RecordKeeper &records, raw_ostream &os) {
11351136
emitTypeConstraintDecls(records, os);
11361137
return false;
11371138
});

mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@
1818

1919
using namespace llvm;
2020

21-
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
22-
static llvm::cl::opt<std::string>
23-
selectedBcDialect("bytecode-dialect",
24-
llvm::cl::desc("The dialect to gen for"),
25-
llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
21+
static cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
22+
static cl::opt<std::string>
23+
selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"),
24+
cl::cat(dialectGenCat), cl::CommaSeparated);
2625

2726
namespace {
2827

@@ -306,7 +305,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
306305
auto funScope = os.scope("{\n", "}\n\n");
307306

308307
// Check that predicates specified if multiple bytecode instances.
309-
for (const llvm::Record *rec : make_second_range(vec)) {
308+
for (const Record *rec : make_second_range(vec)) {
310309
StringRef pred = rec->getValueAsString("printerPredicate");
311310
if (vec.size() > 1 && pred.empty()) {
312311
for (auto [index, rec] : vec) {

mlir/tools/mlir-tblgen/DialectGen.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
using namespace mlir;
3232
using namespace mlir::tblgen;
33+
using llvm::Record;
34+
using llvm::RecordKeeper;
3335

3436
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
3537
llvm::cl::opt<std::string>
@@ -39,8 +41,8 @@ llvm::cl::opt<std::string>
3941
/// Utility iterator used for filtering records for a specific dialect.
4042
namespace {
4143
using DialectFilterIterator =
42-
llvm::filter_iterator<ArrayRef<llvm::Record *>::iterator,
43-
std::function<bool(const llvm::Record *)>>;
44+
llvm::filter_iterator<ArrayRef<Record *>::iterator,
45+
std::function<bool(const Record *)>>;
4446
} // namespace
4547

4648
static void populateDiscardableAttributes(
@@ -62,8 +64,8 @@ static void populateDiscardableAttributes(
6264
/// the given dialect.
6365
template <typename T>
6466
static iterator_range<DialectFilterIterator>
65-
filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
66-
auto filterFn = [&](const llvm::Record *record) {
67+
filterForDialect(ArrayRef<Record *> records, Dialect &dialect) {
68+
auto filterFn = [&](const Record *record) {
6769
return T(record).getDialect() == dialect;
6870
};
6971
return {DialectFilterIterator(records.begin(), records.end(), filterFn),
@@ -295,7 +297,7 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
295297
<< "::" << dialect.getCppClassName() << ")\n";
296298
}
297299

298-
static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
300+
static bool emitDialectDecls(const RecordKeeper &recordKeeper,
299301
raw_ostream &os) {
300302
emitSourceFileHeader("Dialect Declarations", os, recordKeeper);
301303

@@ -340,8 +342,7 @@ static const char *const dialectDestructorStr = R"(
340342
341343
)";
342344

343-
static void emitDialectDef(Dialect &dialect,
344-
const llvm::RecordKeeper &recordKeeper,
345+
static void emitDialectDef(Dialect &dialect, const RecordKeeper &recordKeeper,
345346
raw_ostream &os) {
346347
std::string cppClassName = dialect.getCppClassName();
347348

@@ -389,8 +390,7 @@ static void emitDialectDef(Dialect &dialect,
389390
os << llvm::formatv(dialectDestructorStr, cppClassName);
390391
}
391392

392-
static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
393-
raw_ostream &os) {
393+
static bool emitDialectDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
394394
emitSourceFileHeader("Dialect Definitions", os, recordKeeper);
395395

396396
auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
@@ -411,12 +411,12 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
411411

412412
static mlir::GenRegistration
413413
genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
414-
[](const llvm::RecordKeeper &records, raw_ostream &os) {
414+
[](const RecordKeeper &records, raw_ostream &os) {
415415
return emitDialectDecls(records, os);
416416
});
417417

418418
static mlir::GenRegistration
419419
genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
420-
[](const llvm::RecordKeeper &records, raw_ostream &os) {
420+
[](const RecordKeeper &records, raw_ostream &os) {
421421
return emitDialectDefs(records, os);
422422
});

mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222
using namespace mlir;
2323
using namespace mlir::tblgen;
24+
using llvm::formatv;
25+
using llvm::Record;
26+
using llvm::RecordKeeper;
2427

2528
/// File header and includes.
2629
constexpr const char *fileHeader = R"Py(
@@ -42,44 +45,42 @@ static std::string makePythonEnumCaseName(StringRef name) {
4245

4346
/// Emits the Python class for the given enum.
4447
static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
45-
os << llvm::formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
46-
enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
48+
os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
49+
enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
4750
if (!enumAttr.getSummary().empty())
48-
os << llvm::formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
51+
os << formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
4952
os << "\n";
5053

5154
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
52-
os << llvm::formatv(
53-
" {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()),
54-
enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
55-
: "auto()");
55+
os << formatv(" {0} = {1}\n",
56+
makePythonEnumCaseName(enumCase.getSymbol()),
57+
enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
58+
: "auto()");
5659
}
5760

5861
os << "\n";
5962

6063
if (enumAttr.isBitEnum()) {
61-
os << llvm::formatv(" def __iter__(self):\n"
62-
" return iter([case for case in type(self) if "
63-
"(self & case) is case])\n");
64-
os << llvm::formatv(" def __len__(self):\n"
65-
" return bin(self).count(\"1\")\n");
64+
os << formatv(" def __iter__(self):\n"
65+
" return iter([case for case in type(self) if "
66+
"(self & case) is case])\n");
67+
os << formatv(" def __len__(self):\n"
68+
" return bin(self).count(\"1\")\n");
6669
os << "\n";
6770
}
6871

69-
os << llvm::formatv(" def __str__(self):\n");
72+
os << formatv(" def __str__(self):\n");
7073
if (enumAttr.isBitEnum())
71-
os << llvm::formatv(" if len(self) > 1:\n"
72-
" return \"{0}\".join(map(str, self))\n",
73-
enumAttr.getDef().getValueAsString("separator"));
74+
os << formatv(" if len(self) > 1:\n"
75+
" return \"{0}\".join(map(str, self))\n",
76+
enumAttr.getDef().getValueAsString("separator"));
7477
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
75-
os << llvm::formatv(" if self is {0}.{1}:\n",
76-
enumAttr.getEnumClassName(),
77-
makePythonEnumCaseName(enumCase.getSymbol()));
78-
os << llvm::formatv(" return \"{0}\"\n", enumCase.getStr());
78+
os << formatv(" if self is {0}.{1}:\n", enumAttr.getEnumClassName(),
79+
makePythonEnumCaseName(enumCase.getSymbol()));
80+
os << formatv(" return \"{0}\"\n", enumCase.getStr());
7981
}
80-
os << llvm::formatv(
81-
" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
82-
enumAttr.getEnumClassName());
82+
os << formatv(" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
83+
enumAttr.getEnumClassName());
8384
os << "\n";
8485
}
8586

@@ -105,15 +106,13 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
105106
return true;
106107
}
107108

108-
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
109-
enumAttr.getAttrDefName());
110-
os << llvm::formatv("def _{0}(x, context):\n",
111-
enumAttr.getAttrDefName().lower());
112-
os << llvm::formatv(
113-
" return "
114-
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
115-
"context=context), int(x))\n\n",
116-
bitwidth);
109+
os << formatv("@register_attribute_builder(\"{0}\")\n",
110+
enumAttr.getAttrDefName());
111+
os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower());
112+
os << formatv(" return "
113+
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
114+
"context=context), int(x))\n\n",
115+
bitwidth);
117116
return false;
118117
}
119118

@@ -123,26 +122,25 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
123122
static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
124123
StringRef formatString,
125124
raw_ostream &os) {
126-
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
127-
os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
128-
os << llvm::formatv(" return "
129-
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
130-
formatString);
125+
os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
126+
os << formatv("def _{0}(x, context):\n", attrDefName.lower());
127+
os << formatv(" return "
128+
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
129+
formatString);
131130
return false;
132131
}
133132

134133
/// Emits Python bindings for all enums in the record keeper. Returns
135134
/// `false` on success, `true` on failure.
136-
static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
137-
raw_ostream &os) {
135+
static bool emitPythonEnums(const RecordKeeper &recordKeeper, raw_ostream &os) {
138136
os << fileHeader;
139-
for (const llvm::Record *it :
137+
for (const Record *it :
140138
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
141139
EnumAttr enumAttr(*it);
142140
emitEnumClass(enumAttr, os);
143141
emitAttributeBuilder(enumAttr, os);
144142
}
145-
for (const llvm::Record *it :
143+
for (const Record *it :
146144
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
147145
AttrOrTypeDef attr(&*it);
148146
if (!attr.getMnemonic()) {
@@ -156,11 +154,11 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
156154
if (assemblyFormat == "`<` $value `>`") {
157155
emitDialectEnumAttributeBuilder(
158156
attr.getName(),
159-
llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
157+
formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
160158
} else if (assemblyFormat == "$value") {
161159
emitDialectEnumAttributeBuilder(
162160
attr.getName(),
163-
llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
161+
formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
164162
} else {
165163
llvm::errs()
166164
<< "unsupported assembly format for python enum bindings generation";

0 commit comments

Comments
 (0)