Skip to content

[flang][cuda] Compute grid x when calling a kernel with <<<*, block>>> #115538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 8, 2024

Conversation

clementval
Copy link
Contributor

-1, 1, 1 is passed when calling a kernel with the <<<*, block>>> syntax. Query the device to compute the grid.x value.

@clementval clementval requested a review from wangzpgi November 8, 2024 19:43
@llvmbot llvmbot added flang:runtime flang Flang issues not falling into any other category labels Nov 8, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2024

@llvm/pr-subscribers-flang-runtime

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

-1, 1, 1 is passed when calling a kernel with the &lt;&lt;&lt;*, block&gt;&gt;&gt; syntax. Query the device to compute the grid.x value.


Full diff: https://github.com/llvm/llvm-project/pull/115538.diff

1 Files Affected:

  • (modified) flang/runtime/CUDA/kernel.cpp (+46)
diff --git a/flang/runtime/CUDA/kernel.cpp b/flang/runtime/CUDA/kernel.cpp
index abb7ebb72e5923..8881d8a524aac0 100644
--- a/flang/runtime/CUDA/kernel.cpp
+++ b/flang/runtime/CUDA/kernel.cpp
@@ -25,6 +25,29 @@ void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY,
   blockDim.x = blockX;
   blockDim.y = blockY;
   blockDim.z = blockZ;
+  bool gridIsStar = (gridX < 0); // <<<*, block>>> syntax was used.
+  if (gridIsStar) {
+    int maxBlocks, nbBlocks, dev, multiProcCount;
+    cudaError_t err1, err2;
+    nbBlocks = blockDim.x * blockDim.y * blockDim.z;
+    cudaGetDevice(&dev);
+    err1 = cudaDeviceGetAttribute(
+        &multiProcCount, cudaDevAttrMultiProcessorCount, dev);
+    err2 = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+        &maxBlocks, kernel, nbBlocks, smem);
+    if (err1 == cudaSuccess && err2 == cudaSuccess)
+      maxBlocks = multiProcCount * maxBlocks;
+    if (maxBlocks > 0) {
+      if (gridDim.y > 0)
+        maxBlocks = maxBlocks / gridDim.y;
+      if (gridDim.z > 0)
+        maxBlocks = maxBlocks / gridDim.z;
+      if (maxBlocks < 1)
+        maxBlocks = 1;
+      if (gridIsStar)
+        gridDim.x = maxBlocks;
+    }
+  }
   cudaStream_t stream = 0; // TODO stream managment
   CUDA_REPORT_IF_ERROR(
       cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, stream));
@@ -41,6 +64,29 @@ void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX,
   config.blockDim.x = blockX;
   config.blockDim.y = blockY;
   config.blockDim.z = blockZ;
+  bool gridIsStar = (gridX < 0); // <<<*, block>>> syntax was used.
+  if (gridIsStar) {
+    int maxBlocks, nbBlocks, dev, multiProcCount;
+    cudaError_t err1, err2;
+    nbBlocks = config.blockDim.x * config.blockDim.y * config.blockDim.z;
+    cudaGetDevice(&dev);
+    err1 = cudaDeviceGetAttribute(
+        &multiProcCount, cudaDevAttrMultiProcessorCount, dev);
+    err2 = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+        &maxBlocks, kernel, nbBlocks, smem);
+    if (err1 == cudaSuccess && err2 == cudaSuccess)
+      maxBlocks = multiProcCount * maxBlocks;
+    if (maxBlocks > 0) {
+      if (config.gridDim.y > 0)
+        maxBlocks = maxBlocks / config.gridDim.y;
+      if (config.gridDim.z > 0)
+        maxBlocks = maxBlocks / config.gridDim.z;
+      if (maxBlocks < 1)
+        maxBlocks = 1;
+      if (gridIsStar)
+        config.gridDim.x = maxBlocks;
+    }
+  }
   config.dynamicSmemBytes = smem;
   config.stream = 0; // TODO stream managment
   cudaLaunchAttribute launchAttr[1];

@clementval clementval merged commit 6b21cf8 into llvm:main Nov 8, 2024
6 of 7 checks passed
@clementval clementval deleted the cuf_launch_kernel_compute branch November 8, 2024 22:34
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
llvm#115538)

`-1, 1, 1` is passed when calling a kernel with the `<<<*, block>>>`
syntax. Query the device to compute the grid.x value.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:runtime flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants