Skip to content

Commit 2dde2b4

Browse files
committed
Make nanobind adapter backward compatible (#1790)
1 parent 3884a02 commit 2dde2b4

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

include/gridtools/storage/adapter/nanobind_adapter.hpp

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,33 @@
2424

2525
namespace gridtools {
2626
namespace nanobind_sid_adapter_impl_ {
27+
#if NB_VERSION_MAJOR >= 2
28+
using array_size_t = nanobind::ssize_t;
29+
#else
30+
using array_size_t = std::size_t;
31+
#endif
2732

28-
// Use `-1` for dynamic stride, use an integral value for static stride.
29-
template <nanobind::ssize_t... Values>
30-
using stride_spec = std::integer_sequence<nanobind::ssize_t, Values...>;
33+
inline constexpr array_size_t dynamic_size = -1;
34+
35+
// Use `dynamic_size` for dynamic stride, use an integral value for static
36+
// stride.
37+
template <array_size_t... Values>
38+
using stride_spec = std::integer_sequence<array_size_t, Values...>;
3139

3240
template <class IndexSequence>
3341
struct dynamic_strides_helper;
3442

3543
template <std::size_t... Indices>
3644
struct dynamic_strides_helper<std::index_sequence<Indices...>> {
37-
using type = stride_spec<(void(Indices), -1)...>;
45+
using type = stride_spec<(void(Indices), dynamic_size)...>;
3846
};
3947

4048
template <std::size_t N>
4149
using fully_dynamic_strides = typename dynamic_strides_helper<std::make_index_sequence<N>>::type;
4250

43-
template <nanobind::ssize_t SpecValue>
44-
auto select_static_stride_value(std::size_t dyn_value) {
45-
if constexpr (SpecValue == -1) {
51+
template <array_size_t SpecValue>
52+
constexpr auto select_static_stride_value(std::size_t dyn_value) {
53+
if constexpr (SpecValue == dynamic_size) {
4654
return dyn_value;
4755
} else {
4856
if (SpecValue != dyn_value) {
@@ -52,20 +60,20 @@ namespace gridtools {
5260
}
5361
}
5462

55-
template <nanobind::ssize_t... SpecValues, std::size_t... IndexValues>
56-
auto select_static_strides_helper(
63+
template <array_size_t... SpecValues, std::size_t... IndexValues>
64+
constexpr auto select_static_strides_helper(
5765
stride_spec<SpecValues...>, const std::size_t *dyn_values, std::index_sequence<IndexValues...>) {
5866

5967
return gridtools::tuple{select_static_stride_value<SpecValues>(dyn_values[IndexValues])...};
6068
}
6169

62-
template <nanobind::ssize_t... SpecValues>
63-
auto select_static_strides(stride_spec<SpecValues...> spec, const std::size_t *dyn_values) {
70+
template <array_size_t... SpecValues>
71+
constexpr auto select_static_strides(stride_spec<SpecValues...> spec, const std::size_t *dyn_values) {
6472
return select_static_strides_helper(spec, dyn_values, std::make_index_sequence<sizeof...(SpecValues)>{});
6573
}
6674

6775
template <class T,
68-
nanobind::ssize_t... Sizes,
76+
array_size_t... Sizes,
6977
class... Args,
7078
class Strides = fully_dynamic_strides<sizeof...(Sizes)>,
7179
class StridesKind = sid::unknown_kind>
@@ -93,7 +101,9 @@ namespace gridtools {
93101

94102
namespace nanobind {
95103
using nanobind_sid_adapter_impl_::as_sid;
104+
using nanobind_sid_adapter_impl_::dynamic_size;
96105
using nanobind_sid_adapter_impl_::fully_dynamic_strides;
97106
using nanobind_sid_adapter_impl_::stride_spec;
107+
98108
} // namespace nanobind
99109
} // namespace gridtools

0 commit comments

Comments
 (0)