Skip to content

Commit f2607ef

Browse files
committed
WIP: Use the GCD of the strides to check if two array views could alias.
1 parent 513daec commit f2607ef

File tree

3 files changed

+92
-42
lines changed

3 files changed

+92
-42
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ license = "BSD-2-Clause"
1717
[dependencies]
1818
libc = "0.2"
1919
num-complex = ">= 0.2, < 0.5"
20+
num-integer = "0.1"
2021
num-traits = "0.2"
2122
ndarray = ">= 0.13, < 0.16"
2223
pyo3 = { version = "0.16", default-features = false, features = ["macros"] }

src/borrow.rs

+72-34
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,9 @@
5959
//! });
6060
//! ```
6161
//!
62-
//! The second example shows that while non-overlapping views are supported,
63-
//! interleaved views which do not touch are currently not supported
64-
//! due to over-approximating which borrows are in conflict.
62+
//! The second example shows that non-overlapping and interleaved views are also supported.
6563
//!
6664
//! ```rust
67-
//! # use std::panic::{catch_unwind, AssertUnwindSafe};
68-
//! #
6965
//! use numpy::PyArray1;
7066
//! use pyo3::{types::IntoPyDict, Python};
7167
//!
@@ -78,16 +74,15 @@
7874
//! let view3 = py.eval("array[::2]", None, Some(locals)).unwrap().downcast::<PyArray1<f64>>().unwrap();
7975
//! let view4 = py.eval("array[1::2]", None, Some(locals)).unwrap().downcast::<PyArray1<f64>>().unwrap();
8076
//!
81-
//! let _view1 = view1.readwrite();
82-
//! let _view2 = view2.readwrite();
77+
//! {
78+
//! let _view1 = view1.readwrite();
79+
//! let _view2 = view2.readwrite();
80+
//! }
8381
//!
84-
//! // Will fail at runtime even though `view3` and `view4`
85-
//! // interleave as they are based on the same array.
86-
//! let res = catch_unwind(AssertUnwindSafe(|| {
82+
//! {
8783
//! let _view3 = view3.readwrite();
8884
//! let _view4 = view4.readwrite();
89-
//! }));
90-
//! assert!(res.is_err());
85+
//! }
9186
//! });
9287
//! ```
9388
//!
@@ -125,6 +120,8 @@
125120
//!
126121
//! # Limitations
127122
//!
123+
//! TODO: We only leave the case of aliasing, but only out of bounds. Can this actually happen for array views?
124+
//!
128125
//! Note that the current implementation of this is an over-approximation: It will consider overlapping borrows
129126
//! potentially conflicting if the initial arrays have the same object at the end of their [base object chain][base].
130127
//! For example, creating two views of the same underlying array by slicing can yield potentially conflicting borrows
@@ -143,6 +140,7 @@ use std::collections::hash_map::{Entry, HashMap};
143140
use std::ops::{Deref, Range};
144141

