Skip to content

[mlir][docgen] Emit OpInterface for Operations within Dialect Doc #104693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 84 additions & 8 deletions mlir/tools/mlir-tblgen/OpDocGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"
Expand Down Expand Up @@ -229,7 +230,8 @@ static void emitOpDoc(const Operator &op, raw_ostream &os) {
// Expandable description.
// This appears as just the summary, but when clicked shows the full
// description.
os << "<details>" << "<summary>" << it.attr.getSummary() << "</summary>"
os << "<details>"
<< "<summary>" << it.attr.getSummary() << "</summary>"
<< "{{% markdown %}}" << description << "{{% /markdown %}}"
<< "</details>";
} else {
Expand Down Expand Up @@ -381,6 +383,51 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper,
emitAttrOrTypeDefDoc(AttrOrTypeDef(def), os);
}

//===----------------------------------------------------------------------===//
// OpInterface Documentation
//===----------------------------------------------------------------------===//

static void emitOpInterfaceDoc(const Interface &opIf, raw_ostream &os) {
os << llvm::formatv("### {0}\n", opIf.getName());

ArrayRef<std::pair<Record *, SMRange>> superclasses =
opIf.getDef().getSuperClasses();
if (!superclasses.empty()) {
llvm::interleaveComma(superclasses, os << "Base Classes: ",
[&](const std::pair<Record *, SMRange> &p) {
os << p.first->getName();
});
os << "\n\n";
}

auto descriptionOpt = opIf.getDescription();
if (descriptionOpt) {
os << "#### Description:\n\n";
emitDescription(descriptionOpt.value(), os);
}

ArrayRef<InterfaceMethod> methods = opIf.getMethods();
if (!methods.empty()) {
os << "#### Methods:\n\n";
for (auto &method : methods) {
os << llvm::formatv("**{0}**\n\n", method.getName());
descriptionOpt = method.getDescription();
if (descriptionOpt) {
os << "Description: \n\n";
emitDescription(descriptionOpt.value(), os);
os << "\n";
}
os << llvm::formatv("Return Type: {0}\n\n", method.getReturnType());
if (!method.getArguments().empty()) {
llvm::interleaveComma(
method.getArguments(), os << "Arguments:\n",
[&](const InterfaceMethod::Argument &arg) { os << arg.type; });
}
os << "\n\n";
}
}
}

//===----------------------------------------------------------------------===//
// Enum Documentation
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -446,7 +493,9 @@ static void maybeNest(bool nest, llvm::function_ref<void(raw_ostream &os)> fn,
static void emitBlock(ArrayRef<Attribute> attributes, StringRef inputFilename,
ArrayRef<AttrDef> attrDefs, ArrayRef<OpDocGroup> ops,
ArrayRef<Type> types, ArrayRef<TypeDef> typeDefs,
ArrayRef<EnumAttr> enums, raw_ostream &os) {
ArrayRef<EnumAttr> enums,
ArrayRef<Interface> relevantOpInterfaces,
raw_ostream &os) {
if (!ops.empty()) {
os << "## Operations\n\n";
emitSourceLink(inputFilename, os);
Expand Down Expand Up @@ -498,13 +547,24 @@ static void emitBlock(ArrayRef<Attribute> attributes, StringRef inputFilename,
for (const EnumAttr &def : enums)
emitEnumDoc(def, os);
}

if (!relevantOpInterfaces.empty()) {
os << "## Op Interfaces\n\n";
os << "**The following Op Interfaces are used by the Operations in this "
"Dialect**\n";
for (const Interface &interface : relevantOpInterfaces) {
emitOpInterfaceDoc(interface, os);
}
}
}

static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename,
ArrayRef<Attribute> attributes,
ArrayRef<AttrDef> attrDefs, ArrayRef<OpDocGroup> ops,
ArrayRef<Type> types, ArrayRef<TypeDef> typeDefs,
ArrayRef<EnumAttr> enums, raw_ostream &os) {
ArrayRef<EnumAttr> enums,
ArrayRef<Interface> relevantOpInterfaces,
raw_ostream &os) {
os << "# '" << dialect.getName() << "' Dialect\n\n";
emitIfNotEmpty(dialect.getSummary(), os);
emitIfNotEmpty(dialect.getDescription(), os);
Expand All @@ -515,7 +575,7 @@ static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename,
os << "[TOC]\n\n";

emitBlock(attributes, inputFilename, attrDefs, ops, types, typeDefs, enums,
os);
relevantOpInterfaces, os);
}

static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
Expand Down Expand Up @@ -544,16 +604,19 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
std::vector<Type> dialectTypes;
std::vector<TypeDef> dialectTypeDefs;
std::vector<EnumAttr> dialectEnums;
std::vector<Interface> relevantOpInterfaces;

llvm::SmallDenseSet<Record *> seen;
auto addIfNotSeen = [&](llvm::Record *record, const auto &def, auto &vec) {
llvm::SmallDenseSet<const Record *> seen;
auto addIfNotSeen = [&](const llvm::Record *record, const auto &def,
auto &vec) {
if (seen.insert(record).second) {
vec.push_back(def);
return true;
}
return false;
};
auto addIfInDialect = [&](llvm::Record *record, const auto &def, auto &vec) {
auto addIfInDialect = [&](const llvm::Record *record, const auto &def,
auto &vec) {
return def.getDialect() == *dialect && addIfNotSeen(record, def, vec);
};

Expand Down Expand Up @@ -589,6 +652,19 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
for (Record *def : enumDefs)
addIfNotSeen(def, EnumAttr(def), dialectEnums);

for (Record *def : opDefs) {
Operator o(def);
for (auto &trait : o.getTraits()) {
if (const InterfaceTrait *ifTrait = dyn_cast<InterfaceTrait>(&trait)) {
if (trait.getDef().isSubClassOf("SideEffectsTraitBase")) {
continue;
}
addIfNotSeen(&ifTrait->getDef(), ifTrait->getInterface(),
relevantOpInterfaces);
}
}
}

// Sort alphabetically ignorning dialect for ops and section name for
// sections.
// TODO: The sorting order could be revised, currently attempting to sort of
Expand All @@ -606,7 +682,7 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
emitDialectDoc(*dialect, recordKeeper.getInputFilename(), dialectAttrs,
dialectAttrDefs, dialectOps, dialectTypes, dialectTypeDefs,
dialectEnums, os);
dialectEnums, relevantOpInterfaces, os);
return false;
}

Expand Down
Loading