Skip to content

Collect scattered code for converting strides #305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 67 additions & 77 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{

use ndarray::{
Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, Dim, Dimension, IntoDimension, Ix0, Ix1,
Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, RawArrayView, RawArrayViewMut, RawData, Shape, ShapeBuilder,
Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, RawArrayView, RawArrayViewMut, RawData, ShapeBuilder,
StrideShape,
};
use num_traits::AsPrimitive;
Expand Down Expand Up @@ -338,42 +338,19 @@ impl<T, D> PyArray<T, D> {
}
}

/// Calcurates the total number of elements in the array.
/// Calculates the total number of elements in the array.
pub fn len(&self) -> usize {
self.shape().iter().product()
}

/// Returns `true` if the there are no elements in the array.
pub fn is_empty(&self) -> bool {
self.len() == 0
self.shape().iter().any(|dim| *dim == 0)
}

/// Returns the pointer to the first element of the inner array.
/// Returns the pointer to the first element of the array.
pub(crate) fn data(&self) -> *mut T {
let ptr = self.as_array_ptr();
unsafe { (*ptr).data as *mut _ }
}
}

struct InvertedAxes(u32);

impl InvertedAxes {
fn new(len: usize) -> Self {
assert!(len <= 32, "Only dimensionalities of up to 32 are supported");
Self(0)
}

fn push(&mut self, axis: usize) {
debug_assert!(axis < 32);
self.0 |= 1 << axis;
}

fn invert<S: RawData, D: Dimension>(mut self, array: &mut ArrayBase<S, D>) {
while self.0 != 0 {
let axis = self.0.trailing_zeros() as usize;
self.0 &= !(1 << axis);

array.invert_axis(Axis(axis));
}
unsafe { (*self.as_array_ptr()).data as *mut _ }
}
}

Expand All @@ -384,38 +361,6 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
D::from_dimension(&Dim(self.shape())).expect("mismatching dimensions")
}

fn ndarray_shape_ptr(&self) -> (StrideShape<D>, *mut T, InvertedAxes) {
let shape = self.shape();
let strides = self.strides();

let mut new_strides = D::zeros(strides.len());
let mut data_ptr = self.data();
let mut inverted_axes = InvertedAxes::new(strides.len());

for i in 0..strides.len() {
// FIXME(kngwyu): Replace this hacky negative strides support with
// a proper constructor, when it's implemented.
// See https://github.com/rust-ndarray/ndarray/issues/842 for more.
if strides[i] < 0 {
// Move the pointer to the start position
let offset = strides[i] * (shape[i] as isize - 1) / mem::size_of::<T>() as isize;
unsafe {
data_ptr = data_ptr.offset(offset);
}
new_strides[i] = (-strides[i]) as usize / mem::size_of::<T>();

inverted_axes.push(i);
} else {
new_strides[i] = strides[i] as usize / mem::size_of::<T>();
}
}

let shape = Shape::from(D::from_dimension(&Dim(shape)).expect("mismatching dimensions"));
let new_strides = D::from_dimension(&Dim(new_strides)).expect("mismatching dimensions");

(shape.strides(new_strides), data_ptr, inverted_axes)
}

/// Creates a new uninitialized PyArray in python heap.
///
/// If `is_fortran == true`, returns Fortran-order array. Else, returns C-order array.
Expand Down Expand Up @@ -883,6 +828,63 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
self.try_readwrite().unwrap()
}

fn as_view<S: RawData, F>(&self, from_shape_ptr: F) -> ArrayBase<S, D>
where
F: FnOnce(StrideShape<D>, *mut T) -> ArrayBase<S, D>,
{
fn inner<D: Dimension>(
shape: &[usize],
strides: &[isize],
itemsize: usize,
mut data_ptr: *mut u8,
) -> (StrideShape<D>, u32, *mut u8) {
let shape = D::from_dimension(&Dim(shape)).expect("mismatching dimensions");

assert!(
strides.len() <= 32,
"Only dimensionalities of up to 32 are supported"
);

let mut new_strides = D::zeros(strides.len());
let mut inverted_axes = 0_u32;

for i in 0..strides.len() {
// FIXME(kngwyu): Replace this hacky negative strides support with
// a proper constructor, when it's implemented.
// See https://github.com/rust-ndarray/ndarray/issues/842 for more.
if strides[i] >= 0 {
new_strides[i] = strides[i] as usize / itemsize;
} else {
// Move the pointer to the start position.
data_ptr = unsafe { data_ptr.offset(strides[i] * (shape[i] as isize - 1)) };

new_strides[i] = (-strides[i]) as usize / itemsize;
inverted_axes |= 1 << i;
}
}

(shape.strides(new_strides), inverted_axes, data_ptr)
}

let (shape, mut inverted_axes, data_ptr) = inner(
self.shape(),
self.strides(),
mem::size_of::<T>(),
self.data() as _,
);

let mut array = from_shape_ptr(shape, data_ptr as _);

while inverted_axes != 0 {
let axis = inverted_axes.trailing_zeros() as usize;
inverted_axes &= !(1 << axis);

array.invert_axis(Axis(axis));
}

array
}

/// Returns the internal array as [`ArrayView`].
///
/// See also [`PyReadonlyArray::as_array`].
Expand All @@ -891,10 +893,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
///
/// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior.
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = ArrayView::from_shape_ptr(shape, ptr);
inverted_axes.invert(&mut res);
res
self.as_view(|shape, ptr| ArrayView::from_shape_ptr(shape, ptr))
}

