Skip to content

Commit bc82793

Browse files
joker-ephj2kun
andauthored
[mlir] load dialects for non-namespaced attrs (#96242)
The mlir-translate tool calls into the parser without loading registered dependent dialects, and the parser only loads attributes if the fully-namespaced attribute is present in the textual IR. This causes parsing to break when an op has an attribute that prints/parses without the namespaced attribute. Co-authored-by: Jeremy Kun <[email protected]>
1 parent 739a960 commit bc82793

File tree

6 files changed

+52
-5
lines changed

6 files changed

+52
-5
lines changed

mlir/test/IR/parser.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,3 +1465,14 @@ test.dialect_custom_format_fallback custom_format_fallback
14651465
// CHECK: test.format_optional_result_d_op : f80
14661466
test.format_optional_result_d_op : f80
14671467

1468+
1469+
// -----
1470+
1471+
// This is a testing that a non-qualified attribute in a custom format
1472+
// correctly preload the dialect before creating the attribute.
1473+
#attr = #test.nested_polynomial<<1 + x**2>>
1474+
// CHECK-lABLE: @parse_correctly
1475+
llvm.func @parse_correctly() {
1476+
test.containing_int_polynomial_attr #attr
1477+
llvm.return
1478+
}

mlir/test/lib/Dialect/Test/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ mlir_tablegen(TestOpInterfaces.cpp.inc -gen-op-interface-defs)
1616
add_public_tablegen_target(MLIRTestInterfaceIncGen)
1717

1818
set(LLVM_TARGET_DEFINITIONS TestOps.td)
19-
mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls)
20-
mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs)
19+
mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=test)
20+
mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=test)
2121
add_public_tablegen_target(MLIRTestAttrDefIncGen)
2222

2323
set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td)
@@ -86,6 +86,7 @@ add_mlir_library(MLIRTestDialect
8686
MLIRLinalgTransforms
8787
MLIRLLVMDialect
8888
MLIRPass
89+
MLIRPolynomialDialect
8990
MLIRReduce
9091
MLIRTensorDialect
9192
MLIRTransformUtils

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// To get the test dialect definition.
1717
include "TestDialect.td"
1818
include "TestEnumDefs.td"
19+
include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.td"
1920
include "mlir/Dialect/Utils/StructuredOpsUtils.td"
2021
include "mlir/IR/AttrTypeBase.td"
2122
include "mlir/IR/BuiltinAttributeInterfaces.td"
@@ -351,4 +352,12 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
351352
}];
352353
}
353354

355+
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
356+
let mnemonic = "nested_polynomial";
357+
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
358+
let assemblyFormat = [{
359+
`<` $poly `>`
360+
}];
361+
}
362+
354363
#endif // TEST_ATTRDEFS

mlir/test/lib/Dialect/Test/TestAttributes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <tuple>
1818

1919
#include "TestTraits.h"
20+
#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
2021
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2122
#include "mlir/IR/Attributes.h"
2223
#include "mlir/IR/Diagnostics.h"

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,11 @@ def FloatElementsAttrOp : TEST_Op<"float_elements_attr"> {
232232
);
233233
}
234234

235+
def ContainingIntPolynomialAttrOp : TEST_Op<"containing_int_polynomial_attr"> {
236+
let arguments = (ins NestedPolynomialAttr:$attr);
237+
let assemblyFormat = "$attr attr-dict";
238+
}
239+
235240
// A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>.
236241
// This tests both matching and generating float elements attributes.
237242
def UpdateFloatElementsAttr : Pat<
@@ -2215,7 +2220,7 @@ def ForwardBufferOp : TEST_Op<"forward_buffer", [Pure]> {
22152220
def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
22162221
let description = [{
22172222
Reify a bound for the given index-typed value or dimension size of a shaped
2218-
value. "LB", "EQ" and "UB" bounds are supported. If `scalable` is set,
2223+
value. "LB", "EQ" and "UB" bounds are supported. If `scalable` is set,
22192224
`vscale_min` and `vscale_max` must be provided, which allows computing
22202225
a bound in terms of "vector.vscale" for a given range of vscale.
22212226
}];

mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,9 @@ static const char *const parserErrorStr =
164164
/// {2}: Code template for printing an error.
165165
/// {3}: Name of the attribute or type.
166166
/// {4}: C++ class of the parameter.
167+
/// {5}: Optional code to preload the dialect for this variable.
167168
static const char *const variableParser = R"(
168-
// Parse variable '{0}'
169+
// Parse variable '{0}'{5}
169170
_result_{0} = {1};
170171
if (::mlir::failed(_result_{0})) {{
171172
{2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
@@ -411,9 +412,28 @@ void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
411412
auto customParser = param.getParser();
412413
auto parser =
413414
customParser ? *customParser : StringRef(defaultParameterParser);
415+
416+
// If the variable points to a dialect specific entity (type of attribute),
417+
// we force load the dialect now before trying to parse it.
418+
std::string dialectLoading;
419+
if (auto *defInit = dyn_cast<llvm::DefInit>(param.getDef())) {
420+
auto *dialectValue = defInit->getDef()->getValue("dialect");
421+
if (dialectValue) {
422+
if (auto *dialectInit =
423+
dyn_cast<llvm::DefInit>(dialectValue->getValue())) {
424+
Dialect dialect(dialectInit->getDef());
425+
auto cppNamespace = dialect.getCppNamespace();
426+
std::string name = dialect.getCppClassName();
427+
dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
428+
cppNamespace + "::" + name + ">();")
429+
.str();
430+
}
431+
}
432+
}
414433
os << formatv(variableParser, param.getName(),
415434
tgfmt(parser, &ctx, param.getCppStorageType()),
416-
tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType());
435+
tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType(),
436+
dialectLoading);
417437
}
418438

419439
void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,

0 commit comments

Comments
 (0)