Skip to content

Commit 8caea5d

Browse files
committed
Fix incorrect impl of Iterator::size_hint and add impl of ExactSizeIterator the NumPy iterator wrappers.
1 parent dc00e40 commit 8caea5d

File tree

2 files changed

+78
-13
lines changed

2 files changed

+78
-13
lines changed

src/npyiter.rs

+31-11
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,6 @@ impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> {
282282
pub struct NpySingleIter<'py, T, I> {
283283
iterator: ptr::NonNull<NpyIter>,
284284
iternext: unsafe extern "C" fn(*mut NpyIter) -> c_int,
285-
empty: bool,
286285
iter_size: npy_intp,
287286
dataptr: *mut *mut c_char,
288287
return_type: PhantomData<T>,
@@ -321,7 +320,6 @@ impl<'py, T, I> NpySingleIter<'py, T, I> {
321320
iterator,
322321
iternext,
323322
iter_size,
324-
empty: iter_size == 0,
325323
dataptr,
326324
return_type: PhantomData,
327325
mode: PhantomData,
@@ -331,15 +329,18 @@ impl<'py, T, I> NpySingleIter<'py, T, I> {
331329
}
332330

333331
fn iternext(&mut self) -> Option<*mut T> {
334-
if self.empty {
332+
if self.iter_size == 0 {
335333
None
336334
} else {
337335
// Note: This pointer is correct and doesn't need to be updated,
338336
// note that we're derefencing a **char into a *char casting to a *T
339337
// and then transforming that into a reference, the value that dataptr
340338
// points to is being updated by iternext to point to the next value.
341339
let ret = unsafe { *self.dataptr as *mut T };
342-
self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
340+
let empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
341+
debug_assert_ne!(self.iter_size, 0);
342+
self.iter_size -= 1;
343+
debug_assert!(self.iter_size > 0 || empty);
343344
Some(ret)
344345
}
345346
}
@@ -365,7 +366,13 @@ impl<'py, T: 'py> Iterator for NpySingleIter<'py, T, Readonly> {
365366
}
366367

367368
fn size_hint(&self) -> (usize, Option<usize>) {
368-
(self.iter_size as usize, Some(self.iter_size as usize))
369+
(self.len(), Some(self.len()))
370+
}
371+
}
372+
373+
impl<'py, T: 'py> ExactSizeIterator for NpySingleIter<'py, T, Readonly> {
374+
fn len(&self) -> usize {
375+
self.iter_size as usize
369376
}
370377
}
371378

@@ -377,7 +384,13 @@ impl<'py, T: 'py> Iterator for NpySingleIter<'py, T, ReadWrite> {
377384
}
378385

379386
fn size_hint(&self) -> (usize, Option<usize>) {
380-
(self.iter_size as usize, Some(self.iter_size as usize))
387+
(self.len(), Some(self.len()))
388+
}
389+
}
390+
391+
impl<'py, T: 'py> ExactSizeIterator for NpySingleIter<'py, T, ReadWrite> {
392+
fn len(&self) -> usize {
393+
self.iter_size as usize
381394
}
382395
}
383396

@@ -525,7 +538,6 @@ impl<'py, T: Element, S: MultiIterModeWithManyArrays> NpyMultiIterBuilder<'py, T
525538
pub struct NpyMultiIter<'py, T, S: MultiIterModeWithManyArrays> {
526539
iterator: ptr::NonNull<NpyIter>,
527540
iternext: unsafe extern "C" fn(*mut NpyIter) -> c_int,
528-
empty: bool,
529541
iter_size: npy_intp,
530542
dataptr: *mut *mut c_char,
531543
marker: PhantomData<(T, S)>,
@@ -565,7 +577,6 @@ impl<'py, T, S: MultiIterModeWithManyArrays> NpyMultiIter<'py, T, S> {
565577
iterator,
566578
iternext,
567579
iter_size,
568-
empty: iter_size == 0,
569580
dataptr,
570581
marker: PhantomData,
571582
arrays,
@@ -593,7 +604,7 @@ macro_rules! impl_multi_iter {
593604
type Item = ($($ty,)+);
594605

595606
fn next(&mut self) -> Option<Self::Item> {
596-
if self.empty {
607+
if self.iter_size == 0 {
597608
None
598609
} else {
599610
// Note: This pointer is correct and doesn't need to be updated,
@@ -602,13 +613,22 @@ macro_rules! impl_multi_iter {
602613
// points to is being updated by iternext to point to the next value.
603614
let ($($ptr,)+) = unsafe { $expand::<T>(self.dataptr) };
604615
let retval = Some(unsafe { $deref });
605-
self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
616+
let empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
617+
debug_assert_ne!(self.iter_size, 0);
618+
self.iter_size -= 1;
619+
debug_assert!(self.iter_size > 0 || empty);
606620
retval
607621
}
608622
}
609623

610624
fn size_hint(&self) -> (usize, Option<usize>) {
611-
(self.iter_size as usize, Some(self.iter_size as usize))
625+
(self.len(), Some(self.len()))
626+
}
627+
}
628+
629+
impl<'py, T: 'py> ExactSizeIterator for NpyMultiIter<'py, T, $structure> {
630+
fn len(&self) -> usize {
631+
self.iter_size as usize
612632
}
613633
}
614634
};

tests/iter.rs

+47-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#![allow(deprecated)]
22

33
use ndarray::array;
4-
use numpy::{NpyMultiIterBuilder, NpySingleIterBuilder, PyArray};
5-
use pyo3::PyResult;
4+
use numpy::{pyarray, NpyMultiIterBuilder, NpySingleIterBuilder, PyArray};
5+
use pyo3::{PyResult, Python};
66

77
macro_rules! assert_approx_eq {
88
($x: expr, $y: expr) => {
@@ -96,3 +96,48 @@ fn multiiter_rw() -> PyResult<()> {
9696
Ok(())
9797
})
9898
}
99+
100+
#[test]
101+
fn single_iter_size_hint_len() {
102+
Python::with_gil(|py| {
103+
let arr = pyarray![py, [0, 1], [2, 3], [4, 5]];
104+
105+
let mut iter = NpySingleIterBuilder::readonly(arr.readonly())
106+
.build()
107+
.unwrap();
108+
109+
for len in (1..=6).rev() {
110+
assert_eq!(iter.len(), len);
111+
assert_eq!(iter.size_hint(), (len, Some(len)));
112+
assert!(iter.next().is_some());
113+
}
114+
115+
assert_eq!(iter.len(), 0);
116+
assert_eq!(iter.size_hint(), (0, Some(0)));
117+
assert!(iter.next().is_none());
118+
});
119+
}
120+
121+
#[test]
122+
fn multi_iter_size_hint_len() {
123+
Python::with_gil(|py| {
124+
let arr1 = pyarray![py, [0, 1], [2, 3], [4, 5]];
125+
let arr2 = pyarray![py, [0, 0], [0, 0], [0, 0]];
126+
127+
let mut iter = NpyMultiIterBuilder::new()
128+
.add_readonly(arr1.readonly())
129+
.add_readonly(arr2.readonly())
130+
.build()
131+
.unwrap();
132+
133+
for len in (1..=6).rev() {
134+
assert_eq!(iter.len(), len);
135+
assert_eq!(iter.size_hint(), (len, Some(len)));
136+
assert!(iter.next().is_some());
137+
}
138+
139+
assert_eq!(iter.len(), 0);
140+
assert_eq!(iter.size_hint(), (0, Some(0)));
141+
assert!(iter.next().is_none());
142+
});
143+
}

0 commit comments

Comments
 (0)