/// Returns the internal array as [`ArrayViewMut`].
Expand All @@ -905,26 +904,17 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
///
/// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior.
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
inverted_axes.invert(&mut res);
res
self.as_view(|shape, ptr| ArrayViewMut::from_shape_ptr(shape, ptr))
}

/// Returns the internal array as [`RawArrayView`] enabling element access via raw pointers
pub fn as_raw_array(&self) -> RawArrayView<T, D> {
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = unsafe { RawArrayView::from_shape_ptr(shape, ptr) };
inverted_axes.invert(&mut res);
res
self.as_view(|shape, ptr| unsafe { RawArrayView::from_shape_ptr(shape, ptr) })
}

/// Returns the internal array as [`RawArrayViewMut`] enabling element access via raw pointers
pub fn as_raw_array_mut(&self) -> RawArrayViewMut<T, D> {
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = unsafe { RawArrayViewMut::from_shape_ptr(shape, ptr) };
inverted_axes.invert(&mut res);
res
self.as_view(|shape, ptr| unsafe { RawArrayViewMut::from_shape_ptr(shape, ptr) })
}

/// Get a copy of `PyArray` as
Expand Down
79 changes: 24 additions & 55 deletions src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,27 +144,20 @@ where
fn to_pyarray<'py>(&self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
let len = self.len();
match self.order() {
Some(order) if A::IS_COPY => {
Some(flag) if A::IS_COPY => {
// if the array is contiguous, copy it by `copy_nonoverlapping`.
let strides = self.npy_strides();
unsafe {
let array =
PyArray::new_(py, self.raw_dim(), strides.as_ptr(), order.to_flag());
let array = PyArray::new_(py, self.raw_dim(), strides.as_ptr(), flag);
ptr::copy_nonoverlapping(self.as_ptr(), array.data(), len);
array
}
}
_ => {
// if the array is not contiguous, copy all elements by `ArrayBase::iter`.
let dim = self.raw_dim();
let strides = NpyStrides::new::<_, A>(
dim.default_strides()
.slice()
.iter()
.map(|&x| x as npyffi::npy_intp),
);
unsafe {
let array = PyArray::<A, _>::new_(py, dim, strides.as_ptr(), 0);
let array = PyArray::<A, _>::new(py, dim, false);
let mut data_ptr = array.data();
for item in self.iter() {
data_ptr.write(item.clone());
Expand All @@ -177,69 +170,45 @@ where
}
}

pub(crate) enum Order {
Standard,
Fortran,
}

impl Order {
fn to_flag(&self) -> c_int {
match self {
Order::Standard => 0,
Order::Fortran => 1,
}
}
}

pub(crate) trait ArrayExt {
fn npy_strides(&self) -> NpyStrides;
fn order(&self) -> Option<Order>;
fn npy_strides(&self) -> [npyffi::npy_intp; 32];
fn order(&self) -> Option<c_int>;
}

impl<A, S, D> ArrayExt for ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
fn npy_strides(&self) -> NpyStrides {
NpyStrides::new::<_, A>(self.strides().iter().map(|&x| x as npyffi::npy_intp))
fn npy_strides(&self) -> [npyffi::npy_intp; 32] {
let strides = self.strides();
let itemsize = mem::size_of::<A>() as isize;

assert!(
strides.len() <= 32,
"Only dimensionalities of up to 32 are supported"
);

let mut new_strides = [0; 32];

for i in 0..strides.len() {
new_strides[i] = (strides[i] * itemsize) as npyffi::npy_intp;
}

new_strides
}

fn order(&self) -> Option<Order> {
fn order(&self) -> Option<c_int> {
if self.is_standard_layout() {
Some(Order::Standard)
Some(npyffi::NPY_ORDER::NPY_CORDER as _)
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
Some(Order::Fortran)
Some(npyffi::NPY_ORDER::NPY_FORTRANORDER as _)
} else {
None
}
}
}

/// An array of strides sufficiently large for [any NumPy array][NPY_MAXDIMS]
///
/// [NPY_MAXDIMS]: https://github.com/numpy/numpy/blob/4c60b3263ac50e5e72f6a909e156314fc3c9cba0/numpy/core/include/numpy/ndarraytypes.h#L40
pub(crate) struct NpyStrides([npyffi::npy_intp; 32]);

impl NpyStrides {
pub(crate) fn as_ptr(&self) -> *const npy_intp {
self.0.as_ptr()
}

fn new<S, A>(strides: S) -> Self
where
S: Iterator<Item = npyffi::npy_intp>,
{
let type_size = mem::size_of::<A>() as npyffi::npy_intp;
let mut res = [0; 32];
for (i, s) in strides.enumerate() {
*res.get_mut(i)
.expect("Only dimensionalities of up to 32 are supported") = s * type_size;
}
Self(res)
}
}

/// Utility trait to specify the dimensions of an array.
pub trait ToNpyDims: Dimension + Sealed {
#[doc(hidden)]
Expand Down
5 changes: 5 additions & 0 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ fn rank_zero_array_has_invalid_strides_dimensions() {
assert_eq!(arr.ndim(), 0);
assert_eq!(arr.strides(), &[]);
assert_eq!(arr.shape(), &[]);

assert_eq!(arr.len(), 1);
assert!(!arr.is_empty());

assert_eq!(arr.item(), 0.0);
})
}

Expand Down