Skip to content

Commit a2c4b7c

Browse files
authored
[mlir] Add convertInstruction and getSupportedInstructions to LLVMImportInterface (#86799)
This patch adds the `convertInstruction` and `getSupportedInstructions` to `LLVMImportInterface`, allowing any non-LLVM dialect to specify how to import LLVM IR instructions and overriding the default import of LLVM instructions.
1 parent ccc0256 commit a2c4b7c

File tree

6 files changed

+213
-8
lines changed

6 files changed

+213
-8
lines changed

mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ class LLVMImportDialectInterface
5252
return failure();
5353
}
5454

55+
/// Hook for derived dialect interfaces to implement the import of
56+
/// instructions into MLIR.
57+
virtual LogicalResult
58+
convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
59+
ArrayRef<llvm::Value *> llvmOperands,
60+
LLVM::ModuleImport &moduleImport) const {
61+
return failure();
62+
}
63+
5564
/// Hook for derived dialect interfaces to implement the import of metadata
5665
/// into MLIR. Attaches the converted metadata kind and node to the provided
5766
/// operation.
@@ -66,6 +75,14 @@ class LLVMImportDialectInterface
6675
/// returns the list of supported intrinsic identifiers.
6776
virtual ArrayRef<unsigned> getSupportedIntrinsics() const { return {}; }
6877

78+
/// Hook for derived dialect interfaces to publish the supported instructions.
79+
/// As every LLVM IR instruction has a unique integer identifier, the function
80+
/// returns the list of supported instruction identifiers. These identifiers
81+
/// will then be used to match LLVM instructions to the appropriate import
82+
/// interface and `convertInstruction` method. It is an error to have multiple
83+
/// interfaces overriding the same instruction.
84+
virtual ArrayRef<unsigned> getSupportedInstructions() const { return {}; }
85+
6986
/// Hook for derived dialect interfaces to publish the supported metadata
7087
/// kinds. As every metadata kind has a unique integer identifier, the
7188
/// function returns the list of supported metadata identifiers.
@@ -88,21 +105,40 @@ class LLVMImportInterface
88105
LogicalResult initializeImport() {
89106
for (const LLVMImportDialectInterface &iface : *this) {
90107
// Verify the supported intrinsics have not been mapped before.
91-
const auto *it =
108+
const auto *intrinsicIt =
92109
llvm::find_if(iface.getSupportedIntrinsics(), [&](unsigned id) {
93110
return intrinsicToDialect.count(id);
94111
});
95-
if (it != iface.getSupportedIntrinsics().end()) {
112+
if (intrinsicIt != iface.getSupportedIntrinsics().end()) {
113+
return emitError(
114+
UnknownLoc::get(iface.getContext()),
115+
llvm::formatv(
116+
"expected unique conversion for intrinsic ({0}), but "
117+
"got conflicting {1} and {2} conversions",
118+
*intrinsicIt, iface.getDialect()->getNamespace(),
119+
intrinsicToDialect.lookup(*intrinsicIt)->getNamespace()));
120+
}
121+
const auto *instructionIt =
122+
llvm::find_if(iface.getSupportedInstructions(), [&](unsigned id) {
123+
return instructionToDialect.count(id);
124+
});
125+
if (instructionIt != iface.getSupportedInstructions().end()) {
96126
return emitError(
97127
UnknownLoc::get(iface.getContext()),
98-
llvm::formatv("expected unique conversion for intrinsic ({0}), but "
99-
"got conflicting {1} and {2} conversions",
100-
*it, iface.getDialect()->getNamespace(),
101-
intrinsicToDialect.lookup(*it)->getNamespace()));
128+
llvm::formatv(
129+
"expected unique conversion for instruction ({0}), but "
130+
"got conflicting {1} and {2} conversions",
131+
*intrinsicIt, iface.getDialect()->getNamespace(),
132+
instructionToDialect.lookup(*intrinsicIt)
133+
->getDialect()
134+
->getNamespace()));
102135
}
103136
// Add a mapping for all supported intrinsic identifiers.
104137
for (unsigned id : iface.getSupportedIntrinsics())
105138
intrinsicToDialect[id] = iface.getDialect();
139+
// Add a mapping for all supported instruction identifiers.
140+
for (unsigned id : iface.getSupportedInstructions())
141+
instructionToDialect[id] = &iface;
106142
// Add a mapping for all supported metadata kinds.
107143
for (unsigned kind : iface.getSupportedMetadata())
108144
metadataToDialect[kind].push_back(iface.getDialect());
@@ -132,6 +168,26 @@ class LLVMImportInterface
132168
return intrinsicToDialect.count(id);
133169
}
134170

171+
/// Converts the LLVM instruction to an MLIR operation if a conversion exists.
172+
/// Returns failure otherwise.
173+
LogicalResult convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
174+
ArrayRef<llvm::Value *> llvmOperands,
175+
LLVM::ModuleImport &moduleImport) const {
176+
// Lookup the dialect interface for the given instruction.
177+
const LLVMImportDialectInterface *iface =
178+
instructionToDialect.lookup(inst->getOpcode());
179+
if (!iface)
180+
return failure();
181+
182+
return iface->convertInstruction(builder, inst, llvmOperands, moduleImport);
183+
}
184+
185+
/// Returns true if the given LLVM IR instruction is convertible to an MLIR
186+
/// operation.
187+
bool isConvertibleInstruction(unsigned id) {
188+
return instructionToDialect.count(id);
189+
}
190+
135191
/// Attaches the given LLVM metadata to the imported operation if a conversion
136192
/// to one or more MLIR dialect attributes exists and succeeds. Returns
137193
/// success if at least one of the conversions is successful and failure if
@@ -166,6 +222,7 @@ class LLVMImportInterface
166222

