Skip to content

Commit 2d4ceab

Browse files
committed
Add Slice type and use it for .slice_axis*()
1 parent 8413825 commit 2d4ceab

File tree

4 files changed

+99
-41
lines changed

4 files changed

+99
-41
lines changed

src/impl_methods.rs

+10-27
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ use zip::Zip;
3131

3232
use {
3333
NdIndex,
34+
Slice,
3435
SliceInfo,
3536
SliceOrIndex
3637
};
@@ -307,7 +308,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
307308
.enumerate()
308309
.for_each(|(axis, slice_or_index)| match slice_or_index {
309310
&SliceOrIndex::Slice(start, end, step) => {
310-
self.slice_axis_inplace(Axis(axis), start, end, step)
311+
self.slice_axis_inplace(Axis(axis), Slice(start, end, step))
311312
}
312313
&SliceOrIndex::Index(index) => {
313314
let i_usize = abs_index(self.len_of(Axis(axis)), index);
@@ -320,54 +321,36 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
320321
///
321322
/// **Panics** if an index is out of bounds or step size is zero.<br>
322323
/// **Panics** if `axis` is out of bounds.
323-
pub fn slice_axis(
324-
&self,
325-
axis: Axis,
326-
start: Ixs,
327-
end: Option<Ixs>,
328-
step: Ixs,
329-
) -> ArrayView<A, D> {
324+
pub fn slice_axis(&self, axis: Axis, indices: Slice) -> ArrayView<A, D> {
330325
let mut view = self.view();
331-
view.slice_axis_inplace(axis, start, end, step);
326+
view.slice_axis_inplace(axis, indices);
332327
view
333328
}
334329

335330
/// Return a mutable view of the array, sliced along the specified axis.
336331
///
337332
/// **Panics** if an index is out of bounds or step size is zero.<br>
338333
/// **Panics** if `axis` is out of bounds.
339-
pub fn slice_axis_mut(
340-
&mut self,
341-
axis: Axis,
342-
start: Ixs,
343-
end: Option<Ixs>,
344-
step: Ixs,
345-
) -> ArrayViewMut<A, D>
334+
pub fn slice_axis_mut(&mut self, axis: Axis, indices: Slice) -> ArrayViewMut<A, D>
346335
where
347336
S: DataMut,
348337
{
349338
let mut view_mut = self.view_mut();
350-
view_mut.slice_axis_inplace(axis, start, end, step);
339+
view_mut.slice_axis_inplace(axis, indices);
351340
view_mut
352341
}
353342

354343
/// Slice the array in place along the specified axis.
355344
///
356345
/// **Panics** if an index is out of bounds or step size is zero.<br>
357346
/// **Panics** if `axis` is out of bounds.
358-
pub fn slice_axis_inplace(
359-
&mut self,
360-
axis: Axis,
361-
start: Ixs,
362-
end: Option<Ixs>,
363-
step: Ixs,
364-
) {
347+
pub fn slice_axis_inplace(&mut self, axis: Axis, indices: Slice) {
365348
let offset = do_slice(
366349
&mut self.dim.slice_mut()[axis.index()],
367350
&mut self.strides.slice_mut()[axis.index()],
368-
start,
369-
end,
370-
step,
351+
indices.0,
352+
indices.1,
353+
indices.2,
371354
);
372355
unsafe {
373356
self.ptr = self.ptr.offset(offset);

src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ pub use dimension::NdIndex;
105105
pub use dimension::IxDynImpl;
106106
pub use indexes::{indices, indices_of};
107107
pub use error::{ShapeError, ErrorKind};
108-
pub use slice::{SliceInfo, SliceNextDim, SliceOrIndex};
108+
pub use slice::{Slice, SliceInfo, SliceNextDim, SliceOrIndex};
109109

110110
use iterators::Baseiter;
111111
use iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut};

src/slice.rs

+82-7
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,70 @@ use std::fmt;
1010
use std::marker::PhantomData;
1111
use super::{Dimension, Ixs};
1212

13+
/// A slice (range with step size).
14+
///
15+
/// ## Examples
16+
///
17+
/// `Slice(0, None, 1)` is the full range of an axis. It can also be created
18+
/// with `Slice::from(..)`. The Python equivalent is `[:]`.
19+
///
20+
/// `Slice(a, Some(b), 2)` is every second element from `a` until `b`. It can
21+
/// also be created with `Slice::from(a..b).step(2)`. The Python equivalent is
22+
/// `[a:b:2]`.
23+
///
24+
/// `Slice(a, None, -1)` is every element, from `a` until the end, in reverse
25+
/// order. It can also be created with `Slice::from(a..).step(-1)`. The Python
26+
/// equivalent is `[a::-1]`.
27+
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
28+
pub struct Slice(pub Ixs, pub Option<Ixs>, pub Ixs);
29+
30+
impl Slice {
31+
/// Returns a new `Slice` with the given step size.
32+
#[inline]
33+
pub fn step(self, step: Ixs) -> Self {
34+
Slice(self.0, self.1, step)
35+
}
36+
}
37+
38+
impl From<Range<Ixs>> for Slice {
39+
#[inline]
40+
fn from(r: Range<Ixs>) -> Slice {
41+
Slice(r.start, Some(r.end), 1)
42+
}
43+
}
44+
45+
impl From<RangeFrom<Ixs>> for Slice {
46+
#[inline]
47+
fn from(r: RangeFrom<Ixs>) -> Slice {
48+
Slice(r.start, None, 1)
49+
}
50+
}
51+
52+
impl From<RangeTo<Ixs>> for Slice {
53+
#[inline]
54+
fn from(r: RangeTo<Ixs>) -> Slice {
55+
Slice(0, Some(r.end), 1)
56+
}
57+
}
58+
59+
impl From<RangeFull> for Slice {
60+
#[inline]
61+
fn from(_: RangeFull) -> Slice {
62+
Slice(0, None, 1)
63+
}
64+
}
65+
1366
/// A slice (range with step) or an index.
1467
///
1568
/// See also the [`s![]`](macro.s!.html) macro for a convenient way to create a
1669
/// `&SliceInfo<[SliceOrIndex; n], D>`.
1770
///
1871
/// ## Examples
1972
///
73+
/// `SliceOrIndex::Index(a)` is the index `a`. It can also be created with
74+
/// `SliceOrIndex::from(a)`. The Python equivalent is `[a]`. The macro
75+
/// equivalent is `s![a]`.
76+
///
2077
/// `SliceOrIndex::Slice(0, None, 1)` is the full range of an axis. It can also
2178
/// be created with `SliceOrIndex::from(..)`. The Python equivalent is `[:]`.
2279
/// The macro equivalent is `s![..]`.
@@ -89,6 +146,13 @@ impl fmt::Display for SliceOrIndex {
89146
}
90147
}
91148

149+
impl From<Slice> for SliceOrIndex {
150+
#[inline]
151+
fn from(s: Slice) -> SliceOrIndex {
152+
SliceOrIndex::Slice(s.0, s.1, s.2)
153+
}
154+
}
155+
92156
impl From<Range<Ixs>> for SliceOrIndex {
93157
#[inline]
94158
fn from(r: Range<Ixs>) -> SliceOrIndex {
@@ -261,6 +325,12 @@ pub trait SliceNextDim<D1, D2> {
261325
fn next_dim(&self, PhantomData<D1>) -> PhantomData<D2>;
262326
}
263327

328+
impl<D1: Dimension> SliceNextDim<D1, D1::Larger> for Slice {
329+
fn next_dim(&self, _: PhantomData<D1>) -> PhantomData<D1::Larger> {
330+
PhantomData
331+
}
332+
}
333+
264334
impl<D1: Dimension> SliceNextDim<D1, D1::Larger> for Range<Ixs> {
265335
fn next_dim(&self, _: PhantomData<D1>) -> PhantomData<D1::Larger> {
266336
PhantomData
@@ -293,26 +363,31 @@ impl<D1: Dimension> SliceNextDim<D1, D1> for Ixs {
293363

294364
/// Slice argument constructor.
295365
///
296-
/// `s![]` takes a list of ranges/indices, separated by comma, with optional
297-
/// step sizes that are separated from the range by a semicolon. It is
366+
/// `s![]` takes a list of ranges/slices/indices, separated by comma, with
367+
/// optional step sizes that are separated from the range by a semicolon. It is
298368
/// converted into a [`&SliceInfo`] instance.
299369
///
300370
/// [`&SliceInfo`]: struct.SliceInfo.html
301371
///
302-
/// Each range/index uses signed indices, where a negative value is counted
303-
/// from the end of the axis. Step sizes are also signed and may be negative,
304-
/// but must not be zero.
372+
/// Each range/slice/index uses signed indices, where a negative value is
373+
/// counted from the end of the axis. Step sizes are also signed and may be
374+
/// negative, but must not be zero.
305375
///
306376
/// The syntax is `s![` *[ axis-slice-or-index [, axis-slice-or-index [ , ... ]
307377
/// ] ]* `]`, where *axis-slice-or-index* is any of the following:
308378
///
309379
/// * *index*: an index to use for taking a subview with respect to that axis
310380
/// * *range*: a range with step size 1 to use for slicing that axis
311381
/// * *range* `;` *step*: a range with step size *step* to use for slicing that axis
382+
/// * *slice*: a [`Slice`] instance to use for slicing that axis
383+
/// * *slice* `;` *step*: a range constructed from the start and end of a [`Slice`]
384+
/// instance, with new step size *step*, to use for slicing that axis
385+
///
386+
/// [`Slice`]: struct.Slice.html
312387
///
313388
/// The number of *axis-slice-or-index* must match the number of axes in the
314-
/// array. *index*, *range*, and *step* can be expressions. *index* and *step*
315-
/// must be of type [`Ixs`]. *range* can be of type `Range<Ixs>`,
389+
/// array. *index*, *range*, *slice*, and *step* can be expressions. *index*
390+
/// and *step* must be of type [`Ixs`]. *range* can be of type `Range<Ixs>`,
316391
/// `RangeTo<Ixs>`, `RangeFrom<Ixs>`, or `RangeFull`.
317392
///
318393
/// [`Ixs`]: type.Ixs.html

tests/array.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ extern crate ndarray;
66
extern crate defmac;
77
extern crate itertools;
88

9-
use ndarray::{SliceInfo, SliceOrIndex};
9+
use ndarray::{Slice, SliceInfo, SliceOrIndex};
1010
use ndarray::prelude::*;
1111
use ndarray::{
1212
rcarr2,
@@ -55,14 +55,14 @@ fn test_mat_mul() {
5555
#[test]
5656
fn test_slice()
5757
{
58-
let mut A = RcArray::<usize, _>::zeros((3, 4));
58+
let mut A = RcArray::<usize, _>::zeros((3, 4, 5));
5959
for (i, elt) in A.iter_mut().enumerate() {
6060
*elt = i;
6161
}
6262

63-
let vi = A.slice(s![1.., ..;2]);
64-
assert_eq!(vi.shape(), &[2, 2]);
65-
let vi = A.slice(s![.., ..]);
63+
let vi = A.slice(s![1.., ..;2, Slice(0, None, 2)]);
64+
assert_eq!(vi.shape(), &[2, 2, 3]);
65+
let vi = A.slice(s![.., .., ..]);
6666
assert_eq!(vi.shape(), A.shape());
6767
assert!(vi.iter().zip(A.iter()).all(|(a, b)| a == b));
6868
}
@@ -252,7 +252,7 @@ fn slice_oob()
252252
#[test]
253253
fn slice_axis_oob() {
254254
let a = RcArray::<i32, _>::zeros((3, 4));
255-
let _vi = a.slice_axis(Axis(0), 0, Some(10), 1);
255+
let _vi = a.slice_axis(Axis(0), Slice(0, Some(10), 1));
256256
}
257257

258258
#[should_panic]

0 commit comments

Comments
 (0)