Skip to content

Commit b57c0ba

Browse files
authored
[OpenMP] Update atomic helpers to just use headers (llvm#122185)
Summary: Previously we had some indirection here, this patch updates these utilities to just be normal template functions. We use SFINAE to manage the special case handling for floats. Also this strips address spaces so it can be used more generally.
1 parent 1739ba9 commit b57c0ba

File tree

4 files changed

+170
-217
lines changed

4 files changed

+170
-217
lines changed

offload/DeviceRTL/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ set(bc_flags -c -foffload-lto -std=c++17 -fvisibility=hidden
9999
${clang_opt_flags} --offload-device-only
100100
-nocudalib -nogpulib -nogpuinc -nostdlibinc
101101
-fopenmp -fopenmp-cuda-mode
102-
-Wno-unknown-cuda-version
102+
-Wno-unknown-cuda-version -Wno-openmp-target
103103
-DOMPTARGET_DEVICE_RUNTIME
104104
-I${include_directory}
105105
-I${devicertl_base_directory}/../include

offload/DeviceRTL/include/DeviceUtils.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,47 @@
1919

2020
namespace utils {
2121

22+
template <typename T> struct type_identity {
23+
using type = T;
24+
};
25+
26+
template <typename T, T v> struct integral_constant {
27+
inline static constexpr T value = v;
28+
};
29+
30+
/// Freestanding SFINAE helpers.
31+
template <class T> struct remove_cv : type_identity<T> {};
32+
template <class T> struct remove_cv<const T> : type_identity<T> {};
33+
template <class T> struct remove_cv<volatile T> : type_identity<T> {};
34+
template <class T> struct remove_cv<const volatile T> : type_identity<T> {};
35+
template <class T> using remove_cv_t = typename remove_cv<T>::type;
36+
37+
using true_type = integral_constant<bool, true>;
38+
using false_type = integral_constant<bool, false>;
39+
40+
template <typename T, typename U> struct is_same : false_type {};
41+
template <typename T> struct is_same<T, T> : true_type {};
42+
template <typename T, typename U>
43+
inline constexpr bool is_same_v = is_same<T, U>::value;
44+
45+
template <typename T> struct is_floating_point {
46+
inline static constexpr bool value =
47+
is_same_v<remove_cv<T>, float> || is_same_v<remove_cv<T>, double>;
48+
};
49+
template <typename T>
50+
inline constexpr bool is_floating_point_v = is_floating_point<T>::value;
51+
52+
template <bool B, typename T = void> struct enable_if;
53+
template <typename T> struct enable_if<true, T> : type_identity<T> {};
54+
template <bool B, typename T = void>
55+
using enable_if_t = typename enable_if<B, T>::type;
56+
57+
template <class T> struct remove_addrspace : type_identity<T> {};
58+
template <class T, int N>
59+
struct remove_addrspace<T [[clang::address_space(N)]]> : type_identity<T> {};
60+
template <class T>
61+
using remove_addrspace_t = typename remove_addrspace<T>::type;
62+
2263
/// Return the value \p Var from thread Id \p SrcLane in the warp if the thread
2364
/// is identified by \p Mask.
2465
int32_t shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane, int32_t Width);

offload/DeviceRTL/include/Synchronization.h

Lines changed: 123 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
#define OMPTARGET_DEVICERTL_SYNCHRONIZATION_H
1414

1515
#include "DeviceTypes.h"
16+
#include "DeviceUtils.h"
1617

17-
namespace ompx {
18+
#pragma omp begin declare target device_type(nohost)
1819

20+
namespace ompx {
1921
namespace atomic {
2022

2123
enum OrderingTy {
@@ -48,51 +50,124 @@ uint32_t inc(uint32_t *Addr, uint32_t V, OrderingTy Ordering,
4850
/// result is stored in \p *Addr;
4951
/// {
5052

51-
#define ATOMIC_COMMON_OP(TY) \
52-
TY add(TY *Addr, TY V, OrderingTy Ordering); \
53-
TY mul(TY *Addr, TY V, OrderingTy Ordering); \
54-
TY load(TY *Addr, OrderingTy Ordering); \
55-
void store(TY *Addr, TY V, OrderingTy Ordering); \
56-
bool cas(TY *Addr, TY ExpectedV, TY DesiredV, OrderingTy OrderingSucc, \
57-
OrderingTy OrderingFail);
58-
59-
#define ATOMIC_FP_ONLY_OP(TY) \
60-
TY min(TY *Addr, TY V, OrderingTy Ordering); \
61-
TY max(TY *Addr, TY V, OrderingTy Ordering);
62-
63-
#define ATOMIC_INT_ONLY_OP(TY) \
64-
TY min(TY *Addr, TY V, OrderingTy Ordering); \
65-
TY max(TY *Addr, TY V, OrderingTy Ordering); \
66-
TY bit_or(TY *Addr, TY V, OrderingTy Ordering); \
67-
TY bit_and(TY *Addr, TY V, OrderingTy Ordering); \
68-
TY bit_xor(TY *Addr, TY V, OrderingTy Ordering);
69-
70-
#define ATOMIC_FP_OP(TY) \
71-
ATOMIC_FP_ONLY_OP(TY) \
72-
ATOMIC_COMMON_OP(TY)
73-
74-
#define ATOMIC_INT_OP(TY) \
75-
ATOMIC_INT_ONLY_OP(TY) \
76-
ATOMIC_COMMON_OP(TY)
77-
78-
// This needs to be kept in sync with the header. Also the reason we don't use
79-
// templates here.
80-
ATOMIC_INT_OP(int8_t)
81-
ATOMIC_INT_OP(int16_t)
82-
ATOMIC_INT_OP(int32_t)
83-
ATOMIC_INT_OP(int64_t)
84-
ATOMIC_INT_OP(uint8_t)
85-
ATOMIC_INT_OP(uint16_t)
86-
ATOMIC_INT_OP(uint32_t)
87-
ATOMIC_INT_OP(uint64_t)
88-
ATOMIC_FP_OP(float)
89-
ATOMIC_FP_OP(double)
90-
91-
#undef ATOMIC_INT_ONLY_OP
92-
#undef ATOMIC_FP_ONLY_OP
93-
#undef ATOMIC_COMMON_OP
94-
#undef ATOMIC_INT_OP
95-
#undef ATOMIC_FP_OP
53+
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
54+
bool cas(Ty *Address, V ExpectedV, V DesiredV, atomic::OrderingTy OrderingSucc,
55+
atomic::OrderingTy OrderingFail) {
56+
return __scoped_atomic_compare_exchange(Address, &ExpectedV, &DesiredV, false,
57+
OrderingSucc, OrderingFail,
58+
__MEMORY_SCOPE_DEVICE);
59+
}
60+
61+
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
62+
V add(Ty *Address, V Val, atomic::OrderingTy Ordering) {
63+
return __scoped_atomic_fetch_add(Address, Val, Ordering,
64+
__MEMORY_SCOPE_DEVICE);
65+
}
66+
67+
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
68+
V load(Ty *Address, atomic::OrderingTy Ordering) {
69+
return add(Address, Ty(0), Ordering);
70+
}
71+
72+
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
73+
void store(Ty *Address, V Val, atomic::OrderingTy Ordering) {
74+
__scoped_atomic_store_n(Address, Val, Ordering, __MEMORY_SCOPE_DEVICE);
75+
}
76+
77+
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
78+
V mul(Ty *Address, V Val, atomic::OrderingTy Ordering) {
79+
Ty TypedCurrentVal, TypedResultVal, TypedNewVal;
80+
bool Success;
81+
do {
82+
TypedCurrentVal = atomic::load(Address, Ordering);
83+
TypedNewVal = TypedCurrentVal * Val;
84+
Success = atomic::cas(Address, TypedCurrentVal, TypedNewVal, Ordering,
85+
atomic::relaxed);
86+
} while (!Success);
87+
return TypedResultVal;
88+
}
89+
90+
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
91+
utils::enable_if_t<!utils::is_floating_point_v<V>, V>
92+
max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
93+
return __scoped_atomic_fetch_max(Address, Val, Ordering,
94+
__MEMORY_SCOPE_DEVICE);
95+
}
96+
97+
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
98+
utils::enable_if_t<utils::is_same_v<V, float>, V>
99+
max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
100+
if (Val >= 0)
101+
return utils::convertViaPun<float>(
102+
max((int32_t *)Address, utils::convertViaPun<int32_t>(Val), Ordering));
103+
return utils::convertViaPun<float>(
104+
min((uint32_t *)Address, utils::convertViaPun<uint32_t>(Val), Ordering));
105+
}
106+
107+
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
108+
utils::enable_if_t<utils::is_same_v<V, double>, V>
109+
max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
110+
if (Val >= 0)
111+
return utils::convertViaPun<double>(
112+
max((int64_t *)Address, utils::convertViaPun<int64_t>(Val), Ordering));
113+
return utils::convertViaPun<double>(
114+
min((uint64_t *)Address, utils::convertViaPun<uint64_t>(Val), Ordering));
115+
}
116+
117+
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
118+
utils::enable_if_t<!utils::is_floating_point_v<V>, V>
119+
min(Ty *Address, V Val, atomic::OrderingTy Ordering) {
120+
return __scoped_atomic_fetch_min(Address, Val, Ordering,
121+
__MEMORY_SCOPE_DEVICE);
122+
}
123+
124+
// TODO: Implement this with __atomic_fetch_max and remove the duplication.
125+
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
126+
utils::enable_if_t<utils::is_same_v<V, float>, V>
127+
min(Ty *Address, V Val, atomic::OrderingTy Ordering) {
128+
if (Val >= 0)
129+
return utils::convertViaPun<float>(
130+
min((int32_t *)Address, utils::convertViaPun<int32_t>(Val), Ordering));
131+
return utils::convertViaPun<float>(
132+
max((uint32_t *)Address, utils::convertViaPun<uint32_t>(Val), Ordering));
133+
}
134+
135+
// TODO: Implement this with __atomic_fetch_max and remove the duplication.
136+
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
137+
utils::enable_if_t<utils::is_same_v<V, double>, V>
138+
min(Ty *Address, utils::remove_addrspace_t<Ty> Val,
139+
atomic::OrderingTy Ordering) {
140+
if (Val >= 0)
141+
return utils::convertViaPun<double>(
142+
min((int64_t *)Address, utils::convertViaPun<int64_t>(Val), Ordering));
143+
return utils::convertViaPun<double>(
144+
max((uint64_t *)Address, utils::convertViaPun<uint64_t>(Val), Ordering));
145+
}
146+
147+
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
148+
V bit_or(Ty *Address, V Val, atomic::OrderingTy Ordering) {
149+
return __scoped_atomic_fetch_or(Address, Val, Ordering,
150+
__MEMORY_SCOPE_DEVICE);
151+
}
152+
153+
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
154+
V bit_and(Ty *Address, V Val, atomic::OrderingTy Ordering) {
155+
return __scoped_atomic_fetch_and(Address, Val, Ordering,
156+
__MEMORY_SCOPE_DEVICE);
157+
}
158+
159+
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
160+
V bit_xor(Ty *Address, V Val, atomic::OrderingTy Ordering) {
161+
return __scoped_atomic_fetch_xor(Address, Val, Ordering,
162+
__MEMORY_SCOPE_DEVICE);
163+
}
164+
165+
static inline uint32_t atomicExchange(uint32_t *Address, uint32_t Val,
166+
atomic::OrderingTy Ordering) {
167+
uint32_t R;
168+
__scoped_atomic_exchange(Address, &Val, &R, Ordering, __MEMORY_SCOPE_DEVICE);
169+
return R;
170+
}
96171

97172
///}
98173

@@ -145,4 +220,6 @@ void system(atomic::OrderingTy Ordering);
145220

146221
} // namespace ompx
147222

223+
#pragma omp end declare target
224+
148225
#endif

0 commit comments

Comments
 (0)