Skip to content

[OpenMP] Use __builtin_bit_cast instead of UB type punning #122325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions offload/DeviceRTL/include/DeviceUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ struct remove_addrspace<T [[clang::address_space(N)]]> : type_identity<T> {};
template <class T>
using remove_addrspace_t = typename remove_addrspace<T>::type;

template <typename To, typename From> inline To bitCast(From V) {
static_assert(sizeof(To) == sizeof(From), "Bad conversion");
return __builtin_bit_cast(To, V);
}

/// Return the value \p Var from thread Id \p SrcLane in the warp if the thread
/// is identified by \p Mask.
int32_t shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane, int32_t Width);
Expand Down
32 changes: 16 additions & 16 deletions offload/DeviceRTL/include/Synchronization.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,20 @@ template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
utils::enable_if_t<utils::is_same_v<V, float>, V>
max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
if (Val >= 0)
return utils::convertViaPun<float>(
max((int32_t *)Address, utils::convertViaPun<int32_t>(Val), Ordering));
return utils::convertViaPun<float>(
min((uint32_t *)Address, utils::convertViaPun<uint32_t>(Val), Ordering));
return utils::bitCast<float>(
max((int32_t *)Address, utils::bitCast<int32_t>(Val), Ordering));
return utils::bitCast<float>(
min((uint32_t *)Address, utils::bitCast<uint32_t>(Val), Ordering));
}

template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
utils::enable_if_t<utils::is_same_v<V, double>, V>
max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
if (Val >= 0)
return utils::convertViaPun<double>(
max((int64_t *)Address, utils::convertViaPun<int64_t>(Val), Ordering));
return utils::convertViaPun<double>(
min((uint64_t *)Address, utils::convertViaPun<uint64_t>(Val), Ordering));
return utils::bitCast<double>(
max((int64_t *)Address, utils::bitCast<int64_t>(Val), Ordering));
return utils::bitCast<double>(
min((uint64_t *)Address, utils::bitCast<uint64_t>(Val), Ordering));
}

template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
Expand All @@ -126,10 +126,10 @@ template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
utils::enable_if_t<utils::is_same_v<V, float>, V>
min(Ty *Address, V Val, atomic::OrderingTy Ordering) {
if (Val >= 0)
return utils::convertViaPun<float>(
min((int32_t *)Address, utils::convertViaPun<int32_t>(Val), Ordering));
return utils::convertViaPun<float>(
max((uint32_t *)Address, utils::convertViaPun<uint32_t>(Val), Ordering));
return utils::bitCast<float>(
min((int32_t *)Address, utils::bitCast<int32_t>(Val), Ordering));
return utils::bitCast<float>(
max((uint32_t *)Address, utils::bitCast<uint32_t>(Val), Ordering));
}

// TODO: Implement this with __atomic_fetch_max and remove the duplication.
Expand All @@ -138,10 +138,10 @@ utils::enable_if_t<utils::is_same_v<V, double>, V>
min(Ty *Address, utils::remove_addrspace_t<Ty> Val,
atomic::OrderingTy Ordering) {
if (Val >= 0)
return utils::convertViaPun<double>(
min((int64_t *)Address, utils::convertViaPun<int64_t>(Val), Ordering));
return utils::convertViaPun<double>(
max((uint64_t *)Address, utils::convertViaPun<uint64_t>(Val), Ordering));
return utils::bitCast<double>(
min((int64_t *)Address, utils::bitCast<int64_t>(Val), Ordering));
return utils::bitCast<double>(
max((uint64_t *)Address, utils::bitCast<uint64_t>(Val), Ordering));
}

template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
Expand Down
8 changes: 4 additions & 4 deletions offload/DeviceRTL/src/Mapping.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ int ompx_shfl_down_sync_i(uint64_t mask, int var, unsigned delta, int width) {

float ompx_shfl_down_sync_f(uint64_t mask, float var, unsigned delta,
int width) {
return utils::convertViaPun<float>(utils::shuffleDown(
mask, utils::convertViaPun<int32_t>(var), delta, width));
return utils::bitCast<float>(
utils::shuffleDown(mask, utils::bitCast<int32_t>(var), delta, width));
}

long ompx_shfl_down_sync_l(uint64_t mask, long var, unsigned delta, int width) {
Expand All @@ -381,8 +381,8 @@ long ompx_shfl_down_sync_l(uint64_t mask, long var, unsigned delta, int width) {

double ompx_shfl_down_sync_d(uint64_t mask, double var, unsigned delta,
int width) {
return utils::convertViaPun<double>(utils::shuffleDown(
mask, utils::convertViaPun<int64_t>(var), delta, width));
return utils::bitCast<double>(
utils::shuffleDown(mask, utils::bitCast<int64_t>(var), delta, width));
}
}

Expand Down
5 changes: 0 additions & 5 deletions offload/include/Shared/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ inline uint32_t popc(uint64_t V) {
return __builtin_popcountl(V);
}

template <typename DstTy, typename SrcTy> inline DstTy convertViaPun(SrcTy V) {
static_assert(sizeof(DstTy) == sizeof(SrcTy), "Bad conversion");
return *((DstTy *)(&V));
}

} // namespace utils

#endif // OMPTARGET_SHARED_UTILS_H
Loading