Skip to content

Commit 894909a

Browse files
authored
Revert "[CUDA] Only use vec128 if CUDA version is newer than 12.8" (#150855)
Revert "[CUDA] Only use vec128 if CUDA version is newer than 12.8 (#150818)" This reverts commit 3f236f1.
1 parent ef2b139 commit 894909a

File tree

4 files changed

+11
-23
lines changed

4 files changed

+11
-23
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

+2-4
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
6161
}
6262
}
6363

64-
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION < 12080)
64+
#ifdef USE_ROCM
6565
template <int io_sizes>
6666
constexpr auto elems_per_thread(){
6767
if constexpr (io_sizes == 1) {
@@ -202,7 +202,7 @@ static inline void launch_vectorized_kernel(
202202
constexpr auto io_size = calc_io_size<func_t>();
203203
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
204204
auto stream = at::cuda::getCurrentCUDAStream();
205-
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION < 12080)
205+
#ifdef USE_ROCM
206206
int vec_size = memory::can_vectorize_up_to<func_t>(data);
207207
#else
208208
using cpp_type = typename function_traits<func_t>::result_type;
@@ -224,13 +224,11 @@ static inline void launch_vectorized_kernel(
224224
C10_CUDA_KERNEL_LAUNCH_CHECK();
225225
break;
226226
#endif
227-
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
228227
case 8:
229228
vectorized_elementwise_kernel<8, func_t, array_t>
230229
<<<grid, num_threads(), 0, stream>>>(N, f, data);
231230
C10_CUDA_KERNEL_LAUNCH_CHECK();
232231
break;
233-
#endif
234232
case 4:
235233
vectorized_elementwise_kernel<4, func_t, array_t>
236234
<<<grid, num_threads(), 0, stream>>>(N, f, data);

aten/src/ATen/native/cuda/MemoryAccess.cuh

+1-3
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,7 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
351351
uint64_t address = reinterpret_cast<uint64_t>(pointer);
352352
constexpr int vec2_alignment = std::alignment_of_v<aligned_vector<scalar_t, 2>>;
353353
constexpr int vec4_alignment = std::alignment_of_v<aligned_vector<scalar_t, 4>>;
354-
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
355354
constexpr int vec8_alignment = std::alignment_of_v<aligned_vector<scalar_t, 8>>;
356-
#endif
357355
#ifdef USE_ROCM
358356
constexpr int vec16_alignment = std::alignment_of_v<aligned_vector<scalar_t, 16>>;
359357
constexpr int type_size = sizeof(scalar_t);
@@ -362,7 +360,7 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
362360
} else if (type_size <= 2 && (address % vec8_alignment == 0)) {
363361
return 8;
364362
} else
365-
#elif defined(CUDA_VERSION) && CUDA_VERSION >= 12080
363+
#else
366364
if (address % vec8_alignment == 0) {
367365
return 8;
368366
} else

aten/src/ATen/native/cuda/thread_constants.h

+1-4
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@ constexpr int thread_work_size() { return 4; }
1818
constexpr uint32_t num_threads() {
1919
return C10_WARP_SIZE * 4;
2020
}
21-
#if defined(CUDA_VERSION) && CUDA_VERSION < 12080
22-
constexpr int thread_work_size() { return 4; }
23-
#else
21+
2422
constexpr int thread_work_size() { return 8; }
2523
#endif
26-
#endif
2724

2825
constexpr int block_work_size() { return thread_work_size() * num_threads(); }

aten/src/ATen/test/cuda_vectorized_test.cu

+7-12
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,12 @@ TEST(TestLoops, HasSameArgTypes) {
4646

4747
TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
4848
char *ptr = reinterpret_cast<char *>(buffer1);
49-
#if defined(CUDA_VERSION) && CUDA_VERSION < 12080
50-
constexpr auto vectorize_limit = 4;
51-
#else
52-
constexpr auto vectorize_limit= 8;
53-
#endif
5449

55-
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), vectorize_limit);
56-
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), vectorize_limit);
57-
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), vectorize_limit);
58-
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), vectorize_limit);
59-
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), vectorize_limit);
50+
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 8);
51+
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 8);
52+
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 8);
53+
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 8);
54+
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 8);
6055

6156
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 1), 1);
6257
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 1), 1);
@@ -70,8 +65,8 @@ TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
7065
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 4), 2);
7166
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 4), 1);
7267

73-
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), vectorize_limit);
74-
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), vectorize_limit);
68+
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 8);
69+
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 8);
7570
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 8), 4);
7671
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 8), 2);
7772
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr + 8), 1);

0 commit comments

Comments
 (0)