Skip to content

Commit 542b587

Browse files
committed
WIP: Add dynamic borrow checking for dereferencing NumPy arrays.
1 parent 61882e3 commit 542b587

File tree

8 files changed

+221
-46
lines changed

8 files changed

+221
-46
lines changed

README.md

+11-14
Original file line numberDiff line numberDiff line change
@@ -44,40 +44,37 @@ numpy = "0.15"
4444
```
4545

4646
```rust
47-
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
48-
use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn};
47+
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD, IxDyn};
48+
use numpy::{IntoPyArray, PyArrayDyn, PyArrayRef, PyArrayRefMut};
4949
use pyo3::prelude::{pymodule, PyModule, PyResult, Python};
5050

5151
#[pymodule]
5252
fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
5353
// immutable example
54-
fn axpy(a: f64, x: ArrayViewD<'_, f64>, y: ArrayViewD<'_, f64>) -> ArrayD<f64> {
55-
a * &x + &y
54+
fn axpy(a: f64, x: &ArrayViewD<'_, f64>, y: &ArrayViewD<'_, f64>) -> ArrayD<f64> {
55+
a * x + y
5656
}
5757

5858
// mutable example (no return)
59-
fn mult(a: f64, mut x: ArrayViewMutD<'_, f64>) {
60-
x *= a;
59+
fn mult(a: f64, x: &mut ArrayViewMutD<'_, f64>) {
60+
*x *= a;
6161
}
6262

6363
// wrapper of `axpy`
6464
#[pyfn(m, "axpy")]
6565
fn axpy_py<'py>(
6666
py: Python<'py>,
6767
a: f64,
68-
x: PyReadonlyArrayDyn<f64>,
69-
y: PyReadonlyArrayDyn<f64>,
68+
x: PyArrayRef<f64, IxDyn>,
69+
y: PyArrayRef<f64, IxDyn>,
7070
) -> &'py PyArrayDyn<f64> {
71-
let x = x.as_array();
72-
let y = y.as_array();
73-
axpy(a, x, y).into_pyarray(py)
71+
axpy(a, &x, &y).into_pyarray(py)
7472
}
7573

7674
// wrapper of `mult`
7775
#[pyfn(m, "mult")]
78-
fn mult_py(_py: Python<'_>, a: f64, x: &PyArrayDyn<f64>) -> PyResult<()> {
79-
let x = unsafe { x.as_array_mut() };
80-
mult(a, x);
76+
fn mult_py(_py: Python<'_>, a: f64, mut x: PyArrayRefMut<f64, IxDyn>) -> PyResult<()> {
77+
mult(a, &mut x);
8178
Ok(())
8279
}
8380

examples/linalg/src/lib.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
use ndarray_linalg::solve::Inverse;
2-
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2};
2+
use numpy::{IntoPyArray, Ix2, PyArray2, PyArrayRef};
33
use pyo3::{exceptions::PyRuntimeError, pymodule, types::PyModule, PyErr, PyResult, Python};
44

