Skip to content

Commit e317dc9

Browse files
committed
Mark the read-write variants of the NumPy iterators unsafe for the same reason that PyArray::as_array is unsafe.
1 parent bff49a4 commit e317dc9

File tree

3 files changed

+39
-22
lines changed

3 files changed

+39
-22
lines changed

src/array.rs

+10-6
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ use pyo3::{
2020
};
2121

2222
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
23-
use crate::dtype::Element;
23+
use crate::dtype::{Element, PyArrayDescr};
2424
use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
2525
use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API};
26+
use crate::npyiter::{NpySingleIter, NpySingleIterBuilder, ReadWrite};
2627
use crate::readonly::PyReadonlyArray;
2728
use crate::slice_container::PySliceContainer;
2829

@@ -186,7 +187,7 @@ impl<T, D> PyArray<T, D> {
186187
/// assert!(dtype.is_equiv_to(numpy::dtype::<i32>(py)));
187188
/// });
188189
/// ```
189-
pub fn dtype(&self) -> &crate::PyArrayDescr {
190+
pub fn dtype(&self) -> &PyArrayDescr {
190191
let descr_ptr = unsafe { (*self.as_array_ptr()).descr };
191192
unsafe { pyo3::FromPyPointer::from_borrowed_ptr(self.py(), descr_ptr as _) }
192193
}
@@ -1073,10 +1074,13 @@ impl<T: Element> PyArray<T, Ix1> {
10731074

10741075
/// Iterates all elements of this array.
10751076
/// See [NpySingleIter](../npyiter/struct.NpySingleIter.html) for more.
1076-
pub fn iter<'py>(
1077-
&'py self,
1078-
) -> PyResult<crate::NpySingleIter<'py, T, crate::npyiter::ReadWrite>> {
1079-
crate::NpySingleIterBuilder::readwrite(self).build()
1077+
///
1078+
/// # Safety
1079+
///
1080+
/// The iterator will produce mutable references into the array which must not be
1081+
/// aliased by other references for the life time of the iterator.
1082+
pub unsafe fn iter<'py>(&'py self) -> PyResult<NpySingleIter<'py, T, ReadWrite>> {
1083+
NpySingleIterBuilder::readwrite(self).build()
10801084
}
10811085

10821086
fn resize_<D: IntoDimension>(

src/npyiter.rs

+22-11
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,12 @@ impl<'py, T: Element> NpySingleIterBuilder<'py, T, Readonly> {
174174

175175
impl<'py, T: Element> NpySingleIterBuilder<'py, T, ReadWrite> {
176176
/// Makes a new builder for a writable iterator.
177-
pub fn readwrite<D: ndarray::Dimension>(array: &'py PyArray<T, D>) -> Self {
177+
///
178+
/// # Safety
179+
///
180+
/// The iterator will produce mutable references into the array which must not be
181+
/// aliased by other references for the life time of the iterator.
182+
pub unsafe fn readwrite<D: ndarray::Dimension>(array: &'py PyArray<T, D>) -> Self {
178183
Self {
179184
flags: NPY_ITER_READWRITE,
180185
array: array.to_dyn(),
@@ -230,7 +235,7 @@ impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> {
230235
/// use numpy::NpySingleIterBuilder;
231236
/// pyo3::Python::with_gil(|py| {
232237
/// let array = numpy::PyArray::arange(py, 0, 10, 1);
233-
/// let iter = NpySingleIterBuilder::readwrite(array).build().unwrap();
238+
/// let iter = unsafe { NpySingleIterBuilder::readwrite(array).build().unwrap() };
234239
/// for (i, elem) in iter.enumerate() {
235240
/// assert_eq!(*elem, i as i64);
236241
/// *elem = *elem * 2; // elements are mutable
@@ -242,8 +247,7 @@ impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> {
242247
/// # use numpy::NpySingleIterBuilder;
243248
/// # pyo3::Python::with_gil(|py| {
244249
/// # let array = numpy::PyArray::arange(py, 0, 10, 1);
245-
/// # let iter = NpySingleIterBuilder::readwrite(array).build().unwrap();
246-
/// for (i, elem) in array.iter().unwrap().enumerate() {
250+
/// for (i, elem) in unsafe { array.iter().unwrap().enumerate() } {
247251
/// assert_eq!(*elem, i as i64);
248252
/// *elem = *elem * 2; // elements are mutable
249253
/// }
@@ -416,7 +420,12 @@ impl<'py, T: Element, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> {
416420
}
417421

418422
/// Adds a writable array to the resulting iterator.
419-
pub fn add_readwrite<D: ndarray::Dimension>(
423+
///
424+
/// # Safety
425+
///
426+
/// The iterator will produce mutable references into the array which must not be
427+
/// aliased by other references for the life time of the iterator.
428+
pub unsafe fn add_readwrite<D: ndarray::Dimension>(
420429
mut self,
421430
array: &'py PyArray<T, D>,
422431
) -> NpyMultiIterBuilder<'py, T, RW<S>> {
@@ -483,12 +492,14 @@ impl<'py, T: Element, S: MultiIterModeWithManyArrays> NpyMultiIterBuilder<'py, T
483492
/// let array1 = numpy::PyArray::arange(py, 0, 10, 1);
484493
/// let array2 = numpy::PyArray::arange(py, 10, 20, 1);
485494
/// let array3 = numpy::PyArray::arange(py, 10, 30, 2);
486-
/// let iter = NpyMultiIterBuilder::new()
487-
/// .add_readonly(array1.readonly())
488-
/// .add_readwrite(array2)
489-
/// .add_readonly(array3.readonly())
490-
/// .build()
491-
/// .unwrap();
495+
/// let iter = unsafe {
496+
/// NpyMultiIterBuilder::new()
497+
/// .add_readonly(array1.readonly())
498+
/// .add_readwrite(array2)
499+
/// .add_readonly(array3.readonly())
500+
/// .build()
501+
/// .unwrap()
502+
/// };
492503
/// for (i, j, k) in iter {
493504
/// assert_eq!(*i + *j, *k);
494505
/// *j += *i + *k; // The third element is only mutable.

tests/iter.rs

+7-5
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ fn mutable_iter() -> PyResult<()> {
2828
let data = array![[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]];
2929
pyo3::Python::with_gil(|py| {
3030
let arr = PyArray::from_array(py, &data);
31-
let iter = NpySingleIterBuilder::readwrite(arr).build()?;
31+
let iter = unsafe { NpySingleIterBuilder::readwrite(arr).build()? };
3232
for elem in iter {
3333
*elem *= 2.0;
3434
}
@@ -71,10 +71,12 @@ fn multiiter_rw() -> PyResult<()> {
7171
pyo3::Python::with_gil(|py| {
7272
let arr1 = PyArray::from_array(py, &data1);
7373
let arr2 = PyArray::from_array(py, &data2);
74-
let iter = NpyMultiIterBuilder::new()
75-
.add_readonly(arr1.readonly())
76-
.add_readwrite(arr2)
77-
.build()?;
74+
let iter = unsafe {
75+
NpyMultiIterBuilder::new()
76+
.add_readonly(arr1.readonly())
77+
.add_readwrite(arr2)
78+
.build()?
79+
};
7880

7981
for (x, y) in iter {
8082
*y = *x * 2.0;

0 commit comments

Comments
 (0)