Skip to content

[mlir] Add convertInstruction and getSupportedInstructions to LLVMImportInterface #86799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 63 additions & 6 deletions mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ class LLVMImportDialectInterface
return failure();
}

/// Hook for derived dialect interfaces to implement the import of
/// instructions into MLIR.
virtual LogicalResult
convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
ArrayRef<llvm::Value *> llvmOperands,
LLVM::ModuleImport &moduleImport) const {
return failure();
}

/// Hook for derived dialect interfaces to implement the import of metadata
/// into MLIR. Attaches the converted metadata kind and node to the provided
/// operation.
Expand All @@ -66,6 +75,14 @@ class LLVMImportDialectInterface
/// returns the list of supported intrinsic identifiers.
virtual ArrayRef<unsigned> getSupportedIntrinsics() const { return {}; }

/// Hook for derived dialect interfaces to publish the supported instructions.
/// As every LLVM IR instruction has a unique integer identifier, the function
/// returns the list of supported instruction identifiers. These identifiers
/// will then be used to match LLVM instructions to the appropriate import
/// interface and `convertInstruction` method. It is an error to have multiple
/// interfaces overriding the same instruction.
virtual ArrayRef<unsigned> getSupportedInstructions() const { return {}; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if multiple dialect return overlapping results here?

Also the documentation is saying what this does, but not why it is useful or how this should be used, can you expand?

Copy link
Contributor Author

@fabianmcg fabianmcg Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If multiple dialects try to register the same instruction the interface emits an error and aborts the entire translation, see the initializeImport method.
Small caveat: this mechanism doesn't check if it's colliding with LLVM tablegen conversions, only with respect to other interfaces -which I'd argue is okay, bc allows overriding LLVM behavior.

With respect to how should be used, I added a test where:

%1 = load double, ptr %ptr 

gets imported as:

 %1  = llvm.load %pt : !llvm.ptr -> f64
 %2 = "test.same_operand_element_type"(%1, %1) : 

The example shows how to use the mechanism.
I'll add to the docs that this allows overriding LLVM behavior.


/// Hook for derived dialect interfaces to publish the supported metadata
/// kinds. As every metadata kind has a unique integer identifier, the
/// function returns the list of supported metadata identifiers.
Expand All @@ -88,21 +105,40 @@ class LLVMImportInterface
LogicalResult initializeImport() {
for (const LLVMImportDialectInterface &iface : *this) {
// Verify the supported intrinsics have not been mapped before.
const auto *it =
const auto *intrinsicIt =
llvm::find_if(iface.getSupportedIntrinsics(), [&](unsigned id) {
return intrinsicToDialect.count(id);
});
if (it != iface.getSupportedIntrinsics().end()) {
if (intrinsicIt != iface.getSupportedIntrinsics().end()) {
return emitError(
UnknownLoc::get(iface.getContext()),
llvm::formatv(
"expected unique conversion for intrinsic ({0}), but "
"got conflicting {1} and {2} conversions",
*intrinsicIt, iface.getDialect()->getNamespace(),
intrinsicToDialect.lookup(*intrinsicIt)->getNamespace()));
}
const auto *instructionIt =
llvm::find_if(iface.getSupportedInstructions(), [&](unsigned id) {
return instructionToDialect.count(id);
});
if (instructionIt != iface.getSupportedInstructions().end()) {
return emitError(
UnknownLoc::get(iface.getContext()),
llvm::formatv("expected unique conversion for intrinsic ({0}), but "
"got conflicting {1} and {2} conversions",
*it, iface.getDialect()->getNamespace(),
intrinsicToDialect.lookup(*it)->getNamespace()));
llvm::formatv(
"expected unique conversion for instruction ({0}), but "
"got conflicting {1} and {2} conversions",
*intrinsicIt, iface.getDialect()->getNamespace(),
instructionToDialect.lookup(*intrinsicIt)
->getDialect()
->getNamespace()));
}
// Add a mapping for all supported intrinsic identifiers.
for (unsigned id : iface.getSupportedIntrinsics())
intrinsicToDialect[id] = iface.getDialect();
// Add a mapping for all supported instruction identifiers.
for (unsigned id : iface.getSupportedInstructions())
instructionToDialect[id] = &iface;
// Add a mapping for all supported metadata kinds.
for (unsigned kind : iface.getSupportedMetadata())
metadataToDialect[kind].push_back(iface.getDialect());
Expand Down Expand Up @@ -132,6 +168,26 @@ class LLVMImportInterface
return intrinsicToDialect.count(id);
}

/// Converts the LLVM instruction to an MLIR operation if a conversion exists.
/// Returns failure otherwise.
LogicalResult convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
ArrayRef<llvm::Value *> llvmOperands,
LLVM::ModuleImport &moduleImport) const {
// Lookup the dialect interface for the given instruction.
const LLVMImportDialectInterface *iface =
instructionToDialect.lookup(inst->getOpcode());
if (!iface)
return failure();

return iface->convertInstruction(builder, inst, llvmOperands, moduleImport);
}

/// Returns true if the given LLVM IR instruction is convertible to an MLIR
/// operation.
bool isConvertibleInstruction(unsigned id) {
return instructionToDialect.count(id);
}

/// Attaches the given LLVM metadata to the imported operation if a conversion
/// to one or more MLIR dialect attributes exists and succeeds. Returns
/// success if at least one of the conversions is successful and failure if
Expand Down Expand Up @@ -166,6 +222,7 @@ class LLVMImportInterface

private:
DenseMap<unsigned, Dialect *> intrinsicToDialect;
DenseMap<unsigned, const LLVMImportDialectInterface *> instructionToDialect;
DenseMap<unsigned, SmallVector<Dialect *, 1>> metadataToDialect;
};

