Skip to content

Commit 19e9232

Browse files
authored
Merge pull request #305 from PyO3/converting-strides
Collect scattered code for converting strides
2 parents 25ef24d + 1961c5f commit 19e9232

File tree

3 files changed

+96
-132
lines changed

3 files changed

+96
-132
lines changed

src/array.rs

+67-77
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::{
99

1010
use ndarray::{
1111
Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, Dim, Dimension, IntoDimension, Ix0, Ix1,
12-
Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, RawArrayView, RawArrayViewMut, RawData, Shape, ShapeBuilder,
12+
Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, RawArrayView, RawArrayViewMut, RawData, ShapeBuilder,
1313
StrideShape,
1414
};
1515
use num_traits::AsPrimitive;
@@ -338,42 +338,19 @@ impl<T, D> PyArray<T, D> {
338338
}
339339
}
340340

341-
/// Calcurates the total number of elements in the array.
341+
/// Calculates the total number of elements in the array.
342342
pub fn len(&self) -> usize {
343343
self.shape().iter().product()
344344
}
345345

346+
/// Returns `true` if the there are no elements in the array.
346347
pub fn is_empty(&self) -> bool {
347-
self.len() == 0
348+
self.shape().iter().any(|dim| *dim == 0)
348349
}
349350

350-
/// Returns the pointer to the first element of the inner array.
351+
/// Returns the pointer to the first element of the array.
351352
pub(crate) fn data(&self) -> *mut T {
352-
let ptr = self.as_array_ptr();
353-
unsafe { (*ptr).data as *mut _ }
354-
}
355-
}
356-
357-
struct InvertedAxes(u32);
358-
359-
impl InvertedAxes {
360-
fn new(len: usize) -> Self {
361-
assert!(len <= 32, "Only dimensionalities of up to 32 are supported");
362-
Self(0)
363-
}
364-
365-
fn push(&mut self, axis: usize) {
366-
debug_assert!(axis < 32);
367-
self.0 |= 1 << axis;
368-
}
369-
370-
fn invert<S: RawData, D: Dimension>(mut self, array: &mut ArrayBase<S, D>) {
371-
while self.0 != 0 {
372-
let axis = self.0.trailing_zeros() as usize;
373-
self.0 &= !(1 << axis);
374-
375-
array.invert_axis(Axis(axis));
376-
}
353+
unsafe { (*self.as_array_ptr()).data as *mut _ }
377354
}
378355
}
379356

@@ -384,38 +361,6 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
384361
D::from_dimension(&Dim(self.shape())).expect("mismatching dimensions")
385362
}
386363

387-
fn ndarray_shape_ptr(&self) -> (StrideShape<D>, *mut T, InvertedAxes) {
388-
let shape = self.shape();
389-
let strides = self.strides();
390-
391-
let mut new_strides = D::zeros(strides.len());
392-
let mut data_ptr = self.data();
393-
let mut inverted_axes = InvertedAxes::new(strides.len());
394-
395-
for i in 0..strides.len() {
396-
// FIXME(kngwyu): Replace this hacky negative strides support with
397-
// a proper constructor, when it's implemented.
398-
// See https://github.com/rust-ndarray/ndarray/issues/842 for more.
399-
if strides[i] < 0 {
400-
// Move the pointer to the start position
401-
let offset = strides[i] * (shape[i] as isize - 1) / mem::size_of::<T>() as isize;
402-
unsafe {
403-
data_ptr = data_ptr.offset(offset);
404-
}
405-
new_strides[i] = (-strides[i]) as usize / mem::size_of::<T>();
406-
407-
inverted_axes.push(i);
408-
} else {
409-
new_strides[i] = strides[i] as usize / mem::size_of::<T>();
410-
}
411-
}
412-
413-
let shape = Shape::from(D::from_dimension(&Dim(shape)).expect("mismatching dimensions"));
414-
let new_strides = D::from_dimension(&Dim(new_strides)).expect("mismatching dimensions");
415-
416-
(shape.strides(new_strides), data_ptr, inverted_axes)
417-
}
418-
419364
/// Creates a new uninitialized PyArray in python heap.
420365
///
421366
/// If `is_fortran == true`, returns Fortran-order array. Else, returns C-order array.
@@ -883,6 +828,63 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
883828
self.try_readwrite().unwrap()
884829
}
885830

