Skip to content

Commit e768b07

Browse files
authored
[MLIR][TableGen] Use const pointers for various Init objects (#112562)
This reverts commit 0eed305 and applies additional fixes in `verifyArgument` in OmpOpGen.cpp for gcc-7 bot failures
1 parent 875afa9 commit e768b07

File tree

14 files changed

+78
-68
lines changed

14 files changed

+78
-68
lines changed

mlir/include/mlir/TableGen/AttrOrTypeDef.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class AttrOrTypeParameter {
105105
std::optional<StringRef> getDefaultValue() const;
106106

107107
/// Return the underlying def of this parameter.
108-
llvm::Init *getDef() const;
108+
const llvm::Init *getDef() const;
109109

110110
/// The parameter is pointer-comparable.
111111
bool operator==(const AttrOrTypeParameter &other) const {

mlir/include/mlir/TableGen/Dialect.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class Dialect {
9292
/// dialect.
9393
bool usePropertiesForAttributes() const;
9494

95-
llvm::DagInit *getDiscardableAttributes() const;
95+
const llvm::DagInit *getDiscardableAttributes() const;
9696

9797
const llvm::Record *getDef() const { return def; }
9898

mlir/include/mlir/TableGen/Operator.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,15 @@ class Operator {
119119

120120
/// A utility iterator over a list of variable decorators.
121121
struct VariableDecoratorIterator
122-
: public llvm::mapped_iterator<llvm::Init *const *,
123-
VariableDecorator (*)(llvm::Init *)> {
122+
: public llvm::mapped_iterator<const llvm::Init *const *,
123+
VariableDecorator (*)(
124+
const llvm::Init *)> {
124125
/// Initializes the iterator to the specified iterator.
125-
VariableDecoratorIterator(llvm::Init *const *it)
126-
: llvm::mapped_iterator<llvm::Init *const *,
127-
VariableDecorator (*)(llvm::Init *)>(it,
128-
&unwrap) {}
129-
static VariableDecorator unwrap(llvm::Init *init);
126+
VariableDecoratorIterator(const llvm::Init *const *it)
127+
: llvm::mapped_iterator<const llvm::Init *const *,
128+
VariableDecorator (*)(const llvm::Init *)>(
129+
it, &unwrap) {}
130+
static VariableDecorator unwrap(const llvm::Init *init);
130131
};
131132
using var_decorator_iterator = VariableDecoratorIterator;
132133
using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;

mlir/lib/TableGen/AttrOrTypeDef.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
4040
auto *builderList =
4141
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
4242
if (builderList && !builderList->empty()) {
43-
for (llvm::Init *init : builderList->getValues()) {
43+
for (const llvm::Init *init : builderList->getValues()) {
4444
AttrOrTypeBuilder builder(cast<llvm::DefInit>(init)->getDef(),
4545
def->getLoc());
4646

@@ -58,8 +58,8 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
5858
if (auto *traitList = def->getValueAsListInit("traits")) {
5959
SmallPtrSet<const llvm::Init *, 32> traitSet;
6060
traits.reserve(traitSet.size());
61-
llvm::unique_function<void(llvm::ListInit *)> processTraitList =
62-
[&](llvm::ListInit *traitList) {
61+
llvm::unique_function<void(const llvm::ListInit *)> processTraitList =
62+
[&](const llvm::ListInit *traitList) {
6363
for (auto *traitInit : *traitList) {
6464
if (!traitSet.insert(traitInit).second)
6565
continue;
@@ -335,7 +335,9 @@ std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
335335
return result && !result->empty() ? result : std::nullopt;
336336
}
337337

338-
llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
338+
const llvm::Init *AttrOrTypeParameter::getDef() const {
339+
return def->getArg(index);
340+
}
339341

340342
std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
341343
if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
@@ -349,7 +351,7 @@ std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
349351
//===----------------------------------------------------------------------===//
350352

351353
bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
352-
llvm::Init *paramDef = param->getDef();
354+
const llvm::Init *paramDef = param->getDef();
353355
if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
354356
return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
355357
return false;

mlir/lib/TableGen/Attribute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ StringRef Attribute::getDerivedCodeBody() const {
126126
Dialect Attribute::getDialect() const {
127127
const llvm::RecordVal *record = def->getValue("dialect");
128128
if (record && record->getValue()) {
129-
if (DefInit *init = dyn_cast<DefInit>(record->getValue()))
129+
if (const DefInit *init = dyn_cast<DefInit>(record->getValue()))
130130
return Dialect(init->getDef());
131131
}
132132
return Dialect(nullptr);

mlir/lib/TableGen/Dialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ bool Dialect::usePropertiesForAttributes() const {
106106
return def->getValueAsBit("usePropertiesForAttributes");
107107
}
108108

109-
llvm::DagInit *Dialect::getDiscardableAttributes() const {
109+
const llvm::DagInit *Dialect::getDiscardableAttributes() const {
110110
return def->getValueAsDag("discardableAttrs");
111111
}
112112

mlir/lib/TableGen/Interfaces.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ using namespace mlir::tblgen;
2222
//===----------------------------------------------------------------------===//
2323

2424
InterfaceMethod::InterfaceMethod(const llvm::Record *def) : def(def) {
25-
llvm::DagInit *args = def->getValueAsDag("arguments");
25+
const llvm::DagInit *args = def->getValueAsDag("arguments");
2626
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
2727
arguments.push_back(
2828
{llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
@@ -78,7 +78,7 @@ Interface::Interface(const llvm::Record *def) : def(def) {
7878

7979
// Initialize the interface methods.
8080
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
81-
for (llvm::Init *init : listInit->getValues())
81+
for (const llvm::Init *init : listInit->getValues())
8282
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
8383

8484
// Initialize the interface base classes.
@@ -98,7 +98,7 @@ Interface::Interface(const llvm::Record *def) : def(def) {
9898
baseInterfaces.push_back(std::make_unique<Interface>(baseInterface));
9999
basesAdded.insert(baseInterface.getName());
100100
};
101-
for (llvm::Init *init : basesInit->getValues())
101+
for (const llvm::Init *init : basesInit->getValues())
102102
addBaseInterfaceFn(Interface(cast<llvm::DefInit>(init)->getDef()));
103103
}
104104

mlir/lib/TableGen/Operator.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ std::string Operator::getQualCppClassName() const {
161161
StringRef Operator::getCppNamespace() const { return cppNamespace; }
162162

163163
int Operator::getNumResults() const {
164-
DagInit *results = def.getValueAsDag("results");
164+
const DagInit *results = def.getValueAsDag("results");
165165
return results->getNumArgs();
166166
}
167167

@@ -198,12 +198,12 @@ auto Operator::getResults() const -> const_value_range {
198198
}
199199

200200
TypeConstraint Operator::getResultTypeConstraint(int index) const {
201-
DagInit *results = def.getValueAsDag("results");
201+
const DagInit *results = def.getValueAsDag("results");
202202
return TypeConstraint(cast<DefInit>(results->getArg(index)));
203203
}
204204

205205
StringRef Operator::getResultName(int index) const {
206-
DagInit *results = def.getValueAsDag("results");
206+
const DagInit *results = def.getValueAsDag("results");
207207
return results->getArgNameStr(index);
208208
}
209209

@@ -241,7 +241,7 @@ Operator::arg_range Operator::getArgs() const {
241241
}
242242

243243
StringRef Operator::getArgName(int index) const {
244-
DagInit *argumentValues = def.getValueAsDag("arguments");
244+
const DagInit *argumentValues = def.getValueAsDag("arguments");
245245
return argumentValues->getArgNameStr(index);
246246
}
247247

@@ -557,7 +557,7 @@ void Operator::populateOpStructure() {
557557
auto *opVarClass = recordKeeper.getClass("OpVariable");
558558
numNativeAttributes = 0;
559559

560-
DagInit *argumentValues = def.getValueAsDag("arguments");
560+
const DagInit *argumentValues = def.getValueAsDag("arguments");
561561
unsigned numArgs = argumentValues->getNumArgs();
562562

563563
// Mapping from name of to argument or result index. Arguments are indexed
@@ -721,8 +721,8 @@ void Operator::populateOpStructure() {
721721
" to precede it in traits list");
722722
};
723723

724-
std::function<void(llvm::ListInit *)> insert;
725-
insert = [&](llvm::ListInit *traitList) {
724+
std::function<void(const llvm::ListInit *)> insert;
725+
insert = [&](const llvm::ListInit *traitList) {
726726
for (auto *traitInit : *traitList) {
727727
auto *def = cast<DefInit>(traitInit)->getDef();
728728
if (def->isSubClassOf("TraitList")) {
@@ -780,7 +780,7 @@ void Operator::populateOpStructure() {
780780
auto *builderList =
781781
dyn_cast_or_null<llvm::ListInit>(def.getValueInit("builders"));
782782
if (builderList && !builderList->empty()) {
783-
for (llvm::Init *init : builderList->getValues())
783+
for (const llvm::Init *init : builderList->getValues())
784784
builders.emplace_back(cast<llvm::DefInit>(init)->getDef(), def.getLoc());
785785
} else if (skipDefaultBuilders()) {
786786
PrintFatalError(
@@ -818,7 +818,8 @@ bool Operator::hasAssemblyFormat() const {
818818
}
819819

820820
StringRef Operator::getAssemblyFormat() const {
821-
return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat"))
821+
return TypeSwitch<const llvm::Init *, StringRef>(
822+
def.getValueInit("assemblyFormat"))
822823
.Case<llvm::StringInit>([&](auto *init) { return init->getValue(); });
823824
}
824825

@@ -832,7 +833,7 @@ void Operator::print(llvm::raw_ostream &os) const {
832833
}
833834
}
834835

835-
auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
836+
auto Operator::VariableDecoratorIterator::unwrap(const llvm::Init *init)
836837
-> VariableDecorator {
837838
return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
838839
}

mlir/lib/TableGen/Pattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ int Pattern::getBenefit() const {
700700
// The initial benefit value is a heuristic with number of ops in the source
701701
// pattern.
702702
int initBenefit = getSourcePattern().getNumOps();
703-
llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
703+
const llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
704704
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
705705
PrintFatalError(&def,
706706
"The 'addBenefit' takes and only takes one integer value");

mlir/lib/TableGen/Type.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
5050
const llvm::RecordVal *builderCall = baseType->getValue("builderCall");
5151
if (!builderCall || !builderCall->getValue())
5252
return std::nullopt;
53-
return TypeSwitch<llvm::Init *, std::optional<StringRef>>(
53+
return TypeSwitch<const llvm::Init *, std::optional<StringRef>>(
5454
builderCall->getValue())
5555
.Case<llvm::StringInit>([&](auto *init) {
5656
StringRef value = init->getValue();

mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ enum DeprecatedAction { None, Warn, Error };
3030
static DeprecatedAction actionOnDeprecatedValue;
3131

3232
// Returns if there is a use of `deprecatedInit` in `field`.
33-
static bool findUse(Init *field, Init *deprecatedInit,
34-
llvm::DenseMap<Init *, bool> &known) {
33+
static bool findUse(const Init *field, const Init *deprecatedInit,
34+
llvm::DenseMap<const Init *, bool> &known) {
3535
if (field == deprecatedInit)
3636
return true;
3737

@@ -64,13 +64,13 @@ static bool findUse(Init *field, Init *deprecatedInit,
6464
if (findUse(dagInit->getOperator(), deprecatedInit, known))
6565
return memoize(true);
6666

67-
return memoize(llvm::any_of(dagInit->getArgs(), [&](Init *arg) {
67+
return memoize(llvm::any_of(dagInit->getArgs(), [&](const Init *arg) {
6868
return findUse(arg, deprecatedInit, known);
6969
}));
7070
}
7171

72-
if (ListInit *li = dyn_cast<ListInit>(field)) {
73-
return memoize(llvm::any_of(li->getValues(), [&](Init *jt) {
72+
if (const ListInit *li = dyn_cast<ListInit>(field)) {
73+
return memoize(llvm::any_of(li->getValues(), [&](const Init *jt) {
7474
return findUse(jt, deprecatedInit, known);
7575
}));
7676
}
@@ -83,8 +83,8 @@ static bool findUse(Init *field, Init *deprecatedInit,
8383
}
8484

8585
// Returns if there is a use of `deprecatedInit` in `record`.
86-
static bool findUse(Record &record, Init *deprecatedInit,
87-
llvm::DenseMap<Init *, bool> &known) {
86+
static bool findUse(Record &record, const Init *deprecatedInit,
87+
llvm::DenseMap<const Init *, bool> &known) {
8888
return llvm::any_of(record.getValues(), [&](const RecordVal &val) {
8989
return findUse(val.getValue(), deprecatedInit, known);
9090
});
@@ -100,7 +100,7 @@ static void warnOfDeprecatedUses(const RecordKeeper &records) {
100100
if (!r || !r->getValue())
101101
continue;
102102

103-
llvm::DenseMap<Init *, bool> hasUse;
103+
llvm::DenseMap<const Init *, bool> hasUse;
104104
if (auto *si = dyn_cast<StringInit>(r->getValue())) {
105105
for (auto &jt : records.getDefs()) {
106106
// Skip anonymous defs.

mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ class Generator {
4646
private:
4747
/// Emits parse calls to construct given kind.
4848
void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
49-
ArrayRef<Init *> args, ArrayRef<std::string> argNames,
50-
StringRef failure, mlir::raw_indented_ostream &ios);
49+
ArrayRef<const Init *> args,
50+
ArrayRef<std::string> argNames, StringRef failure,
51+
mlir::raw_indented_ostream &ios);
5152

5253
/// Emits print instructions.
5354
void emitPrintHelper(const Record *memberRec, StringRef kind,
@@ -135,10 +136,12 @@ void Generator::emitParse(StringRef kind, const Record &x) {
135136
R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
136137
mlir::raw_indented_ostream os(output);
137138
std::string returnType = getCType(&x);
138-
os << formatv(head, kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type", x.getName());
139-
DagInit *members = x.getValueAsDag("members");
140-
SmallVector<std::string> argNames =
141-
llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
139+
os << formatv(head,
140+
kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type",
141+
x.getName());
142+
const DagInit *members = x.getValueAsDag("members");
143+
SmallVector<std::string> argNames = llvm::to_vector(
144+
map_range(members->getArgNames(), [](const StringInit *init) {
142145
return init->getAsUnquotedString();
143146
}));
144147
StringRef builder = x.getValueAsString("cBuilder").trim();
@@ -148,7 +151,7 @@ void Generator::emitParse(StringRef kind, const Record &x) {
148151
}
149152

150153
void printParseConditional(mlir::raw_indented_ostream &ios,
151-
ArrayRef<Init *> args,
154+
ArrayRef<const Init *> args,
152155
ArrayRef<std::string> argNames) {
153156
ios << "if ";
154157
auto parenScope = ios.scope("(", ") {");
@@ -159,7 +162,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
159162
};
160163

161164
auto parsedArgs =
162-
llvm::to_vector(make_filter_range(args, [](Init *const attr) {
165+
llvm::to_vector(make_filter_range(args, [](const Init *const attr) {
163166
const Record *def = cast<DefInit>(attr)->getDef();
164167
if (def->isSubClassOf("Array"))
165168
return true;
@@ -168,7 +171,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
168171

169172
interleave(
170173
zip(parsedArgs, argNames),
171-
[&](std::tuple<llvm::Init *&, const std::string &> it) {
174+
[&](std::tuple<const Init *&, const std::string &> it) {
172175
const Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
173176
std::string parser;
174177
if (auto optParser = attr->getValueAsOptionalString("cParser")) {
@@ -196,7 +199,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
196199
}
197200

198201
void Generator::emitParseHelper(StringRef kind, StringRef returnType,
199-
StringRef builder, ArrayRef<Init *> args,
202+
StringRef builder, ArrayRef<const Init *> args,
200203
ArrayRef<std::string> argNames,
201204
StringRef failure,
202205
mlir::raw_indented_ostream &ios) {
@@ -210,7 +213,7 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
210213
// Print decls.
211214
std::string lastCType = "";
212215
for (auto [arg, name] : zip(args, argNames)) {
213-
DefInit *first = dyn_cast<DefInit>(arg);
216+
const DefInit *first = dyn_cast<DefInit>(arg);
214217
if (!first)
215218
PrintFatalError("Unexpected type for " + name);
216219
const Record *def = first->getDef();
@@ -251,13 +254,14 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
251254
std::string returnType = getCType(def);
252255
ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
253256
<< returnType << "> ";
254-
SmallVector<Init *> args;
257+
SmallVector<const Init *> args;
255258
SmallVector<std::string> argNames;
256259
if (def->isSubClassOf("CompositeBytecode")) {
257-
DagInit *members = def->getValueAsDag("members");
258-
args = llvm::to_vector(members->getArgs());
260+
const DagInit *members = def->getValueAsDag("members");
261+
args = llvm::to_vector(map_range(
262+
members->getArgs(), [](Init *init) { return (const Init *)init; }));
259263
argNames = llvm::to_vector(
260-
map_range(members->getArgNames(), [](StringInit *init) {
264+
map_range(members->getArgNames(), [](const StringInit *init) {
261265
return init->getAsUnquotedString();
262266
}));
263267
} else {
@@ -332,7 +336,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
332336
auto *members = rec->getValueAsDag("members");
333337
for (auto [arg, name] :
334338
llvm::zip(members->getArgs(), members->getArgNames())) {
335-
DefInit *def = dyn_cast<DefInit>(arg);
339+
const DefInit *def = dyn_cast<DefInit>(arg);
336340
assert(def);
337341
const Record *memberRec = def->getDef();
338342
emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
@@ -385,7 +389,7 @@ void Generator::emitPrintHelper(const Record *memberRec, StringRef kind,
385389
auto *members = memberRec->getValueAsDag("members");
386390
for (auto [arg, argName] :
387391
zip(members->getArgs(), members->getArgNames())) {
388-
DefInit *def = dyn_cast<DefInit>(arg);
392+
const DefInit *def = dyn_cast<DefInit>(arg);
389393
assert(def);
390394
emitPrintHelper(def->getDef(), kind, parent,
391395
argName->getAsUnquotedString(), ios);

0 commit comments

Comments
 (0)