Skip to content

[MLIR][TableGen] Use const pointers for various Init objects #112562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/TableGen/AttrOrTypeDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class AttrOrTypeParameter {
std::optional<StringRef> getDefaultValue() const;

/// Return the underlying def of this parameter.
llvm::Init *getDef() const;
const llvm::Init *getDef() const;

/// The parameter is pointer-comparable.
bool operator==(const AttrOrTypeParameter &other) const {
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/TableGen/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class Dialect {
/// dialect.
bool usePropertiesForAttributes() const;

llvm::DagInit *getDiscardableAttributes() const;
const llvm::DagInit *getDiscardableAttributes() const;

const llvm::Record *getDef() const { return def; }

Expand Down
15 changes: 8 additions & 7 deletions mlir/include/mlir/TableGen/Operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,15 @@ class Operator {

/// A utility iterator over a list of variable decorators.
struct VariableDecoratorIterator
: public llvm::mapped_iterator<llvm::Init *const *,
VariableDecorator (*)(llvm::Init *)> {
: public llvm::mapped_iterator<const llvm::Init *const *,
VariableDecorator (*)(
const llvm::Init *)> {
/// Initializes the iterator to the specified iterator.
VariableDecoratorIterator(llvm::Init *const *it)
: llvm::mapped_iterator<llvm::Init *const *,
VariableDecorator (*)(llvm::Init *)>(it,
&unwrap) {}
static VariableDecorator unwrap(llvm::Init *init);
VariableDecoratorIterator(const llvm::Init *const *it)
: llvm::mapped_iterator<const llvm::Init *const *,
VariableDecorator (*)(const llvm::Init *)>(
it, &unwrap) {}
static VariableDecorator unwrap(const llvm::Init *init);
};
using var_decorator_iterator = VariableDecoratorIterator;
using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;
Expand Down
12 changes: 7 additions & 5 deletions mlir/lib/TableGen/AttrOrTypeDef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (llvm::Init *init : builderList->getValues()) {
for (const llvm::Init *init : builderList->getValues()) {
AttrOrTypeBuilder builder(cast<llvm::DefInit>(init)->getDef(),
def->getLoc());

Expand All @@ -58,8 +58,8 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
if (auto *traitList = def->getValueAsListInit("traits")) {
SmallPtrSet<const llvm::Init *, 32> traitSet;
traits.reserve(traitSet.size());
llvm::unique_function<void(llvm::ListInit *)> processTraitList =
[&](llvm::ListInit *traitList) {
llvm::unique_function<void(const llvm::ListInit *)> processTraitList =
[&](const llvm::ListInit *traitList) {
for (auto *traitInit : *traitList) {
if (!traitSet.insert(traitInit).second)
continue;
Expand Down Expand Up @@ -335,7 +335,9 @@ std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
return result && !result->empty() ? result : std::nullopt;
}

llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
const llvm::Init *AttrOrTypeParameter::getDef() const {
return def->getArg(index);
}

std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
Expand All @@ -349,7 +351,7 @@ std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
//===----------------------------------------------------------------------===//

bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
llvm::Init *paramDef = param->getDef();
const llvm::Init *paramDef = param->getDef();
if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
return false;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/TableGen/Attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ StringRef Attribute::getDerivedCodeBody() const {
Dialect Attribute::getDialect() const {
const llvm::RecordVal *record = def->getValue("dialect");
if (record && record->getValue()) {
if (DefInit *init = dyn_cast<DefInit>(record->getValue()))
if (const DefInit *init = dyn_cast<DefInit>(record->getValue()))
return Dialect(init->getDef());
}
return Dialect(nullptr);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/TableGen/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ bool Dialect::usePropertiesForAttributes() const {
return def->getValueAsBit("usePropertiesForAttributes");
}

llvm::DagInit *Dialect::getDiscardableAttributes() const {
const llvm::DagInit *Dialect::getDiscardableAttributes() const {
return def->getValueAsDag("discardableAttrs");
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/TableGen/Interfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//

InterfaceMethod::InterfaceMethod(const llvm::Record *def) : def(def) {
llvm::DagInit *args = def->getValueAsDag("arguments");
const llvm::DagInit *args = def->getValueAsDag("arguments");
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
arguments.push_back(
{llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
Expand Down Expand Up @@ -78,7 +78,7 @@ Interface::Interface(const llvm::Record *def) : def(def) {

// Initialize the interface methods.
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
for (llvm::Init *init : listInit->getValues())
for (const llvm::Init *init : listInit->getValues())
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());

// Initialize the interface base classes.
Expand All @@ -98,7 +98,7 @@ Interface::Interface(const llvm::Record *def) : def(def) {
baseInterfaces.push_back(std::make_unique<Interface>(baseInterface));
basesAdded.insert(baseInterface.getName());
};
for (llvm::Init *init : basesInit->getValues())
for (const llvm::Init *init : basesInit->getValues())
addBaseInterfaceFn(Interface(cast<llvm::DefInit>(init)->getDef()));
}

Expand Down
21 changes: 11 additions & 10 deletions mlir/lib/TableGen/Operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ std::string Operator::getQualCppClassName() const {
StringRef Operator::getCppNamespace() const { return cppNamespace; }

int Operator::getNumResults() const {
DagInit *results = def.getValueAsDag("results");
const DagInit *results = def.getValueAsDag("results");
return results->getNumArgs();
}

Expand Down Expand Up @@ -198,12 +198,12 @@ auto Operator::getResults() const -> const_value_range {
}

TypeConstraint Operator::getResultTypeConstraint(int index) const {
DagInit *results = def.getValueAsDag("results");
const DagInit *results = def.getValueAsDag("results");
return TypeConstraint(cast<DefInit>(results->getArg(index)));
}

StringRef Operator::getResultName(int index) const {
DagInit *results = def.getValueAsDag("results");
const DagInit *results = def.getValueAsDag("results");
return results->getArgNameStr(index);
}

Expand Down Expand Up @@ -241,7 +241,7 @@ Operator::arg_range Operator::getArgs() const {
}

StringRef Operator::getArgName(int index) const {
DagInit *argumentValues = def.getValueAsDag("arguments");
const DagInit *argumentValues = def.getValueAsDag("arguments");
return argumentValues->getArgNameStr(index);
}

Expand Down Expand Up @@ -557,7 +557,7 @@ void Operator::populateOpStructure() {
auto *opVarClass = recordKeeper.getClass("OpVariable");
numNativeAttributes = 0;

DagInit *argumentValues = def.getValueAsDag("arguments");
const DagInit *argumentValues = def.getValueAsDag("arguments");
unsigned numArgs = argumentValues->getNumArgs();

// Mapping from name of to argument or result index. Arguments are indexed
Expand Down Expand Up @@ -721,8 +721,8 @@ void Operator::populateOpStructure() {
" to precede it in traits list");
};

std::function<void(llvm::ListInit *)> insert;
insert = [&](llvm::ListInit *traitList) {
std::function<void(const llvm::ListInit *)> insert;
insert = [&](const llvm::ListInit *traitList) {
for (auto *traitInit : *traitList) {
auto *def = cast<DefInit>(traitInit)->getDef();
if (def->isSubClassOf("TraitList")) {
Expand Down Expand Up @@ -780,7 +780,7 @@ void Operator::populateOpStructure() {
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def.getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (llvm::Init *init : builderList->getValues())
for (const llvm::Init *init : builderList->getValues())
builders.emplace_back(cast<llvm::DefInit>(init)->getDef(), def.getLoc());
} else if (skipDefaultBuilders()) {
PrintFatalError(
Expand Down Expand Up @@ -818,7 +818,8 @@ bool Operator::hasAssemblyFormat() const {
}

StringRef Operator::getAssemblyFormat() const {
return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat"))
return TypeSwitch<const llvm::Init *, StringRef>(
def.getValueInit("assemblyFormat"))
.Case<llvm::StringInit>([&](auto *init) { return init->getValue(); });
}

Expand All @@ -832,7 +833,7 @@ void Operator::print(llvm::raw_ostream &os) const {
}
}

auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
auto Operator::VariableDecoratorIterator::unwrap(const llvm::Init *init)
-> VariableDecorator {
return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/TableGen/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ int Pattern::getBenefit() const {
// The initial benefit value is a heuristic with number of ops in the source
// pattern.
int initBenefit = getSourcePattern().getNumOps();
llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
const llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
PrintFatalError(&def,
"The 'addBenefit' takes and only takes one integer value");
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/TableGen/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
const llvm::RecordVal *builderCall = baseType->getValue("builderCall");
if (!builderCall || !builderCall->getValue())
return std::nullopt;
return TypeSwitch<llvm::Init *, std::optional<StringRef>>(
return TypeSwitch<const llvm::Init *, std::optional<StringRef>>(
builderCall->getValue())
.Case<llvm::StringInit>([&](auto *init) {
StringRef value = init->getValue();
Expand Down
16 changes: 8 additions & 8 deletions mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ enum DeprecatedAction { None, Warn, Error };
static DeprecatedAction actionOnDeprecatedValue;

// Returns if there is a use of `deprecatedInit` in `field`.
static bool findUse(Init *field, Init *deprecatedInit,
llvm::DenseMap<Init *, bool> &known) {
static bool findUse(const Init *field, const Init *deprecatedInit,
llvm::DenseMap<const Init *, bool> &known) {
if (field == deprecatedInit)
return true;

Expand Down Expand Up @@ -64,13 +64,13 @@ static bool findUse(Init *field, Init *deprecatedInit,
if (findUse(dagInit->getOperator(), deprecatedInit, known))
return memoize(true);

return memoize(llvm::any_of(dagInit->getArgs(), [&](Init *arg) {
return memoize(llvm::any_of(dagInit->getArgs(), [&](const Init *arg) {
return findUse(arg, deprecatedInit, known);
}));
}

if (ListInit *li = dyn_cast<ListInit>(field)) {
return memoize(llvm::any_of(li->getValues(), [&](Init *jt) {
if (const ListInit *li = dyn_cast<ListInit>(field)) {
return memoize(llvm::any_of(li->getValues(), [&](const Init *jt) {
return findUse(jt, deprecatedInit, known);
}));
}
Expand All @@ -83,8 +83,8 @@ static bool findUse(Init *field, Init *deprecatedInit,
}

// Returns if there is a use of `deprecatedInit` in `record`.
static bool findUse(Record &record, Init *deprecatedInit,
llvm::DenseMap<Init *, bool> &known) {
static bool findUse(Record &record, const Init *deprecatedInit,
llvm::DenseMap<const Init *, bool> &known) {
return llvm::any_of(record.getValues(), [&](const RecordVal &val) {
return findUse(val.getValue(), deprecatedInit, known);
});
Expand All @@ -100,7 +100,7 @@ static void warnOfDeprecatedUses(const RecordKeeper &records) {
if (!r || !r->getValue())
continue;

llvm::DenseMap<Init *, bool> hasUse;
llvm::DenseMap<const Init *, bool> hasUse;
if (auto *si = dyn_cast<StringInit>(r->getValue())) {
for (auto &jt : records.getDefs()) {
// Skip anonymous defs.
Expand Down
38 changes: 21 additions & 17 deletions mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ class Generator {
private:
/// Emits parse calls to construct given kind.
void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
ArrayRef<Init *> args, ArrayRef<std::string> argNames,
StringRef failure, mlir::raw_indented_ostream &ios);
ArrayRef<const Init *> args,
ArrayRef<std::string> argNames, StringRef failure,
mlir::raw_indented_ostream &ios);

/// Emits print instructions.
void emitPrintHelper(const Record *memberRec, StringRef kind,
Expand Down Expand Up @@ -135,10 +136,12 @@ void Generator::emitParse(StringRef kind, const Record &x) {
R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
mlir::raw_indented_ostream os(output);
std::string returnType = getCType(&x);
os << formatv(head, kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type", x.getName());
DagInit *members = x.getValueAsDag("members");
SmallVector<std::string> argNames =
llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
os << formatv(head,
kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type",
x.getName());
const DagInit *members = x.getValueAsDag("members");
SmallVector<std::string> argNames = llvm::to_vector(
map_range(members->getArgNames(), [](const StringInit *init) {
return init->getAsUnquotedString();
}));
StringRef builder = x.getValueAsString("cBuilder").trim();
Expand All @@ -148,7 +151,7 @@ void Generator::emitParse(StringRef kind, const Record &x) {
}

void printParseConditional(mlir::raw_indented_ostream &ios,
ArrayRef<Init *> args,
ArrayRef<const Init *> args,
ArrayRef<std::string> argNames) {
ios << "if ";
auto parenScope = ios.scope("(", ") {");
Expand All @@ -159,7 +162,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
};

auto parsedArgs =
llvm::to_vector(make_filter_range(args, [](Init *const attr) {
llvm::to_vector(make_filter_range(args, [](const Init *const attr) {
const Record *def = cast<DefInit>(attr)->getDef();
if (def->isSubClassOf("Array"))
return true;
Expand All @@ -168,7 +171,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,

interleave(
zip(parsedArgs, argNames),
[&](std::tuple<llvm::Init *&, const std::string &> it) {
[&](std::tuple<const Init *&, const std::string &> it) {
const Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
std::string parser;
if (auto optParser = attr->getValueAsOptionalString("cParser")) {
Expand Down Expand Up @@ -196,7 +199,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
}

void Generator::emitParseHelper(StringRef kind, StringRef returnType,
StringRef builder, ArrayRef<Init *> args,
StringRef builder, ArrayRef<const Init *> args,
ArrayRef<std::string> argNames,
StringRef failure,
mlir::raw_indented_ostream &ios) {
Expand All @@ -210,7 +213,7 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
// Print decls.
std::string lastCType = "";
for (auto [arg, name] : zip(args, argNames)) {
DefInit *first = dyn_cast<DefInit>(arg);
const DefInit *first = dyn_cast<DefInit>(arg);
if (!first)
PrintFatalError("Unexpected type for " + name);
const Record *def = first->getDef();
Expand Down Expand Up @@ -251,13 +254,14 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
std::string returnType = getCType(def);
ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
<< returnType << "> ";
SmallVector<Init *> args;
SmallVector<const Init *> args;
SmallVector<std::string> argNames;
if (def->isSubClassOf("CompositeBytecode")) {
DagInit *members = def->getValueAsDag("members");
args = llvm::to_vector(members->getArgs());
const DagInit *members = def->getValueAsDag("members");
args = llvm::to_vector(map_range(
members->getArgs(), [](Init *init) { return (const Init *)init; }));
argNames = llvm::to_vector(
map_range(members->getArgNames(), [](StringInit *init) {
map_range(members->getArgNames(), [](const StringInit *init) {
return init->getAsUnquotedString();
}));
} else {
Expand Down Expand Up @@ -332,7 +336,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
auto *members = rec->getValueAsDag("members");
for (auto [arg, name] :
llvm::zip(members->getArgs(), members->getArgNames())) {
DefInit *def = dyn_cast<DefInit>(arg);
const DefInit *def = dyn_cast<DefInit>(arg);
assert(def);
const Record *memberRec = def->getDef();
emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
Expand Down Expand Up @@ -385,7 +389,7 @@ void Generator::emitPrintHelper(const Record *memberRec, StringRef kind,
auto *members = memberRec->getValueAsDag("members");
for (auto [arg, argName] :
zip(members->getArgs(), members->getArgNames())) {
DefInit *def = dyn_cast<DefInit>(arg);
const DefInit *def = dyn_cast<DefInit>(arg);
assert(def);
emitPrintHelper(def->getDef(), kind, parent,
argName->getAsUnquotedString(), ios);
Expand Down
Loading
Loading