Skip to content

Commit 404f0fe

Browse files
Implementing histogramdd
1 parent afa6980 commit 404f0fe

File tree

12 files changed

+1143
-98
lines changed

12 files changed

+1143
-98
lines changed

dpnp/backend/extensions/statistics/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
3030
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
3131
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
32+
${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp
3233
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
3334
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
3435
)

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,30 @@
3434
// so sycl.hpp must be included before math_utils.hpp
3535
#include <sycl/sycl.hpp>
3636
#include "utils/math_utils.hpp"
37+
#include "utils/type_utils.hpp"
3738
// clang-format on
3839

40+
namespace dpctl
41+
{
42+
namespace tensor
43+
{
44+
namespace type_utils
45+
{
46+
// Upstream to dpctl
47+
template <class T>
48+
struct is_complex<const std::complex<T>> : std::true_type
49+
{
50+
};
51+
52+
template <typename T>
53+
constexpr bool is_complex_v = is_complex<T>::value;
54+
55+
} // namespace type_utils
56+
} // namespace tensor
57+
} // namespace dpctl
58+
59+
namespace type_utils = dpctl::tensor::type_utils;
60+
3961
namespace statistics
4062
{
4163
namespace common
@@ -56,24 +78,20 @@ constexpr auto Align(N n, D d)
5678
template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
5779
struct AtomicOp
5880
{
59-
static void add(T &lhs, const T value)
81+
static void add(T &lhs, const T &value)
6082
{
61-
sycl::atomic_ref<T, Order, Scope> lh(lhs);
62-
lh += value;
63-
}
64-
};
83+
if constexpr (type_utils::is_complex_v<T>) {
84+
using vT = typename T::value_type;
85+
vT *_lhs = reinterpret_cast<vT(&)[2]>(lhs);
86+
const vT *_val = reinterpret_cast<const vT(&)[2]>(value);
6587

66-
template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
67-
struct AtomicOp<std::complex<T>, Order, Scope>
68-
{
69-
static void add(std::complex<T> &lhs, const std::complex<T> value)
70-
{
71-
T *_lhs = reinterpret_cast<T(&)[2]>(lhs);
72-
const T *_val = reinterpret_cast<const T(&)[2]>(value);
73-
sycl::atomic_ref<T, Order, Scope> lh0(_lhs[0]);
74-
lh0 += _val[0];
75-
sycl::atomic_ref<T, Order, Scope> lh1(_lhs[1]);
76-
lh1 += _val[1];
88+
AtomicOp<vT, Order, Scope>::add(_lhs[0], _val[0]);
89+
AtomicOp<vT, Order, Scope>::add(_lhs[1], _val[1]);
90+
}
91+
else {
92+
sycl::atomic_ref<T, Order, Scope> lh(lhs);
93+
lh += value;
94+
}
7795
}
7896
};
7997

@@ -82,17 +100,12 @@ struct Less
82100
{
83101
bool operator()(const T &lhs, const T &rhs) const
84102
{
85-
return std::less{}(lhs, rhs);
86-
}
87-
};
88-
89-
template <typename T>
90-
struct Less<std::complex<T>>
91-
{
92-
bool operator()(const std::complex<T> &lhs,
93-
const std::complex<T> &rhs) const
94-
{
95-
return dpctl::tensor::math_utils::less_complex(lhs, rhs);
103+
if constexpr (type_utils::is_complex_v<T>) {
104+
return dpctl::tensor::math_utils::less_complex(lhs, rhs);
105+
}
106+
else {
107+
return std::less{}(lhs, rhs);
108+
}
96109
}
97110
};
98111

@@ -101,26 +114,25 @@ struct IsNan
101114
{
102115
static bool isnan(const T &v)
103116
{
104-
if constexpr (std::is_floating_point_v<T> ||
105-
std::is_same_v<T, sycl::half>) {
106-
return sycl::isnan(v);
117+
if constexpr (type_utils::is_complex_v<T>) {
118+
const auto real1 = std::real(v);
119+
const auto imag1 = std::imag(v);
120+
121+
using vT = typename T::value_type;
122+
123+
return IsNan<vT>::isnan(real1) || IsNan<vT>::isnan(imag1);
124+
}
125+
else {
126+
if constexpr (std::is_floating_point_v<T> ||
127+
std::is_same_v<T, sycl::half>) {
128+
return sycl::isnan(v);
129+
}
107130
}
108131

109132
return false;
110133
}
111134
};
112135

113-
template <typename T>
114-
struct IsNan<std::complex<T>>
115-
{
116-
static bool isnan(const std::complex<T> &v)
117-
{
118-
T real1 = std::real(v);
119-
T imag1 = std::imag(v);
120-
return sycl::isnan(real1) || sycl::isnan(imag1);
121-
}
122-
};
123-
124136
size_t get_max_local_size(const sycl::device &device);
125137
size_t get_max_local_size(const sycl::device &device,
126138
int cpu_local_size_limit,

dpnp/backend/extensions/statistics/histogram_common.cpp

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -137,22 +137,16 @@ void validate(const usm_ndarray &sample,
137137
" parameter must have at least 1 element");
138138
}
139139

140-
if (histogram.get_ndim() != 1) {
141-
throw py::value_error(get_name(&histogram) +
142-
" parameter must be 1d. Actual " +
143-
std::to_string(histogram.get_ndim()) + "d");
144-
}
145-
146140
if (weights_ptr) {
147141
if (weights_ptr->get_ndim() != 1) {
148142
throw py::value_error(
149143
get_name(weights_ptr) + " parameter must be 1d. Actual " +
150144
std::to_string(weights_ptr->get_ndim()) + "d");
151145
}
152146

153-
auto sample_size = sample.get_size();
147+
auto sample_size = sample.get_shape(0);
154148
auto weights_size = weights_ptr->get_size();
155-
if (sample.get_size() != weights_ptr->get_size()) {
149+
if (sample_size != weights_ptr->get_size()) {
156150
throw py::value_error(
157151
get_name(&sample) + " size (" + std::to_string(sample_size) +
158152
") and " + get_name(weights_ptr) + " size (" +
@@ -168,42 +162,37 @@ void validate(const usm_ndarray &sample,
168162
}
169163

170164
if (sample.get_ndim() == 1) {
171-
if (bins_ptr != nullptr && bins_ptr->get_ndim() != 1) {
165+
if (histogram.get_ndim() != 1) {
172166
throw py::value_error(get_name(&sample) + " parameter is 1d, but " +
173-
get_name(bins_ptr) + " is " +
174-
std::to_string(bins_ptr->get_ndim()) + "d");
167+
get_name(&histogram) + " is " +
168+
std::to_string(histogram.get_ndim()) + "d");
169+
}
170+
171+
if (bins_ptr && histogram.get_size() != bins_ptr->get_size() - 1) {
172+
auto hist_size = histogram.get_size();
173+
auto bins_size = bins_ptr->get_size();
174+
throw py::value_error(
175+
get_name(&histogram) + " parameter and " + get_name(bins_ptr) +
176+
" parameters shape mismatch. " + get_name(&histogram) +
177+
" size is " + std::to_string(hist_size) + get_name(bins_ptr) +
178+
" must have size " + std::to_string(hist_size + 1) +
179+
" but have " + std::to_string(bins_size));
175180
}
176181
}
177182
else if (sample.get_ndim() == 2) {
178183
auto sample_count = sample.get_shape(0);
179184
auto expected_dims = sample.get_shape(1);
180185

181-
if (bins_ptr != nullptr && bins_ptr->get_ndim() != expected_dims) {
182-
throw py::value_error(get_name(&sample) + " parameter has shape {" +
183-
std::to_string(sample_count) + "x" +
184-
std::to_string(expected_dims) + "}" +
185-
", so " + get_name(bins_ptr) +
186+
if (histogram.get_ndim() != expected_dims) {
187+
throw py::value_error(get_name(&sample) + " parameter has shape (" +
188+
std::to_string(sample_count) + ", " +
189+
std::to_string(expected_dims) + ")" +
190+
", so " + get_name(&histogram) +
186191
" parameter expected to be " +
187192
std::to_string(expected_dims) +
188193
"d. "
189194
"Actual " +
190-
std::to_string(bins->get_ndim()) + "d");
191-
}
192-
}
193-
194-
if (bins_ptr != nullptr) {
195-
py::ssize_t expected_hist_size = 1;
196-
for (int i = 0; i < bins_ptr->get_ndim(); ++i) {
197-
expected_hist_size *= (bins_ptr->get_shape(i) - 1);
198-
}
199-
200-
if (histogram.get_size() != expected_hist_size) {
201-
throw py::value_error(
202-
get_name(&histogram) + " and " + get_name(bins_ptr) +
203-
" shape mismatch. " + get_name(&histogram) +
204-
" expected to have size = " +
205-
std::to_string(expected_hist_size) + ". Actual " +
206-
std::to_string(histogram.get_size()));
195+
std::to_string(histogram.get_ndim()) + "d");
207196
}
208197
}
209198

dpnp/backend/extensions/statistics/histogram_common.hpp

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,15 @@ template <typename T, int Dims>
5252
struct CachedData
5353
{
5454
static constexpr bool const sync_after_init = true;
55-
using pointer_type = T *;
55+
using Shape = sycl::range<Dims>;
56+
using value_type = T;
57+
using pointer_type = value_type *;
58+
static constexpr auto dims = Dims;
5659

57-
using ncT = typename std::remove_const<T>::type;
60+
using ncT = typename std::remove_const<value_type>::type;
5861
using LocalData = sycl::local_accessor<ncT, Dims>;
5962

60-
CachedData(T *global_data, sycl::range<Dims> shape, sycl::handler &cgh)
63+
CachedData(T *global_data, Shape shape, sycl::handler &cgh)
6164
{
6265
this->global_data = global_data;
6366
local_data = LocalData(shape, cgh);
@@ -87,17 +90,30 @@ struct CachedData
8790
return local_data.size();
8891
}
8992

93+
T &operator[](const sycl::id<Dims> &id) const
94+
{
95+
return local_data[id];
96+
}
97+
98+
template <typename = std::enable_if_t<Dims == 1>>
99+
T &operator[](const size_t id) const
100+
{
101+
return local_data[id];
102+
}
103+
90104
private:
91105
LocalData local_data;
92-
T *global_data = nullptr;
106+
value_type *global_data = nullptr;
93107
};
94108

95109
template <typename T, int Dims>
96110
struct UncachedData
97111
{
98112
static constexpr bool const sync_after_init = false;
99113
using Shape = sycl::range<Dims>;
100-
using pointer_type = T *;
114+
using value_type = T;
115+
using pointer_type = value_type *;
116+
static constexpr auto dims = Dims;
101117

102118
UncachedData(T *global_data, const Shape &shape, sycl::handler &)
103119
{
@@ -120,6 +136,17 @@ struct UncachedData
120136
return _shape.size();
121137
}
122138

139+
T &operator[](const sycl::id<Dims> &id) const
140+
{
141+
return global_data[id];
142+
}
143+
144+
template <typename = std::enable_if_t<Dims == 1>>
145+
T &operator[](const size_t id) const
146+
{
147+
return global_data[id];
148+
}
149+
123150
private:
124151
T *global_data = nullptr;
125152
Shape _shape;
@@ -290,9 +317,9 @@ class histogram_kernel;
290317

291318
template <typename T, typename HistImpl, typename Edges, typename Weights>
292319
void submit_histogram(const T *in,
293-
size_t size,
294-
size_t dims,
295-
uint32_t WorkPI,
320+
const size_t size,
321+
const size_t dims,
322+
const uint32_t WorkPI,
296323
const HistImpl &hist,
297324
const Edges &edges,
298325
const Weights &weights,

0 commit comments

Comments
 (0)