Skip to content

Commit 9c9cb93

Browse files
sergey-kozubGoogle-ML-Automation
authored andcommitted
PR #22029: [XLA:GPU] Add support for SM101a and SM120a architectures (Blackwell)
Imported from GitHub PR #22029 In addition to SM120a, also add SM101a mentioned in the PTX 8.7 spec (https://docs.nvidia.com/cuda/parallel-thread-execution/#release-notes), which is a slight variation of SM100a. Bumping the max supported PTX version to 8.7, as the LLVM PR (llvm/llvm-project#124155) adding the support is now integrated to OpenXLA. Copybara import of the project: -- be59b7a by Sergey Kozub <[email protected]>: [XLA:GPU] Add support for SM101a and SM120a architectures (Blackwell) Merging this change closes #22029 FUTURE_COPYBARA_INTEGRATE_REVIEW=#22029 from openxla:devel/sm120a be59b7a PiperOrigin-RevId: 721049239
1 parent 114273d commit 9c9cb93

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

xla/service/gpu/llvm_gpu_backend/nvptx_backend.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ std::string GetSmName(se::CudaComputeCapability compute_capability) {
238238
int sm_version = 30;
239239
// If the current compute capability isn't known, fallback to the
240240
// most recent version before it.
241-
int supported_versions[] = {100, 90, 89, 87, 86, 80, 75, 72, 70, 62,
242-
61, 60, 53, 52, 50, 37, 35, 32, 30};
241+
int supported_versions[] = {120, 101, 100, 90, 89, 87, 86, 80, 75, 72, 70,
242+
62, 61, 60, 53, 52, 50, 37, 35, 32, 30};
243243
for (int v : supported_versions) {
244244
if (v <= compute_capability_version) {
245245
sm_version = v;
@@ -261,7 +261,7 @@ std::string GetSmName(se::CudaComputeCapability compute_capability) {
261261
// On Hopper, default to sm_90a so that all instructions can be used. But
262262
// only sm_90 is forward compatible, so don't use sm_90a with newer hardware:
263263
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility
264-
// Similarly for sm_100a (Blackwell).
264+
// Similarly for sm_100a, sm_101a and sm_120a (Blackwell).
265265
absl::string_view extension =
266266
stream_executor::ShouldUsePtxExtension(compute_capability) ? "a" : "";
267267
return absl::StrCat("sm_", sm_version, extension);
@@ -333,7 +333,7 @@ absl::StatusOr<std::string> CompileToPtx(
333333

334334
namespace {
335335
constexpr stream_executor::SemanticVersion kFallbackPtxVersion{6, 5, 0};
336-
constexpr stream_executor::SemanticVersion kMaxPtxVersion{8, 6, 0};
336+
constexpr stream_executor::SemanticVersion kMaxPtxVersion{8, 7, 0};
337337
} // namespace
338338

339339
stream_executor::SemanticVersion
@@ -357,7 +357,8 @@ DetermineHighestSupportedPtxVersionFromCudaVersion(
357357
return {cuda_version.major() - 4, cuda_version.minor(), 0};
358358
}
359359
// CUDA 12.6 -> PTX 8.5
360-
if (cuda_version < stream_executor::SemanticVersion{12, 7, 0}) {
360+
// CUDA 12.8 -> PTX 8.7
361+
if (cuda_version < stream_executor::SemanticVersion{12, 9, 0}) {
361362
return {cuda_version.major() - 4, cuda_version.minor() - 1, 0};
362363
}
363364

xla/service/gpu/llvm_gpu_backend/nvptx_backend_test.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ namespace se = ::stream_executor;
2929
TEST(UtilsTest, TestGetSmName) {
3030
ASSERT_EQ(nvptx::GetSmName(se::CudaComputeCapability{9, 0}), "sm_90a");
3131
ASSERT_EQ(nvptx::GetSmName(se::CudaComputeCapability{10, 0}), "sm_100a");
32+
ASSERT_EQ(nvptx::GetSmName(se::CudaComputeCapability{10, 1}), "sm_101a");
33+
ASSERT_EQ(nvptx::GetSmName(se::CudaComputeCapability{12, 0}), "sm_120a");
3234
// Do not use the extension for a yet-unknown compute capability.
3335
// https://docs.nvidia.com/cuda/parallel-thread-execution/#release-notes-ptx-release-history
34-
ASSERT_EQ(nvptx::GetSmName(se::CudaComputeCapability{10, 9}), "sm_100");
36+
ASSERT_EQ(nvptx::GetSmName(se::CudaComputeCapability{12, 9}), "sm_120");
3537
}
3638

3739
using VersionPair = std::pair<se::SemanticVersion, se::SemanticVersion>;
@@ -63,6 +65,7 @@ INSTANTIATE_TEST_SUITE_P(VersionTest, PtxVersionFromCudaVersionTest,
6365
{{12, 4, 0}, {8, 4, 0}},
6466
{{12, 5, 0}, {8, 5, 0}},
6567
{{12, 6, 0}, {8, 5, 0}},
68+
{{12, 8, 0}, {8, 7, 0}},
6669
}),
6770
[](::testing::TestParamInfo<VersionPair> data) {
6871
se::SemanticVersion cuda_version = data.param.first;

xla/stream_executor/cuda/ptx_compiler_helpers.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,12 @@ void WarnIfBadPtxasVersion(absl::string_view method,
101101
});
102102
}
103103

104-
// The extension is used for compute capabilities 9.0 and 10.0.
104+
// The extension is used for compute capabilities 9.0, 10.0, 10.1 and 12.0.
105105
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility
106106
bool ShouldUsePtxExtension(const CudaComputeCapability& cc) {
107-
return (cc.major == 9 && cc.minor == 0) || (cc.major == 10 && cc.minor == 0);
107+
return (cc.major == 9 && cc.minor == 0) ||
108+
(cc.major == 10 && (cc.minor == 0 || cc.minor == 1)) ||
109+
(cc.major == 12 && cc.minor == 0);
108110
}
109111

110112
} // namespace stream_executor

0 commit comments

Comments
 (0)