59
59
//! });
60
60
//! ```
61
61
//!
62
- //! The second example shows that while non-overlapping views are supported,
63
- //! interleaved views which do not touch are currently not supported
64
- //! due to over-approximating which borrows are in conflict.
62
+ //! The second example shows that non-overlapping and interleaved views are also supported.
65
63
//!
66
64
//! ```rust
67
- //! # use std::panic::{catch_unwind, AssertUnwindSafe};
68
- //! #
69
65
//! use numpy::PyArray1;
70
66
//! use pyo3::{types::IntoPyDict, Python};
71
67
//!
78
74
//! let view3 = py.eval("array[::2]", None, Some(locals)).unwrap().downcast::<PyArray1<f64>>().unwrap();
79
75
//! let view4 = py.eval("array[1::2]", None, Some(locals)).unwrap().downcast::<PyArray1<f64>>().unwrap();
80
76
//!
81
- //! let _view1 = view1.readwrite();
82
- //! let _view2 = view2.readwrite();
77
+ //! {
78
+ //! let _view1 = view1.readwrite();
79
+ //! let _view2 = view2.readwrite();
80
+ //! }
83
81
//!
84
- //! // Will fail at runtime even though `view3` and `view4`
85
- //! // interleave as they are based on the same array.
86
- //! let res = catch_unwind(AssertUnwindSafe(|| {
82
+ //! {
87
83
//! let _view3 = view3.readwrite();
88
84
//! let _view4 = view4.readwrite();
89
- //! }));
90
- //! assert!(res.is_err());
85
+ //! }
91
86
//! });
92
87
//! ```
93
88
//!
125
120
//!
126
121
//! # Limitations
127
122
//!
128
- //! Note that the current implementation of this is an over-approximation: It will consider overlapping borrows
123
+ //! Note that the current implementation of this is an over-approximation: It will consider borrows
129
124
//! potentially conflicting if the initial arrays have the same object at the end of their [base object chain][base].
130
- //! For example, creating two views of the same underlying array by slicing can yield potentially conflicting borrows
131
- //! even if the slice indices are chosen so that the two views do not actually share any elements by interleaving along one of its axes.
125
+ //! Then, multiple conditions which are sufficient but not necessary to show the absence of conflicts are checked,
126
+ //! but there are cases which they do not handle, for example slicing an array with a step size
127
+ //! that does not divide its dimension along that axis. In these situations, borrows are rejected even though the arrays
128
+ //! do not actually share any elements.
132
129
//!
133
130
//! This does limit the set of programs that can be written using safe Rust in way similar to rustc itself
134
131
//! which ensures that all accepted programs are memory safe but does not necessarily accept all memory safe programs.
135
- //! The plan is to refine this checking to correctly handle more involved cases like interleaved views
136
- //! into the same array and until then the unsafe method [`PyArray::as_array_mut`] can be used as an escape hatch.
132
+ //! The plan is to refine this checking to correctly handle more involved cases and until then
133
+ //! the unsafe method [`PyArray::as_array_mut`] can be used as an escape hatch.
137
134
//!
138
135
//! [base]: https://numpy.org/doc/stable/reference/c-api/types-and-structures.html#c.NPY_AO.base
139
136
#![ deny( missing_docs) ]
@@ -143,6 +140,7 @@ use std::collections::hash_map::{Entry, HashMap};
143
140
use std:: ops:: { Deref , Range } ;
144
141
145
142
use ndarray:: { ArrayView , ArrayViewMut , Dimension , Ix1 , Ix2 , Ix3 , Ix4 , Ix5 , Ix6 , IxDyn } ;
143
+ use num_integer:: gcd;
146
144
use pyo3:: { FromPyObject , PyAny , PyResult } ;
147
145
148
146
use crate :: array:: PyArray ;
@@ -155,9 +153,28 @@ use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
155
153
#[ derive( PartialEq , Eq , Hash ) ]
156
154
struct BorrowKey {
157
155
range : Range < usize > ,
156
+ data_ptr : usize ,
157
+ gcd_strides : isize ,
158
158
}
159
159
160
160
impl BorrowKey {
161
+ fn from_array < T , D > ( array : & PyArray < T , D > ) -> Self
162
+ where
163
+ T : Element ,
164
+ D : Dimension ,
165
+ {
166
+ let range = data_range ( array) ;
167
+
168
+ let data_ptr = array. data ( ) as usize ;
169
+ let gcd_strides = reduce ( array. strides ( ) . iter ( ) . copied ( ) , gcd) . unwrap_or ( 1 ) ;
170
+
171
+ Self {
172
+ range,
173
+ data_ptr,
174
+ gcd_strides,
175
+ }
176
+ }
177
+
161
178
fn conflicts ( & self , other : & Self ) -> bool {
162
179
debug_assert ! ( self . range. start <= self . range. end) ;
163
180
debug_assert ! ( other. range. start <= other. range. end) ;
@@ -166,6 +183,21 @@ impl BorrowKey {
166
183
return false ;
167
184
}
168
185
186
+ // The Diophantine equation which describes whether any integers can combine the data pointers and strides of the two arrays s.t.
187
+ // they yield the same element has a solution if and only if the GCD of all strides divides the difference of the data pointers.
188
+ //
189
+ // That solution could be out of bounds which mean that this is still an over-approximation.
190
+ // It appears sufficient to handle typical cases like the color channels of an image,
191
+ // but fails when slicing an array with a step size that does not divide the dimension along that axis.
192
+ //
193
+ // https://users.rust-lang.org/t/math-for-borrow-checking-numpy-arrays/73303
194
+ let ptr_diff = abs_diff ( self . data_ptr , other. data_ptr ) as isize ;
195
+ let gcd_strides = gcd ( self . gcd_strides , other. gcd_strides ) ;
196
+
197
+ if ptr_diff % gcd_strides != 0 {
198
+ return false ;
199
+ }
200
+
169
201
true
170
202
}
171
203
}
@@ -192,10 +224,7 @@ impl BorrowFlags {
192
224
D : Dimension ,
193
225
{
194
226
let address = base_address ( array) ;
195
-
196
- let key = BorrowKey {
197
- range : data_range ( array) ,
198
- } ;
227
+ let key = BorrowKey :: from_array ( array) ;
199
228
200
229
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
201
230
// and we are not calling into user code which might re-enter this function.
@@ -242,10 +271,7 @@ impl BorrowFlags {
242
271
D : Dimension ,
243
272
{
244
273
let address = base_address ( array) ;
245
-
246
- let key = BorrowKey {
247
- range : data_range ( array) ,
248
- } ;
274
+ let key = BorrowKey :: from_array ( array) ;
249
275
250
276
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
251
277
// and we are not calling into user code which might re-enter this function.
@@ -272,10 +298,7 @@ impl BorrowFlags {
272
298
D : Dimension ,
273
299
{
274
300
let address = base_address ( array) ;
275
-
276
- let key = BorrowKey {
277
- range : data_range ( array) ,
278
- } ;
301
+ let key = BorrowKey :: from_array ( array) ;
279
302
280
303
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
281
304
// and we are not calling into user code which might re-enter this function.
@@ -320,10 +343,7 @@ impl BorrowFlags {
320
343
D : Dimension ,
321
344
{
322
345
let address = base_address ( array) ;
323
-
324
- let key = BorrowKey {
325
- range : data_range ( array) ,
326
- } ;
346
+ let key = BorrowKey :: from_array ( array) ;
327
347
328
348
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
329
349
// and we are not calling into user code which might re-enter this function.
@@ -628,6 +648,25 @@ where
628
648
Range { start, end }
629
649
}
630
650
651
+ // FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable.
652
+ fn abs_diff ( lhs : usize , rhs : usize ) -> usize {
653
+ if lhs >= rhs {
654
+ lhs - rhs
655
+ } else {
656
+ rhs - lhs
657
+ }
658
+ }
659
+
660
+ // FIXME(adamreichold): Use `Iterator::reduce` from std when our MSRV reaches 1.51.
661
+ fn reduce < I , F > ( mut iter : I , f : F ) -> Option < I :: Item >
662
+ where
663
+ I : Iterator ,
664
+ F : FnMut ( I :: Item , I :: Item ) -> I :: Item ,
665
+ {
666
+ let first = iter. next ( ) ?;
667
+ Some ( iter. fold ( first, f) )
668
+ }
669
+
631
670
#[ cfg( test) ]
632
671
mod tests {
633
672
use super :: * ;
@@ -650,7 +689,7 @@ mod tests {
650
689
assert_eq ! ( base_address, array as * const _ as usize ) ;
651
690
652
691
let data_range = data_range ( array) ;
653
- assert_eq ! ( data_range. start, unsafe { array. data( ) } as usize ) ;
692
+ assert_eq ! ( data_range. start, array. data( ) as usize ) ;
654
693
assert_eq ! ( data_range. end, unsafe { array. data( ) . add( 15 ) } as usize ) ;
655
694
} ) ;
656
695
}
@@ -668,7 +707,7 @@ mod tests {
668
707
assert_eq ! ( base_address, base as usize ) ;
669
708
670
709
let data_range = data_range ( array) ;
671
- assert_eq ! ( data_range. start, unsafe { array. data( ) } as usize ) ;
710
+ assert_eq ! ( data_range. start, array. data( ) as usize ) ;
672
711
assert_eq ! ( data_range. end, unsafe { array. data( ) . add( 15 ) } as usize ) ;
673
712
} ) ;
674
713
}
@@ -694,7 +733,7 @@ mod tests {
694
733
assert_eq ! ( base_address, base as usize ) ;
695
734
696
735
let data_range = data_range ( view) ;
697
- assert_eq ! ( data_range. start, unsafe { view. data( ) } as usize ) ;
736
+ assert_eq ! ( data_range. start, view. data( ) as usize ) ;
698
737
assert_eq ! ( data_range. end, unsafe { view. data( ) . add( 12 ) } as usize ) ;
699
738
} ) ;
700
739
}
@@ -724,7 +763,7 @@ mod tests {
724
763
assert_eq ! ( base_address, base as usize ) ;
725
764
726
765
let data_range = data_range ( view) ;
727
- assert_eq ! ( data_range. start, unsafe { view. data( ) } as usize ) ;
766
+ assert_eq ! ( data_range. start, view. data( ) as usize ) ;
728
767
assert_eq ! ( data_range. end, unsafe { view. data( ) . add( 12 ) } as usize ) ;
729
768
} ) ;
730
769
}
@@ -763,7 +802,7 @@ mod tests {
763
802
assert_eq ! ( base_address, base as usize ) ;
764
803
765
804
let data_range = data_range ( view2) ;
766
- assert_eq ! ( data_range. start, unsafe { view2. data( ) } as usize ) ;
805
+ assert_eq ! ( data_range. start, view2. data( ) as usize ) ;
767
806
assert_eq ! ( data_range. end, unsafe { view2. data( ) . add( 6 ) } as usize ) ;
768
807
} ) ;
769
808
}
@@ -806,7 +845,7 @@ mod tests {
806
845
assert_eq ! ( base_address, base as usize ) ;
807
846
808
847
let data_range = data_range ( view2) ;
809
- assert_eq ! ( data_range. start, unsafe { view2. data( ) } as usize ) ;
848
+ assert_eq ! ( data_range. start, view2. data( ) as usize ) ;
810
849
assert_eq ! ( data_range. end, unsafe { view2. data( ) . add( 6 ) } as usize ) ;
811
850
} ) ;
812
851
}
@@ -836,4 +875,63 @@ mod tests {
836
875
assert_eq ! ( data_range. end, unsafe { view. data( ) . offset( 6 ) } as usize ) ;
837
876
} ) ;
838
877
}
878
+
879
+ #[ test]
880
+ fn view_with_non_dividing_strides ( ) {
881
+ Python :: with_gil ( |py| {
882
+ let array = PyArray :: < f64 , _ > :: zeros ( py, ( 10 , 10 ) , false ) ;
883
+ let locals = [ ( "array" , array) ] . into_py_dict ( py) ;
884
+
885
+ let view1 = py
886
+ . eval ( "array[:,::3]" , None , Some ( locals) )
887
+ . unwrap ( )
888
+ . downcast :: < PyArray2 < f64 > > ( )
889
+ . unwrap ( ) ;
890
+
891
+ let key1 = BorrowKey :: from_array ( view1) ;
892
+
893
+ assert_eq ! ( view1. strides( ) , & [ 80 , 24 ] ) ;
894
+ assert_eq ! ( key1. gcd_strides, 8 ) ;
895
+
896
+ let view2 = py
897
+ . eval ( "array[:,1::3]" , None , Some ( locals) )
898
+ . unwrap ( )
899
+ . downcast :: < PyArray2 < f64 > > ( )
900
+ . unwrap ( ) ;
901
+
902
+ let key2 = BorrowKey :: from_array ( view2) ;
903
+
904
+ assert_eq ! ( view2. strides( ) , & [ 80 , 24 ] ) ;
905
+ assert_eq ! ( key2. gcd_strides, 8 ) ;
906
+
907
+ let view3 = py
908
+ . eval ( "array[:,::2]" , None , Some ( locals) )
909
+ . unwrap ( )
910
+ . downcast :: < PyArray2 < f64 > > ( )
911
+ . unwrap ( ) ;
912
+
913
+ let key3 = BorrowKey :: from_array ( view3) ;
914
+
915
+ assert_eq ! ( view3. strides( ) , & [ 80 , 16 ] ) ;
916
+ assert_eq ! ( key3. gcd_strides, 16 ) ;
917
+
918
+ let view4 = py
919
+ . eval ( "array[:,1::2]" , None , Some ( locals) )
920
+ . unwrap ( )
921
+ . downcast :: < PyArray2 < f64 > > ( )
922
+ . unwrap ( ) ;
923
+
924
+ let key4 = BorrowKey :: from_array ( view4) ;
925
+
926
+ assert_eq ! ( view4. strides( ) , & [ 80 , 16 ] ) ;
927
+ assert_eq ! ( key4. gcd_strides, 16 ) ;
928
+
929
+ assert ! ( !key3. conflicts( & key4) ) ;
930
+ assert ! ( key1. conflicts( & key3) ) ;
931
+ assert ! ( key2. conflicts( & key4) ) ;
932
+
933
+ // This is a false conflict where all aliasing indices like (0,7) and (2,0) are out of bounds.
934
+ assert ! ( !key1. conflicts( & key2) ) ;
935
+ } ) ;
936
+ }
839
937
}
0 commit comments