Skip to content

Commit e2766b2

Browse files
authored
[flang][cuda] Add entry point to launch cuda fortran kernel (#113490)
1 parent 2dfb1c6 commit e2766b2

File tree

4 files changed

+62
-1
lines changed

4 files changed

+62
-1
lines changed

flang/include/flang/Runtime/CUDA/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ static constexpr unsigned kDeviceToDevice = 2;
3030
const char *name = cudaGetErrorName(err); \
3131
if (!name) \
3232
name = "<unknown>"; \
33-
Terminator terminator{__FILE__, __LINE__}; \
33+
Fortran::runtime::Terminator terminator{__FILE__, __LINE__}; \
3434
terminator.Crash("'%s' failed with '%s'", #expr, name); \
3535
}(expr)
3636

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===-- include/flang/Runtime/CUDA/kernel.h ---------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef FORTRAN_RUNTIME_CUDA_KERNEL_H_
10+
#define FORTRAN_RUNTIME_CUDA_KERNEL_H_
11+
12+
#include "flang/Runtime/entry-names.h"
13+
#include <cstddef>
14+
#include <stdint.h>
15+
16+
extern "C" {
17+
18+
// This function uses intptr_t instead of CUDA's unsigned int to match
19+
// the type of MLIR's index type. This avoids the need for casts in the
20+
// generated MLIR code.
21+
void RTDEF(CUFLaunchKernel)(const void *kernelName, intptr_t gridX,
22+
intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY,
23+
intptr_t blockZ, int32_t smem, void **params, void **extra);
24+
25+
} // extern "C"
26+
27+
#endif // FORTRAN_RUNTIME_CUDA_KERNEL_H_

flang/runtime/CUDA/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_flang_library(${CUFRT_LIBNAME}
1717
allocator.cpp
1818
allocatable.cpp
1919
descriptor.cpp
20+
kernel.cpp
2021
memory.cpp
2122
registration.cpp
2223
)

flang/runtime/CUDA/kernel.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//===-- runtime/CUDA/kernel.cpp -------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Runtime/CUDA/kernel.h"
10+
#include "../terminator.h"
11+
#include "flang/Runtime/CUDA/common.h"
12+
13+
#include "cuda_runtime.h"
14+
15+
extern "C" {
16+
17+
void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY,
18+
intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ,
19+
int32_t smem, void **params, void **extra) {
20+
dim3 gridDim;
21+
gridDim.x = gridX;
22+
gridDim.y = gridY;
23+
gridDim.z = gridZ;
24+
dim3 blockDim;
25+
blockDim.x = blockX;
26+
blockDim.y = blockY;
27+
blockDim.z = blockZ;
28+
cudaStream_t stream = 0;
29+
CUDA_REPORT_IF_ERROR(
30+
cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, stream));
31+
}
32+
33+
} // extern "C"

0 commit comments

Comments
 (0)