Skip to content

Commit cdeaca7

Browse files
authored
Merge pull request #276 from PyO3/simpler-npy-strides
NpyStrides does not require SBO as the upper limit is statically known.
2 parents 4fbdebd + 7852fbc commit cdeaca7

File tree

1 file changed

+23
-34
lines changed

1 file changed

+23
-34
lines changed

src/convert.rs

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,12 @@ where
134134
_ => {
135135
// if the array is not contiguous, copy all elements by `ArrayBase::iter`.
136136
let dim = self.raw_dim();
137-
let strides = NpyStrides::from_dim(&dim, mem::size_of::<A>());
137+
let strides = NpyStrides::new::<_, A>(
138+
dim.default_strides()
139+
.slice()
140+
.iter()
141+
.map(|&x| x as npyffi::npy_intp),
142+
);
138143
unsafe {
139144
let array = PyArray::<A, _>::new_(py, dim, strides.as_ptr(), 0);
140145
let data_ptr = array.data();
@@ -173,10 +178,7 @@ where
173178
D: Dimension,
174179
{
175180
fn npy_strides(&self) -> NpyStrides {
176-
NpyStrides::new(
177-
self.strides().iter().map(|&x| x as npyffi::npy_intp),
178-
mem::size_of::<A>(),
179-
)
181+
NpyStrides::new::<_, A>(self.strides().iter().map(|&x| x as npyffi::npy_intp))
180182
}
181183

182184
fn order(&self) -> Option<Order> {
@@ -190,40 +192,27 @@ where
190192
}
191193
}
192194

193-
/// Numpy strides with short array optimization
194-
pub(crate) enum NpyStrides {
195-
Short([npyffi::npy_intp; 8]),
196-
Long(Vec<npyffi::npy_intp>),
197-
}
195+
/// An array of strides sufficiently large for [any NumPy array][NPY_MAXDIMS]
196+
///
197+
/// [NPY_MAXDIMS]: https://github.com/numpy/numpy/blob/4c60b3263ac50e5e72f6a909e156314fc3c9cba0/numpy/core/include/numpy/ndarraytypes.h#L40
198+
pub(crate) struct NpyStrides([npyffi::npy_intp; 32]);
198199

199200
impl NpyStrides {
200201
pub(crate) fn as_ptr(&self) -> *const npy_intp {
201-
match self {
202-
NpyStrides::Short(inner) => inner.as_ptr(),
203-
NpyStrides::Long(inner) => inner.as_ptr(),
204-
}
202+
self.0.as_ptr()
205203
}
206-
fn from_dim<D: Dimension>(dim: &D, type_size: usize) -> Self {
207-
Self::new(
208-
dim.default_strides()
209-
.slice()
210-
.iter()
211-
.map(|&x| x as npyffi::npy_intp),
212-
type_size,
213-
)
214-
}
215-
fn new(strides: impl ExactSizeIterator<Item = npyffi::npy_intp>, type_size: usize) -> Self {
216-
let len = strides.len();
217-
let type_size = type_size as npyffi::npy_intp;
218-
if len <= 8 {
219-
let mut res = [0; 8];
220-
for (i, s) in strides.enumerate() {
221-
res[i] = s * type_size;
222-
}
223-
NpyStrides::Short(res)
224-
} else {
225-
NpyStrides::Long(strides.map(|n| n as npyffi::npy_intp * type_size).collect())
204+
205+
fn new<S, A>(strides: S) -> Self
206+
where
207+
S: Iterator<Item = npyffi::npy_intp>,
208+
{
209+
let type_size = mem::size_of::<A>() as npyffi::npy_intp;
210+
let mut res = [0; 32];
211+
for (i, s) in strides.enumerate() {
212+
*res.get_mut(i)
213+
.expect("Only dimensionalities of up to 32 are supported") = s * type_size;
226214
}
215+
Self(res)
227216
}
228217
}
229218

0 commit comments

Comments
 (0)