Expand Down
10 changes: 8 additions & 2 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,18 @@ static SmallVector<int64_t> getPositionFromIndices(ArrayRef<unsigned> indices) {
/// access to the private module import methods.
static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
llvm::Instruction *inst,
ModuleImport &moduleImport) {
ModuleImport &moduleImport,
LLVMImportInterface &iface) {
// Copy the operands to an LLVM operands array reference for conversion.
SmallVector<llvm::Value *> operands(inst->operands());
ArrayRef<llvm::Value *> llvmOperands(operands);

// Convert all instructions that provide an MLIR builder.
if (iface.isConvertibleInstruction(inst->getOpcode()))
return iface.convertInstruction(odsBuilder, inst, llvmOperands,
moduleImport);
// TODO: Implement the `convertInstruction` hooks in the
// `LLVMDialectLLVMIRImportInterface` and move the following include there.
#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
return failure();
}
Expand Down Expand Up @@ -1596,7 +1602,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
}

// Convert all instructions that have an mlirBuilder.
if (succeeded(convertInstructionImpl(builder, inst, *this)))
if (succeeded(convertInstructionImpl(builder, inst, *this, iface)))
return success();

return emitError(loc) << "unhandled instruction: " << diag(*inst);
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Target/LLVMIR/Import/test.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
; RUN: mlir-translate -test-import-llvmir %s | FileCheck %s

; CHECK-LABEL: @custom_load
; CHECK-SAME: %[[PTR:[[:alnum:]]+]]
define double @custom_load(ptr %ptr) {
; CHECK: %[[LOAD:[0-9]+]] = llvm.load %[[PTR]] : !llvm.ptr -> f64
; CHECK: %[[TEST:[0-9]+]] = "test.same_operand_element_type"(%[[LOAD]], %[[LOAD]]) : (f64, f64) -> f64
%1 = load double, ptr %ptr
; CHECK: llvm.return %[[TEST]] : f64
ret double %1
}
18 changes: 18 additions & 0 deletions mlir/test/lib/Dialect/Test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
TestDialect.cpp
TestPatterns.cpp
TestTraits.cpp
TestFromLLVMIRTranslation.cpp
TestToLLVMIRTranslation.cpp
)

Expand Down Expand Up @@ -86,6 +87,23 @@ add_mlir_library(MLIRTestDialect
MLIRTransforms
)

add_mlir_translation_library(MLIRTestFromLLVMIRTranslation
TestFromLLVMIRTranslation.cpp

EXCLUDE_FROM_LIBMLIR

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMDialect
MLIRTestDialect
MLIRSupport
MLIRTargetLLVMIRImport
MLIRLLVMIRToLLVMTranslation
)

