Skip to content

RFC: Mark the read-write variants of the NumPy iterators unsafe #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Support borrowing arrays that are part of other Python objects via `PyArray::borrow_from_array` ([#230](https://github.com/PyO3/rust-numpy/pull/216))
- Fixed downcasting ignoring element type and dimensionality ([#265](https://github.com/PyO3/rust-numpy/pull/265))
- `PyArray::new` is now `unsafe`, as it produces uninitialized arrays ([#220](https://github.com/PyO3/rust-numpy/pull/220))
- `PyArray::iter`, `NpySingleIterBuilder::readwrite` and `NpyMultiIterBuilder::add_readwrite` are now `unsafe`, as they allow aliasing mutable references to be created ([#278/](https://github.com/PyO3/rust-numpy/pull/278))
- `PyArray::from_exact_iter` does not unsoundly trust `ExactSizeIterator::len` any more ([#262](https://github.com/PyO3/rust-numpy/pull/262))
- `PyArray::as_cell_slice` was removed as it unsoundly interacts with `PyReadonlyArray` allowing safe code to violate aliasing rules ([#260](https://github.com/PyO3/rust-numpy/pull/260))
- `rayon` feature is now removed, and directly specifying the feature via `ndarray` dependency is recommended ([#250](https://github.com/PyO3/rust-numpy/pull/250))
Expand All @@ -19,10 +20,8 @@
- `i32`, `i64`, `u32`, `u64` are now guaranteed to map to `np.u?int{32,64}`.
- Removed `cfg_if` dependency
- Removed `DataType` enum
- Added `PyArrayDescr::new` constructor
([#266](https://github.com/PyO3/rust-numpy/pull/266))
- New `PyArrayDescr` methods
([#266](https://github.com/PyO3/rust-numpy/pull/261)):
- Added `PyArrayDescr::new` constructor ([#266](https://github.com/PyO3/rust-numpy/pull/266))
- New `PyArrayDescr` methods ([#266](https://github.com/PyO3/rust-numpy/pull/261)):
- `num`, `base`, `ndim`, `shape`, `byteorder`, `char`, `kind`, `itemsize`,
`alignment`, `flags`, `has_object`, `is_aligned_struct`, `names`,
`get_field`, `has_subarray`, `has_fields`, `is_native_byteorder`
Expand Down
16 changes: 10 additions & 6 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ use pyo3::{
};

use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
use crate::dtype::Element;
use crate::dtype::{Element, PyArrayDescr};
use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API};
use crate::npyiter::{NpySingleIter, NpySingleIterBuilder, ReadWrite};
use crate::readonly::PyReadonlyArray;
use crate::slice_container::PySliceContainer;

Expand Down Expand Up @@ -186,7 +187,7 @@ impl<T, D> PyArray<T, D> {
/// assert!(dtype.is_equiv_to(numpy::dtype::<i32>(py)));
/// });
/// ```
pub fn dtype(&self) -> &crate::PyArrayDescr {
pub fn dtype(&self) -> &PyArrayDescr {
let descr_ptr = unsafe { (*self.as_array_ptr()).descr };
unsafe { pyo3::FromPyPointer::from_borrowed_ptr(self.py(), descr_ptr as _) }
}
Expand Down Expand Up @@ -1073,10 +1074,13 @@ impl<T: Element> PyArray<T, Ix1> {

/// Iterates all elements of this array.
/// See [NpySingleIter](../npyiter/struct.NpySingleIter.html) for more.
pub fn iter<'py>(
&'py self,
) -> PyResult<crate::NpySingleIter<'py, T, crate::npyiter::ReadWrite>> {
crate::NpySingleIterBuilder::readwrite(self).build()
///
/// # Safety
///
/// The iterator will produce mutable references into the array which must not be
/// aliased by other references for the life time of the iterator.
pub unsafe fn iter<'py>(&'py self) -> PyResult<NpySingleIter<'py, T, ReadWrite>> {
NpySingleIterBuilder::readwrite(self).build()
}

fn resize_<D: IntoDimension>(
Expand Down
33 changes: 22 additions & 11 deletions src/npyiter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,12 @@ impl<'py, T: Element> NpySingleIterBuilder<'py, T, Readonly> {

impl<'py, T: Element> NpySingleIterBuilder<'py, T, ReadWrite> {
/// Makes a new builder for a writable iterator.
pub fn readwrite<D: ndarray::Dimension>(array: &'py PyArray<T, D>) -> Self {
///
/// # Safety
///
/// The iterator will produce mutable references into the array which must not be
/// aliased by other references for the life time of the iterator.
pub unsafe fn readwrite<D: ndarray::Dimension>(array: &'py PyArray<T, D>) -> Self {
Self {
flags: NPY_ITER_READWRITE,
array: array.to_dyn(),
Expand Down Expand Up @@ -230,7 +235,7 @@ impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> {
/// use numpy::NpySingleIterBuilder;
/// pyo3::Python::with_gil(|py| {
/// let array = numpy::PyArray::arange(py, 0, 10, 1);
/// let iter = NpySingleIterBuilder::readwrite(array).build().unwrap();
/// let iter = unsafe { NpySingleIterBuilder::readwrite(array).build().unwrap() };
/// for (i, elem) in iter.enumerate() {
/// assert_eq!(*elem, i as i64);
/// *elem = *elem * 2; // elements are mutable
Expand All @@ -242,8 +247,7 @@ impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> {
/// # use numpy::NpySingleIterBuilder;
/// # pyo3::Python::with_gil(|py| {
/// # let array = numpy::PyArray::arange(py, 0, 10, 1);
/// # let iter = NpySingleIterBuilder::readwrite(array).build().unwrap();
/// for (i, elem) in array.iter().unwrap().enumerate() {
/// for (i, elem) in unsafe { array.iter().unwrap().enumerate() } {
/// assert_eq!(*elem, i as i64);
/// *elem = *elem * 2; // elements are mutable
/// }
Expand Down Expand Up @@ -416,7 +420,12 @@ impl<'py, T: Element, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> {
}

/// Adds a writable array to the resulting iterator.
pub fn add_readwrite<D: ndarray::Dimension>(
///
/// # Safety
///
/// The iterator will produce mutable references into the array which must not be
/// aliased by other references for the life time of the iterator.
pub unsafe fn add_readwrite<D: ndarray::Dimension>(
mut self,
array: &'py PyArray<T, D>,
) -> NpyMultiIterBuilder<'py, T, RW<S>> {
Expand Down Expand Up @@ -483,12 +492,14 @@ impl<'py, T: Element, S: MultiIterModeWithManyArrays> NpyMultiIterBuilder<'py, T
/// let array1 = numpy::PyArray::arange(py, 0, 10, 1);
/// let array2 = numpy::PyArray::arange(py, 10, 20, 1);
/// let array3 = numpy::PyArray::arange(py, 10, 30, 2);
/// let iter = NpyMultiIterBuilder::new()
/// .add_readonly(array1.readonly())
/// .add_readwrite(array2)
/// .add_readonly(array3.readonly())
/// .build()
/// .unwrap();
/// let iter = unsafe {
/// NpyMultiIterBuilder::new()
/// .add_readonly(array1.readonly())
/// .add_readwrite(array2)
/// .add_readonly(array3.readonly())
/// .build()
/// .unwrap()
/// };
/// for (i, j, k) in iter {
/// assert_eq!(*i + *j, *k);
/// *j += *i + *k; // The third element is only mutable.
Expand Down
12 changes: 7 additions & 5 deletions tests/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fn mutable_iter() -> PyResult<()> {
let data = array![[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]];
pyo3::Python::with_gil(|py| {
let arr = PyArray::from_array(py, &data);
let iter = NpySingleIterBuilder::readwrite(arr).build()?;
let iter = unsafe { NpySingleIterBuilder::readwrite(arr).build()? };
for elem in iter {
*elem *= 2.0;
}
Expand Down Expand Up @@ -71,10 +71,12 @@ fn multiiter_rw() -> PyResult<()> {
pyo3::Python::with_gil(|py| {
let arr1 = PyArray::from_array(py, &data1);
let arr2 = PyArray::from_array(py, &data2);
let iter = NpyMultiIterBuilder::new()
.add_readonly(arr1.readonly())
.add_readwrite(arr2)
.build()?;
let iter = unsafe {
NpyMultiIterBuilder::new()
.add_readonly(arr1.readonly())
.add_readwrite(arr2)
.build()?
};

for (x, y) in iter {
*y = *x * 2.0;
Expand Down