Skip to content

Commit 5406834

Browse files
authored
[flang][cuda] Add cuf.register_module operation (#112971)
Add a new operation to register the fatbin and pass it to `cuf.register_kernel`
1 parent 85df281 commit 5406834

File tree

5 files changed

+36
-10
lines changed

5 files changed

+36
-10
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
1313
#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
1414
#include "flang/Optimizer/Dialect/FIRType.h"
15+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1516
#include "mlir/IR/OpDefinition.h"
1617

1718
#define GET_OP_CLASSES

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ include "flang/Optimizer/Dialect/CUF/CUFDialect.td"
1818
include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
1919
include "flang/Optimizer/Dialect/FIRTypes.td"
2020
include "flang/Optimizer/Dialect/FIRAttr.td"
21+
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
2122
include "mlir/Interfaces/LoopLikeInterface.td"
2223
include "mlir/IR/BuiltinAttributes.td"
2324

@@ -288,15 +289,30 @@ def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments,
288289
let hasVerifier = 1;
289290
}
290291

292+
def cuf_RegisterModuleOp : cuf_Op<"register_module", []> {
293+
let summary = "Register a CUDA module";
294+
295+
let arguments = (ins
296+
SymbolRefAttr:$name
297+
);
298+
299+
let assemblyFormat = [{
300+
$name attr-dict `->` type($modulePtr)
301+
}];
302+
303+
let results = (outs LLVM_AnyPointer:$modulePtr);
304+
}
305+
291306
def cuf_RegisterKernelOp : cuf_Op<"register_kernel", []> {
292307
let summary = "Register a CUDA kernel";
293308

294309
let arguments = (ins
295-
SymbolRefAttr:$name
310+
SymbolRefAttr:$name,
311+
LLVM_AnyPointer:$modulePtr
296312
);
297313

298314
let assemblyFormat = [{
299-
$name attr-dict
315+
$name `(` $modulePtr `:` type($modulePtr) `)`attr-dict
300316
}];
301317

302318
let hasVerifier = 1;

flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,15 @@ struct CUFAddConstructor
6262
// Register kernels
6363
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaModName);
6464
if (gpuMod) {
65+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
66+
auto registeredMod = builder.create<cuf::RegisterModuleOp>(
67+
loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));
6568
for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
6669
if (func.isKernel()) {
6770
auto kernelName = mlir::SymbolRefAttr::get(
6871
builder.getStringAttr(cudaModName),
6972
{mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
70-
builder.create<cuf::RegisterKernelOp>(loc, kernelName);
73+
builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
7174
}
7275
}
7376
}

flang/test/Fir/CUDA/cuda-register-func.fir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ module attributes {gpu.container_module} {
1212
}
1313

1414
// CHECK-LABEL: llvm.func internal @__cudaFortranConstructor()
15-
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1
16-
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2
15+
// CHECK: %[[MOD_HANDLE:.*]] = cuf.register_module @cuda_device_mod -> !llvm.ptr
16+
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%[[MOD_HANDLE]] : !llvm.ptr)
17+
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2(%[[MOD_HANDLE]] : !llvm.ptr)

flang/test/Fir/cuf-invalid.fir

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,9 @@ module attributes {gpu.container_module} {
135135
}
136136
}
137137
llvm.func internal @__cudaFortranConstructor() {
138+
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
138139
// expected-error@+1{{'cuf.register_kernel' op only kernel gpu.func can be registered}}
139-
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
140+
cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
140141
llvm.return
141142
}
142143
}
@@ -150,8 +151,9 @@ module attributes {gpu.container_module} {
150151
}
151152
}
152153
llvm.func internal @__cudaFortranConstructor() {
154+
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
153155
// expected-error@+1{{'cuf.register_kernel' op device function not found}}
154-
cuf.register_kernel @cuda_device_mod::@_QPsub_device2
156+
cuf.register_kernel @cuda_device_mod::@_QPsub_device2(%0 : !llvm.ptr)
155157
llvm.return
156158
}
157159
}
@@ -160,8 +162,9 @@ module attributes {gpu.container_module} {
160162

161163
module attributes {gpu.container_module} {
162164
llvm.func internal @__cudaFortranConstructor() {
165+
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
163166
// expected-error@+1{{'cuf.register_kernel' op gpu module not found}}
164-
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
167+
cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
165168
llvm.return
166169
}
167170
}
@@ -170,8 +173,9 @@ module attributes {gpu.container_module} {
170173

171174
module attributes {gpu.container_module} {
172175
llvm.func internal @__cudaFortranConstructor() {
176+
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
173177
// expected-error@+1{{'cuf.register_kernel' op expect a module and a kernel name}}
174-
cuf.register_kernel @_QPsub_device1
178+
cuf.register_kernel @_QPsub_device1(%0 : !llvm.ptr)
175179
llvm.return
176180
}
177181
}
@@ -185,8 +189,9 @@ module attributes {gpu.container_module} {
185189
}
186190
}
187191
llvm.func internal @__cudaFortranConstructor() {
192+
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
188193
// expected-error@+1{{'cuf.register_kernel' op only gpu.kernel llvm.func can be registered}}
189-
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
194+
cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
190195
llvm.return
191196
}
192197
}

0 commit comments

Comments
 (0)