Skip to content

[mlir][gpu] Change GPU modules to globals #135478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#include "mlir/ExecutionEngine/CRunnerUtils.h"

#include <stdio.h>
#include <cstdio>

#include "cuda.h"
#include "cuda_bf16.h"
Expand Down Expand Up @@ -56,14 +56,10 @@

thread_local static int32_t defaultDevice = 0;

const char *kDebugEnvironmentVariable = "MLIR_CUDA_DEBUG";

/// Helper method that checks environment value for debugging.
bool isDebugEnabled() {
static bool isInitialized = false;
static bool isEnabled = false;
if (!isInitialized)
isEnabled = getenv(kDebugEnvironmentVariable) != nullptr;
const char *kDebugEnvironmentVariable = "MLIR_CUDA_DEBUG";
static bool isEnabled = getenv(kDebugEnvironmentVariable) != nullptr;
return isEnabled;
}

Expand Down
258 changes: 128 additions & 130 deletions mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"

#include "llvm/ADT/ScopeExit.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"

using namespace mlir;

Expand All @@ -31,9 +33,13 @@ namespace {
class SelectObjectAttrImpl
: public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
SelectObjectAttrImpl> {
// Returns the selected object for embedding.
gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const;

public:
// Translates a `gpu.binary`, embedding the binary into a host LLVM module as
// global binary string.
// global binary string which gets loaded/unloaded into a global module
// object through a global ctor/dtor.
LogicalResult embedBinary(Attribute attribute, Operation *operation,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const;
Expand All @@ -45,23 +51,9 @@ class SelectObjectAttrImpl
Operation *binaryOperation,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const;

// Returns the selected object for embedding.
gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const;
};
// Returns an identifier for the global string holding the binary.
std::string getBinaryIdentifier(StringRef binaryName) {
return binaryName.str() + "_bin_cst";
}
} // namespace

void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
});
}

gpu::ObjectAttr
SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
ArrayRef<Attribute> objects = op.getObjectsAttr().getValue();
Expand Down Expand Up @@ -96,6 +88,94 @@ SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
}

static Twine getModuleIdentifier(StringRef moduleName) {
return moduleName + "_module";
}

namespace llvm {
static LogicalResult embedBinaryImpl(StringRef moduleName,
gpu::ObjectAttr object, Module &module) {

// Embed the object as a global string.
// Add null for assembly output for JIT paths that expect null-terminated
// strings.
bool addNull = (object.getFormat() == gpu::CompilationTarget::Assembly);
StringRef serializedStr = object.getObject().getValue();
Constant *serializedCst =
ConstantDataArray::getString(module.getContext(), serializedStr, addNull);
GlobalVariable *serializedObj =
new GlobalVariable(module, serializedCst->getType(), true,
GlobalValue::LinkageTypes::InternalLinkage,
serializedCst, moduleName + "_binary");
serializedObj->setAlignment(MaybeAlign(8));
serializedObj->setUnnamedAddr(GlobalValue::UnnamedAddr::None);

// Default JIT optimization level.
auto optLevel = APInt::getZero(32);

if (DictionaryAttr objectProps = object.getProperties()) {
if (auto section = dyn_cast_or_null<StringAttr>(
objectProps.get(gpu::elfSectionName))) {
serializedObj->setSection(section.getValue());
}
// Check if there's an optimization level embedded in the object.
if (auto optAttr = dyn_cast_or_null<IntegerAttr>(objectProps.get("O")))
optLevel = optAttr.getValue();
}

IRBuilder<> builder(module.getContext());
auto i32Ty = builder.getInt32Ty();
auto i64Ty = builder.getInt64Ty();
auto ptrTy = builder.getPtrTy(0);
auto voidTy = builder.getVoidTy();

// Embed the module as a global object.
auto *modulePtr = new GlobalVariable(
module, ptrTy, /*isConstant=*/false, GlobalValue::InternalLinkage,
/*Initializer=*/ConstantPointerNull::get(ptrTy),
getModuleIdentifier(moduleName));

auto *loadFn = Function::Create(FunctionType::get(voidTy, /*IsVarArg=*/false),
GlobalValue::InternalLinkage,
moduleName + "_load", module);
loadFn->setSection(".text.startup");
auto *loadBlock = BasicBlock::Create(module.getContext(), "entry", loadFn);
builder.SetInsertPoint(loadBlock);
Value *moduleObj = [&] {
if (object.getFormat() == gpu::CompilationTarget::Assembly) {
FunctionCallee moduleLoadFn = module.getOrInsertFunction(
"mgpuModuleLoadJIT", FunctionType::get(ptrTy, {ptrTy, i32Ty}, false));
Constant *optValue = ConstantInt::get(i32Ty, optLevel);
return builder.CreateCall(moduleLoadFn, {serializedObj, optValue});
} else {
FunctionCallee moduleLoadFn = module.getOrInsertFunction(
"mgpuModuleLoad", FunctionType::get(ptrTy, {ptrTy, i64Ty}, false));
Constant *binarySize =
ConstantInt::get(i64Ty, serializedStr.size() + (addNull ? 1 : 0));
return builder.CreateCall(moduleLoadFn, {serializedObj, binarySize});
}
}();
builder.CreateStore(moduleObj, modulePtr);
builder.CreateRetVoid();
appendToGlobalCtors(module, loadFn, /*Priority=*/123);

auto *unloadFn = Function::Create(
FunctionType::get(voidTy, /*IsVarArg=*/false),
GlobalValue::InternalLinkage, moduleName + "_unload", module);
unloadFn->setSection(".text.startup");
auto *unloadBlock =
BasicBlock::Create(module.getContext(), "entry", unloadFn);
builder.SetInsertPoint(unloadBlock);
FunctionCallee moduleUnloadFn = module.getOrInsertFunction(
"mgpuModuleUnload", FunctionType::get(voidTy, ptrTy, false));
builder.CreateCall(moduleUnloadFn, builder.CreateLoad(ptrTy, modulePtr));
builder.CreateRetVoid();
appendToGlobalDtors(module, unloadFn, /*Priority=*/123);

return success();
}
} // namespace llvm