831+
fn as_view<S: RawData, F>(&self, from_shape_ptr: F) -> ArrayBase<S, D>
832+
where
833+
F: FnOnce(StrideShape<D>, *mut T) -> ArrayBase<S, D>,
834+
{
835+
fn inner<D: Dimension>(
836+
shape: &[usize],
837+
strides: &[isize],
838+
itemsize: usize,
839+
mut data_ptr: *mut u8,
840+
) -> (StrideShape<D>, u32, *mut u8) {
841+
let shape = D::from_dimension(&Dim(shape)).expect("mismatching dimensions");
842+
843+
assert!(
844+
strides.len() <= 32,
845+
"Only dimensionalities of up to 32 are supported"
846+
);
847+
848+
let mut new_strides = D::zeros(strides.len());
849+
let mut inverted_axes = 0_u32;
850+
851+
for i in 0..strides.len() {
852+
// FIXME(kngwyu): Replace this hacky negative strides support with
853+
// a proper constructor, when it's implemented.
854+
// See https://github.com/rust-ndarray/ndarray/issues/842 for more.
855+
if strides[i] >= 0 {
856+
new_strides[i] = strides[i] as usize / itemsize;
857+
} else {
858+
// Move the pointer to the start position.
859+
data_ptr = unsafe { data_ptr.offset(strides[i] * (shape[i] as isize - 1)) };
860+
861+
new_strides[i] = (-strides[i]) as usize / itemsize;
862+
inverted_axes |= 1 << i;
863+
}
864+
}
865+
866+
(shape.strides(new_strides), inverted_axes, data_ptr)
867+
}
868+
869+
let (shape, mut inverted_axes, data_ptr) = inner(
870+
self.shape(),
871+
self.strides(),
872+
mem::size_of::<T>(),
873+
self.data() as _,
874+
);
875+
876+
let mut array = from_shape_ptr(shape, data_ptr as _);
877+
878+
while inverted_axes != 0 {
879+
let axis = inverted_axes.trailing_zeros() as usize;
880+
inverted_axes &= !(1 << axis);
881+
882+
array.invert_axis(Axis(axis));
883+
}
884+
885+
array
886+
}
887+
886888
/// Returns the internal array as [`ArrayView`].
887889
///
888890
/// See also [`PyReadonlyArray::as_array`].
@@ -891,10 +893,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
891893
///
892894
/// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior.
893895
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
894-
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
895-
let mut res = ArrayView::from_shape_ptr(shape, ptr);
896-
inverted_axes.invert(&mut res);
897-
res
896+
self.as_view(|shape, ptr| ArrayView::from_shape_ptr(shape, ptr))
898897
}
899898

900899
/// Returns the internal array as [`ArrayViewMut`].
@@ -905,26 +904,17 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
905904
///
906905
/// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior.
907906
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
908-
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
909-
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
910-
inverted_axes.invert(&mut res);
911-
res
907+
self.as_view(|shape, ptr| ArrayViewMut::from_shape_ptr(shape, ptr))
912908
}
913909

914910
/// Returns the internal array as [`RawArrayView`] enabling element access via raw pointers
915911
pub fn as_raw_array(&self) -> RawArrayView<T, D> {
916-
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
917-
let mut res = unsafe { RawArrayView::from_shape_ptr(shape, ptr) };
918-
inverted_axes.invert(&mut res);
919-
res
912+
self.as_view(|shape, ptr| unsafe { RawArrayView::from_shape_ptr(shape, ptr) })
920913
}
921914

922915
/// Returns the internal array as [`RawArrayViewMut`] enabling element access via raw pointers
923916
pub fn as_raw_array_mut(&self) -> RawArrayViewMut<T, D> {
924-
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
925-
let mut res = unsafe { RawArrayViewMut::from_shape_ptr(shape, ptr) };
926-
inverted_axes.invert(&mut res);
927-
res
917+
self.as_view(|shape, ptr| unsafe { RawArrayViewMut::from_shape_ptr(shape, ptr) })
928918
}
929919

930920
/// Get a copy of `PyArray` as

src/convert.rs

