Skip to content

Commit 9b50844

Browse files
author
Vladislav Vinogradov
committed
[mlir] Fix delayed object interfaces registration
Store both interfaceID and objectID as key for interface registration callback. Otherwise the implementation allows to register only one external model per one object in the single dialect. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D107274
1 parent 4f4f278 commit 9b50844

File tree

3 files changed

+51
-21
lines changed

3 files changed

+51
-21
lines changed

mlir/include/mlir/IR/Dialect.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Support/TypeID.h"
1818

1919
#include <map>
20+
#include <tuple>
2021

2122
namespace mlir {
2223
class DialectAsmParser;
@@ -285,7 +286,7 @@ class DialectRegistry {
285286
SmallVector<std::pair<TypeID, DialectInterfaceAllocatorFunction>, 2>
286287
dialectInterfaces;
287288
/// Attribute/Operation/Type interfaces.
288-
SmallVector<std::pair<TypeID, ObjectInterfaceAllocatorFunction>, 2>
289+
SmallVector<std::tuple<TypeID, TypeID, ObjectInterfaceAllocatorFunction>, 2>
289290
objectInterfaces;
290291
};
291292

@@ -367,7 +368,8 @@ class DialectRegistry {
367368
void addOpInterface() {
368369
StringRef opName = OpTy::getOperationName();
369370
StringRef dialectName = opName.split('.').first;
370-
addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(),
371+
addObjectInterface(dialectName, TypeID::get<OpTy>(),
372+
ModelTy::Interface::getInterfaceID(),
371373
[](MLIRContext *context) {
372374
OpTy::template attachInterface<ModelTy>(*context);
373375
});
@@ -401,14 +403,16 @@ class DialectRegistry {
401403

402404
/// Add an attribute/operation/type interface constructible with the given
403405
/// allocation function to the dialect identified by its namespace.
404-
void addObjectInterface(StringRef dialectName, TypeID interfaceTypeID,
406+
void addObjectInterface(StringRef dialectName, TypeID objectID,
407+
TypeID interfaceTypeID,
405408
ObjectInterfaceAllocatorFunction allocator);
406409

407410
/// Add an external model for an attribute/type interface to the dialect
408411
/// identified by its namespace.
409412
template <typename ObjectTy, typename ModelTy>
410413
void addStorageUserInterface(StringRef dialectName) {
411-
addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(),
414+
addObjectInterface(dialectName, TypeID::get<ObjectTy>(),
415+
ModelTy::Interface::getInterfaceID(),
412416
[](MLIRContext *context) {
413417
ObjectTy::template attachInterface<ModelTy>(*context);
414418
});

mlir/lib/IR/Dialect.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,27 @@ void DialectRegistry::addDialectInterface(
5858
}
5959

6060
void DialectRegistry::addObjectInterface(
61-
StringRef dialectName, TypeID interfaceTypeID,
61+
StringRef dialectName, TypeID objectID, TypeID interfaceTypeID,
6262
ObjectInterfaceAllocatorFunction allocator) {
6363
assert(allocator && "unexpected null interface allocation function");
64+
6465
auto it = registry.find(dialectName.str());
6566
assert(it != registry.end() &&
6667
"adding an interface for an op from an unregistered dialect");
6768

68-
auto &ifaces = interfaces[it->second.first];
69-
for (const auto &kvp : ifaces.objectInterfaces) {
70-
if (kvp.first == interfaceTypeID) {
69+
auto dialectID = it->second.first;
70+
auto &ifaces = interfaces[dialectID];
71+
72+
for (const auto &info : ifaces.objectInterfaces) {
73+
if (std::get<0>(info) == objectID && std::get<1>(info) == interfaceTypeID) {
7174
LLVM_DEBUG(llvm::dbgs()
7275
<< "[" DEBUG_TYPE
7376
"] repeated interface object interface registration");
7477
return;
7578
}
7679
}
7780

78-
ifaces.objectInterfaces.emplace_back(interfaceTypeID, allocator);
81+
ifaces.objectInterfaces.emplace_back(objectID, interfaceTypeID, allocator);
7982
}
8083

8184
DialectAllocatorFunctionRef
@@ -110,8 +113,8 @@ void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
110113
}
111114

112115
// Add attribute, operation and type interfaces.
113-
for (const auto &kvp : it->getSecond().objectInterfaces)
114-
kvp.second(dialect->getContext());
116+
for (const auto &info : it->getSecond().objectInterfaces)
117+
std::get<2>(info)(dialect->getContext());
115118
}
116119

117120
//===----------------------------------------------------------------------===//

mlir/unittests/IR/InterfaceAttachmentTest.cpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -321,55 +321,78 @@ TEST(InterfaceAttachment, Operation) {
321321
ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp.getOperation()));
322322
}
323323

324+
template <class ConcreteOp>
324325
struct TestExternalTestOpModel
325-
: public TestExternalOpInterface::ExternalModel<TestExternalTestOpModel,
326-
test::OpJ> {
326+
: public TestExternalOpInterface::ExternalModel<
327+
TestExternalTestOpModel<ConcreteOp>, ConcreteOp> {
327328
unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
328329
return op->getName().getStringRef().size() + arg;
329330
}
330331

331332
static unsigned getNameLengthPlusArgTwice(unsigned arg) {
332-
return test::OpJ::getOperationName().size() + 2 * arg;
333+
return ConcreteOp::getOperationName().size() + 2 * arg;
333334
}
334335
};
335336

336337
TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
337338
DialectRegistry registry;
338339
registry.insert<test::TestDialect>();
339340
registry.addOpInterface<ModuleOp, TestExternalOpModel>();
340-
registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
341+
registry.addOpInterface<test::OpJ, TestExternalTestOpModel<test::OpJ>>();
342+
registry.addOpInterface<test::OpH, TestExternalTestOpModel<test::OpH>>();
341343

342344
// Construct the context directly from a registry. The interfaces are expected
343345
// to be readily available on operations.
344346
MLIRContext context(registry);
345347
context.loadDialect<test::TestDialect>();
348+
346349
ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
347350
OpBuilder builder(module);
348-
auto op =
351+
auto opJ =
349352
builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
353+
auto opH =
354+
builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
355+
auto opI =
356+
builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
357+
350358
EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
351-
EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
359+
EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
360+
EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
361+
EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
352362
}
353363

354364
TEST(InterfaceAttachment, OperationDelayedContextAppend) {
355365
DialectRegistry registry;
356366
registry.insert<test::TestDialect>();
357367
registry.addOpInterface<ModuleOp, TestExternalOpModel>();
358-
registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
368+
registry.addOpInterface<test::OpJ, TestExternalTestOpModel<test::OpJ>>();
369+
registry.addOpInterface<test::OpH, TestExternalTestOpModel<test::OpH>>();
359370

360371
// Construct the context, create ops, and only then append the registry. The
361372
// interfaces are expected to be available after appending the registry.
362373
MLIRContext context;
363374
context.loadDialect<test::TestDialect>();
375+
364376
ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
365377
OpBuilder builder(module);
366-
auto op =
378+
auto opJ =
367379
builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
380+
auto opH =
381+
builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
382+
auto opI =
383+
builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
384+
368385
EXPECT_FALSE(isa<TestExternalOpInterface>(module.getOperation()));
369-
EXPECT_FALSE(isa<TestExternalOpInterface>(op.getOperation()));
386+
EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation()));
387+
EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation()));
388+
EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
389+
370390
context.appendDialectRegistry(registry);
391+
371392
EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
372-
EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
393+
EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
394+
EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
395+
EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
373396
}
374397

375398
} // end namespace

0 commit comments

Comments
 (0)