59
59
//! });
60
60
//! ```
61
61
//!
62
- //! The second example shows that while non-overlapping views are supported,
63
- //! interleaved views which do not touch are currently not supported
64
- //! due to over-approximating which borrows are in conflict.
62
+ //! The second example shows that non-overlapping and interleaved views are also supported.
65
63
//!
66
64
//! ```rust
67
- //! # use std::panic::{catch_unwind, AssertUnwindSafe};
68
- //! #
69
65
//! use numpy::PyArray1;
70
66
//! use pyo3::{types::IntoPyDict, Python};
71
67
//!
78
74
//! let view3 = py.eval("array[::2]", None, Some(locals)).unwrap().downcast::<PyArray1<f64>>().unwrap();
79
75
//! let view4 = py.eval("array[1::2]", None, Some(locals)).unwrap().downcast::<PyArray1<f64>>().unwrap();
80
76
//!
81
- //! let _view1 = view1.readwrite();
82
- //! let _view2 = view2.readwrite();
77
+ //! {
78
+ //! let _view1 = view1.readwrite();
79
+ //! let _view2 = view2.readwrite();
80
+ //! }
83
81
//!
84
- //! // Will fail at runtime even though `view3` and `view4`
85
- //! // interleave as they are based on the same array.
86
- //! let res = catch_unwind(AssertUnwindSafe(|| {
82
+ //! {
87
83
//! let _view3 = view3.readwrite();
88
84
//! let _view4 = view4.readwrite();
89
- //! }));
90
- //! assert!(res.is_err());
85
+ //! }
91
86
//! });
92
87
//! ```
93
88
//!
125
120
//!
126
121
//! # Limitations
127
122
//!
123
+ //! TODO: We only leave the case of aliasing, but only out of bounds. Can this actually happen for array views?
124
+ //!
128
125
//! Note that the current implementation of this is an over-approximation: It will consider overlapping borrows
129
126
//! potentially conflicting if the initial arrays have the same object at the end of their [base object chain][base].
130
127
//! For example, creating two views of the same underlying array by slicing can yield potentially conflicting borrows
@@ -143,6 +140,7 @@ use std::collections::hash_map::{Entry, HashMap};
143
140
use std:: ops:: { Deref , Range } ;
144
141
145
142
use ndarray:: { ArrayView , ArrayViewMut , Dimension , Ix1 , Ix2 , Ix3 , Ix4 , Ix5 , Ix6 , IxDyn } ;
143
+ use num_integer:: gcd;
146
144
use pyo3:: { FromPyObject , PyAny , PyResult } ;
147
145
148
146
use crate :: array:: PyArray ;
@@ -155,9 +153,28 @@ use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
155
153
#[ derive( PartialEq , Eq , Hash ) ]
156
154
struct BorrowKey {
157
155
range : Range < usize > ,
156
+ data_ptr : usize ,
157
+ gcd_strides : isize ,
158
158
}
159
159
160
160
impl BorrowKey {
161
+ fn from_array < T , D > ( array : & PyArray < T , D > ) -> Self
162
+ where
163
+ T : Element ,
164
+ D : Dimension ,
165
+ {
166
+ let range = data_range ( array) ;
167
+
168
+ let data_ptr = array. data ( ) as usize ;
169
+ let gcd_strides = reduce ( array. strides ( ) . iter ( ) . copied ( ) , gcd) . unwrap_or ( 1 ) ;
170
+
171
+ Self {
172
+ range,
173
+ data_ptr,
174
+ gcd_strides,
175
+ }
176
+ }
177
+
161
178
fn conflicts ( & self , other : & Self ) -> bool {
162
179
debug_assert ! ( self . range. start <= self . range. end) ;
163
180
debug_assert ! ( other. range. start <= other. range. end) ;
@@ -166,6 +183,20 @@ impl BorrowKey {
166
183
return false ;
167
184
}
168
185
186
+ // The Diophantine equation which describes whether any integers can combine the data pointers and strides of the two arrays s.t.
187
+ // they yield the same element has a solution if and only if the GCD of all strides divides the difference of the data pointers.
188
+ //
189
+ // That solution could be out of bounds which mean that this is still an approximation,
190
+ // but it seems sufficient to handle typical cases like the color channels of an image.
191
+ //
192
+ // https://users.rust-lang.org/t/math-for-borrow-checking-numpy-arrays/73303
193
+ let ptr_diff = abs_diff ( self . data_ptr , other. data_ptr ) as isize ;
194
+ let gcd_strides = gcd ( self . gcd_strides , other. gcd_strides ) ;
195
+
196
+ if ptr_diff % gcd_strides != 0 {
197
+ return false ;
198
+ }
199
+
169
200
true
170
201
}
171
202
}
@@ -192,10 +223,7 @@ impl BorrowFlags {
192
223
D : Dimension ,
193
224
{
194
225
let address = base_address ( array) ;
195
-
196
- let key = BorrowKey {
197
- range : data_range ( array) ,
198
- } ;
226
+ let key = BorrowKey :: from_array ( array) ;
199
227
200
228
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
201
229
// and we are not calling into user code which might re-enter this function.
@@ -242,10 +270,7 @@ impl BorrowFlags {
242
270
D : Dimension ,
243
271
{
244
272
let address = base_address ( array) ;
245
-
246
- let key = BorrowKey {
247
- range : data_range ( array) ,
248
- } ;
273
+ let key = BorrowKey :: from_array ( array) ;
249
274
250
275
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
251
276
// and we are not calling into user code which might re-enter this function.
@@ -272,10 +297,7 @@ impl BorrowFlags {
272
297
D : Dimension ,
273
298
{
274
299
let address = base_address ( array) ;
275
-
276
- let key = BorrowKey {
277
- range : data_range ( array) ,
278
- } ;
300
+ let key = BorrowKey :: from_array ( array) ;
279
301
280
302
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
281
303
// and we are not calling into user code which might re-enter this function.
@@ -320,10 +342,7 @@ impl BorrowFlags {
320
342
D : Dimension ,
321
343
{
322
344
let address = base_address ( array) ;
323
-
324
- let key = BorrowKey {
325
- range : data_range ( array) ,
326
- } ;
345
+ let key = BorrowKey :: from_array ( array) ;
327
346
328
347
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
329
348
// and we are not calling into user code which might re-enter this function.
@@ -628,6 +647,25 @@ where
628
647
Range { start, end }
629
648
}
630
649
650
+ // FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable.
651
+ fn abs_diff ( lhs : usize , rhs : usize ) -> usize {
652
+ if lhs >= rhs {
653
+ lhs - rhs
654
+ } else {
655
+ rhs - lhs
656
+ }
657
+ }
658
+
659
+ // FIXME(adamreichold): Use `Iterator::reduce` from std when our MSRV reaches 1.51.
660
+ fn reduce < I , F > ( mut iter : I , f : F ) -> Option < I :: Item >
661
+ where
662
+ I : Iterator ,
663
+ F : FnMut ( I :: Item , I :: Item ) -> I :: Item ,
664
+ {
665
+ let first = iter. next ( ) ?;
666
+ Some ( iter. fold ( first, f) )
667
+ }
668
+
631
669
#[ cfg( test) ]
632
670
mod tests {
633
671
use super :: * ;
@@ -650,7 +688,7 @@ mod tests {
650
688
assert_eq ! ( base_address, array as * const _ as usize ) ;
651
689
652
690
let data_range = data_range ( array) ;
653
- assert_eq ! ( data_range. start, unsafe { array. data( ) } as usize ) ;
691
+ assert_eq ! ( data_range. start, array. data( ) as usize ) ;
654
692
assert_eq ! ( data_range. end, unsafe { array. data( ) . add( 15 ) } as usize ) ;
655
693
} ) ;
656
694
}
@@ -668,7 +706,7 @@ mod tests {
668
706
assert_eq ! ( base_address, base as usize ) ;
669
707
670
708
let data_range = data_range ( array) ;
671
- assert_eq ! ( data_range. start, unsafe { array. data( ) } as usize ) ;
709
+ assert_eq ! ( data_range. start, array. data( ) as usize ) ;
672
710
assert_eq ! ( data_range. end, unsafe { array. data( ) . add( 15 ) } as usize ) ;
673
711
} ) ;
674
712
}
@@ -694,7 +732,7 @@ mod tests {
694
732
assert_eq ! ( base_address, base as usize ) ;
695
733
696
734
let data_range = data_range ( view) ;
697
- assert_eq ! ( data_range. start, unsafe { view. data( ) } as usize ) ;
735
+ assert_eq ! ( data_range. start, view. data( ) as usize ) ;
698
736
assert_eq ! ( data_range. end, unsafe { view. data( ) . add( 12 ) } as usize ) ;
699
737
} ) ;
700
738
}
@@ -724,7 +762,7 @@ mod tests {
724
762
assert_eq ! ( base_address, base as usize ) ;
725
763
726
764
let data_range = data_range ( view) ;
727
- assert_eq ! ( data_range. start, unsafe { view. data( ) } as usize ) ;
765
+ assert_eq ! ( data_range. start, view. data( ) as usize ) ;
728
766
assert_eq ! ( data_range. end, unsafe { view. data( ) . add( 12 ) } as usize ) ;
729
767
} ) ;
730
768
}
@@ -763,7 +801,7 @@ mod tests {
763
801
assert_eq ! ( base_address, base as usize ) ;
764
802
765
803
let data_range = data_range ( view2) ;
766
- assert_eq ! ( data_range. start, unsafe { view2. data( ) } as usize ) ;
804
+ assert_eq ! ( data_range. start, view2. data( ) as usize ) ;
767
805
assert_eq ! ( data_range. end, unsafe { view2. data( ) . add( 6 ) } as usize ) ;
768
806
} ) ;
769
807
}
@@ -806,7 +844,7 @@ mod tests {
806
844
assert_eq ! ( base_address, base as usize ) ;
807
845
808
846
let data_range = data_range ( view2) ;
809
- assert_eq ! ( data_range. start, unsafe { view2. data( ) } as usize ) ;
847
+ assert_eq ! ( data_range. start, view2. data( ) as usize ) ;
810
848
assert_eq ! ( data_range. end, unsafe { view2. data( ) . add( 6 ) } as usize ) ;
811
849
} ) ;
812
850
}
0 commit comments