+24-55
Original file line numberDiff line numberDiff line change
@@ -144,27 +144,20 @@ where
144144
fn to_pyarray<'py>(&self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
145145
let len = self.len();
146146
match self.order() {
147-
Some(order) if A::IS_COPY => {
147+
Some(flag) if A::IS_COPY => {
148148
// if the array is contiguous, copy it by `copy_nonoverlapping`.
149149
let strides = self.npy_strides();
150150
unsafe {
151-
let array =
152-
PyArray::new_(py, self.raw_dim(), strides.as_ptr(), order.to_flag());
151+
let array = PyArray::new_(py, self.raw_dim(), strides.as_ptr(), flag);
153152
ptr::copy_nonoverlapping(self.as_ptr(), array.data(), len);
154153
array
155154
}
156155
}
157156
_ => {
158157
// if the array is not contiguous, copy all elements by `ArrayBase::iter`.
159158
let dim = self.raw_dim();
160-
let strides = NpyStrides::new::<_, A>(
161-
dim.default_strides()
162-
.slice()
163-
.iter()
164-
.map(|&x| x as npyffi::npy_intp),
165-
);
166159
unsafe {
167-
let array = PyArray::<A, _>::new_(py, dim, strides.as_ptr(), 0);
160+
let array = PyArray::<A, _>::new(py, dim, false);
168161
let mut data_ptr = array.data();
169162
for item in self.iter() {
170163
data_ptr.write(item.clone());
@@ -177,69 +170,45 @@ where
177170
}
178171
}
179172

180-
pub(crate) enum Order {
181-
Standard,
182-
Fortran,
183-
}
184-
185-
impl Order {
186-
fn to_flag(&self) -> c_int {
187-
match self {
188-
Order::Standard => 0,
189-
Order::Fortran => 1,
190-
}
191-
}
192-
}
193-
194173
pub(crate) trait ArrayExt {
195-
fn npy_strides(&self) -> NpyStrides;
196-
fn order(&self) -> Option<Order>;
174+
fn npy_strides(&self) -> [npyffi::npy_intp; 32];
175+
fn order(&self) -> Option<c_int>;
197176
}
198177

199178
impl<A, S, D> ArrayExt for ArrayBase<S, D>
200179
where
201180
S: Data<Elem = A>,
202181
D: Dimension,
203182
{
204-
fn npy_strides(&self) -> NpyStrides {
205-
NpyStrides::new::<_, A>(self.strides().iter().map(|&x| x as npyffi::npy_intp))
183+
fn npy_strides(&self) -> [npyffi::npy_intp; 32] {
184+
let strides = self.strides();
185+
let itemsize = mem::size_of::<A>() as isize;
186+
187+
assert!(
188+
strides.len() <= 32,
189+
"Only dimensionalities of up to 32 are supported"
190+
);
191+
192+
let mut new_strides = [0; 32];
193+
194+
for i in 0..strides.len() {
195+
new_strides[i] = (strides[i] * itemsize) as npyffi::npy_intp;
196+
}
197+
198+
new_strides
206199
}
207200

208-
fn order(&self) -> Option<Order> {
201+
fn order(&self) -> Option<c_int> {
209202
if self.is_standard_layout() {
210-
Some(Order::Standard)
203+
Some(npyffi::NPY_ORDER::NPY_CORDER as _)
211204
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
212-
Some(Order::Fortran)
205+
Some(npyffi::NPY_ORDER::NPY_FORTRANORDER as _)
213206
} else {
214207
None
215208
}
216209
}
217210
}
218211

219-
/// An array of strides sufficiently large for [any NumPy array][NPY_MAXDIMS]
220-
///
221-
/// [NPY_MAXDIMS]: https://github.com/numpy/numpy/blob/4c60b3263ac50e5e72f6a909e156314fc3c9cba0/numpy/core/include/numpy/ndarraytypes.h#L40
222-
pub(crate) struct NpyStrides([npyffi::npy_intp; 32]);
223-
224-
impl NpyStrides {
225-
pub(crate) fn as_ptr(&self) -> *const npy_intp {
226-
self.0.as_ptr()
227-
}
228-
229-
fn new<S, A>(strides: S) -> Self
230-
where
231-
S: Iterator<Item = npyffi::npy_intp>,
232-
{
233-
let type_size = mem::size_of::<A>() as npyffi::npy_intp;
234-
let mut res = [0; 32];
235-
for (i, s) in strides.enumerate() {
236-
*res.get_mut(i)
237-
.expect("Only dimensionalities of up to 32 are supported") = s * type_size;
238-
}
239-
Self(res)
240-
}
241-
}
242-
243212
/// Utility trait to specify the dimensions of an array.
244213
pub trait ToNpyDims: Dimension + Sealed {
245214
#[doc(hidden)]

tests/array.rs

+5
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ fn rank_zero_array_has_invalid_strides_dimensions() {
9090
assert_eq!(arr.ndim(), 0);
9191
assert_eq!(arr.strides(), &[]);
9292
assert_eq!(arr.shape(), &[]);
93+
94+
assert_eq!(arr.len(), 1);
95+
assert!(!arr.is_empty());
96+
97+
assert_eq!(arr.item(), 0.0);
9398
})
9499
}
95100

0 commit comments

Comments
 (0)