55
#[pymodule]
66
fn rust_linalg(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
77
#[pyfn(m)]
8-
fn inv<'py>(py: Python<'py>, x: PyReadonlyArray2<'py, f64>) -> PyResult<&'py PyArray2<f64>> {
9-
let x = x.as_array();
8+
fn inv<'py>(py: Python<'py>, x: PyArrayRef<'py, f64, Ix2>) -> PyResult<&'py PyArray2<f64>> {
109
let y = x
1110
.inv()
1211
.map_err(|e| PyErr::new::<PyRuntimeError, _>(format!("[rust_linalg] {}", e)))?;

examples/parallel/src/lib.rs

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
// We need to link `blas_src` directly, c.f. https://github.com/rust-ndarray/ndarray#how-to-enable-blas-integration
22
extern crate blas_src;
33

4-
use ndarray::Zip;
5-
use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, PyReadonlyArray2};
4+
use numpy::ndarray::{ArrayView1, Zip};
5+
use numpy::{IntoPyArray, Ix1, Ix2, PyArray1, PyArrayRef};
66
use pyo3::{pymodule, types::PyModule, PyResult, Python};
77

88
#[pymodule]
99
fn rust_parallel(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
1010
#[pyfn(m)]
1111
fn rows_dot<'py>(
1212
py: Python<'py>,
13-
x: PyReadonlyArray2<'py, f64>,
14-
y: PyReadonlyArray1<'py, f64>,
13+
x: PyArrayRef<'py, f64, Ix2>,
14+
y: PyArrayRef<'py, f64, Ix1>,
1515
) -> &'py PyArray1<f64> {
16-
let x = x.as_array();
17-
let y = y.as_array();
18-
let z = Zip::from(x.rows()).par_map_collect(|row| row.dot(&y));
16+
let y: &ArrayView1<f64> = &y;
17+
let z = Zip::from(x.rows()).par_map_collect(|row| row.dot(y));
1918
z.into_pyarray(py)
2019
}
2120
Ok(())

examples/simple-extension/src/lib.rs

+9-9
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ use pyo3::{
99
#[pymodule]
1010
fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
1111
// immutable example
12-
fn axpy(a: f64, x: ArrayViewD<'_, f64>, y: ArrayViewD<'_, f64>) -> ArrayD<f64> {
13-
a * &x + &y
12+
fn axpy(a: f64, x: &ArrayViewD<'_, f64>, y: &ArrayViewD<'_, f64>) -> ArrayD<f64> {
13+
a * x + y
1414
}
1515

1616
// mutable example (no return)
17-
fn mult(a: f64, mut x: ArrayViewMutD<'_, f64>) {
18-
x *= a;
17+
fn mult(a: f64, x: &mut ArrayViewMutD<'_, f64>) {
18+
*x *= a;
1919
}
2020

2121
// complex example
22-
fn conj(x: ArrayViewD<'_, Complex64>) -> ArrayD<Complex64> {
22+
fn conj(x: &ArrayViewD<'_, Complex64>) -> ArrayD<Complex64> {
2323
x.map(|c| c.conj())
2424
}
2525

@@ -34,16 +34,16 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
3434
) -> &'py PyArrayDyn<f64> {
3535
let x = x.as_array();
3636
let y = y.as_array();
37-
let z = axpy(a, x, y);
37+
let z = axpy(a, &x, &y);
3838
z.into_pyarray(py)
3939
}
4040

4141
// wrapper of `mult`
4242
#[pyfn(m)]
4343
#[pyo3(name = "mult")]
4444
fn mult_py(a: f64, x: &PyArrayDyn<f64>) {
45-
let x = unsafe { x.as_array_mut() };
46-
mult(a, x);
45+
let mut x = x.as_array_mut();
46+
mult(a, &mut x);
4747
}
4848

4949
// wrapper of `conj`
@@ -53,7 +53,7 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
5353
py: Python<'py>,
5454
x: PyReadonlyArrayDyn<'_, Complex64>,
5555
) -> &'py PyArrayDyn<Complex64> {
56-
conj(x.as_array()).into_pyarray(py)
56+
conj(&x.as_array()).into_pyarray(py)
5757
}
5858

5959
#[pyfn(m)]

src/array.rs

+20-12
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use pyo3::{
1818
Python, ToPyObject,
1919
};
2020

21+
use crate::borrow::{PyArrayRef, PyArrayRefMut};
2122
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
2223
use crate::dtype::Element;
2324
use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
@@ -825,27 +826,34 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
825826

826827
/// Get the immutable view of the internal data of `PyArray`, as
827828
/// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
828-
///
829-
/// Please consider the use of safe alternatives
830-
/// ([`PyReadonlyArray::as_array`](../struct.PyReadonlyArray.html#method.as_array)
831-
/// or [`to_array`](#method.to_array)) instead of this.
829+
pub fn as_array(&self) -> PyArrayRef<'_, T, D> {
830+
PyArrayRef::try_new(self).expect("NumPy array already borrowed")
831+
}
832+
833+
/// Get the immutable view of the internal data of `PyArray`, as
834+
/// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
835+
pub fn as_array_mut(&self) -> PyArrayRefMut<'_, T, D> {
836+
PyArrayRefMut::try_new(self).expect("NumPy array already borrowed")
837+
}
838+
839+
/// Returns the internal array as [`ArrayView`]. See also [`as_array_unchecked`].
832840
///
833841
/// # Safety
834-
/// If the internal array is not readonly and can be mutated from Python code,
835-
/// holding the `ArrayView` might cause undefined behavior.
836-
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
842+
///
843+
/// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior.
844+
pub unsafe fn as_array_unchecked(&self) -> ArrayView<'_, T, D> {
837845
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
838846
let mut res = ArrayView::from_shape_ptr(shape, ptr);
839847
inverted_axes.invert(&mut res);
840848
res
841849
}
842850

