Skip to content

Commit ef7a87e

Browse files
committed
Use the GCD of the strides to check if two array views could alias.
1 parent 935aeb4 commit ef7a87e

File tree

3 files changed

+157
-47
lines changed

3 files changed

+157
-47
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ license = "BSD-2-Clause"
1717
[dependencies]
1818
libc = "0.2"
1919
num-complex = ">= 0.2, < 0.5"
20+
num-integer = "0.1"
2021
num-traits = "0.2"
2122
ndarray = ">= 0.13, < 0.16"
2223
pyo3 = { version = "0.16", default-features = false, features = ["macros"] }

src/borrow.rs

+137-39
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,9 @@
5959
//! });
6060
//! ```
6161
//!
62-
//! The second example shows that while non-overlapping views are supported,
63-
//! interleaved views which do not touch are currently not supported
64-
//! due to over-approximating which borrows are in conflict.
62+
//! The second example shows that non-overlapping and interleaved views are also supported.
6563
//!
6664
//! ```rust
67-
//! # use std::panic::{catch_unwind, AssertUnwindSafe};
68-
//! #
6965
//! use numpy::PyArray1;
7066
//! use pyo3::{types::IntoPyDict, Python};
7167
//!
@@ -78,16 +74,15 @@
7874
//! let view3 = py.eval("array[::2]", None, Some(locals)).unwrap().downcast::<PyArray1<f64>>().unwrap();
7975
//! let view4 = py.eval("array[1::2]", None, Some(locals)).unwrap().downcast::<PyArray1<f64>>().unwrap();
8076
//!
81-
//! let _view1 = view1.readwrite();
82-
//! let _view2 = view2.readwrite();
77+
//! {
78+
//! let _view1 = view1.readwrite();
79+
//! let _view2 = view2.readwrite();
80+
//! }
8381
//!
84-
//! // Will fail at runtime even though `view3` and `view4`
85-
//! // interleave as they are based on the same array.
86-
//! let res = catch_unwind(AssertUnwindSafe(|| {
82+
//! {
8783
//! let _view3 = view3.readwrite();
8884
//! let _view4 = view4.readwrite();
89-
//! }));
90-
//! assert!(res.is_err());
85+
//! }
9186
//! });
9287
//! ```
9388
//!
@@ -125,15 +120,17 @@
125120
//!
126121
//! # Limitations
127122
//!
128-
//! Note that the current implementation of this is an over-approximation: It will consider overlapping borrows
123+
//! Note that the current implementation of this is an over-approximation: It will consider borrows
129124
//! potentially conflicting if the initial arrays have the same object at the end of their [base object chain][base].
130-
//! For example, creating two views of the same underlying array by slicing can yield potentially conflicting borrows
131-
//! even if the slice indices are chosen so that the two views do not actually share any elements by interleaving along one of its axes.
125+
//! Then, multiple conditions which are sufficient but not necessary to show the absence of conflicts are checked,
126+
//! but there are cases which they do not handle, for example slicing an array with a step size
127+
//! that does not divide its dimension along that axis. In these situations, borrows are rejected even though the arrays
128+
//! do not actually share any elements.
132129
//!
133130
//! This does limit the set of programs that can be written using safe Rust in way similar to rustc itself
134131
//! which ensures that all accepted programs are memory safe but does not necessarily accept all memory safe programs.
135-
//! The plan is to refine this checking to correctly handle more involved cases like interleaved views
136-
//! into the same array and until then the unsafe method [`PyArray::as_array_mut`] can be used as an escape hatch.
132+
//! The plan is to refine this checking to correctly handle more involved cases and until then
133+
//! the unsafe method [`PyArray::as_array_mut`] can be used as an escape hatch.
137134
//!
138135
//! [base]: https://numpy.org/doc/stable/reference/c-api/types-and-structures.html#c.NPY_AO.base
139136
#![deny(missing_docs)]
@@ -143,6 +140,7 @@ use std::collections::hash_map::{Entry, HashMap};
143140
use std::ops::{Deref, Range};
144141

