Skip to content

Commit 73aad37

Browse files
committed
Collect code for create ndarray views in a single place to make it easier to follow.
1 parent 25ef24d commit 73aad37

File tree

2 files changed

+64
-77
lines changed

2 files changed

+64
-77
lines changed

src/array.rs

Lines changed: 59 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::{
99

1010
use ndarray::{
1111
Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, Dim, Dimension, IntoDimension, Ix0, Ix1,
12-
Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, RawArrayView, RawArrayViewMut, RawData, Shape, ShapeBuilder,
12+
Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, RawArrayView, RawArrayViewMut, RawData, ShapeBuilder,
1313
StrideShape,
1414
};
1515
use num_traits::AsPrimitive;
@@ -338,42 +338,19 @@ impl<T, D> PyArray<T, D> {
338338
}
339339
}
340340

341-
/// Calcurates the total number of elements in the array.
341+
/// Calculates the total number of elements in the array.
342342
pub fn len(&self) -> usize {
343343
self.shape().iter().product()
344344
}
345345

346+
/// Returns `true` if the there are no elements in the array.
346347
pub fn is_empty(&self) -> bool {
347-
self.len() == 0
348+
self.shape().iter().any(|dim| *dim == 0)
348349
}
349350

350-
/// Returns the pointer to the first element of the inner array.
351+
/// Returns the pointer to the first element of the array.
351352
pub(crate) fn data(&self) -> *mut T {
352-
let ptr = self.as_array_ptr();
353-
unsafe { (*ptr).data as *mut _ }
354-
}
355-
}
356-
357-
struct InvertedAxes(u32);
358-
359-
impl InvertedAxes {
360-
fn new(len: usize) -> Self {
361-
assert!(len <= 32, "Only dimensionalities of up to 32 are supported");
362-
Self(0)
363-
}
364-
365-
fn push(&mut self, axis: usize) {
366-
debug_assert!(axis < 32);
367-
self.0 |= 1 << axis;
368-
}
369-
370-
fn invert<S: RawData, D: Dimension>(mut self, array: &mut ArrayBase<S, D>) {
371-
while self.0 != 0 {
372-
let axis = self.0.trailing_zeros() as usize;
373-
self.0 &= !(1 << axis);
374-
375-
array.invert_axis(Axis(axis));
376-
}
353+
unsafe { (*self.as_array_ptr()).data as *mut _ }
377354
}
378355
}
379356

@@ -384,38 +361,6 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
384361
D::from_dimension(&Dim(self.shape())).expect("mismatching dimensions")
385362
}
386363

387-
fn ndarray_shape_ptr(&self) -> (StrideShape<D>, *mut T, InvertedAxes) {
388-
let shape = self.shape();
389-
let strides = self.strides();
390-
391-
let mut new_strides = D::zeros(strides.len());
392-
let mut data_ptr = self.data();
393-
let mut inverted_axes = InvertedAxes::new(strides.len());
394-
395-
for i in 0..strides.len() {
396-
// FIXME(kngwyu): Replace this hacky negative strides support with
397-
// a proper constructor, when it's implemented.
398-
// See https://github.com/rust-ndarray/ndarray/issues/842 for more.
399-
if strides[i] < 0 {
400-
// Move the pointer to the start position
401-
let offset = strides[i] * (shape[i] as isize - 1) / mem::size_of::<T>() as isize;
402-
unsafe {
403-
data_ptr = data_ptr.offset(offset);
404-
}
405-
new_strides[i] = (-strides[i]) as usize / mem::size_of::<T>();
406-
407-
inverted_axes.push(i);
408-
} else {
409-
new_strides[i] = strides[i] as usize / mem::size_of::<T>();
410-
}
411-
}
412-
413-
let shape = Shape::from(D::from_dimension(&Dim(shape)).expect("mismatching dimensions"));
414-
let new_strides = D::from_dimension(&Dim(new_strides)).expect("mismatching dimensions");
415-
416-
(shape.strides(new_strides), data_ptr, inverted_axes)
417-
}
418-
419364
/// Creates a new uninitialized PyArray in python heap.
420365
///
421366
/// If `is_fortran == true`, returns Fortran-order array. Else, returns C-order array.
@@ -883,6 +828,55 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
883828
self.try_readwrite().unwrap()
884829
}
885830

