Skip to content

Commit 51c0cd2

Browse files
committed
WIP: Add dynamic borrow checking for dereferencing NumPy arrays.
1 parent 61882e3 commit 51c0cd2

File tree

8 files changed

+345
-363
lines changed

8 files changed

+345
-363
lines changed

examples/simple-extension/src/lib.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
2-
use numpy::{Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn};
2+
use numpy::{
3+
Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn, PyReadwriteArrayDyn,
4+
};
35
use pyo3::{
46
pymodule,
57
types::{PyDict, PyModule},
@@ -41,8 +43,8 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
4143
// wrapper of `mult`
4244
#[pyfn(m)]
4345
#[pyo3(name = "mult")]
44-
fn mult_py(a: f64, x: &PyArrayDyn<f64>) {
45-
let x = unsafe { x.as_array_mut() };
46+
fn mult_py(a: f64, mut x: PyReadwriteArrayDyn<f64>) {
47+
let x = x.as_array_mut();
4648
mult(a, x);
4749
}
4850

src/array.rs

+26-35
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use pyo3::{
1818
Python, ToPyObject,
1919
};
2020

21+
use crate::borrow::{PyReadonlyArray, PyReadwriteArray};
2122
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
2223
use crate::dtype::Element;
2324
use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
@@ -190,18 +191,8 @@ impl<T, D> PyArray<T, D> {
190191
}
191192

192193
#[inline(always)]
193-
fn check_flag(&self, flag: c_int) -> bool {
194-
unsafe { *self.as_array_ptr() }.flags & flag == flag
195-
}
196-
197-
#[inline(always)]
198-
pub(crate) fn get_flag(&self) -> c_int {
199-
unsafe { *self.as_array_ptr() }.flags
200-
}
201-
202-
/// Returns a temporally unwriteable reference of the array.
203-
pub fn readonly(&self) -> crate::PyReadonlyArray<T, D> {
204-
self.into()
194+
pub(crate) fn check_flags(&self, flags: c_int) -> bool {
195+
unsafe { *self.as_array_ptr() }.flags & flags != 0
205196
}
206197

207198
/// Returns `true` if the internal data of the array is C-style contiguous
@@ -223,18 +214,17 @@ impl<T, D> PyArray<T, D> {
223214
/// });
224215
/// ```
225216
pub fn is_contiguous(&self) -> bool {
226-
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
227-
| self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
217+
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS)
228218
}
229219

230220
/// Returns `true` if the internal data of the array is Fortran-style contiguous.
231221
pub fn is_fortran_contiguous(&self) -> bool {
232-
self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
222+
self.check_flags(npyffi::NPY_ARRAY_F_CONTIGUOUS)
233223
}
234224

235225
/// Returns `true` if the internal data of the array is C-style contiguous.
236226
pub fn is_c_contiguous(&self) -> bool {
237-
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
227+
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS)
238228
}
239229

240230
/// Get `Py<PyArray>` from `&PyArray`, which is the owned wrapper of PyObject.
@@ -823,28 +813,37 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
823813
ToPyArray::to_pyarray(arr, py)
824814
}
825815

826-
/// Get the immutable view of the internal data of `PyArray`, as
827-
/// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
816+
/// Get an immutable borrow of the NumPy array
817+
pub fn readonly(&self) -> PyReadonlyArray<'_, T, D> {
818+
PyReadonlyArray::try_new(self).unwrap()
819+
}
820+
821+
/// Get a mutable borrow of the NumPy array
822+
pub fn readwrite(&self) -> PyReadwriteArray<'_, T, D> {
823+
PyReadwriteArray::try_new(self).unwrap()
824+
}
825+
826+
/// Returns the internal array as [`ArrayView`].
828827
///
829-
/// Please consider the use of safe alternatives
830-
/// ([`PyReadonlyArray::as_array`](../struct.PyReadonlyArray.html#method.as_array)
831-
/// or [`to_array`](#method.to_array)) instead of this.
828+
/// See also [`PyArrayRef::as_array`].
832829
///
833830
/// # Safety
834-
/// If the internal array is not readonly and can be mutated from Python code,
835-
/// holding the `ArrayView` might cause undefined behavior.
831+
///
832+
/// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior.
836833
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
837834
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
838835
let mut res = ArrayView::from_shape_ptr(shape, ptr);
839836
inverted_axes.invert(&mut res);
840837
res
841838
}
842839

