Skip to content

Commit b7a6b67

Browse files
authored
Merge pull request #2063 from IntelPython/use-constant-value-in-imag
Use constant value in `imag` implementation for real-valued data types
2 parents ac7c007 + b909228 commit b7a6b67

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,16 @@ using dpctl::tensor::ssize_t;
5353
namespace td_ns = dpctl::tensor::type_dispatch;
5454

5555
using dpctl::tensor::type_utils::is_complex;
56+
using dpctl::tensor::type_utils::is_complex_v;
5657

5758
template <typename argT, typename resT> struct ImagFunctor
5859
{
5960

6061
// is function constant for given argT
61-
using is_constant = typename std::false_type;
62+
using is_constant =
63+
typename std::is_same<is_complex<argT>, std::false_type>;
6264
// constant value, if constant
63-
// constexpr resT constant_value = resT{};
65+
static constexpr resT constant_value = resT{0};
6466
// is function defined for sycl::vec
6567
using supports_vec = typename std::false_type;
6668
// do both argTy and resTy support sugroup store/load operation
@@ -69,12 +71,12 @@ template <typename argT, typename resT> struct ImagFunctor
6971

7072
resT operator()(const argT &in) const
7173
{
72-
if constexpr (is_complex<argT>::value) {
74+
if constexpr (is_complex_v<argT>) {
7375
return std::imag(in);
7476
}
7577
else {
7678
static_assert(std::is_same_v<resT, argT>);
77-
return resT{0};
79+
return constant_value;
7880
}
7981
}
8082
};

dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ using dpctl::tensor::ssize_t;
5353
namespace td_ns = dpctl::tensor::type_dispatch;
5454

5555
using dpctl::tensor::type_utils::is_complex;
56+
using dpctl::tensor::type_utils::is_complex_v;
5657

5758
template <typename argT, typename resT> struct RealFunctor
5859
{
@@ -69,7 +70,7 @@ template <typename argT, typename resT> struct RealFunctor
6970

7071
resT operator()(const argT &in) const
7172
{
72-
if constexpr (is_complex<argT>::value) {
73+
if constexpr (is_complex_v<argT>) {
7374
return std::real(in);
7475
}
7576
else {

0 commit comments

Comments
 (0)