Skip to content

Commit 8a072e4

Browse files
committed
WIP: Add dynamic borrow checking for dereferencing NumPy arrays.
1 parent 983514d commit 8a072e4

File tree

9 files changed

+758
-391
lines changed

9 files changed

+758
-391
lines changed

examples/simple/src/lib.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
2-
use numpy::{Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn};
2+
use numpy::{
3+
Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn, PyReadwriteArrayDyn,
4+
};
35
use pyo3::{
46
pymodule,
57
types::{PyDict, PyModule},
@@ -41,8 +43,8 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
4143
// wrapper of `mult`
4244
#[pyfn(m)]
4345
#[pyo3(name = "mult")]
44-
fn mult_py(a: f64, x: &PyArrayDyn<f64>) {
45-
let x = unsafe { x.as_array_mut() };
46+
fn mult_py(a: f64, mut x: PyReadwriteArrayDyn<f64>) {
47+
let x = x.as_array_mut();
4648
mult(a, x);
4749
}
4850

src/array.rs

Lines changed: 26 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@ use pyo3::{
1919
Python, ToPyObject,
2020
};
2121

22+
use crate::borrow::{PyReadonlyArray, PyReadwriteArray};
2223
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
2324
use crate::dtype::{Element, PyArrayDescr};
2425
use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
2526
use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API};
26-
#[allow(deprecated)]
27-
use crate::npyiter::{NpySingleIter, NpySingleIterBuilder, ReadWrite};
28-
use crate::readonly::PyReadonlyArray;
2927
use crate::slice_container::PySliceContainer;
3028

3129
/// A safe, static-typed interface for
@@ -194,18 +192,8 @@ impl<T, D> PyArray<T, D> {
194192
}
195193

196194
#[inline(always)]
197-
fn check_flag(&self, flag: c_int) -> bool {
198-
unsafe { *self.as_array_ptr() }.flags & flag == flag
199-
}
200-
201-
#[inline(always)]
202-
pub(crate) fn get_flag(&self) -> c_int {
203-
unsafe { *self.as_array_ptr() }.flags
204-
}
205-
206-
/// Returns a temporally unwriteable reference of the array.
207-
pub fn readonly(&self) -> PyReadonlyArray<T, D> {
208-
self.into()
195+
pub(crate) fn check_flags(&self, flags: c_int) -> bool {
196+
unsafe { *self.as_array_ptr() }.flags & flags != 0
209197
}
210198

211199
/// Returns `true` if the internal data of the array is C-style contiguous
@@ -227,18 +215,17 @@ impl<T, D> PyArray<T, D> {
227215
/// });
228216
/// ```
229217
pub fn is_contiguous(&self) -> bool {
230-
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
231-
| self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
218+
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS)
232219
}
233220

234221
/// Returns `true` if the internal data of the array is Fortran-style contiguous.
235222
pub fn is_fortran_contiguous(&self) -> bool {
236-
self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
223+
self.check_flags(npyffi::NPY_ARRAY_F_CONTIGUOUS)
237224
}
238225

239226
/// Returns `true` if the internal data of the array is C-style contiguous.
240227
pub fn is_c_contiguous(&self) -> bool {
241-
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
228+
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS)
242229
}
243230

244231
/// Get `Py<PyArray>` from `&PyArray`, which is the owned wrapper of PyObject.
@@ -827,28 +814,37 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
827814
ToPyArray::to_pyarray(arr, py)
828815
}
829816

830-
/// Get the immutable view of the internal data of `PyArray`, as
831-
/// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
817+
/// Get an immutable borrow of the NumPy array
818+
pub fn readonly(&self) -> PyReadonlyArray<'_, T, D> {
819+
PyReadonlyArray::try_new(self).unwrap()
820+
}
821+
822+
/// Get a mutable borrow of the NumPy array
823+
pub fn readwrite(&self) -> PyReadwriteArray<'_, T, D> {
824+
PyReadwriteArray::try_new(self).unwrap()
825+
}
826+
827+
/// Returns the internal array as [`ArrayView`].
832828
///
833-
/// Please consider the use of safe alternatives
834-
/// ([`PyReadonlyArray::as_array`](../struct.PyReadonlyArray.html#method.as_array)
835-
/// or [`to_array`](#method.to_array)) instead of this.
829+
/// See also [`PyArrayRef::as_array`].
836830
///
837831
/// # Safety
838-
/// If the internal array is not readonly and can be mutated from Python code,
839-
/// holding the `ArrayView` might cause undefined behavior.
832+
///
833+
/// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior.
840834
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
841835
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
842836
let mut res = ArrayView::from_shape_ptr(shape, ptr);
843837
inverted_axes.invert(&mut res);
844838
res
845839
}
846840

847-
/// Returns the internal array as [`ArrayViewMut`]. See also [`as_array`](#method.as_array).
841+
/// Returns the internal array as [`ArrayViewMut`].
842+
///
843+
/// See also [`PyArrayRefMut::as_array_mut`].
848844
///
849845
/// # Safety
850-
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
851-
/// it might cause undefined behavior.
846+
///
847+
/// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior.
852848
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
853849
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
854850
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
@@ -924,7 +920,7 @@ impl<D: Dimension> PyArray<PyObject, D> {
924920
///
925921
/// let pyarray = PyArray::from_owned_object_array(py, array);
926922
///
927-
/// assert!(pyarray.readonly().get(0).unwrap().as_ref(py).is_instance_of::<CustomElement>().unwrap());
923+
/// assert!(pyarray.readonly().as_array().get(0).unwrap().as_ref(py).is_instance_of::<CustomElement>().unwrap());
928924
/// });
929925
/// ```
930926
pub fn from_owned_object_array<'py, T>(py: Python<'py>, arr: Array<Py<T>, D>) -> &'py Self {
@@ -1073,21 +1069,6 @@ impl<T: Element> PyArray<T, Ix1> {
10731069
self.resize_(self.py(), [new_elems], 1, NPY_ORDER::NPY_ANYORDER)
10741070
}
10751071

1076-
/// Iterates all elements of this array.
1077-
/// See [NpySingleIter](../npyiter/struct.NpySingleIter.html) for more.
1078-
///
1079-
/// # Safety
1080-
///
1081-
/// The iterator will produce mutable references into the array which must not be
1082-
/// aliased by other references for the life time of the iterator.
1083-
#[deprecated(
1084-
note = "The wrappers of the array iterator API are deprecated, please use ndarray's `ArrayBase::iter_mut` instead."
1085-
)]
1086-
#[allow(deprecated)]
1087-
pub unsafe fn iter<'py>(&'py self) -> PyResult<NpySingleIter<'py, T, ReadWrite>> {
1088-
NpySingleIterBuilder::readwrite(self).build()
1089-
}
1090-
10911072
fn resize_<D: IntoDimension>(
10921073
&self,
10931074
py: Python,

0 commit comments

Comments
 (0)