Skip to content

Commit edae8f6

Browse files
authored
[mlir] Make classof substitution in interface use an instance (#65492)
The substitution supported by `extraClassOf` is currently limited to only the base instance, i.e. `Operation*`, `Type` or `Attribute`, which limits the kind of checks you can perform in the `classof` implementation. Since prior to the user code, the interface concept is fetched, we can use it to construct an instance of the interface, allowing use of its methods in the `classof` check. Since an instance of the interface allows access to the base class methods through the `->` operator, I've gone ahead and replaced the substitution of `$_op/$_type/$_attr` with an interface instance. This is also consistent with `extraSharedClassDeclaration` and other methods created in the interface class which do the same.
1 parent c47c480 commit edae8f6

File tree

6 files changed

+40
-5
lines changed

6 files changed

+40
-5
lines changed

mlir/include/mlir/IR/Interfaces.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class Interface<string name, list<Interface> baseInterfacesArg = []> {
114114
// be used to better enable "optional" interfaces, where an entity only
115115
// implements the interface if some dynamic characteristic holds.
116116
// `$_attr`/`$_op`/`$_type` may be used to refer to an instance of the
117-
// entity being checked.
117+
// interface instance being checked.
118118
code extraClassOf = "";
119119

120120
// An optional set of base interfaces that this interface

mlir/test/lib/Dialect/Test/TestInterfaces.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,17 @@ class TestEffects<list<TestEffect> effects = []>
134134

135135
def TestConcreteEffect : TestEffect<"TestEffects::Concrete">;
136136

137+
def TestOptionallyImplementedOpInterface
138+
: OpInterface<"TestOptionallyImplementedOpInterface"> {
139+
let cppNamespace = "::mlir";
140+
141+
let methods = [
142+
InterfaceMethod<"", "bool", "getImplementsInterface", (ins)>,
143+
];
144+
145+
let extraClassOf = [{
146+
return $_op.getImplementsInterface();
147+
}];
148+
}
149+
137150
#endif // MLIR_TEST_DIALECT_TEST_INTERFACES

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2952,4 +2952,10 @@ def TestStoreWithARegionTerminator : TEST_Op<"store_with_a_region_terminator",
29522952
let assemblyFormat = "attr-dict";
29532953
}
29542954

2955+
def TestOpOptionallyImplementingInterface
2956+
: TEST_Op<"op_optionally_implementing_interface",
2957+
[TestOptionallyImplementedOpInterface]> {
2958+
let arguments = (ins BoolAttr:$implementsInterface);
2959+
}
2960+
29552961
#endif // TEST_OPS

mlir/test/mlir-tblgen/op-interface.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ def ExtraClassOfInterface : OpInterface<"ExtraClassOfInterface"> {
1111

1212
// DECL: class ExtraClassOfInterface
1313
// DECL: static bool classof(::mlir::Operation * base) {
14-
// DECL-NEXT: if (!getInterfaceFor(base))
14+
// DECL-NEXT: auto* concept = getInterfaceFor(base);
15+
// DECL-NEXT: if (!concept)
1516
// DECL-NEXT: return false;
16-
// DECL-NEXT: return base->someOtherMethod();
17+
// DECL-NEXT: ExtraClassOfInterface odsInterfaceInstance(base, concept);
18+
// DECL-NEXT: return odsInterfaceInstance->someOtherMethod();
1719
// DECL-NEXT: }
1820

1921
def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> {

mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,10 +582,12 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
582582
// Emit classof code if necessary.
583583
if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
584584
auto extraClassOfFmt = tblgen::FmtContext();
585-
extraClassOfFmt.addSubst(substVar, "base");
585+
extraClassOfFmt.addSubst(substVar, "odsInterfaceInstance");
586586
os << " static bool classof(" << valueType << " base) {\n"
587-
<< " if (!getInterfaceFor(base))\n"
587+
<< " auto* concept = getInterfaceFor(base);\n"
588+
<< " if (!concept)\n"
588589
" return false;\n"
590+
" " << interfaceName << " odsInterfaceInstance(base, concept);\n"
589591
<< " " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
590592
<< "\n }\n";
591593
}

mlir/unittests/IR/InterfaceTest.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,15 @@ TEST(InterfaceTest, TypeInterfaceDenseMapKey) {
5757
EXPECT_TRUE(typeSet.contains(type2));
5858
EXPECT_FALSE(typeSet.contains(type3));
5959
}
60+
61+
TEST(InterfaceTest, TestCustomClassOf) {
62+
MLIRContext context;
63+
context.loadDialect<test::TestDialect>();
64+
65+
OpBuilder builder(&context);
66+
auto op = builder.create<TestOpOptionallyImplementingInterface>(
67+
builder.getUnknownLoc(), /*implementsInterface=*/true);
68+
EXPECT_TRUE(isa<TestOptionallyImplementedOpInterface>(*op));
69+
op.setImplementsInterface(false);
70+
EXPECT_FALSE(isa<TestOptionallyImplementedOpInterface>(*op));
71+
}

0 commit comments

Comments
 (0)