145142
use ndarray::{ArrayView, ArrayViewMut, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
143+
use num_integer::gcd;
146144
use pyo3::{FromPyObject, PyAny, PyResult};
147145

148146
use crate::array::PyArray;
@@ -155,9 +153,28 @@ use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
155153
#[derive(PartialEq, Eq, Hash)]
156154
struct BorrowKey {
157155
range: Range<usize>,
156+
data_ptr: usize,
157+
gcd_strides: isize,
158158
}
159159

160160
impl BorrowKey {
161+
fn from_array<T, D>(array: &PyArray<T, D>) -> Self
162+
where
163+
T: Element,
164+
D: Dimension,
165+
{
166+
let range = data_range(array);
167+
168+
let data_ptr = array.data() as usize;
169+
let gcd_strides = reduce(array.strides().iter().copied(), gcd).unwrap_or(1);
170+
171+
Self {
172+
range,
173+
data_ptr,
174+
gcd_strides,
175+
}
176+
}
177+
161178
fn conflicts(&self, other: &Self) -> bool {
162179
debug_assert!(self.range.start <= self.range.end);
163180
debug_assert!(other.range.start <= other.range.end);
@@ -166,6 +183,20 @@ impl BorrowKey {
166183
return false;
167184
}
168185

186+
// The Diophantine equation which describes whether any integers can combine the data pointers and strides of the two arrays s.t.
187+
// they yield the same element has a solution if and only if the GCD of all strides divides the difference of the data pointers.
188+
//
189+
// That solution could be out of bounds which mean that this is still an approximation,
190+
// but it seems sufficient to handle typical cases like the color channels of an image.
191+
//
192+
// https://users.rust-lang.org/t/math-for-borrow-checking-numpy-arrays/73303
193+
let ptr_diff = abs_diff(self.data_ptr, other.data_ptr) as isize;
194+
let gcd_strides = gcd(self.gcd_strides, other.gcd_strides);
195+
196+
if ptr_diff % gcd_strides != 0 {
197+
return false;
198+
}
199+
169200
true
170201
}
171202
}
@@ -192,10 +223,7 @@ impl BorrowFlags {
192223
D: Dimension,
193224
{
194225
let address = base_address(array);
195-
196-
let key = BorrowKey {
197-
range: data_range(array),
198-
};
226+
let key = BorrowKey::from_array(array);
199227

200228
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
201229
// and we are not calling into user code which might re-enter this function.
@@ -242,10 +270,7 @@ impl BorrowFlags {
242270
D: Dimension,
243271
{
244272
let address = base_address(array);
245-
246-
let key = BorrowKey {
247-
range: data_range(array),
248-
};
273+
let key = BorrowKey::from_array(array);
249274

250275
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
251276
// and we are not calling into user code which might re-enter this function.
@@ -272,10 +297,7 @@ impl BorrowFlags {
272297
D: Dimension,
273298
{
274299
let address = base_address(array);
275-
276-
let key = BorrowKey {
277-
range: data_range(array),
278-
};
300+
let key = BorrowKey::from_array(array);
279301

280302
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
281303
// and we are not calling into user code which might re-enter this function.
@@ -320,10 +342,7 @@ impl BorrowFlags {
320342
D: Dimension,
321343
{
322344
let address = base_address(array);
323-
324-
let key = BorrowKey {
325-
range: data_range(array),
326-
};
345+
let key = BorrowKey::from_array(array);
327346

328347
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
329348
// and we are not calling into user code which might re-enter this function.
@@ -628,6 +647,25 @@ where
628647
Range { start, end }
629648
}
630649

650+
// FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable.
651+
fn abs_diff(lhs: usize, rhs: usize) -> usize {
652+
if lhs >= rhs {
653+
lhs - rhs
654+
} else {
655+
rhs - lhs
656+
}
657+
}
658+
659+
// FIXME(adamreichold): Use `Iterator::reduce` from std when our MSRV reaches 1.51.
660+
fn reduce<I, F>(mut iter: I, f: F) -> Option<I::Item>
661+
where
662+
I: Iterator,
663+
F: FnMut(I::Item, I::Item) -> I::Item,
664+
{
665+
let first = iter.next()?;
666+
Some(iter.fold(first, f))
667+
}
668+
631669
#[cfg(test)]
632670
mod tests {
633671
use super::*;
@@ -650,7 +688,7 @@ mod tests {
650688
assert_eq!(base_address, array as *const _ as usize);
651689

652690
let data_range = data_range(array);
653-
assert_eq!(data_range.start, unsafe { array.data() } as usize);
691+
assert_eq!(data_range.start, array.data() as usize);
654692
assert_eq!(data_range.end, unsafe { array.data().add(15) } as usize);
655693
});
656694
}
@@ -668,7 +706,7 @@ mod tests {
668706
assert_eq!(base_address, base as usize);
669707

670708
let data_range = data_range(array);
671-
assert_eq!(data_range.start, unsafe { array.data() } as usize);
709+
assert_eq!(data_range.start, array.data() as usize);
672710
assert_eq!(data_range.end, unsafe { array.data().add(15) } as usize);
673711
});
674712
}
@@ -694,7 +732,7 @@ mod tests {
694732
assert_eq!(base_address, base as usize);
695733

696734
let data_range = data_range(view);
697-
assert_eq!(data_range.start, unsafe { view.data() } as usize);
735+
assert_eq!(data_range.start, view.data() as usize);
698736
assert_eq!(data_range.end, unsafe { view.data().add(12) } as usize);
699737
});
700738
}
@@ -724,7 +762,7 @@ mod tests {
724762
assert_eq!(base_address, base as usize);
725763

726764
let data_range = data_range(view);
727-
assert_eq!(data_range.start, unsafe { view.data() } as usize);
765+
assert_eq!(data_range.start, view.data() as usize);
728766
assert_eq!(data_range.end, unsafe { view.data().add(12) } as usize);
729767
});
730768
}
@@ -763,7 +801,7 @@ mod tests {
763801
assert_eq!(base_address, base as usize);
764802

765803
let data_range = data_range(view2);
766-
assert_eq!(data_range.start, unsafe { view2.data() } as usize);
804+
assert_eq!(data_range.start, view2.data() as usize);
767805
assert_eq!(data_range.end, unsafe { view2.data().add(6) } as usize);
768806
});
769807
}
@@ -806,7 +844,7 @@ mod tests {
806844
assert_eq!(base_address, base as usize);
807845

808846
let data_range = data_range(view2);
809-
assert_eq!(data_range.start, unsafe { view2.data() } as usize);
847+
assert_eq!(data_range.start, view2.data() as usize);
810848
assert_eq!(data_range.end, unsafe { view2.data().add(6) } as usize);
811849
});
812850
}

