Skip to content

Commit 61882e3

Browse files
authored
Merge pull request #275 from PyO3/sbo-inverted-axes
Small buffer optimization for InvertedAxes
2 parents cdeaca7 + 35e0bf2 commit 61882e3

File tree

1 file changed

+42
-26
lines changed

1 file changed

+42
-26
lines changed

src/array.rs

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -358,12 +358,25 @@ impl<T, D> PyArray<T, D> {
358358
}
359359
}
360360

361-
struct InvertedAxises(Vec<Axis>);
361+
struct InvertedAxes(u32);
362362

363-
impl InvertedAxises {
364-
fn invert<S: RawData, D: Dimension>(self, array: &mut ArrayBase<S, D>) {
365-
for axis in self.0 {
366-
array.invert_axis(axis);
363+
impl InvertedAxes {
364+
fn new(len: usize) -> Self {
365+
assert!(len <= 32, "Only dimensionalities of up to 32 are supported");
366+
Self(0)
367+
}
368+
369+
fn push(&mut self, axis: usize) {
370+
debug_assert!(axis < 32);
371+
self.0 |= 1 << axis;
372+
}
373+
374+
fn invert<S: RawData, D: Dimension>(mut self, array: &mut ArrayBase<S, D>) {
375+
while self.0 != 0 {
376+
let axis = self.0.trailing_zeros() as usize;
377+
self.0 &= !(1 << axis);
378+
379+
array.invert_axis(Axis(axis));
367380
}
368381
}
369382
}
@@ -372,36 +385,39 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
372385
/// Same as [shape](#method.shape), but returns `D`
373386
#[inline(always)]
374387
pub fn dims(&self) -> D {
375-
D::from_dimension(&Dim(self.shape())).expect("PyArray::dims different dimension")
388+
D::from_dimension(&Dim(self.shape())).expect("mismatching dimensions")
376389
}
377390

378-
fn ndarray_shape_ptr(&self) -> (StrideShape<D>, *mut T, InvertedAxises) {
379-
let shape_slice = self.shape();
380-
let shape: Shape<_> = Dim(self.dims()).into();
381-
let sizeof_t = mem::size_of::<T>();
391+
fn ndarray_shape_ptr(&self) -> (StrideShape<D>, *mut T, InvertedAxes) {
392+
let shape = self.shape();
382393
let strides = self.strides();
394+
383395
let mut new_strides = D::zeros(strides.len());
384396
let mut data_ptr = unsafe { self.data() };
385-
let mut inverted_axises = vec![];
397+
let mut inverted_axes = InvertedAxes::new(strides.len());
398+
386399
for i in 0..strides.len() {
387400
// TODO(kngwyu): Replace this hacky negative strides support with
388401
// a proper constructor, when it's implemented.
389402
// See https://github.com/rust-ndarray/ndarray/issues/842 for more.
390403
if strides[i] < 0 {
391404
// Move the pointer to the start position
392-
let offset = strides[i] * (shape_slice[i] as isize - 1) / sizeof_t as isize;
405+
let offset = strides[i] * (shape[i] as isize - 1) / mem::size_of::<T>() as isize;
393406
unsafe {
394407
data_ptr = data_ptr.offset(offset);
395408
}
396-
new_strides[i] = (-strides[i]) as usize / sizeof_t;
397-
inverted_axises.push(Axis(i));
409+
new_strides[i] = (-strides[i]) as usize / mem::size_of::<T>();
410+
411+
inverted_axes.push(i);
398412
} else {
399-
new_strides[i] = strides[i] as usize / sizeof_t;
413+
new_strides[i] = strides[i] as usize / mem::size_of::<T>();
400414
}
401415
}
402-
let st = D::from_dimension(&Dim(new_strides))
403-
.expect("PyArray::ndarray_shape: dimension mismatching");
404-
(shape.strides(st), data_ptr, InvertedAxises(inverted_axises))
416+
417+
let shape = Shape::from(D::from_dimension(&Dim(shape)).expect("mismatching dimensions"));
418+
let new_strides = D::from_dimension(&Dim(new_strides)).expect("mismatching dimensions");
419+
420+
(shape.strides(new_strides), data_ptr, inverted_axes)
405421
}
406422

407423
/// Creates a new uninitialized PyArray in python heap.
@@ -818,9 +834,9 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
818834
/// If the internal array is not readonly and can be mutated from Python code,
819835
/// holding the `ArrayView` might cause undefined behavior.
820836
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
821-
let (shape, ptr, inverted_axises) = self.ndarray_shape_ptr();
837+
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
822838
let mut res = ArrayView::from_shape_ptr(shape, ptr);
823-
inverted_axises.invert(&mut res);
839+
inverted_axes.invert(&mut res);
824840
res
825841
}
826842

@@ -830,25 +846,25 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
830846
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
831847
/// it might cause undefined behavior.
832848
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
833-
let (shape, ptr, inverted_axises) = self.ndarray_shape_ptr();
849+
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
834850
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
835-
inverted_axises.invert(&mut res);
851+
inverted_axes.invert(&mut res);
836852
res
837853
}
838854

839855
/// Returns the internal array as [`RawArrayView`] enabling element access via raw pointers
840856
pub fn as_raw_array(&self) -> RawArrayView<T, D> {
841-
let (shape, ptr, inverted_axises) = self.ndarray_shape_ptr();
857+
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
842858
let mut res = unsafe { RawArrayView::from_shape_ptr(shape, ptr) };
843-
inverted_axises.invert(&mut res);
859+
inverted_axes.invert(&mut res);
844860
res
845861
}
846862

847863
/// Returns the internal array as [`RawArrayViewMut`] enabling element access via raw pointers
848864
pub fn as_raw_array_mut(&self) -> RawArrayViewMut<T, D> {
849-
let (shape, ptr, inverted_axises) = self.ndarray_shape_ptr();
865+
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
850866
let mut res = unsafe { RawArrayViewMut::from_shape_ptr(shape, ptr) };
851-
inverted_axises.invert(&mut res);
867+
inverted_axes.invert(&mut res);
852868
res
853869
}
854870

0 commit comments

Comments
 (0)