831+
fn as_view<S: RawData, F>(&self, from_shape_ptr: F) -> ArrayBase<S, D>
832+
where
833+
F: FnOnce(StrideShape<D>, *mut T) -> ArrayBase<S, D>,
834+
{
835+
let shape = self.shape();
836+
let strides = self.strides();
837+
let itemsize = mem::size_of::<T>();
838+
839+
assert!(
840+
strides.len() <= 32,
841+
"Only dimensionalities of up to 32 are supported"
842+
);
843+
844+
let mut new_strides = D::zeros(strides.len());
845+
let mut data_ptr = self.data();
846+
let mut inverted_axes = 0_u32;
847+
848+
for i in 0..strides.len() {
849+
// FIXME(kngwyu): Replace this hacky negative strides support with
850+
// a proper constructor, when it's implemented.
851+
// See https://github.com/rust-ndarray/ndarray/issues/842 for more.
852+
if strides[i] >= 0 {
853+
new_strides[i] = strides[i] as usize / itemsize;
854+
} else {
855+
// Move the pointer to the start position.
856+
let offset = strides[i] * (shape[i] as isize - 1) / itemsize as isize;
857+
data_ptr = unsafe { data_ptr.offset(offset) };
858+
859+
new_strides[i] = (-strides[i]) as usize / itemsize;
860+
861+
inverted_axes |= 1 << i;
862+
}
863+
}
864+
865+
let shape = D::from_dimension(&Dim(shape)).expect("mismatching dimensions");
866+
let new_strides = D::from_dimension(&Dim(new_strides)).expect("mismatching dimensions");
867+
868+
let mut array = from_shape_ptr(shape.strides(new_strides), data_ptr);
869+
870+
while inverted_axes != 0 {
871+
let axis = inverted_axes.trailing_zeros() as usize;
872+
inverted_axes &= !(1 << axis);
873+
874+
array.invert_axis(Axis(axis));
875+
}
876+
877+
array
878+
}
879+
886880
/// Returns the internal array as [`ArrayView`].
887881
///
888882
/// See also [`PyReadonlyArray::as_array`].
@@ -891,10 +885,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
891885
///
892886
/// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior.
893887
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
894-
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
895-
let mut res = ArrayView::from_shape_ptr(shape, ptr);
896-
inverted_axes.invert(&mut res);
897-
res
888+
self.as_view(|shape, ptr| ArrayView::from_shape_ptr(shape, ptr))
898889
}
899890

900891
/// Returns the internal array as [`ArrayViewMut`].
@@ -905,26 +896,17 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
905896
///
906897
/// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior.
907898
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
908-
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
909-
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
910-
inverted_axes.invert(&mut res);
911-
res
899+
self.as_view(|shape, ptr| ArrayViewMut::from_shape_ptr(shape, ptr))
912900
}
913901

914902
/// Returns the internal array as [`RawArrayView`] enabling element access via raw pointers
915903
pub fn as_raw_array(&self) -> RawArrayView<T, D> {
916-
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
917-
let mut res = unsafe { RawArrayView::from_shape_ptr(shape, ptr) };
918-
inverted_axes.invert(&mut res);
919-
res
904+
self.as_view(|shape, ptr| unsafe { RawArrayView::from_shape_ptr(shape, ptr) })
920905
}
921906

922907
/// Returns the internal array as [`RawArrayViewMut`] enabling element access via raw pointers
923908
pub fn as_raw_array_mut(&self) -> RawArrayViewMut<T, D> {
924-
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
925-
let mut res = unsafe { RawArrayViewMut::from_shape_ptr(shape, ptr) };
926-
inverted_axes.invert(&mut res);
927-
res
909+
self.as_view(|shape, ptr| unsafe { RawArrayViewMut::from_shape_ptr(shape, ptr) })
928910
}
929911

930912
/// Get a copy of `PyArray` as

tests/array.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ fn rank_zero_array_has_invalid_strides_dimensions() {
9090
assert_eq!(arr.ndim(), 0);
9191
assert_eq!(arr.strides(), &[]);
9292
assert_eq!(arr.shape(), &[]);
93+
94+
assert_eq!(arr.len(), 1);
95+
assert!(!arr.is_empty());
96+
97+
assert_eq!(arr.item(), 0.0);
9398
})
9499
}
95100

0 commit comments

Comments
 (0)