@@ -134,7 +134,12 @@ where
134
134
_ => {
135
135
// if the array is not contiguous, copy all elements by `ArrayBase::iter`.
136
136
let dim = self . raw_dim ( ) ;
137
- let strides = NpyStrides :: from_dim ( & dim, mem:: size_of :: < A > ( ) ) ;
137
+ let strides = NpyStrides :: new :: < _ , A > (
138
+ dim. default_strides ( )
139
+ . slice ( )
140
+ . iter ( )
141
+ . map ( |& x| x as npyffi:: npy_intp ) ,
142
+ ) ;
138
143
unsafe {
139
144
let array = PyArray :: < A , _ > :: new_ ( py, dim, strides. as_ptr ( ) , 0 ) ;
140
145
let data_ptr = array. data ( ) ;
@@ -173,10 +178,7 @@ where
173
178
D : Dimension ,
174
179
{
175
180
fn npy_strides ( & self ) -> NpyStrides {
176
- NpyStrides :: new (
177
- self . strides ( ) . iter ( ) . map ( |& x| x as npyffi:: npy_intp ) ,
178
- mem:: size_of :: < A > ( ) ,
179
- )
181
+ NpyStrides :: new :: < _ , A > ( self . strides ( ) . iter ( ) . map ( |& x| x as npyffi:: npy_intp ) )
180
182
}
181
183
182
184
fn order ( & self ) -> Option < Order > {
@@ -190,40 +192,27 @@ where
190
192
}
191
193
}
192
194
193
- /// Numpy strides with short array optimization
194
- pub ( crate ) enum NpyStrides {
195
- Short ( [ npyffi:: npy_intp ; 8 ] ) ,
196
- Long ( Vec < npyffi:: npy_intp > ) ,
197
- }
195
+ /// An array of strides sufficiently large for [any NumPy array][NPY_MAXDIMS]
196
+ ///
197
+ /// [NPY_MAXDIMS]: https://github.com/numpy/numpy/blob/4c60b3263ac50e5e72f6a909e156314fc3c9cba0/numpy/core/include/numpy/ndarraytypes.h#L40
198
+ pub ( crate ) struct NpyStrides ( [ npyffi:: npy_intp ; 32 ] ) ;
198
199
199
200
impl NpyStrides {
200
201
pub ( crate ) fn as_ptr ( & self ) -> * const npy_intp {
201
- match self {
202
- NpyStrides :: Short ( inner) => inner. as_ptr ( ) ,
203
- NpyStrides :: Long ( inner) => inner. as_ptr ( ) ,
204
- }
202
+ self . 0 . as_ptr ( )
205
203
}
206
- fn from_dim < D : Dimension > ( dim : & D , type_size : usize ) -> Self {
207
- Self :: new (
208
- dim. default_strides ( )
209
- . slice ( )
210
- . iter ( )
211
- . map ( |& x| x as npyffi:: npy_intp ) ,
212
- type_size,
213
- )
214
- }
215
- fn new ( strides : impl ExactSizeIterator < Item = npyffi:: npy_intp > , type_size : usize ) -> Self {
216
- let len = strides. len ( ) ;
217
- let type_size = type_size as npyffi:: npy_intp ;
218
- if len <= 8 {
219
- let mut res = [ 0 ; 8 ] ;
220
- for ( i, s) in strides. enumerate ( ) {
221
- res[ i] = s * type_size;
222
- }
223
- NpyStrides :: Short ( res)
224
- } else {
225
- NpyStrides :: Long ( strides. map ( |n| n as npyffi:: npy_intp * type_size) . collect ( ) )
204
+
205
+ fn new < S , A > ( strides : S ) -> Self
206
+ where
207
+ S : Iterator < Item = npyffi:: npy_intp > ,
208
+ {
209
+ let type_size = mem:: size_of :: < A > ( ) as npyffi:: npy_intp ;
210
+ let mut res = [ 0 ; 32 ] ;
211
+ for ( i, s) in strides. enumerate ( ) {
212
+ * res. get_mut ( i)
213
+ . expect ( "Only dimensionalities of up to 32 are supported" ) = s * type_size;
226
214
}
215
+ Self ( res)
227
216
}
228
217
}
229
218
0 commit comments