Skip to content

Commit a5dfccc

Browse files
authored
[MLIR][TableGen] Change MLIR TableGen to use const Record * (#110687)
This is a part of effort to have better const correctness in TableGen backends: https://discourse.llvm.org/t/psa-planned-changes-to-tablegen-getallderiveddefinitions-api-potential-downstream-breakages/81089
1 parent fef3566 commit a5dfccc

File tree

7 files changed

+33
-26
lines changed

7 files changed

+33
-26
lines changed

mlir/include/mlir/TableGen/Predicate.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class CombinedPred : public Pred {
104104
const llvm::Record *getCombinerDef() const;
105105

106106
// Get the predicates that are combined by this predicate.
107-
std::vector<llvm::Record *> getChildren() const;
107+
std::vector<const llvm::Record *> getChildren() const;
108108
};
109109

110110
// A combined predicate that requires all child predicates of 'CPred' type to

mlir/lib/TableGen/Predicate.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ const llvm::Record *CombinedPred::getCombinerDef() const {
7979
return def->getValueAsDef("kind");
8080
}
8181

82-
std::vector<llvm::Record *> CombinedPred::getChildren() const {
82+
std::vector<const llvm::Record *> CombinedPred::getChildren() const {
8383
assert(def->getValue("children") &&
8484
"CombinedPred must have a value 'children'");
85-
return def->getValueAsListOfDefs("children");
85+
return def->getValueAsListOfConstDefs("children");
8686
}
8787

8888
namespace {

mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@ class Generator {
3232
Generator(raw_ostream &output) : output(output) {}
3333

3434
/// Returns whether successfully emitted attribute/type parsers.
35-
void emitParse(StringRef kind, Record &x);
35+
void emitParse(StringRef kind, const Record &x);
3636

3737
/// Returns whether successfully emitted attribute/type printers.
3838
void emitPrint(StringRef kind, StringRef type,
39-
ArrayRef<std::pair<int64_t, Record *>> vec);
39+
ArrayRef<std::pair<int64_t, const Record *>> vec);
4040

4141
/// Emits parse dispatch table.
42-
void emitParseDispatch(StringRef kind, ArrayRef<Record *> vec);
42+
void emitParseDispatch(StringRef kind, ArrayRef<const Record *> vec);
4343

4444
/// Emits print dispatch table.
4545
void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
@@ -51,8 +51,9 @@ class Generator {
5151
StringRef failure, mlir::raw_indented_ostream &ios);
5252

5353
/// Emits print instructions.
54-
void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent,
55-
StringRef name, mlir::raw_indented_ostream &ios);
54+
void emitPrintHelper(const Record *memberRec, StringRef kind,
55+
StringRef parent, StringRef name,
56+
mlir::raw_indented_ostream &ios);
5657

5758
raw_ostream &output;
5859
};
@@ -75,7 +76,7 @@ static std::string capitalize(StringRef str) {
7576
}
7677

7778
/// Return the C++ type for the given record.
78-
static std::string getCType(Record *def) {
79+
static std::string getCType(const Record *def) {
7980
std::string format = "{0}";
8081
if (def->isSubClassOf("Array")) {
8182
def = def->getValueAsDef("elemT");
@@ -92,7 +93,8 @@ static std::string getCType(Record *def) {
9293
return formatv(format.c_str(), cType.str());
9394
}
9495

95-
void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) {
96+
void Generator::emitParseDispatch(StringRef kind,
97+
ArrayRef<const Record *> vec) {
9698
mlir::raw_indented_ostream os(output);
9799
char const *head =
98100
R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
@@ -126,7 +128,7 @@ void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) {
126128
os << "return " << capitalize(kind) << "();\n";
127129
}
128130

129-
void Generator::emitParse(StringRef kind, Record &x) {
131+
void Generator::emitParse(StringRef kind, const Record &x) {
130132
if (x.getNameInitAsString() == "ReservedOrDead")
131133
return;
132134

@@ -293,7 +295,7 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
293295
}
294296

295297
void Generator::emitPrint(StringRef kind, StringRef type,
296-
ArrayRef<std::pair<int64_t, Record *>> vec) {
298+
ArrayRef<std::pair<int64_t, const Record *>> vec) {
297299
if (type == "ReservedOrDead")
298300
return;
299301

@@ -304,7 +306,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
304306
auto funScope = os.scope("{\n", "}\n\n");
305307

306308
// Check that predicates specified if multiple bytecode instances.
307-
for (llvm::Record *rec : make_second_range(vec)) {
309+
for (const llvm::Record *rec : make_second_range(vec)) {
308310
StringRef pred = rec->getValueAsString("printerPredicate");
309311
if (vec.size() > 1 && pred.empty()) {
310312
for (auto [index, rec] : vec) {
@@ -344,7 +346,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
344346
}
345347
}
346348

347-
void Generator::emitPrintHelper(Record *memberRec, StringRef kind,
349+
void Generator::emitPrintHelper(const Record *memberRec, StringRef kind,
348350
StringRef parent, StringRef name,
349351
mlir::raw_indented_ostream &ios) {
350352
std::string getter;
@@ -423,7 +425,7 @@ void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
423425
namespace {
424426
/// Container of Attribute or Type for Dialect.
425427
struct AttrOrType {
426-
std::vector<Record *> attr, type;
428+
std::vector<const Record *> attr, type;
427429
};
428430
} // namespace
429431

@@ -435,14 +437,14 @@ static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
435437
it->getValueAsString("dialect") != selectedBcDialect)
436438
continue;
437439
dialectAttrOrType[it->getValueAsString("dialect")].attr =
438-
it->getValueAsListOfDefs("elems");
440+
it->getValueAsListOfConstDefs("elems");
439441
}
440442
for (const Record *it : records.getAllDerivedDefinitions("DialectTypes")) {
441443
if (!selectedBcDialect.empty() &&
442444
it->getValueAsString("dialect") != selectedBcDialect)
443445
continue;
444446
dialectAttrOrType[it->getValueAsString("dialect")].type =
445-
it->getValueAsListOfDefs("elems");
447+
it->getValueAsListOfConstDefs("elems");
446448
}
447449

448450
if (dialectAttrOrType.size() != 1)
@@ -452,15 +454,16 @@ static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
452454
auto it = dialectAttrOrType.front();
453455
Generator gen(os);
454456

455-
SmallVector<std::vector<Record *> *, 2> vecs;
457+
SmallVector<std::vector<const Record *> *, 2> vecs;
456458
SmallVector<std::string, 2> kinds;
457459
vecs.push_back(&it.second.attr);
458460
kinds.push_back("attribute");
459461
vecs.push_back(&it.second.type);
460462
kinds.push_back("type");
461463
for (auto [vec, kind] : zip(vecs, kinds)) {
462464
// Handle Attribute/Type emission.
463-
std::map<std::string, std::vector<std::pair<int64_t, Record *>>> perType;
465+
std::map<std::string, std::vector<std::pair<int64_t, const Record *>>>
466+
perType;
464467
for (auto kt : llvm::enumerate(*vec))
465468
perType[getCType(kt.value())].emplace_back(kt.index(), kt.value());
466469
for (const auto &jt : perType) {

mlir/tools/mlir-tblgen/OmpOpGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ static void genOperandsDef(const Record *op, raw_ostream &os) {
318318
return;
319319

320320
SmallVector<std::string> clauseNames;
321-
for (Record *clause : op->getValueAsListOfDefs("clauseList"))
321+
for (const Record *clause : op->getValueAsListOfDefs("clauseList"))
322322
clauseNames.push_back((extractOmpClauseName(clause) + "ClauseOps").str());
323323

324324
StringRef opName = stripPrefixAndSuffix(

mlir/tools/mlir-tblgen/OpDocGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
155155
llvm::raw_string_ostream os(effectStr);
156156
os << effectName << "{";
157157
auto list = trait.getDef().getValueAsListOfDefs("effects");
158-
llvm::interleaveComma(list, os, [&](Record *rec) {
158+
llvm::interleaveComma(list, os, [&](const Record *rec) {
159159
StringRef effect = rec->getValueAsString("effect");
160160
effect.consume_front("::");
161161
effect.consume_front("mlir::");

mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ std::vector<Availability> getAvailabilities(const Record &def) {
166166
std::vector<Availability> availabilities;
167167

168168
if (def.getValue("availability")) {
169-
std::vector<Record *> availDefs = def.getValueAsListOfDefs("availability");
169+
std::vector<const Record *> availDefs =
170+
def.getValueAsListOfConstDefs("availability");
170171
availabilities.reserve(availDefs.size());
171172
for (const Record *avail : availDefs)
172173
availabilities.emplace_back(avail);
@@ -1449,7 +1450,8 @@ static bool emitCapabilityImplication(const RecordKeeper &recordKeeper,
14491450
if (!def.getValue("implies"))
14501451
continue;
14511452

1452-
std::vector<Record *> impliedCapsDefs = def.getValueAsListOfDefs("implies");
1453+
std::vector<const Record *> impliedCapsDefs =
1454+
def.getValueAsListOfConstDefs("implies");
14531455
os << " case spirv::Capability::" << enumerant.getSymbol()
14541456
<< ": {static const spirv::Capability implies[" << impliedCapsDefs.size()
14551457
<< "] = {";

mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
249249
std::vector<Value> constraints;
250250
constraints.push_back(createTypeConstraint(
251251
builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
252-
for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
252+
for (const Record *child : predRec.getValueAsListOfDefs("predicateList")) {
253253
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
254254
}
255255
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
@@ -273,7 +273,8 @@ Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
273273
std::vector<Value> constraints;
274274
constraints.push_back(createAttrConstraint(
275275
builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
276-
for (Record *child : predRec.getValueAsListOfDefs("attrConstraints")) {
276+
for (const Record *child :
277+
predRec.getValueAsListOfDefs("attrConstraints")) {
277278
constraints.push_back(createPredicate(
278279
builder, tblgen::Pred(child->getValueAsDef("predicate"))));
279280
}
@@ -283,7 +284,8 @@ Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
283284

284285
if (predRec.isSubClassOf("AnyAttrOf")) {
285286
std::vector<Value> constraints;
286-
for (Record *child : predRec.getValueAsListOfDefs("allowedAttributes")) {
287+
for (const Record *child :
288+
predRec.getValueAsListOfDefs("allowedAttributes")) {
287289
constraints.push_back(
288290
createAttrConstraint(builder, tblgen::Constraint(child)));
289291
}

0 commit comments

Comments
 (0)