Skip to content

Commit b3be43a

Browse files
committed
Fix incorrect impl of Iterator::size_hint and add impl of ExactSizeIterator the NumPy iterator wrappers.
1 parent a863dbd commit b3be43a

File tree

2 files changed

+78
-13
lines changed

2 files changed

+78
-13
lines changed

src/npyiter.rs

+31-11
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> {
285285
pub struct NpySingleIter<'py, T, I> {
286286
iterator: ptr::NonNull<NpyIter>,
287287
iternext: unsafe extern "C" fn(*mut NpyIter) -> c_int,
288-
empty: bool,
289288
iter_size: npy_intp,
290289
dataptr: *mut *mut c_char,
291290
return_type: PhantomData<T>,
@@ -324,7 +323,6 @@ impl<'py, T, I> NpySingleIter<'py, T, I> {
324323
iterator,
325324
iternext,
326325
iter_size,
327-
empty: iter_size == 0,
328326
dataptr,
329327
return_type: PhantomData,
330328
mode: PhantomData,
@@ -334,15 +332,18 @@ impl<'py, T, I> NpySingleIter<'py, T, I> {
334332
}
335333

336334
fn iternext(&mut self) -> Option<*mut T> {
337-
if self.empty {
335+
if self.iter_size == 0 {
338336
None
339337
} else {
340338
// Note: This pointer is correct and doesn't need to be updated,
341339
// note that we're derefencing a **char into a *char casting to a *T
342340
// and then transforming that into a reference, the value that dataptr
343341
// points to is being updated by iternext to point to the next value.
344342
let ret = unsafe { *self.dataptr as *mut T };
345-
self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
343+
let empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
344+
debug_assert_ne!(self.iter_size, 0);
345+
self.iter_size -= 1;
346+
debug_assert!(self.iter_size > 0 || empty);
346347
Some(ret)
347348
}
348349
}
@@ -368,7 +369,13 @@ impl<'py, T: 'py> Iterator for NpySingleIter<'py, T, Readonly> {
368369
}
369370

370371
fn size_hint(&self) -> (usize, Option<usize>) {
371-
(self.iter_size as usize, Some(self.iter_size as usize))
372+
(self.len(), Some(self.len()))
373+
}
374+
}
375+
376+
impl<'py, T: 'py> ExactSizeIterator for NpySingleIter<'py, T, Readonly> {
377+
fn len(&self) -> usize {
378+
self.iter_size as usize
372379
}
373380
}
374381

@@ -380,7 +387,13 @@ impl<'py, T: 'py> Iterator for NpySingleIter<'py, T, ReadWrite> {
380387
}
381388

382389
fn size_hint(&self) -> (usize, Option<usize>) {
383-
(self.iter_size as usize, Some(self.iter_size as usize))
390+
(self.len(), Some(self.len()))
391+
}
392+
}
393+
394+
impl<'py, T: 'py> ExactSizeIterator for NpySingleIter<'py, T, ReadWrite> {
395+
fn len(&self) -> usize {
396+
self.iter_size as usize
384397
}
385398
}
386399

@@ -528,7 +541,6 @@ impl<'py, T: Element, S: MultiIterModeWithManyArrays> NpyMultiIterBuilder<'py, T
528541
pub struct NpyMultiIter<'py, T, S: MultiIterModeWithManyArrays> {
529542
iterator: ptr::NonNull<NpyIter>,
530543
iternext: unsafe extern "C" fn(*mut NpyIter) -> c_int,
531-
empty: bool,
532544
iter_size: npy_intp,
533545
dataptr: *mut *mut c_char,
534546
marker: PhantomData<(T, S)>,
@@ -568,7 +580,6 @@ impl<'py, T, S: MultiIterModeWithManyArrays> NpyMultiIter<'py, T, S> {
568580
iterator,
569581
iternext,
570582
iter_size,
571-
empty: iter_size == 0,
572583
dataptr,
573584
marker: PhantomData,
574585
arrays,
@@ -596,7 +607,7 @@ macro_rules! impl_multi_iter {
596607
type Item = ($($ty,)+);
597608

598609
fn next(&mut self) -> Option<Self::Item> {
599-
if self.empty {
610+
if self.iter_size == 0 {
600611
None
601612
} else {
602613
// Note: This pointer is correct and doesn't need to be updated,
@@ -605,13 +616,22 @@ macro_rules! impl_multi_iter {
605616
// points to is being updated by iternext to point to the next value.
606617
let ($($ptr,)+) = unsafe { $expand::<T>(self.dataptr) };
607618
let retval = Some(unsafe { $deref });
608-
self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
619+
let empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
620+
debug_assert_ne!(self.iter_size, 0);
621+
self.iter_size -= 1;
622+
debug_assert!(self.iter_size > 0 || empty);
609623
retval
610624
}
611625
}
612626

613627
fn size_hint(&self) -> (usize, Option<usize>) {
614-
(self.iter_size as usize, Some(self.iter_size as usize))
628+
(self.len(), Some(self.len()))
629+
}
630+
}
631+
632+
impl<'py, T: 'py> ExactSizeIterator for NpyMultiIter<'py, T, $structure> {
633+
fn len(&self) -> usize {
634+
self.iter_size as usize
615635
}
616636
}
617637
};

tests/iter.rs

+47-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#![allow(deprecated)]
22

33
use ndarray::array;
4-
use numpy::{NpyMultiIterBuilder, NpySingleIterBuilder, PyArray};
5-
use pyo3::PyResult;
4+
use numpy::{pyarray, NpyMultiIterBuilder, NpySingleIterBuilder, PyArray};
5+
use pyo3::{PyResult, Python};
66

77
macro_rules! assert_approx_eq {
88
($x: expr, $y: expr) => {
@@ -96,3 +96,48 @@ fn multiiter_rw() -> PyResult<()> {
9696
Ok(())
9797
})
9898
}
99+
100+
#[test]
101+
fn single_iter_size_hint_len() {
102+
Python::with_gil(|py| {
103+
let arr = pyarray![py, [0, 1], [2, 3], [4, 5]];
104+
105+
let mut iter = NpySingleIterBuilder::readonly(arr.readonly())
106+
.build()
107+
.unwrap();
108+
109+
for len in (1..=6).rev() {
110+
assert_eq!(iter.len(), len);
111+
assert_eq!(iter.size_hint(), (len, Some(len)));
112+
assert!(iter.next().is_some());
113+
}
114+
115+
assert_eq!(iter.len(), 0);
116+
assert_eq!(iter.size_hint(), (0, Some(0)));
117+
assert!(iter.next().is_none());
118+
});
119+
}
120+
121+
#[test]
122+
fn multi_iter_size_hint_len() {
123+
Python::with_gil(|py| {
124+
let arr1 = pyarray![py, [0, 1], [2, 3], [4, 5]];
125+
let arr2 = pyarray![py, [0, 0], [0, 0], [0, 0]];
126+
127+
let mut iter = NpyMultiIterBuilder::new()
128+
.add_readonly(arr1.readonly())
129+
.add_readonly(arr2.readonly())
130+
.build()
131+
.unwrap();
132+
133+
for len in (1..=6).rev() {
134+
assert_eq!(iter.len(), len);
135+
assert_eq!(iter.size_hint(), (len, Some(len)));
136+
assert!(iter.next().is_some());
137+
}
138+
139+
assert_eq!(iter.len(), 0);
140+
assert_eq!(iter.size_hint(), (0, Some(0)));
141+
assert!(iter.next().is_none());
142+
});
143+
}

0 commit comments

Comments
 (0)