add_mlir_translation_library(MLIRTestToLLVMIRTranslation
TestToLLVMIRTranslation.cpp

Expand Down
111 changes: 111 additions & 0 deletions mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
//===- TestFromLLVMIRTranslation.cpp - Import Test dialect from LLVM IR ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a translation between LLVM IR and the MLIR Test dialect.
//
//===----------------------------------------------------------------------===//

#include "TestDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
#include "mlir/Target/LLVMIR/Import.h"
#include "mlir/Target/LLVMIR/ModuleImport.h"
#include "mlir/Tools/mlir-translate/Translation.h"

#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"

using namespace mlir;
using namespace test;

static ArrayRef<unsigned> getSupportedInstructionsImpl() {
static unsigned instructions[] = {llvm::Instruction::Load};
return instructions;
}

static LogicalResult convertLoad(OpBuilder &builder, llvm::Instruction *inst,
ArrayRef<llvm::Value *> llvmOperands,
LLVM::ModuleImport &moduleImport) {
FailureOr<Value> addr = moduleImport.convertValue(llvmOperands[0]);
if (failed(addr))
return failure();
// Create the LoadOp
Value loadOp = builder.create<LLVM::LoadOp>(
moduleImport.translateLoc(inst->getDebugLoc()),
moduleImport.convertType(inst->getType()), *addr);
moduleImport.mapValue(inst) = builder.create<SameOperandElementTypeOp>(
loadOp.getLoc(), loadOp.getType(), loadOp, loadOp);
return success();
}

namespace {
class TestDialectLLVMImportDialectInterface
: public LLVMImportDialectInterface {
public:
using LLVMImportDialectInterface::LLVMImportDialectInterface;

LogicalResult
convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
ArrayRef<llvm::Value *> llvmOperands,
LLVM::ModuleImport &moduleImport) const override {
switch (inst->getOpcode()) {
case llvm::Instruction::Load:
return convertLoad(builder, inst, llvmOperands, moduleImport);
default:
break;
}
return failure();
}

ArrayRef<unsigned> getSupportedInstructions() const override {
return getSupportedInstructionsImpl();
}
};
} // namespace

namespace mlir {
void registerTestFromLLVMIR() {
TranslateToMLIRRegistration registration(
"test-import-llvmir", "test dialect from LLVM IR",
[](llvm::SourceMgr &sourceMgr,
MLIRContext *context) -> OwningOpRef<Operation *> {
llvm::SMDiagnostic err;
llvm::LLVMContext llvmContext;
std::unique_ptr<llvm::Module> llvmModule =
llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()),
err, llvmContext);
if (!llvmModule) {
std::string errStr;
llvm::raw_string_ostream errStream(errStr);
err.print(/*ProgName=*/"", errStream);
emitError(UnknownLoc::get(context)) << errStream.str();
return {};
}
if (llvm::verifyModule(*llvmModule, &llvm::errs()))
return nullptr;

return translateLLVMIRToModule(std::move(llvmModule), context, false);
},
[](DialectRegistry &registry) {
registry.insert<DLTIDialect>();
registry.insert<test::TestDialect>();
registerLLVMDialectImport(registry);
registry.addExtension(
+[](MLIRContext *ctx, test::TestDialect *dialect) {
dialect->addInterfaces<TestDialectLLVMImportDialectInterface>();
});
});
}
} // namespace mlir
2 changes: 2 additions & 0 deletions mlir/tools/mlir-translate/mlir-translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ void registerTestRoundtripSPIRV();
void registerTestRoundtripDebugSPIRV();
#ifdef MLIR_INCLUDE_TESTS
void registerTestToLLVMIR();
void registerTestFromLLVMIR();
#endif
} // namespace mlir

Expand All @@ -31,6 +32,7 @@ static void registerTestTranslations() {
registerTestRoundtripDebugSPIRV();
#ifdef MLIR_INCLUDE_TESTS
registerTestToLLVMIR();
registerTestFromLLVMIR();
#endif
}

Expand Down