145142
use ndarray::{ArrayView, ArrayViewMut, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
143+
use num_integer::gcd;
146144
use pyo3::{FromPyObject, PyAny, PyResult};
147145

148146
use crate::array::PyArray;
@@ -155,9 +153,28 @@ use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
155153
#[derive(PartialEq, Eq, Hash)]
156154
struct BorrowKey {
157155
range: Range<usize>,
156+
data_ptr: usize,
157+
gcd_strides: isize,
158158
}
159159

160160
impl BorrowKey {
161+
fn from_array<T, D>(array: &PyArray<T, D>) -> Self
162+
where
163+
T: Element,
164+
D: Dimension,
165+
{
166+
let range = data_range(array);
167+
168+
let data_ptr = array.data() as usize;
169+
let gcd_strides = reduce(array.strides().iter().copied(), gcd).unwrap_or(1);
170+
171+
Self {
172+
range,
173+
data_ptr,
174+
gcd_strides,
175+
}
176+
}
177+
161178
fn conflicts(&self, other: &Self) -> bool {
162179
debug_assert!(self.range.start <= self.range.end);
163180
debug_assert!(other.range.start <= other.range.end);
@@ -166,6 +183,21 @@ impl BorrowKey {
166183
return false;
167184
}
168185

186+
// The Diophantine equation which describes whether any integers can combine the data pointers and strides of the two arrays s.t.
187+
// they yield the same element has a solution if and only if the GCD of all strides divides the difference of the data pointers.
188+
//
189+
// That solution could be out of bounds which mean that this is still an over-approximation.
190+
// It appears sufficient to handle typical cases like the color channels of an image,
191+
// but fails when slicing an array with a step size that does not divide the dimension along that axis.
192+
//
193+
// https://users.rust-lang.org/t/math-for-borrow-checking-numpy-arrays/73303
194+
let ptr_diff = abs_diff(self.data_ptr, other.data_ptr) as isize;
195+
let gcd_strides = gcd(self.gcd_strides, other.gcd_strides);
196+
197+
if ptr_diff % gcd_strides != 0 {
198+
return false;
199+
}
200+
169201
true
170202
}
171203
}
@@ -192,10 +224,7 @@ impl BorrowFlags {
192224
D: Dimension,
193225
{
194226
let address = base_address(array);
195-
196-
let key = BorrowKey {
197-
range: data_range(array),
198-
};
227+
let key = BorrowKey::from_array(array);
199228

200229
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
201230
// and we are not calling into user code which might re-enter this function.
@@ -242,10 +271,7 @@ impl BorrowFlags {
242271
D: Dimension,
243272
{
244273
let address = base_address(array);
245-
246-
let key = BorrowKey {
247-
range: data_range(array),
248-
};
274+
let key = BorrowKey::from_array(array);
249275

250276
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
251277
// and we are not calling into user code which might re-enter this function.
@@ -272,10 +298,7 @@ impl BorrowFlags {
272298
D: Dimension,
273299
{
274300
let address = base_address(array);
275-
276-
let key = BorrowKey {
277-
range: data_range(array),
278-
};
301+
let key = BorrowKey::from_array(array);
279302

280303
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
281304
// and we are not calling into user code which might re-enter this function.
@@ -320,10 +343,7 @@ impl BorrowFlags {
320343
D: Dimension,
321344
{
322345
let address = base_address(array);
323-
324-
let key = BorrowKey {
325-
range: data_range(array),
326-
};
346+
let key = BorrowKey::from_array(array);
327347

328348
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
329349
// and we are not calling into user code which might re-enter this function.
@@ -628,6 +648,25 @@ where
628648
Range { start, end }
629649
}
630650

651+
// FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable.
652+
fn abs_diff(lhs: usize, rhs: usize) -> usize {
653+
if lhs >= rhs {
654+
lhs - rhs
655+
} else {
656+
rhs - lhs
657+
}
658+
}
659+
660+
// FIXME(adamreichold): Use `Iterator::reduce` from std when our MSRV reaches 1.51.
661+
fn reduce<I, F>(mut iter: I, f: F) -> Option<I::Item>
662+
where
663+
I: Iterator,
664+
F: FnMut(I::Item, I::Item) -> I::Item,
665+
{
666+
let first = iter.next()?;
667+
Some(iter.fold(first, f))
668+
}
669+
631670
#[cfg(test)]
632671
mod tests {
633672
use super::*;
@@ -650,7 +689,7 @@ mod tests {
650689
assert_eq!(base_address, array as *const _ as usize);
651690

652691
let data_range = data_range(array);
653-
assert_eq!(data_range.start, unsafe { array.data() } as usize);
692+
assert_eq!(data_range.start, array.data() as usize);
654693
assert_eq!(data_range.end, unsafe { array.data().add(15) } as usize);
655694
});
656695
}
@@ -668,7 +707,7 @@ mod tests {
668707
assert_eq!(base_address, base as usize);
669708

670709
let data_range = data_range(array);
671-
assert_eq!(data_range.start, unsafe { array.data() } as usize);
710+
assert_eq!(data_range.start, array.data() as usize);
672711
assert_eq!(data_range.end, unsafe { array.data().add(15) } as usize);
673712
});
674713
}
@@ -694,7 +733,7 @@ mod tests {
694733
assert_eq!(base_address, base as usize);
695734

696735
let data_range = data_range(view);
697-
assert_eq!(data_range.start, unsafe { view.data() } as usize);
736+
assert_eq!(data_range.start, view.data() as usize);
698737
assert_eq!(data_range.end, unsafe { view.data().add(12) } as usize);
699738
});
700739
}
@@ -724,7 +763,7 @@ mod tests {
724763
assert_eq!(base_address, base as usize);
725764

726765
let data_range = data_range(view);
727-
assert_eq!(data_range.start, unsafe { view.data() } as usize);
766+
assert_eq!(data_range.start, view.data() as usize);
728767
assert_eq!(data_range.end, unsafe { view.data().add(12) } as usize);
729768
});
730769
}
@@ -763,7 +802,7 @@ mod tests {
763802
assert_eq!(base_address, base as usize);
764803

765804
let data_range = data_range(view2);
766-
assert_eq!(data_range.start, unsafe { view2.data() } as usize);
805+
assert_eq!(data_range.start, view2.data() as usize);
767806
assert_eq!(data_range.end, unsafe { view2.data().add(6) } as usize);
768807
});
769808
}
@@ -806,7 +845,7 @@ mod tests {
806845
assert_eq!(base_address, base as usize);
807846

808847
let data_range = data_range(view2);
809-
assert_eq!(data_range.start, unsafe { view2.data() } as usize);
848+
assert_eq!(data_range.start, view2.data() as usize);
810849
assert_eq!(data_range.end, unsafe { view2.data().add(6) } as usize);
811850
});
812851
}
@@ -836,4 +875,63 @@ mod tests {
836875
assert_eq!(data_range.end, unsafe { view.data().offset(6) } as usize);
837876
});
838877
}
878+
879+
#[test]
880+
fn view_with_non_dividing_strides() {
881+
Python::with_gil(|py| {
882+
let array = PyArray::<f64, _>::zeros(py, (10, 10), false);
883+
let locals = [("array", array)].into_py_dict(py);
884+
885+
let view1 = py
886+
.eval("array[:,::3]", None, Some(locals))
887+
.unwrap()
888+
.downcast::<PyArray2<f64>>()
889+
.unwrap();
890+
891+
let key1 = BorrowKey::from_array(view1);
892+
893+
assert_eq!(view1.strides(), &[80, 24]);
894+
assert_eq!(key1.gcd_strides, 8);
895+
896+
let view2 = py
897+
.eval("array[:,1::3]", None, Some(locals))
898+
.unwrap()
899+
.downcast::<PyArray2<f64>>()
900+
.unwrap();
901+
902+
let key2 = BorrowKey::from_array(view2);
903+
904+
assert_eq!(view2.strides(), &[80, 24]);
905+
assert_eq!(key2.gcd_strides, 8);
906+
907+
let view3 = py
908+
.eval("array[:,::2]", None, Some(locals))
909+
.unwrap()
910+
.downcast::<PyArray2<f64>>()
911+
.unwrap();
912+
913+
let key3 = BorrowKey::from_array(view3);
914+
915+
assert_eq!(view3.strides(), &[80, 16]);
916+
assert_eq!(key3.gcd_strides, 16);
917+
918+
let view4 = py
919+
.eval("array[:,1::2]", None, Some(locals))
920+
.unwrap()
921+
.downcast::<PyArray2<f64>>()
922+
.unwrap();
923+
924+
let key4 = BorrowKey::from_array(view4);
925+
926+
assert_eq!(view4.strides(), &[80, 16]);
927+
assert_eq!(key4.gcd_strides, 16);
928+
929+
assert!(!key3.conflicts(&key4));
930+
assert!(key1.conflicts(&key3));
931+
assert!(key2.conflicts(&key4));
932+
933+
// This is a false conflict where all aliasing indices like (0,7) and (2,0) are out of bounds.
934+
assert!(!key1.conflicts(&key2));
935+
});
936+
}
839937
}