tests/borrow.rs

+19-8
Original file line numberDiff line numberDiff line change
@@ -212,28 +212,39 @@ fn conflict_due_reborrow_of_overlapping_views() {
212212
}
213213

214214
#[test]
215-
#[should_panic(expected = "AlreadyBorrowed")]
216-
fn interleaved_views_conflict() {
215+
fn interleaved_views_do_not_conflict() {
217216
Python::with_gil(|py| {
218-
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);
217+
let array = PyArray::<f64, _>::zeros(py, (23, 42, 3), false);
219218
let locals = [("array", array)].into_py_dict(py);
220219

221220
let view1 = py
222-
.eval("array[:,:,1]", None, Some(locals))
221+
.eval("array[:,:,0]", None, Some(locals))
223222
.unwrap()
224223
.downcast::<PyArray2<f64>>()
225224
.unwrap();
226-
assert_eq!(view1.shape(), [1, 2]);
225+
assert_eq!(view1.shape(), [23, 42]);
227226

228227
let view2 = py
228+
.eval("array[:,:,1]", None, Some(locals))
229+
.unwrap()
230+
.downcast::<PyArray2<f64>>()
231+
.unwrap();
232+
assert_eq!(view2.shape(), [23, 42]);
233+
234+
let view3 = py
229235
.eval("array[:,:,2]", None, Some(locals))
230236
.unwrap()
231237
.downcast::<PyArray2<f64>>()
232238
.unwrap();
233-
assert_eq!(view2.shape(), [1, 2]);
239+
assert_eq!(view2.shape(), [23, 42]);
234240

235-
let _exclusive1 = view1.readwrite();
236-
let _exclusive2 = view2.readwrite();
241+
let exclusive1 = view1.readwrite();
242+
let exclusive2 = view2.readwrite();
243+
let exclusive3 = view3.readwrite();
244+
245+
assert_eq!(exclusive3.len(), 23 * 42);
246+
assert_eq!(exclusive2.len(), 23 * 42);
247+
assert_eq!(exclusive1.len(), 23 * 42);
237248
});
238249
}
239250

0 commit comments

Comments
 (0)