LogicalResult SelectObjectAttrImpl::embedBinary(
Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const {
Expand All @@ -113,29 +193,8 @@ LogicalResult SelectObjectAttrImpl::embedBinary(
if (!object)
return failure();

llvm::Module *module = moduleTranslation.getLLVMModule();

// Embed the object as a global string.
// Add null for assembly output for JIT paths that expect null-terminated
// strings.
bool addNull = (object.getFormat() == gpu::CompilationTarget::Assembly);
llvm::Constant *binary = llvm::ConstantDataArray::getString(
builder.getContext(), object.getObject().getValue(), addNull);
llvm::GlobalVariable *serializedObj =
new llvm::GlobalVariable(*module, binary->getType(), true,
llvm::GlobalValue::LinkageTypes::InternalLinkage,
binary, getBinaryIdentifier(op.getName()));

if (object.getProperties()) {
if (auto section = mlir::dyn_cast_or_null<mlir::StringAttr>(
object.getProperties().get(gpu::elfSectionName))) {
serializedObj->setSection(section.getValue());
}
}
serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
serializedObj->setAlignment(llvm::MaybeAlign(8));
serializedObj->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None);
return success();
return embedBinaryImpl(op.getName(), object,
*moduleTranslation.getLLVMModule());
}

namespace llvm {
Expand All @@ -153,15 +212,6 @@ class LaunchKernel {
// Get the module function callee.
FunctionCallee getModuleFunctionFn();

// Get the module load callee.
FunctionCallee getModuleLoadFn();

// Get the module load JIT callee.
FunctionCallee getModuleLoadJITFn();

// Get the module unload callee.
FunctionCallee getModuleUnloadFn();

// Get the stream create callee.
FunctionCallee getStreamCreateFn();

Expand Down Expand Up @@ -261,24 +311,6 @@ llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, ptrTy}), false));
}

llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
return module.getOrInsertFunction(
"mgpuModuleLoad",
FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i64Ty}), false));
}

llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() {
return module.getOrInsertFunction(
"mgpuModuleLoadJIT",
FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i32Ty}), false));
}

llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() {
return module.getOrInsertFunction(
"mgpuModuleUnload",
FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
}

llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
return module.getOrInsertFunction("mgpuStreamCreate",
FunctionType::get(ptrTy, false));
Expand All @@ -301,9 +333,9 @@ llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
StringRef kernelName) {
std::string globalName =
std::string(formatv("{0}_{1}_kernel_name", moduleName, kernelName));
std::string(formatv("{0}_{1}_name", moduleName, kernelName));

if (GlobalVariable *gv = module.getGlobalVariable(globalName))
if (GlobalVariable *gv = module.getGlobalVariable(globalName, true))
return gv;

return builder.CreateGlobalString(kernelName, globalName);
Expand Down Expand Up @@ -346,16 +378,13 @@ llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
}