tests/borrow.rs

+19-8
Original file line numberDiff line numberDiff line change
@@ -212,28 +212,39 @@ fn conflict_due_reborrow_of_overlapping_views() {
212212
}
213213

214214
#[test]
215-
#[should_panic(expected = "AlreadyBorrowed")]
216-
fn interleaved_views_conflict() {
215+
fn interleaved_views_do_not_conflict() {
217216
Python::with_gil(|py| {
218-
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);
217+
let array = PyArray::<f64, _>::zeros(py, (23, 42, 3), false);
219218
let locals = [("array", array)].into_py_dict(py);
220219

221220
let view1 = py
222-
.eval("array[:,:,1]", None, Some(locals))
221+
.eval("array[:,:,0]", None, Some(locals))
223222
.unwrap()
224223
.downcast::<PyArray2<f64>>()
225224
.unwrap();
226-
assert_eq!(view1.shape(), [1, 2]);
225+
assert_eq!(view1.shape(), [23, 42]);
227226

228227
let view2 = py
228+
.eval("array[:,:,1]", None, Some(locals))
229+
.unwrap()
230+
.downcast::<PyArray2<f64>>()
231+
.unwrap();
232+
assert_eq!(view2.shape(), [23, 42]);
233+
234+
let view3 = py
229235
.eval("array[:,:,2]", None, Some(locals))
230236
.unwrap()
231237
.downcast::<PyArray2<f64>>()
232238
.unwrap();
233-
assert_eq!(view2.shape(), [1, 2]);
239+
assert_eq!(view2.shape(), [23, 42]);
234240

235-
let _exclusive1 = view1.readwrite();
236-
let _exclusive2 = view2.readwrite();
241+
let exclusive1 = view1.readwrite();
242+
let exclusive2 = view2.readwrite();
243+
let exclusive3 = view3.readwrite();
244+
245+
assert_eq!(exclusive3.len(), 23 * 42);
246+
assert_eq!(exclusive2.len(), 23 * 42);
247+
assert_eq!(exclusive1.len(), 23 * 42);
237248
});
238249
}
239250

0 commit comments

Comments
 (0)