Skip to content

Commit 229cad7

Browse files
committed
Collect code for create ndarray views in a single place to make it easier to follow.
1 parent 25ef24d commit 229cad7

File tree

2 files changed

+62
-77
lines changed

2 files changed

+62
-77
lines changed

src/array.rs

Lines changed: 57 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::{
99

1010
use ndarray::{
1111
Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, Dim, Dimension, IntoDimension, Ix0, Ix1,
12-
Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, RawArrayView, RawArrayViewMut, RawData, Shape, ShapeBuilder,
12+
Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, RawArrayView, RawArrayViewMut, RawData, ShapeBuilder,
1313
StrideShape,
1414
};
1515
use num_traits::AsPrimitive;
@@ -338,42 +338,19 @@ impl<T, D> PyArray<T, D> {
338338
}
339339
}
340340

341-
/// Calcurates the total number of elements in the array.
341+
/// Calculates the total number of elements in the array.
342342
pub fn len(&self) -> usize {
343343
self.shape().iter().product()
344344
}
345345

346+
/// Returns `true` if the there are no elements in the array.
346347
pub fn is_empty(&self) -> bool {
347-
self.len() == 0
348+
self.shape().iter().any(|dim| *dim == 0)
348349
}
349350

350-
/// Returns the pointer to the first element of the inner array.
351+
/// Returns the pointer to the first element of the array.
351352
pub(crate) fn data(&self) -> *mut T {
352-
let ptr = self.as_array_ptr();
353-
unsafe { (*ptr).data as *mut _ }
354-
}
355-
}
356-
357-
struct InvertedAxes(u32);
358-
359-
impl InvertedAxes {
360-
fn new(len: usize) -> Self {
361-
assert!(len <= 32, "Only dimensionalities of up to 32 are supported");
362-
Self(0)
363-
}
364-
365-
fn push(&mut self, axis: usize) {
366-
debug_assert!(axis < 32);
367-
self.0 |= 1 << axis;
368-
}
369-
370-
fn invert<S: RawData, D: Dimension>(mut self, array: &mut ArrayBase<S, D>) {
371-
while self.0 != 0 {
372-
let axis = self.0.trailing_zeros() as usize;
373-
self.0 &= !(1 << axis);
374-
375-
array.invert_axis(Axis(axis));
376-
}
353+
unsafe { (*self.as_array_ptr()).data as *mut _ }
377354
}
378355
}
379356

@@ -384,38 +361,6 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
384361
D::from_dimension(&Dim(self.shape())).expect("mismatching dimensions")
385362
}
386363

387-
fn ndarray_shape_ptr(&self) -> (StrideShape<D>, *mut T, InvertedAxes) {
388-
let shape = self.shape();
389-
let strides = self.strides();
390-
391-
let mut new_strides = D::zeros(strides.len());
392-
let mut data_ptr = self.data();
393-
let mut inverted_axes = InvertedAxes::new(strides.len());
394-
395-
for i in 0..strides.len() {
396-
// FIXME(kngwyu): Replace this hacky negative strides support with
397-
// a proper constructor, when it's implemented.
398-
// See https://github.com/rust-ndarray/ndarray/issues/842 for more.
399-
if strides[i] < 0 {
400-
// Move the pointer to the start position
401-
let offset = strides[i] * (shape[i] as isize - 1) / mem::size_of::<T>() as isize;
402-
unsafe {
403-
data_ptr = data_ptr.offset(offset);
404-
}
405-
new_strides[i] = (-strides[i]) as usize / mem::size_of::<T>();
406-
407-
inverted_axes.push(i);
408-
} else {
409-
new_strides[i] = strides[i] as usize / mem::size_of::<T>();
410-
}
411-
}
412-
413-
let shape = Shape::from(D::from_dimension(&Dim(shape)).expect("mismatching dimensions"));
414-
let new_strides = D::from_dimension(&Dim(new_strides)).expect("mismatching dimensions");
415-
416-
(shape.strides(new_strides), data_ptr, inverted_axes)
417-
}
418-
419364
/// Creates a new uninitialized PyArray in python heap.
420365
///
421366
/// If `is_fortran == true`, returns Fortran-order array. Else, returns C-order array.
@@ -883,6 +828,53 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
883828
self.try_readwrite().unwrap()
884829
}
885830