// Emits LLVM IR to launch a kernel function:
// %0 = call %binarygetter
// %1 = call %moduleLoad(%0)
// %2 = <see generateKernelNameConstant>
// %3 = call %moduleGetFunction(%1, %2)
// %4 = call %streamCreate()
// %5 = <see generateParamsArray>
// call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
// call %streamSynchronize(%4)
// call %streamDestroy(%4)
// call %moduleUnload(%1)
// %1 = load %global_module_object
// %2 = call @mgpuModuleGetFunction(%1, %global_kernel_name)
// %3 = call @mgpuStreamCreate()
// %4 = <see createKernelArgArray()>
// call @mgpuLaunchKernel(%2, ..., %3, %4, ...)
// call @mgpuStreamSynchronize(%3)
// call @mgpuStreamDestroy(%3)
llvm::LogicalResult
llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
mlir::gpu::ObjectAttr object) {
Expand Down Expand Up @@ -385,58 +414,29 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
// Create the argument array.
Value *argArray = createKernelArgArray(op);

// Default JIT optimization level.
llvm::Constant *optV = llvm::ConstantInt::get(i32Ty, 0);
// Check if there's an optimization level embedded in the object.
DictionaryAttr objectProps = object.getProperties();
mlir::Attribute optAttr;
if (objectProps && (optAttr = objectProps.get("O"))) {
auto optLevel = dyn_cast<IntegerAttr>(optAttr);
if (!optLevel)
return op.emitError("the optimization level must be an integer");
optV = llvm::ConstantInt::get(i32Ty, optLevel.getValue());
}

// Load the kernel module.
StringRef moduleName = op.getKernelModuleName().getValue();
std::string binaryIdentifier = getBinaryIdentifier(moduleName);
Value *binary = module.getGlobalVariable(binaryIdentifier, true);
if (!binary)
return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;

auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
if (!binaryVar)
return op.emitError() << "Binary is not a global variable: "
<< binaryIdentifier;
llvm::Constant *binaryInit = binaryVar->getInitializer();
auto binaryDataSeq =
dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
if (!binaryDataSeq)
return op.emitError() << "Couldn't find binary data array: "
<< binaryIdentifier;
llvm::Constant *binarySize =
llvm::ConstantInt::get(i64Ty, binaryDataSeq->getNumElements() *
binaryDataSeq->getElementByteSize());

Value *moduleObject =
object.getFormat() == gpu::CompilationTarget::Assembly
? builder.CreateCall(getModuleLoadJITFn(), {binary, optV})
: builder.CreateCall(getModuleLoadFn(), {binary, binarySize});

// Load the kernel function.
Value *moduleFunction = builder.CreateCall(
getModuleFunctionFn(),
{moduleObject,
getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
StringRef moduleName = op.getKernelModuleName().getValue();
Twine moduleIdentifier = getModuleIdentifier(moduleName);
Value *modulePtr = module.getGlobalVariable(moduleIdentifier.str(), true);
if (!modulePtr)
return op.emitError() << "Couldn't find the binary: " << moduleIdentifier;
Value *moduleObj = builder.CreateLoad(ptrTy, modulePtr);
Value *functionName = getOrCreateFunctionName(moduleName, op.getKernelName());
Value *moduleFunction =
builder.CreateCall(getModuleFunctionFn(), {moduleObj, functionName});

// Get the stream to use for execution. If there's no async object then create
// a stream to make a synchronous kernel launch.
Value *stream = nullptr;
bool handleStream = false;
// Sync & destroy the stream, for synchronous launches.
auto destroyStream = make_scope_exit([&]() {
builder.CreateCall(getStreamSyncFn(), {stream});
builder.CreateCall(getStreamDestroyFn(), {stream});
});
if (mlir::Value asyncObject = op.getAsyncObject()) {
stream = llvmValue(asyncObject);
destroyStream.release();
} else {
handleStream = true;
stream = builder.CreateCall(getStreamCreateFn(), {});
}

Expand All @@ -462,14 +462,12 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
argArray, nullPtr, paramsCount}));
}

// Sync & destroy the stream, for synchronous launches.
if (handleStream) {
builder.CreateCall(getStreamSyncFn(), {stream});
builder.CreateCall(getStreamDestroyFn(), {stream});
}

// Unload the kernel module.
builder.CreateCall(getModuleUnloadFn(), {moduleObject});

return success();
}

void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
});
}
Loading