843-
/// Returns the internal array as [`ArrayViewMut`]. See also [`as_array`](#method.as_array).
840+
/// Returns the internal array as [`ArrayViewMut`].
841+
///
842+
/// See also [`PyArrayRefMut::as_array_mut`].
844843
///
845844
/// # Safety
846-
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
847-
/// it might cause undefined behavior.
845+
///
846+
/// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior.
848847
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
849848
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
850849
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
@@ -920,7 +919,7 @@ impl<D: Dimension> PyArray<PyObject, D> {
920919
///
921920
/// let pyarray = PyArray::from_owned_object_array(py, array);
922921
///
923-
/// assert!(pyarray.readonly().get(0).unwrap().as_ref(py).is_instance::<CustomElement>().unwrap());
922+
/// assert!(pyarray.readonly().as_array().get(0).unwrap().as_ref(py).is_instance::<CustomElement>().unwrap());
924923
/// });
925924
/// ```
926925
pub fn from_owned_object_array<'py, T>(py: Python<'py>, arr: Array<Py<T>, D>) -> &'py Self {
@@ -1069,14 +1068,6 @@ impl<T: Element> PyArray<T, Ix1> {
10691068
self.resize_(self.py(), [new_elems], 1, NPY_ORDER::NPY_ANYORDER)
10701069
}
10711070