167223
private:
168224
DenseMap<unsigned, Dialect *> intrinsicToDialect;
225+
DenseMap<unsigned, const LLVMImportDialectInterface *> instructionToDialect;
169226
DenseMap<unsigned, SmallVector<Dialect *, 1>> metadataToDialect;
170227
};
171228

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,18 @@ static SmallVector<int64_t> getPositionFromIndices(ArrayRef<unsigned> indices) {
123123
/// access to the private module import methods.
124124
static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
125125
llvm::Instruction *inst,
126-
ModuleImport &moduleImport) {
126+
ModuleImport &moduleImport,
127+
LLVMImportInterface &iface) {
127128
// Copy the operands to an LLVM operands array reference for conversion.
128129
SmallVector<llvm::Value *> operands(inst->operands());
129130
ArrayRef<llvm::Value *> llvmOperands(operands);
130131

131132
// Convert all instructions that provide an MLIR builder.
133+
if (iface.isConvertibleInstruction(inst->getOpcode()))
134+
return iface.convertInstruction(odsBuilder, inst, llvmOperands,
135+
moduleImport);
136+
// TODO: Implement the `convertInstruction` hooks in the
137+
// `LLVMDialectLLVMIRImportInterface` and move the following include there.
132138
#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
133139
return failure();
134140
}
@@ -1596,7 +1602,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
15961602
}
15971603

15981604
// Convert all instructions that have an mlirBuilder.
1599-
if (succeeded(convertInstructionImpl(builder, inst, *this)))
1605+
if (succeeded(convertInstructionImpl(builder, inst, *this, iface)))
16001606
return success();
16011607

16021608
return emitError(loc) << "unhandled instruction: " << diag(*inst);
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
; RUN: mlir-translate -test-import-llvmir %s | FileCheck %s
2+
3+
; CHECK-LABEL: @custom_load
4+
; CHECK-SAME: %[[PTR:[[:alnum:]]+]]
5+
define double @custom_load(ptr %ptr) {
6+
; CHECK: %[[LOAD:[0-9]+]] = llvm.load %[[PTR]] : !llvm.ptr -> f64
7+
; CHECK: %[[TEST:[0-9]+]] = "test.same_operand_element_type"(%[[LOAD]], %[[LOAD]]) : (f64, f64) -> f64
8+
%1 = load double, ptr %ptr
9+
; CHECK: llvm.return %[[TEST]] : f64
10+
ret double %1
11+
}

mlir/test/lib/Dialect/Test/CMakeLists.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
22
TestDialect.cpp
33
TestPatterns.cpp
44
TestTraits.cpp
5+
TestFromLLVMIRTranslation.cpp
56
TestToLLVMIRTranslation.cpp
67
)
78

@@ -86,6 +87,23 @@ add_mlir_library(MLIRTestDialect
8687
MLIRTransforms
8788
)
8889

