Skip to content

Commit d37bc32

Browse files
authored
[flang][cuda] Translate cuf.register_kernel and cuf.register_module (#112972)
Add LLVM IR Translation for `cuf.register_module` and `cuf.register_kernel`. These are lowered to function call to the CUF runtime entries.
1 parent 5406834 commit d37bc32

File tree

8 files changed

+197
-0
lines changed

8 files changed

+197
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===- CUFToLLVMIRTranslation.h - CUF Dialect to LLVM IR --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This provides registration calls for GPU dialect to LLVM IR translation.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
14+
#define FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
15+
16+
namespace mlir {
17+
class DialectRegistry;
18+
class MLIRContext;
19+
} // namespace mlir
20+
21+
namespace cuf {
22+
23+
/// Register the CUF dialect and the translation from it to the LLVM IR in
24+
/// the given registry.
25+
void registerCUFDialectTranslation(mlir::DialectRegistry &registry);
26+
27+
} // namespace cuf
28+
29+
#endif // FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_

flang/include/flang/Optimizer/Support/InitFIR.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define FORTRAN_OPTIMIZER_SUPPORT_INITFIR_H
1515

1616
#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
17+
#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
1718
#include "flang/Optimizer/Dialect/FIRDialect.h"
1819
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
1920
#include "mlir/Conversion/Passes.h"
@@ -61,6 +62,7 @@ inline void addFIRExtensions(mlir::DialectRegistry &registry,
6162
if (addFIRInlinerInterface)
6263
addFIRInlinerExtension(registry);
6364
addFIRToLLVMIRExtension(registry);
65+
cuf::registerCUFDialectTranslation(registry);
6466
}
6567

6668
inline void loadNonCodegenDialects(mlir::MLIRContext &context) {
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===-- include/flang/Runtime/CUDA/registration.h ---------------*- C -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
10+
#define FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
11+
12+
#include "flang/Runtime/entry-names.h"
13+
#include <cstddef>
14+
15+
namespace Fortran::runtime::cuda {
16+
17+
extern "C" {
18+
19+
/// Register a CUDA module.
20+
void *RTDECL(CUFRegisterModule)(void *data);
21+
22+
/// Register a device function.
23+
void RTDECL(CUFRegisterFunction)(void **module, const char *fct);
24+
25+
} // extern "C"
26+
27+
} // namespace Fortran::runtime::cuda
28+
#endif // FORTRAN_RUNTIME_CUDA_REGISTRATION_H_

flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_subdirectory(Attributes)
33
add_flang_library(CUFDialect
44
CUFDialect.cpp
55
CUFOps.cpp
6+
CUFToLLVMIRTranslation.cpp
67

78
DEPENDS
89
MLIRIR
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
//===- CUFToLLVMIRTranslation.cpp - Translate CUF dialect to LLVM IR ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a translation between the MLIR CUF dialect and LLVM IR.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
14+
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
15+
#include "flang/Runtime/entry-names.h"
16+
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
17+
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
18+
#include "llvm/ADT/TypeSwitch.h"
19+
#include "llvm/IR/IRBuilder.h"
20+
#include "llvm/IR/Module.h"
21+
#include "llvm/Support/FormatVariadic.h"
22+
23+
using namespace mlir;
24+
25+
namespace {
26+
27+
LogicalResult registerModule(cuf::RegisterModuleOp op,
28+
llvm::IRBuilderBase &builder,
29+
LLVM::ModuleTranslation &moduleTranslation) {
30+
std::string binaryIdentifier =
31+
op.getName().getLeafReference().str() + "_bin_cst";
32+
llvm::Module *module = moduleTranslation.getLLVMModule();
33+
llvm::Value *binary = module->getGlobalVariable(binaryIdentifier, true);
34+
if (!binary)
35+
return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
36+
37+
llvm::Type *ptrTy = builder.getPtrTy(0);
38+
llvm::FunctionCallee fct = module->getOrInsertFunction(
39+
RTNAME_STRING(CUFRegisterModule),
40+
llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy}), false));
41+
auto *handle = builder.CreateCall(fct, {binary});
42+
moduleTranslation.mapValue(op->getResults().front()) = handle;
43+
return mlir::success();
44+
}
45+
46+
llvm::Value *getOrCreateFunctionName(llvm::Module *module,
47+
llvm::IRBuilderBase &builder,
48+
llvm::StringRef moduleName,
49+
llvm::StringRef kernelName) {
50+
std::string globalName =
51+
std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, kernelName));
52+
53+
if (llvm::GlobalVariable *gv = module->getGlobalVariable(globalName))
54+
return gv;
55+
56+
return builder.CreateGlobalString(kernelName, globalName);
57+
}
58+
59+
LogicalResult registerKernel(cuf::RegisterKernelOp op,
60+
llvm::IRBuilderBase &builder,
61+
LLVM::ModuleTranslation &moduleTranslation) {
62+
llvm::Module *module = moduleTranslation.getLLVMModule();
63+
llvm::Type *ptrTy = builder.getPtrTy(0);
64+
llvm::FunctionCallee fct = module->getOrInsertFunction(
65+
RTNAME_STRING(CUFRegisterFunction),
66+
llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
67+
false));
68+
llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
69+
builder.CreateCall(
70+
fct, {modulePtr, getOrCreateFunctionName(module, builder,
71+
op.getKernelModuleName().str(),
72+
op.getKernelName().str())});
73+
return mlir::success();
74+
}
75+
76+
class CUFDialectLLVMIRTranslationInterface
77+
: public LLVMTranslationDialectInterface {
78+
public:
79+
using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
80+
81+
LogicalResult
82+
convertOperation(Operation *operation, llvm::IRBuilderBase &builder,
83+
LLVM::ModuleTranslation &moduleTranslation) const override {
84+
return llvm::TypeSwitch<Operation *, LogicalResult>(operation)
85+
.Case([&](cuf::RegisterModuleOp op) {
86+
return registerModule(op, builder, moduleTranslation);
87+
})
88+
.Case([&](cuf::RegisterKernelOp op) {
89+
return registerKernel(op, builder, moduleTranslation);
90+
})
91+
.Default([&](Operation *op) {
92+
return op->emitError("unsupported GPU operation: ") << op->getName();
93+
});
94+
}
95+
};
96+
97+
} // namespace
98+
99+
void cuf::registerCUFDialectTranslation(DialectRegistry &registry) {
100+
registry.insert<cuf::CUFDialect>();
101+
registry.addExtension(+[](MLIRContext *ctx, cuf::CUFDialect *dialect) {
102+
dialect->addInterfaces<CUFDialectLLVMIRTranslationInterface>();
103+
});
104+
}