843-
/// Returns the internal array as [`ArrayViewMut`]. See also [`as_array`](#method.as_array).
851+
/// Returns the internal array as [`ArrayViewMut`]. See also [`as_array_unchecked`].
844852
///
845853
/// # Safety
846-
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
847-
/// it might cause undefined behavior.
848-
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
854+
///
855+
/// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior.
856+
pub unsafe fn as_array_mut_unchecked(&self) -> ArrayViewMut<'_, T, D> {
849857
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
850858
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
851859
inverted_axes.invert(&mut res);
@@ -884,7 +892,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
884892
/// });
885893
/// ```
886894
pub fn to_owned_array(&self) -> Array<T, D> {
887-
unsafe { self.as_array() }.to_owned()
895+
unsafe { self.as_array_unchecked() }.to_owned()
888896
}
889897
}
890898

src/borrow.rs

+170
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
use std::cell::UnsafeCell;
2+
use std::collections::hash_map::{Entry, HashMap};
3+
use std::ops::{Deref, DerefMut};
4+
5+
use ndarray::{ArrayView, ArrayViewMut, Dimension};
6+
use pyo3::{FromPyObject, PyAny, PyResult};
7+
8+
use crate::array::PyArray;
9+
use crate::dtype::Element;
10+
11+
thread_local! {
12+
static BORROW_FLAGS: UnsafeCell<HashMap<usize, isize>> = UnsafeCell::new(HashMap::new());
13+
}
14+
15+
pub struct PyArrayRef<'a, T, D> {
16+
array: &'a PyArray<T, D>,
17+
view: ArrayView<'a, T, D>,
18+
}
19+
20+
impl<'a, T, D> Deref for PyArrayRef<'a, T, D> {
21+
type Target = ArrayView<'a, T, D>;
22+
23+
fn deref(&self) -> &Self::Target {
24+
&self.view
25+
}
26+
}
27+
28+
impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyArrayRef<'py, T, D> {
29+
fn extract(obj: &'py PyAny) -> PyResult<Self> {
30+
let array: &'py PyArray<T, D> = obj.extract()?;
31+
Ok(array.as_array())
32+
}
33+
}
34+
35+
impl<'a, T, D> PyArrayRef<'a, T, D>
36+
where
37+
T: Element,
38+
D: Dimension,
39+
{
40+
pub(crate) fn try_new(array: &'a PyArray<T, D>) -> Option<Self> {
41+
let address = array as *const PyArray<T, D> as usize;
42+
43+
BORROW_FLAGS.with(|borrow_flags| {
44+
// SAFETY: Called on a thread local variable in a leaf function.
45+
let borrow_flags = unsafe { &mut *borrow_flags.get() };
46+
47+
match borrow_flags.entry(address) {
48+
Entry::Occupied(entry) => {
49+
let readers = entry.into_mut();
50+
51+
let new_readers = readers.wrapping_add(1);
52+
53+
if new_readers <= 0 {
54+
cold();
55+
return None;
56+
}
57+
58+
*readers = new_readers;
59+
}
60+
Entry::Vacant(entry) => {
61+
entry.insert(1);
62+
}
63+
}
64+
65+
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
66+
// and `PyArray` is neither `Send` nor `Sync`
67+
let view = unsafe { array.as_array_unchecked() };
68+
69+
Some(Self { array, view })
70+
})
71+
}
72+
}
73+
74+
impl<'a, T, D> Drop for PyArrayRef<'a, T, D> {
75+
fn drop(&mut self) {
76+
let address = self.array as *const PyArray<T, D> as usize;
77+
78+
BORROW_FLAGS.with(|borrow_flags| {
79+
// SAFETY: Called on a thread local variable in a leaf function.
80+
let borrow_flags = unsafe { &mut *borrow_flags.get() };
81+
82+
let readers = borrow_flags.get_mut(&address).unwrap();
83+
84+
*readers -= 1;
85+
86+
if *readers == 0 {
87+
borrow_flags.remove(&address).unwrap();
88+
}
89+
});
90+
}
91+
}
92+
93+
pub struct PyArrayRefMut<'a, T, D> {
94+
array: &'a PyArray<T, D>,
95+
view: ArrayViewMut<'a, T, D>,
96+
}
97+
98+
impl<'a, T, D> Deref for PyArrayRefMut<'a, T, D> {
99+
type Target = ArrayViewMut<'a, T, D>;
100+
101+
fn deref(&self) -> &Self::Target {
102+
&self.view
103+
}
104+
}
105+
106+
impl<'a, T, D> DerefMut for PyArrayRefMut<'a, T, D> {
107+
fn deref_mut(&mut self) -> &mut Self::Target {
108+
&mut self.view
109+
}
110+
}
111+
112+
impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyArrayRefMut<'py, T, D> {
113+
fn extract(obj: &'py PyAny) -> PyResult<Self> {
114+
let array: &'py PyArray<T, D> = obj.extract()?;
115+
Ok(array.as_array_mut())
116+
}
117+
}
118+
119+
impl<'a, T, D> PyArrayRefMut<'a, T, D>
120+
where
121+
T: Element,
122+
D: Dimension,
123+
{
124+
pub(crate) fn try_new(array: &'a PyArray<T, D>) -> Option<Self> {
125+
let address = array as *const PyArray<T, D> as usize;
126+
127+
BORROW_FLAGS.with(|borrow_flags| {
128+
// SAFETY: Called on a thread local variable in a leaf function.
129+
let borrow_flags = unsafe { &mut *borrow_flags.get() };
130+
131+
match borrow_flags.entry(address) {
132+
Entry::Occupied(entry) => {
133+
let writers = entry.into_mut();
134+
135+
if *writers != 0 {
136+
cold();
137+
return None;
138+
}
139+
140+
*writers = -1;
141+
}
142+
Entry::Vacant(entry) => {
143+
entry.insert(-1);
144+
}
145+
}
146+
147+
// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
148+
// and `PyArray` is neither `Send` nor `Sync`
149+
let view = unsafe { array.as_array_mut_unchecked() };
150+
151+
Some(Self { array, view })
152+
})
153+
}
154+
}
155+
156+
impl<'a, T, D> Drop for PyArrayRefMut<'a, T, D> {
157+
fn drop(&mut self) {
158+
let address = self.array as *const PyArray<T, D> as usize;
159+
160+
BORROW_FLAGS.with(|borrow_flags| {
161+
// SAFETY: Called on a thread local variable in a leaf function.
162+
let borrow_flags = unsafe { &mut *borrow_flags.get() };
163+
164+
borrow_flags.remove(&address).unwrap();
165+
});
166+
}
167+
}
168+
#[cold]
169+
#[inline(always)]
170+
fn cold() {}

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#![allow(clippy::needless_lifetimes)] // We often want to make the GIL lifetime explicit.
3131

3232
pub mod array;
33+
mod borrow;
3334
pub mod convert;
3435
mod dtype;
3536
mod error;
@@ -46,6 +47,7 @@ pub use crate::array::{
4647
get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5,
4748
PyArray6, PyArrayDyn,
4849
};
50+
pub use crate::borrow::{PyArrayRef, PyArrayRefMut};
4951
pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
5052
pub use crate::dtype::{dtype, Complex32, Complex64, Element, PyArrayDescr};
5153
pub use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};

0 commit comments

Comments
 (0)