1072-
/// Iterates all elements of this array.
1073-
/// See [NpySingleIter](../npyiter/struct.NpySingleIter.html) for more.
1074-
pub fn iter<'py>(
1075-
&'py self,
1076-
) -> PyResult<crate::NpySingleIter<'py, T, crate::npyiter::ReadWrite>> {
1077-
crate::NpySingleIterBuilder::readwrite(self).build()
1078-
}
1079-
10801071
fn resize_<D: IntoDimension>(
10811072
&self,
10821073
py: Python,

src/borrow.rs

+229
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
use std::cell::UnsafeCell;
2+
use std::collections::hash_map::{Entry, HashMap};
3+
use std::ops::Deref;
4+
5+
use ndarray::{ArrayView, ArrayViewMut, Dimension, Ix1, Ix2, IxDyn};
6+
use pyo3::{FromPyObject, PyAny, PyResult};
7+
8+
use crate::array::PyArray;
9+
use crate::dtype::Element;
10+
use crate::error::{BorrowError, NotContiguousError};
11+
use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
12+
13+
struct BorrowFlags(UnsafeCell<Option<HashMap<usize, isize>>>);
14+
15+
unsafe impl Sync for BorrowFlags {}
16+
17+
impl BorrowFlags {
18+
const fn new() -> Self {
19+
Self(UnsafeCell::new(None))
20+
}
21+
22+
#[allow(clippy::mut_from_ref)]
23+
unsafe fn get(&self) -> &mut HashMap<usize, isize> {
24+
(*self.0.get()).get_or_insert_with(HashMap::new)
25+
}
26+
}
27+
28+
static BORROW_FLAGS: BorrowFlags = BorrowFlags::new();
29+
30+
pub struct PyReadonlyArray<'py, T, D>(&'py PyArray<T, D>);
31+
32+
pub type PyReadonlyArray1<'py, T> = PyReadonlyArray<'py, T, Ix1>;
33+
34+
pub type PyReadonlyArray2<'py, T> = PyReadonlyArray<'py, T, Ix2>;
35+
36+
pub type PyReadonlyArrayDyn<'py, T> = PyReadonlyArray<'py, T, IxDyn>;
37+
38+
impl<'py, T, D> Deref for PyReadonlyArray<'py, T, D> {
39+
type Target = PyArray<T, D>;
40+
41+
fn deref(&self) -> &Self::Target {
42+
self.0
43+
}
44+
}
45+
46+
impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyReadonlyArray<'py, T, D> {
47+
fn extract(obj: &'py PyAny) -> PyResult<Self> {
48+
let array: &'py PyArray<T, D> = obj.extract()?;
49+
Ok(array.readonly())
50+
}
51+
}
52+
53+
impl<'py, T, D> PyReadonlyArray<'py, T, D>
54+
where
55+
T: Element,
56+
D: Dimension,
57+
{
58+
pub(crate) fn try_new(array: &'py PyArray<T, D>) -> Result<Self, BorrowError> {
59+
let address = base_address(array);
60+
61+
// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
62+
// and we are not calling into user code which might re-enter this function.
63+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
64+
65+
match borrow_flags.entry(address) {
66+
Entry::Occupied(entry) => {
67+
let readers = entry.into_mut();
68+
69+
let new_readers = readers.wrapping_add(1);
70+
71+
if new_readers <= 0 {
72+
cold();
73+
return Err(BorrowError::AlreadyBorrowed);
74+
}
75+
76+
*readers = new_readers;
77+
}
78+
Entry::Vacant(entry) => {
79+
entry.insert(1);
80+
}
81+
}
82+
83+
Ok(Self(array))
84+
}
85+
86+
pub fn as_array(&self) -> ArrayView<T, D> {
87+
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
88+
// and `PyArray` is neither `Send` nor `Sync`
89+
unsafe { self.0.as_array() }
90+
}
91+
92+
pub fn as_slice(&self) -> Result<&[T], NotContiguousError> {
93+
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
94+
// and `PyArray` is neither `Send` nor `Sync`
95+
unsafe { self.0.as_slice() }
96+
}
97+
}
98+
99+
impl<'a, T, D> Drop for PyReadonlyArray<'a, T, D> {
100+
fn drop(&mut self) {
101+
let address = base_address(self.0);
102+
103+
// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
104+
// and we are not calling into user code which might re-enter this function.
105+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
106+
107+
let readers = borrow_flags.get_mut(&address).unwrap();
108+
109+
*readers -= 1;
110+
111+
if *readers == 0 {
112+
borrow_flags.remove(&address).unwrap();
113+
}
114+
}
115+
}
116+
117+
pub struct PyReadwriteArray<'py, T, D>(&'py PyArray<T, D>);
118+
119+
pub type PyReadwriteArrayDyn<'py, T> = PyReadwriteArray<'py, T, IxDyn>;
120+
121+
impl<'py, T, D> Deref for PyReadwriteArray<'py, T, D> {
122+
type Target = PyArray<T, D>;
123+
124+
fn deref(&self) -> &Self::Target {
125+
self.0
126+
}
127+
}
128+
129+
impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyReadwriteArray<'py, T, D> {
130+
fn extract(obj: &'py PyAny) -> PyResult<Self> {
131+
let array: &'py PyArray<T, D> = obj.extract()?;
132+
Ok(array.readwrite())
133+
}
134+
}
135+
136+
impl<'py, T, D> PyReadwriteArray<'py, T, D>
137+
where
138+
T: Element,
139+
D: Dimension,
140+
{
141+
pub(crate) fn try_new(array: &'py PyArray<T, D>) -> Result<Self, BorrowError> {
142+
if !array.check_flags(NPY_ARRAY_WRITEABLE) {
143+
return Err(BorrowError::NotWriteable);
144+
}
145+
146+
let address = base_address(array);
147+
148+
// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
149+
// and we are not calling into user code which might re-enter this function.
150+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
151+
152+
match borrow_flags.entry(address) {
153+
Entry::Occupied(entry) => {
154+
let writers = entry.into_mut();
155+
156+
if *writers != 0 {
157+
cold();
158+
return Err(BorrowError::AlreadyBorrowed);
159+
}
160+
161+
*writers = -1;
162+
}
163+
Entry::Vacant(entry) => {
164+
entry.insert(-1);
165+
}
166+
}
167+
168+
Ok(Self(array))
169+
}
170+
171+
pub fn as_array(&self) -> ArrayView<T, D> {
172+
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
173+
// and `PyArray` is neither `Send` nor `Sync`
174+
unsafe { self.0.as_array() }
175+
}
176+
177+
pub fn as_slice(&self) -> Result<&[T], NotContiguousError> {
178+
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
179+
// and `PyArray` is neither `Send` nor `Sync`
180+
unsafe { self.0.as_slice() }
181+
}
182+
183+
pub fn as_array_mut(&mut self) -> ArrayViewMut<T, D> {
184+
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
185+
// and `PyArray` is neither `Send` nor `Sync`
186+
unsafe { self.0.as_array_mut() }
187+
}
188+
189+
pub fn as_slice_mut(&self) -> Result<&mut [T], NotContiguousError> {
190+
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
191+
// and `PyArray` is neither `Send` nor `Sync`
192+
unsafe { self.0.as_slice_mut() }
193+
}
194+
}
195+
196+
impl<'a, T, D> Drop for PyReadwriteArray<'a, T, D> {
197+
fn drop(&mut self) {
198+
let address = base_address(self.0);
199+
200+
// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
201+
// and we are not calling into user code which might re-enter this function.
202+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
203+
204+
borrow_flags.remove(&address).unwrap();
205+
}
206+
}
207+
208+
// FIXME(adamreichold): This is a coarse approximation and needs to be refined,
209+
// i.e. borrows of non-overlapping views into the same base should not be considered conflicting.
210+
fn base_address<T, D>(array: &PyArray<T, D>) -> usize {
211+
let py = array.py();
212+
let mut array = array.as_array_ptr();
213+
214+
loop {
215+
let base = unsafe { (*array).base };
216+
217+
if base.is_null() {
218+
return unsafe { (*array).data } as usize;
219+
} else if unsafe { npyffi::PyArray_Check(py, base) } != 0 {
220+
array = base as *mut PyArrayObject;
221+
} else {
222+
return base as usize;
223+
}
224+
}
225+
}
226+
227+
#[cold]
228+
#[inline(always)]
229+
fn cold() {}

src/error.rs

+19
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,22 @@ impl fmt::Display for NotContiguousError {
113113
}
114114

115115
impl_pyerr!(NotContiguousError);
116+
117+
/// Inidcates why borrowing an array failed.
118+
#[derive(Debug)]
119+
#[non_exhaustive]
120+
pub enum BorrowError {
121+
AlreadyBorrowed,
122+
NotWriteable,
123+
}
124+
125+
impl fmt::Display for BorrowError {
126+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
127+
match self {
128+
Self::AlreadyBorrowed => write!(f, "The given array is already borrowed"),
129+
Self::NotWriteable => write!(f, "The given is not writeable"),
130+
}
131+
}
132+
}
133+
134+
impl_pyerr!(BorrowError);

0 commit comments

Comments
 (0)