Skip to content

Small buffer optimization for InvertedAxes #275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 15, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 42 additions & 26 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,25 @@ impl<T, D> PyArray<T, D> {
}
}

struct InvertedAxises(Vec<Axis>);
struct InvertedAxes(u32);

impl InvertedAxises {
fn invert<S: RawData, D: Dimension>(self, array: &mut ArrayBase<S, D>) {
for axis in self.0 {
array.invert_axis(axis);
impl InvertedAxes {
fn new(len: usize) -> Self {
assert!(len <= 32, "Only dimensionalities of up to 32 are supported");
Self(0)
}

fn push(&mut self, axis: usize) {
debug_assert!(axis < 32);
self.0 |= 1 << axis;
}

fn invert<S: RawData, D: Dimension>(mut self, array: &mut ArrayBase<S, D>) {
while self.0 != 0 {
let axis = self.0.trailing_zeros() as usize;
self.0 &= !(1 << axis);

array.invert_axis(Axis(axis));
}
}
}
Expand All @@ -372,36 +385,39 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
/// Same as [shape](#method.shape), but returns `D`
#[inline(always)]
pub fn dims(&self) -> D {
D::from_dimension(&Dim(self.shape())).expect("PyArray::dims different dimension")
D::from_dimension(&Dim(self.shape())).expect("mismatching dimensions")
}

fn ndarray_shape_ptr(&self) -> (StrideShape<D>, *mut T, InvertedAxises) {
let shape_slice = self.shape();
let shape: Shape<_> = Dim(self.dims()).into();
let sizeof_t = mem::size_of::<T>();
fn ndarray_shape_ptr(&self) -> (StrideShape<D>, *mut T, InvertedAxes) {
let shape = self.shape();
let strides = self.strides();

let mut new_strides = D::zeros(strides.len());
let mut data_ptr = unsafe { self.data() };
let mut inverted_axises = vec![];
let mut inverted_axes = InvertedAxes::new(strides.len());

for i in 0..strides.len() {
// TODO(kngwyu): Replace this hacky negative strides support with
// a proper constructor, when it's implemented.
// See https://github.com/rust-ndarray/ndarray/issues/842 for more.
if strides[i] < 0 {
// Move the pointer to the start position
let offset = strides[i] * (shape_slice[i] as isize - 1) / sizeof_t as isize;
let offset = strides[i] * (shape[i] as isize - 1) / mem::size_of::<T>() as isize;
unsafe {
data_ptr = data_ptr.offset(offset);
}
new_strides[i] = (-strides[i]) as usize / sizeof_t;
inverted_axises.push(Axis(i));
new_strides[i] = (-strides[i]) as usize / mem::size_of::<T>();

inverted_axes.push(i);
} else {
new_strides[i] = strides[i] as usize / sizeof_t;
new_strides[i] = strides[i] as usize / mem::size_of::<T>();
}
}
let st = D::from_dimension(&Dim(new_strides))
.expect("PyArray::ndarray_shape: dimension mismatching");
(shape.strides(st), data_ptr, InvertedAxises(inverted_axises))

let shape = Shape::from(D::from_dimension(&Dim(shape)).expect("mismatching dimensions"));
let new_strides = D::from_dimension(&Dim(new_strides)).expect("mismatching dimensions");

(shape.strides(new_strides), data_ptr, inverted_axes)
}

/// Creates a new uninitialized PyArray in python heap.
Expand Down Expand Up @@ -818,9 +834,9 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
/// If the internal array is not readonly and can be mutated from Python code,
/// holding the `ArrayView` might cause undefined behavior.
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
let (shape, ptr, inverted_axises) = self.ndarray_shape_ptr();
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = ArrayView::from_shape_ptr(shape, ptr);
inverted_axises.invert(&mut res);
inverted_axes.invert(&mut res);
res
}

Expand All @@ -830,25 +846,25 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
/// it might cause undefined behavior.
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
let (shape, ptr, inverted_axises) = self.ndarray_shape_ptr();
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
inverted_axises.invert(&mut res);
inverted_axes.invert(&mut res);
res
}

/// Returns the internal array as [`RawArrayView`] enabling element access via raw pointers
pub fn as_raw_array(&self) -> RawArrayView<T, D> {
let (shape, ptr, inverted_axises) = self.ndarray_shape_ptr();
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = unsafe { RawArrayView::from_shape_ptr(shape, ptr) };
inverted_axises.invert(&mut res);
inverted_axes.invert(&mut res);
res
}

/// Returns the internal array as [`RawArrayViewMut`] enabling element access via raw pointers
pub fn as_raw_array_mut(&self) -> RawArrayViewMut<T, D> {
let (shape, ptr, inverted_axises) = self.ndarray_shape_ptr();
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = unsafe { RawArrayViewMut::from_shape_ptr(shape, ptr) };
inverted_axises.invert(&mut res);
inverted_axes.invert(&mut res);
res
}

Expand Down