Skip to content

Commit 0e50556

Browse files
antonwolfyvtavana
andauthored
Add implementation of dpnp.i0 (#2118)
* Implement dpnp.i0 function * Add sycl::half type to matrix * Add third party test * Add more tests * Add CFD tests * Remove redundant code for i0 kernel * Fix typo in docstring * Add tests with NaN and Inf values * Add a proper math include based on compiler version * Define proper DPC++ version where sycl::ext::intel::math::cyl_bessel_i0(x) works --------- Co-authored-by: vtavana <[email protected]>
1 parent 3b235ac commit 0e50556

File tree

11 files changed

+574
-7
lines changed

11 files changed

+574
-7
lines changed

dpnp/backend/extensions/ufunc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ set(_elementwise_sources
3434
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmod.cpp
3535
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/gcd.cpp
3636
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/heaviside.cpp
37+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/i0.cpp
3738
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp
3839
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp
3940
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp

dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "fmod.hpp"
3535
#include "gcd.hpp"
3636
#include "heaviside.hpp"
37+
#include "i0.hpp"
3738
#include "lcm.hpp"
3839
#include "ldexp.hpp"
3940
#include "logaddexp2.hpp"
@@ -59,6 +60,7 @@ void init_elementwise_functions(py::module_ m)
5960
init_fmod(m);
6061
init_gcd(m);
6162
init_heaviside(m);
63+
init_i0(m);
6264
init_lcm(m);
6365
init_ldexp(m);
6466
init_logaddexp2(m);
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <sycl/sycl.hpp>
27+
28+
#include "dpctl4pybind11.hpp"
29+
30+
#include "i0.hpp"
31+
#include "kernels/elementwise_functions/i0.hpp"
32+
#include "populate.hpp"
33+
34+
// include a local copy of elementwise common header from dpctl tensor:
35+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
36+
// TODO: replace by including dpctl header once available
37+
#include "../../elementwise_functions/elementwise_functions.hpp"
38+
39+
// dpctl tensor headers
40+
#include "kernels/elementwise_functions/common.hpp"
41+
#include "utils/type_dispatch.hpp"
42+
43+
namespace dpnp::extensions::ufunc
44+
{
45+
namespace py = pybind11;
46+
namespace py_int = dpnp::extensions::py_internal;
47+
48+
namespace impl
49+
{
50+
namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common;
51+
namespace td_ns = dpctl::tensor::type_dispatch;
52+
53+
/**
54+
* @brief A factory to define pairs of supported types for which
55+
* sycl::i0<T> function is available.
56+
*
57+
* @tparam T Type of input vector `a` and of result vector `y`.
58+
*/
59+
template <typename T>
60+
struct OutputType
61+
{
62+
using value_type =
63+
typename std::disjunction<td_ns::TypeMapResultEntry<T, sycl::half>,
64+
td_ns::TypeMapResultEntry<T, float>,
65+
td_ns::TypeMapResultEntry<T, double>,
66+
td_ns::DefaultResultEntry<void>>::result_type;
67+
};
68+
69+
using dpnp::kernels::i0::I0Functor;
70+
71+
template <typename argT,
72+
typename resT = argT,
73+
unsigned int vec_sz = 4,
74+
unsigned int n_vecs = 2,
75+
bool enable_sg_loadstore = true>
76+
using ContigFunctor = ew_cmn_ns::UnaryContigFunctor<argT,
77+
resT,
78+
I0Functor<argT, resT>,
79+
vec_sz,
80+
n_vecs,
81+
enable_sg_loadstore>;
82+
83+
template <typename argTy, typename resTy, typename IndexerT>
84+
using StridedFunctor = ew_cmn_ns::
85+
UnaryStridedFunctor<argTy, resTy, IndexerT, I0Functor<argTy, resTy>>;
86+
87+
using ew_cmn_ns::unary_contig_impl_fn_ptr_t;
88+
using ew_cmn_ns::unary_strided_impl_fn_ptr_t;
89+
90+
static unary_contig_impl_fn_ptr_t i0_contig_dispatch_vector[td_ns::num_types];
91+
static int i0_output_typeid_vector[td_ns::num_types];
92+
static unary_strided_impl_fn_ptr_t i0_strided_dispatch_vector[td_ns::num_types];
93+
94+
MACRO_POPULATE_DISPATCH_VECTORS(i0);
95+
} // namespace impl
96+
97+
void init_i0(py::module_ m)
98+
{
99+
using arrayT = dpctl::tensor::usm_ndarray;
100+
using event_vecT = std::vector<sycl::event>;
101+
{
102+
impl::populate_i0_dispatch_vectors();
103+
using impl::i0_contig_dispatch_vector;
104+
using impl::i0_output_typeid_vector;
105+
using impl::i0_strided_dispatch_vector;
106+
107+
auto i0_pyapi = [&](const arrayT &src, const arrayT &dst,
108+
sycl::queue &exec_q,
109+
const event_vecT &depends = {}) {
110+
return py_int::py_unary_ufunc(
111+
src, dst, exec_q, depends, i0_output_typeid_vector,
112+
i0_contig_dispatch_vector, i0_strided_dispatch_vector);
113+
};
114+
m.def("_i0", i0_pyapi, "", py::arg("src"), py::arg("dst"),
115+
py::arg("sycl_queue"), py::arg("depends") = py::list());
116+
117+
auto i0_result_type_pyapi = [&](const py::dtype &dtype) {
118+
return py_int::py_unary_ufunc_result_type(dtype,
119+
i0_output_typeid_vector);
120+
};
121+
m.def("_i0_result_type", i0_result_type_pyapi);
122+
}
123+
}
124+
} // namespace dpnp::extensions::ufunc
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
namespace py = pybind11;
31+
32+
namespace dpnp::extensions::ufunc
33+
{
34+
void init_i0(py::module_ m);
35+
} // namespace dpnp::extensions::ufunc

0 commit comments

Comments
 (0)