831+
fn as_view<S: RawData, F>(&self, from_shape_ptr: F) -> ArrayBase<S, D>
832+
where
833+
F: FnOnce(StrideShape<D>, *mut T) -> ArrayBase<S, D>,
834+
{
835+
let shape = D::from_dimension(&Dim(self.shape())).expect("mismatching dimensions");
836+
837+
let strides = self.strides();
838+
let itemsize = mem::size_of::<T>();
839+
840+
assert!(
841+
strides.len() <= 32,
842+
"Only dimensionalities of up to 32 are supported"
843+
);
844+
845+
let mut new_strides = D::zeros(strides.len());
846+
let mut data_ptr = self.data();
847+
let mut inverted_axes = 0_u32;
848+
849+
for i in 0..strides.len() {
850+
// FIXME(kngwyu): Replace this hacky negative strides support with
851+
// a proper constructor, when it's implemented.
852+
// See https://github.com/rust-ndarray/ndarray/issues/842 for more.
853+
if strides[i] >= 0 {
854+
new_strides[i] = strides[i] as usize / itemsize;
855+
} else {
856+
// Move the pointer to the start position.
857+
let offset = strides[i] * (shape[i] as isize - 1) / itemsize as isize;
858+
data_ptr = unsafe { data_ptr.offset(offset) };
859+
860+
new_strides[i] = (-strides[i]) as usize / itemsize;
861+
862+
inverted_axes |= 1 << i;
863+
}
864+
}
865+
866+
let mut array = from_shape_ptr(shape.strides(new_strides), data_ptr);
867+
868+
while inverted_axes != 0 {
869+
let axis = inverted_axes.trailing_zeros() as usize;
870+
inverted_axes &= !(1 << axis);
871+
872+
array.invert_axis(Axis(axis));
873+
}
874+
875+
array
876+
}
877+
886878
/// Returns the internal array as [`ArrayView`].
887879
///
888880
/// See also [`PyReadonlyArray::as_array`].
@@ -891,10 +883,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
891883
///
892884
/// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior.
893885
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
894-
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
895-
let mut res = ArrayView::from_shape_ptr(shape, ptr);
896-
inverted_axes.invert(&mut res);
897-
res
886+
self.as_view(|shape, ptr| ArrayView::from_shape_ptr(shape, ptr))
898887
}
899888

900889
/// Returns the internal array as [`ArrayViewMut`].
@@ -905,26 +894,17 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
905894
///
906895
/// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior.
907896
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
908-
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
909-
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
910-
inverted_axes.invert(&mut res);
911-
res
897+
self.as_view(|shape, ptr| ArrayViewMut::from_shape_ptr(shape, ptr))
912898
}
913899

914900
/// Returns the internal array as [`RawArrayView`] enabling element access via raw pointers
915901
pub fn as_raw_array(&self) -> RawArrayView<T, D> {
916-
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
917-
let mut res = unsafe { RawArrayView::from_shape_ptr(shape, ptr) };
918-
inverted_axes.invert(&mut res);
919-
res
902+
self.as_view(|shape, ptr| unsafe { RawArrayView::from_shape_ptr(shape, ptr) })
920903
}
921904

922905
/// Returns the internal array as [`RawArrayViewMut`] enabling element access via raw pointers
923906
pub fn as_raw_array_mut(&self) -> RawArrayViewMut<T, D> {
924-
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
925-
let mut res = unsafe { RawArrayViewMut::from_shape_ptr(shape, ptr) };
926-
inverted_axes.invert(&mut res);
927-
res
907+
self.as_view(|shape, ptr| unsafe { RawArrayViewMut::from_shape_ptr(shape, ptr) })
928908
}
929909

930910
/// Get a copy of `PyArray` as

tests/array.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ fn rank_zero_array_has_invalid_strides_dimensions() {
9090
assert_eq!(arr.ndim(), 0);
9191
assert_eq!(arr.strides(), &[]);
9292
assert_eq!(arr.shape(), &[]);
93+
94+
assert_eq!(arr.len(), 1);
95+
assert!(!arr.is_empty());
96+
97+
assert_eq!(arr.item(), 0.0);
9398
})
9499
}
95100

0 commit comments

Comments
 (0)