90+
add_mlir_translation_library(MLIRTestFromLLVMIRTranslation
91+
TestFromLLVMIRTranslation.cpp
92+
93+
EXCLUDE_FROM_LIBMLIR
94+
95+
LINK_COMPONENTS
96+
Core
97+
98+
LINK_LIBS PUBLIC
99+
MLIRIR
100+
MLIRLLVMDialect
101+
MLIRTestDialect
102+
MLIRSupport
103+
MLIRTargetLLVMIRImport
104+
MLIRLLVMIRToLLVMTranslation
105+
)
106+
89107
add_mlir_translation_library(MLIRTestToLLVMIRTranslation
90108
TestToLLVMIRTranslation.cpp
91109

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
//===- TestFromLLVMIRTranslation.cpp - Import Test dialect from LLVM IR ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a translation between LLVM IR and the MLIR Test dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "TestDialect.h"
14+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15+
#include "mlir/IR/Builders.h"
16+
#include "mlir/IR/BuiltinAttributes.h"
17+
#include "mlir/IR/BuiltinOps.h"
18+
#include "mlir/Support/LLVM.h"
19+
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
20+
#include "mlir/Target/LLVMIR/Import.h"
21+
#include "mlir/Target/LLVMIR/ModuleImport.h"
22+
#include "mlir/Tools/mlir-translate/Translation.h"
23+
24+
#include "llvm/IR/Instructions.h"
25+
#include "llvm/IR/Module.h"
26+
#include "llvm/IR/Verifier.h"
27+
#include "llvm/IRReader/IRReader.h"
28+
#include "llvm/Support/SourceMgr.h"
29+
30+
using namespace mlir;
31+
using namespace test;
32+
33+
static ArrayRef<unsigned> getSupportedInstructionsImpl() {
34+
static unsigned instructions[] = {llvm::Instruction::Load};
35+
return instructions;
36+
}
37+
38+
static LogicalResult convertLoad(OpBuilder &builder, llvm::Instruction *inst,
39+
ArrayRef<llvm::Value *> llvmOperands,
40+
LLVM::ModuleImport &moduleImport) {
41+
FailureOr<Value> addr = moduleImport.convertValue(llvmOperands[0]);
42+
if (failed(addr))
43+
return failure();
44+
// Create the LoadOp
45+
Value loadOp = builder.create<LLVM::LoadOp>(
46+
moduleImport.translateLoc(inst->getDebugLoc()),
47+
moduleImport.convertType(inst->getType()), *addr);
48+
moduleImport.mapValue(inst) = builder.create<SameOperandElementTypeOp>(
49+
loadOp.getLoc(), loadOp.getType(), loadOp, loadOp);
50+
return success();
51+
}
52+
53+
namespace {
54+
class TestDialectLLVMImportDialectInterface
55+
: public LLVMImportDialectInterface {
56+
public:
57+
using LLVMImportDialectInterface::LLVMImportDialectInterface;
58+
59+
LogicalResult
60+
convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
61+
ArrayRef<llvm::Value *> llvmOperands,
62+
LLVM::ModuleImport &moduleImport) const override {
63+
switch (inst->getOpcode()) {
64+
case llvm::Instruction::Load:
65+
return convertLoad(builder, inst, llvmOperands, moduleImport);
66+
default:
67+
break;
68+
}
69+
return failure();
70+
}
71+
72+
ArrayRef<unsigned> getSupportedInstructions() const override {
73+
return getSupportedInstructionsImpl();
74+
}
75+
};
76+
} // namespace
77+
78+
namespace mlir {
79+
void registerTestFromLLVMIR() {
80+
TranslateToMLIRRegistration registration(
81+
"test-import-llvmir", "test dialect from LLVM IR",
82+
[](llvm::SourceMgr &sourceMgr,
83+
MLIRContext *context) -> OwningOpRef<Operation *> {
84+
llvm::SMDiagnostic err;
85+
llvm::LLVMContext llvmContext;
86+
std::unique_ptr<llvm::Module> llvmModule =
87+
llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()),
88+
err, llvmContext);
89+
if (!llvmModule) {
90+
std::string errStr;
91+
llvm::raw_string_ostream errStream(errStr);
92+
err.print(/*ProgName=*/"", errStream);
93+
emitError(UnknownLoc::get(context)) << errStream.str();
94+
return {};
95+
}
96+
if (llvm::verifyModule(*llvmModule, &llvm::errs()))
97+
return nullptr;
98+
99+
return translateLLVMIRToModule(std::move(llvmModule), context, false);
100+
},
101+
[](DialectRegistry &registry) {
102+
registry.insert<DLTIDialect>();
103+
registry.insert<test::TestDialect>();
104+
registerLLVMDialectImport(registry);
105+
registry.addExtension(
106+
+[](MLIRContext *ctx, test::TestDialect *dialect) {
107+
dialect->addInterfaces<TestDialectLLVMImportDialectInterface>();
108+
});
109+
});
110+
}
111+
} // namespace mlir

mlir/tools/mlir-translate/mlir-translate.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ void registerTestRoundtripSPIRV();
2323
void registerTestRoundtripDebugSPIRV();
2424
#ifdef MLIR_INCLUDE_TESTS
2525
void registerTestToLLVMIR();
26+
void registerTestFromLLVMIR();
2627
#endif
2728
} // namespace mlir
2829

@@ -31,6 +32,7 @@ static void registerTestTranslations() {
3132
registerTestRoundtripDebugSPIRV();
3233
#ifdef MLIR_INCLUDE_TESTS
3334
registerTestToLLVMIR();
35+
registerTestFromLLVMIR();
3436
#endif
3537
}
3638

0 commit comments

Comments
 (0)