flang/lib/Optimizer/Transforms/CufOpConversion.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "flang/Runtime/CUDA/descriptor.h"
2121
#include "flang/Runtime/CUDA/memory.h"
2222
#include "flang/Runtime/allocatable.h"
23+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2324
#include "mlir/Pass/Pass.h"
2425
#include "mlir/Transforms/DialectConversion.h"
2526
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

flang/runtime/CUDA/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ add_flang_library(${CUFRT_LIBNAME}
1818
allocatable.cpp
1919
descriptor.cpp
2020
memory.cpp
21+
registration.cpp
2122
)
2223

2324
if (BUILD_SHARED_LIBS)

flang/runtime/CUDA/registration.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===-- runtime/CUDA/registration.cpp -------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Runtime/CUDA/registration.h"
10+
11+
#include "cuda_runtime.h"
12+
13+
namespace Fortran::runtime::cuda {
14+
15+
extern "C" {
16+
17+
extern void **__cudaRegisterFatBinary(void *data);
18+
extern void __cudaRegisterFunction(void **fatCubinHandle, const char *hostFun,
19+
char *deviceFun, const char *deviceName, int thread_limit, uint3 *tid,
20+
uint3 *bid, dim3 *bDim, dim3 *gDim, int *wSize);
21+
22+
void *RTDECL(CUFRegisterModule)(void *data) {
23+
return __cudaRegisterFatBinary(data);
24+
}
25+
26+
void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
27+
__cudaRegisterFunction(module, fct, const_cast<char *>(fct), fct, -1,
28+
(uint3 *)0, (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
29+
}
30+
}
31+
} // namespace Fortran::runtime::cuda

0 commit comments

Comments
 (0)