Skip to content

[mlir] Add convertInstruction and getSupportedInstructions to LLVMImportInterface #86799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 7, 2024

Conversation

fabianmcg
Copy link
Contributor

This patch adds the convertInstruction and getSupportedInstructions to LLVMImportInterface, allowing any non-LLVM dialect to specify how to import LLVM IR instructions.

This patch is necessary for #73057

…VMImportInterface`

This patch adds the `convertInstruction` and `getSupportedInstructions` to
`LLVMImportInterface`, allowing any non-LLVM dialect to specify how to import
LLVM IR instructions.

This patch is necessary for llvm#73057
@fabianmcg fabianmcg marked this pull request as ready for review March 27, 2024 13:26
@fabianmcg fabianmcg requested review from gysit and zero9178 March 27, 2024 13:26
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:llvm mlir labels Mar 27, 2024
@llvmbot
Copy link
Member

llvmbot commented Mar 27, 2024

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Fabian Mora (fabianmcg)

Changes

This patch adds the convertInstruction and getSupportedInstructions to LLVMImportInterface, allowing any non-LLVM dialect to specify how to import LLVM IR instructions.

This patch is necessary for #73057


Full diff: https://github.com/llvm/llvm-project/pull/86799.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h (+53)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+6-2)
  • (added) mlir/test/Target/LLVMIR/Import/test.ll (+11)
  • (modified) mlir/test/lib/Dialect/Test/CMakeLists.txt (+18)
  • (added) mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp (+113)
  • (modified) mlir/tools/mlir-translate/mlir-translate.cpp (+2)
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
index 9f8da83ae9c205..1bd81fcd9400cb 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
@@ -52,6 +52,15 @@ class LLVMImportDialectInterface
     return failure();
   }
 
+  /// Hook for derived dialect interfaces to implement the import of
+  /// instructions into MLIR.
+  virtual LogicalResult
+  convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
+                     ArrayRef<llvm::Value *> llvmOperands,
+                     LLVM::ModuleImport &moduleImport) const {
+    return failure();
+  }
+
   /// Hook for derived dialect interfaces to implement the import of metadata
   /// into MLIR. Attaches the converted metadata kind and node to the provided
   /// operation.
@@ -66,6 +75,11 @@ class LLVMImportDialectInterface
   /// returns the list of supported intrinsic identifiers.
   virtual ArrayRef<unsigned> getSupportedIntrinsics() const { return {}; }
 
+  /// Hook for derived dialect interfaces to publish the supported instructions.
+  /// As every LLVM IR instructions has a unique integer identifier, the
+  /// function returns the list of supported instructions identifiers.
+  virtual ArrayRef<unsigned> getSupportedInstructions() const { return {}; }
+
   /// Hook for derived dialect interfaces to publish the supported metadata
   /// kinds. As every metadata kind has a unique integer identifier, the
   /// function returns the list of supported metadata identifiers.
@@ -100,9 +114,27 @@ class LLVMImportInterface
                           *it, iface.getDialect()->getNamespace(),
                           intrinsicToDialect.lookup(*it)->getNamespace()));
       }
+      const auto *instIt =
+          llvm::find_if(iface.getSupportedInstructions(), [&](unsigned id) {
+            return instructionToDialect.count(id);
+          });
+      if (instIt != iface.getSupportedInstructions().end()) {
+        return emitError(
+            UnknownLoc::get(iface.getContext()),
+            llvm::formatv(
+                "expected unique conversion for instruction ({0}), but "
+                "got conflicting {1} and {2} conversions",
+                *it, iface.getDialect()->getNamespace(),
+                instructionToDialect.lookup(*it)
+                    ->getDialect()
+                    ->getNamespace()));
+      }
       // Add a mapping for all supported intrinsic identifiers.
       for (unsigned id : iface.getSupportedIntrinsics())
         intrinsicToDialect[id] = iface.getDialect();
+      // Add a mapping for all supported instruction identifiers.
+      for (unsigned id : iface.getSupportedInstructions())
+        instructionToDialect[id] = &iface;
       // Add a mapping for all supported metadata kinds.
       for (unsigned kind : iface.getSupportedMetadata())
         metadataToDialect[kind].push_back(iface.getDialect());
@@ -132,6 +164,26 @@ class LLVMImportInterface
     return intrinsicToDialect.count(id);
   }
 
+  /// Converts the LLVM instruction to an MLIR operation if a conversion exists.
+  /// Returns failure otherwise.
+  LogicalResult convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
+                                   ArrayRef<llvm::Value *> llvmOperands,
+                                   LLVM::ModuleImport &moduleImport) const {
+    // Lookup the dialect interface for the given instruction.
+    const LLVMImportDialectInterface *iface =
+        instructionToDialect.lookup(inst->getOpcode());
+    if (!iface)
+      return failure();
+
+    return iface->convertInstruction(builder, inst, llvmOperands, moduleImport);
+  }
+
+  /// Returns true if the given LLVM IR instruction is convertible to an MLIR
+  /// operation.
+  bool isConvertibleInstruction(unsigned id) {
+    return instructionToDialect.count(id);
+  }
+
   /// Attaches the given LLVM metadata to the imported operation if a conversion
   /// to one or more MLIR dialect attributes exists and succeeds. Returns
   /// success if at least one of the conversions is successful and failure if
@@ -166,6 +218,7 @@ class LLVMImportInterface
 
 private:
   DenseMap<unsigned, Dialect *> intrinsicToDialect;
+  DenseMap<unsigned, const LLVMImportDialectInterface *> instructionToDialect;
   DenseMap<unsigned, SmallVector<Dialect *, 1>> metadataToDialect;
 };
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 6e70d52fa760b6..3320c6cd3ab24d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -123,12 +123,16 @@ static SmallVector<int64_t> getPositionFromIndices(ArrayRef<unsigned> indices) {
 /// access to the private module import methods.
 static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
                                             llvm::Instruction *inst,
-                                            ModuleImport &moduleImport) {
+                                            ModuleImport &moduleImport,
+                                            LLVMImportInterface &importIface) {
   // Copy the operands to an LLVM operands array reference for conversion.
   SmallVector<llvm::Value *> operands(inst->operands());
   ArrayRef<llvm::Value *> llvmOperands(operands);
 
   // Convert all instructions that provide an MLIR builder.
+  if (importIface.isConvertibleInstruction(inst->getOpcode()))
+    return importIface.convertInstruction(odsBuilder, inst, llvmOperands,
+                                          moduleImport);
 #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
   return failure();
 }
@@ -1596,7 +1600,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
   }
 
   // Convert all instructions that have an mlirBuilder.
-  if (succeeded(convertInstructionImpl(builder, inst, *this)))
+  if (succeeded(convertInstructionImpl(builder, inst, *this, iface)))
     return success();
 
   return emitError(loc) << "unhandled instruction: " << diag(*inst);
diff --git a/mlir/test/Target/LLVMIR/Import/test.ll b/mlir/test/Target/LLVMIR/Import/test.ll
new file mode 100644
index 00000000000000..6f3dd1acf9586d
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/test.ll
@@ -0,0 +1,11 @@
+; RUN: mlir-translate -test-import-llvmir %s | FileCheck %s
+
+; CHECK-LABEL: @custom_load
+; CHECK-SAME:  %[[PTR:[a-zA-Z0-9]+]]
+define double @custom_load(ptr %ptr) {
+  ; CHECK:  %[[LOAD:[0-9]+]] = llvm.load %[[PTR]] : !llvm.ptr -> f64
+  ; CHECK:  %[[TEST:[0-9]+]] = "test.same_operand_element_type"(%[[LOAD]], %[[LOAD]]) : (f64, f64) -> f64
+  %1 = load double, ptr %ptr
+  ; CHECK:   llvm.return %[[TEST]] : f64
+  ret double %1
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index b82b1631eead59..47ddcf6524748c 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
   TestDialect.cpp
   TestPatterns.cpp
   TestTraits.cpp
+  TestFromLLVMIRTranslation.cpp
   TestToLLVMIRTranslation.cpp
 )
 
@@ -86,6 +87,23 @@ add_mlir_library(MLIRTestDialect
   MLIRTransforms
 )
 
+add_mlir_translation_library(MLIRTestFromLLVMIRTranslation
+  TestFromLLVMIRTranslation.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMDialect
+  MLIRTestDialect
+  MLIRSupport
+  MLIRTargetLLVMIRImport
+  MLIRLLVMIRToLLVMTranslation
+)
+
 add_mlir_translation_library(MLIRTestToLLVMIRTranslation
   TestToLLVMIRTranslation.cpp
 
diff --git a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
new file mode 100644
index 00000000000000..1ecbb5eb445060
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
@@ -0,0 +1,113 @@
+//===- TestFromLLVMIRTranslation.cpp - Import Test dialect from LLVM IR ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between LLVM IR and the MLIR Test dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
+#include "mlir/Target/LLVMIR/Import.h"
+#include "mlir/Target/LLVMIR/ModuleImport.h"
+#include "mlir/Tools/mlir-translate/Translation.h"
+
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Verifier.h"
+#include "llvm/IRReader/IRReader.h"
+#include "llvm/Support/SourceMgr.h"
+
+using namespace mlir;
+using namespace test;
+
+namespace {
+inline ArrayRef<unsigned> getSupportedInstructionsImpl() {
+  static unsigned instructions[] = {llvm::Instruction::Load};
+  return instructions;
+}
+
+LogicalResult convertLoad(OpBuilder &builder, llvm::Instruction *inst,
+                          ArrayRef<llvm::Value *> llvmOperands,
+                          LLVM::ModuleImport &moduleImport) {
+  FailureOr<Value> addr = moduleImport.convertValue(llvmOperands[0]);
+  if (failed(addr))
+    return failure();
+  auto *loadInst = cast<llvm::LoadInst>(inst);
+  unsigned alignment = loadInst->getAlign().value();
+  // Create the LoadOp
+  Value loadOp = builder.create<LLVM::LoadOp>(
+      moduleImport.translateLoc(inst->getDebugLoc()),
+      moduleImport.convertType(inst->getType()), *addr);
+  moduleImport.mapValue(inst) = builder.create<SameOperandElementTypeOp>(
+      loadOp.getLoc(), loadOp.getType(), loadOp, loadOp);
+  return success();
+}
+
+class TestDialectLLVMImportDialectInterface
+    : public LLVMImportDialectInterface {
+public:
+  using LLVMImportDialectInterface::LLVMImportDialectInterface;
+
+  LogicalResult
+  convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
+                     ArrayRef<llvm::Value *> llvmOperands,
+                     LLVM::ModuleImport &moduleImport) const override {
+    switch (inst->getOpcode()) {
+    case llvm::Instruction::Load:
+      return convertLoad(builder, inst, llvmOperands, moduleImport);
+    default:
+      break;
+    }
+    return failure();
+  }
+
+  ArrayRef<unsigned> getSupportedInstructions() const override {
+    return getSupportedInstructionsImpl();
+  }
+};
+} // namespace
+
+namespace mlir {
+void registerTestFromLLVMIR() {
+  TranslateToMLIRRegistration registration(
+      "test-import-llvmir", "test dialect from LLVM IR",
+      [](llvm::SourceMgr &sourceMgr,
+         MLIRContext *context) -> OwningOpRef<Operation *> {
+        llvm::SMDiagnostic err;
+        llvm::LLVMContext llvmContext;
+        std::unique_ptr<llvm::Module> llvmModule =
+            llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()),
+                          err, llvmContext);
+        if (!llvmModule) {
+          std::string errStr;
+          llvm::raw_string_ostream errStream(errStr);
+          err.print(/*ProgName=*/"", errStream);
+          emitError(UnknownLoc::get(context)) << errStream.str();
+          return {};
+        }
+        if (llvm::verifyModule(*llvmModule, &llvm::errs()))
+          return nullptr;
+
+        return translateLLVMIRToModule(std::move(llvmModule), context, false);
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<DLTIDialect>();
+        registry.insert<test::TestDialect>();
+        registerLLVMDialectImport(registry);
+        registry.addExtension(
+            +[](MLIRContext *ctx, test::TestDialect *dialect) {
+              dialect->addInterfaces<TestDialectLLVMImportDialectInterface>();
+            });
+      });
+}
+} // namespace mlir
diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp
index 4f9661c058c2d3..309def888a073c 100644
--- a/mlir/tools/mlir-translate/mlir-translate.cpp
+++ b/mlir/tools/mlir-translate/mlir-translate.cpp
@@ -23,6 +23,7 @@ void registerTestRoundtripSPIRV();
 void registerTestRoundtripDebugSPIRV();
 #ifdef MLIR_INCLUDE_TESTS
 void registerTestToLLVMIR();
+void registerTestFromLLVMIR();
 #endif
 } // namespace mlir
 
@@ -31,6 +32,7 @@ static void registerTestTranslations() {
   registerTestRoundtripDebugSPIRV();
 #ifdef MLIR_INCLUDE_TESTS
   registerTestToLLVMIR();
+  registerTestFromLLVMIR();
 #endif
 }
 

@fabianmcg fabianmcg requested a review from Dinistro March 27, 2024 13:39
Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this extension. This will make a lot of sense for you pointer dialect change, and will allow us to gradually transition into using it.

Once the transition period is over, we should consider to also add the same hook to the LLVM dialect import interface.

@Dinistro Dinistro requested a review from ftynse March 27, 2024 14:08
Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but please wait a bit for @gysit to take a look as well.

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

LGTM modulo one nit.

Comment on lines 127 to 128
*it, iface.getDialect()->getNamespace(),
instructionToDialect.lookup(*it)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
*it, iface.getDialect()->getNamespace(),
instructionToDialect.lookup(*it)
*instIt, iface.getDialect()->getNamespace(),
instructionToDialect.lookup(*instIt)

I believe this is the intrinsic iterator? Maybe rename the iterator above to intrIt or so that it is more obvious if things are mixed up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed them to intrinsicIt and instructionIt, as instIt was already in use.

/// Hook for derived dialect interfaces to publish the supported instructions.
/// As every LLVM IR instructions has a unique integer identifier, the
/// function returns the list of supported instructions identifiers.
virtual ArrayRef<unsigned> getSupportedInstructions() const { return {}; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if multiple dialect return overlapping results here?

Also the documentation is saying what this does, but not why it is useful or how this should be used, can you expand?

Copy link
Contributor Author

@fabianmcg fabianmcg Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If multiple dialects try to register the same instruction the interface emits an error and aborts the entire translation, see the initializeImport method.
Small caveat: this mechanism doesn't check if it's colliding with LLVM tablegen conversions, only with respect to other interfaces -which I'd argue is okay, bc allows overriding LLVM behavior.

With respect to how should be used, I added a test where:

%1 = load double, ptr %ptr 

gets imported as:

 %1  = llvm.load %pt : !llvm.ptr -> f64
 %2 = "test.same_operand_element_type"(%1, %1) : 

The example shows how to use the mechanism.
I'll add to the docs that this allows overriding LLVM behavior.

@Dinistro
Copy link
Contributor

Dinistro commented Apr 4, 2024

Can we land this?

@fabianmcg
Copy link
Contributor Author

fabianmcg commented Apr 4, 2024

@gysit, @zero9178 ping

@gysit
Copy link
Contributor

gysit commented Apr 4, 2024

Yes this is good to go!

Thanks for adding this additional extensibility to the import.

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to land!

@fabianmcg fabianmcg merged commit a2c4b7c into llvm:main Apr 7, 2024
@fabianmcg fabianmcg deleted the pr-module-import branch April 7, 2024 08:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:llvm mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants