Skip to content

Commit e0a4452

Browse files
committed
Add map and set extract_if
1 parent 922c6ad commit e0a4452

File tree

7 files changed

+317
-5
lines changed

7 files changed

+317
-5
lines changed

src/map.rs

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ mod tests;
1616
pub use self::core::raw_entry_v1::{self, RawEntryApiV1};
1717
pub use self::core::{Entry, IndexedEntry, OccupiedEntry, VacantEntry};
1818
pub use self::iter::{
19-
Drain, IntoIter, IntoKeys, IntoValues, Iter, IterMut, IterMut2, Keys, Splice, Values, ValuesMut,
19+
Drain, ExtractIf, IntoIter, IntoKeys, IntoValues, Iter, IterMut, IterMut2, Keys, Splice,
20+
Values, ValuesMut,
2021
};
2122
pub use self::mutable::MutableEntryKey;
2223
pub use self::mutable::MutableKeys;
@@ -36,7 +37,7 @@ use alloc::vec::Vec;
3637
#[cfg(feature = "std")]
3738
use std::collections::hash_map::RandomState;
3839

39-
use self::core::IndexMapCore;
40+
pub(crate) use self::core::{ExtractCore, IndexMapCore};
4041
use crate::util::{third, try_simplify_range};
4142
use crate::{Bucket, Entries, Equivalent, HashValue, TryReserveError};
4243

@@ -306,6 +307,44 @@ impl<K, V, S> IndexMap<K, V, S> {
306307
Drain::new(self.core.drain(range))
307308
}
308309

310+
/// Creates an iterator which uses a closure to determine if an element should be removed.
311+
///
312+
/// If the closure returns true, the element is removed from the map and yielded.
313+
/// If the closure returns false, or panics, the element remains in the map and will not be
314+
/// yielded.
315+
///
316+
/// Note that `extract_if` lets you mutate every value in the filter closure, regardless of
317+
/// whether you choose to keep or remove it.
318+
///
319+
/// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating
320+
/// or the iteration short-circuits, then the remaining elements will be retained.
321+
/// Use [`retain`] with a negated predicate if you do not need the returned iterator.
322+
///
323+
/// [`retain`]: IndexMap::retain
324+
///
325+
/// # Examples
326+
///
327+
/// Splitting a map into even and odd keys, reusing the original map:
328+
///
329+
/// ```
330+
/// use indexmap::IndexMap;
331+
///
332+
/// let mut map: IndexMap<i32, i32> = (0..8).map(|x| (x, x)).collect();
333+
/// let extracted: IndexMap<i32, i32> = map.extract_if(|k, _v| k % 2 == 0).collect();
334+
///
335+
/// let evens = extracted.keys().copied().collect::<Vec<_>>();
336+
/// let odds = map.keys().copied().collect::<Vec<_>>();
337+
///
338+
/// assert_eq!(evens, vec![0, 2, 4, 6]);
339+
/// assert_eq!(odds, vec![1, 3, 5, 7]);
340+
/// ```
341+
pub fn extract_if<F>(&mut self, pred: F) -> ExtractIf<'_, K, V, F>
342+
where
343+
F: FnMut(&K, &mut V) -> bool,
344+
{
345+
ExtractIf::new(&mut self.core, pred)
346+
}
347+
309348
/// Splits the collection into two at the given index.
310349
///
311350
/// Returns a newly allocated map containing the elements in the range

src/map/core.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use crate::util::simplify_range;
2323
use crate::{Bucket, Entries, Equivalent, HashValue};
2424

2525
pub use entry::{Entry, IndexedEntry, OccupiedEntry, VacantEntry};
26+
pub(crate) use raw::ExtractCore;
2627

2728
/// Core of the map that does not depend on S
2829
pub(crate) struct IndexMapCore<K, V> {
@@ -145,6 +146,7 @@ impl<K, V> IndexMapCore<K, V> {
145146

146147
#[inline]
147148
pub(crate) fn len(&self) -> usize {
149+
debug_assert_eq!(self.entries.len(), self.indices.len());
148150
self.indices.len()
149151
}
150152

src/map/core/raw.rs

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,20 @@ impl<K, V> IndexMapCore<K, V> {
101101
// only the item references that are appropriately bound to `&mut self`.
102102
unsafe { self.indices.iter().map(|bucket| bucket.as_mut()) }
103103
}
104+
105+
pub(crate) fn extract(&mut self) -> ExtractCore<'_, K, V> {
106+
// SAFETY: We must have consistent lengths to start, so that's a hard assertion.
107+
// Then the worst `set_len(0)` can do is leak items if `ExtractCore` doesn't drop.
108+
assert_eq!(self.entries.len(), self.indices.len());
109+
unsafe {
110+
self.entries.set_len(0);
111+
}
112+
ExtractCore {
113+
map: self,
114+
current: 0,
115+
new_len: 0,
116+
}
117+
}
104118
}
105119

106120
/// A view into an occupied raw entry in an `IndexMap`.
@@ -162,3 +176,80 @@ impl<'a, K, V> RawTableEntry<'a, K, V> {
162176
(self.map, index)
163177
}
164178
}
179+
180+
pub(crate) struct ExtractCore<'a, K, V> {
181+
map: &'a mut IndexMapCore<K, V>,
182+
current: usize,
183+
new_len: usize,
184+
}
185+
186+
impl<K, V> Drop for ExtractCore<'_, K, V> {
187+
fn drop(&mut self) {
188+
let old_len = self.map.indices.len();
189+
let mut new_len = self.new_len;
190+
debug_assert!(new_len <= self.current);
191+
debug_assert!(self.current <= old_len);
192+
debug_assert!(old_len <= self.map.entries.capacity());
193+
194+
// SAFETY: We assume `new_len` and `current` were correctly maintained by the iterator.
195+
// So `entries[new_len..current]` were extracted, but the rest before and after are valid.
196+
unsafe {
197+
if new_len == self.current {
198+
// Nothing was extracted, so any remaining items can be left in place.
199+
new_len = old_len;
200+
} else if self.current < old_len {
201+
// Need to shift the remaining items down.
202+
let tail_len = old_len - self.current;
203+
let base = self.map.entries.as_mut_ptr();
204+
let src = base.add(self.current);
205+
let dest = base.add(new_len);
206+
src.copy_to(dest, tail_len);
207+
new_len += tail_len;
208+
}
209+
self.map.entries.set_len(new_len);
210+
}
211+
212+
if new_len != old_len {
213+
// We don't keep track of *which* items were extracted, so reindex everything.
214+
self.map.rebuild_hash_table();
215+
}
216+
}
217+
}
218+
219+
impl<K, V> ExtractCore<'_, K, V> {
220+
pub(crate) fn extract_if<F>(&mut self, mut pred: F) -> Option<Bucket<K, V>>
221+
where
222+
F: FnMut(&mut Bucket<K, V>) -> bool,
223+
{
224+
let old_len = self.map.indices.len();
225+
debug_assert!(old_len <= self.map.entries.capacity());
226+
227+
let base = self.map.entries.as_mut_ptr();
228+
while self.current < old_len {
229+
// SAFETY: We're maintaining both indices within bounds of the original entries, so
230+
// 0..new_len and current..old_len are always valid items for our Drop to keep.
231+
unsafe {
232+
let item = base.add(self.current);
233+
if pred(&mut *item) {
234+
// Extract it!
235+
self.current += 1;
236+
return Some(item.read());
237+
} else {
238+
// Keep it, shifting it down if needed.
239+
if self.new_len != self.current {
240+
debug_assert!(self.new_len < self.current);
241+
let dest = base.add(self.new_len);
242+
item.copy_to_nonoverlapping(dest, 1);
243+
}
244+
self.current += 1;
245+
self.new_len += 1;
246+
}
247+
}
248+
}
249+
None
250+
}
251+
252+
pub(crate) fn remaining(&self) -> usize {
253+
self.map.indices.len() - self.current
254+
}
255+
}

src/map/iter.rs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
use super::core::IndexMapCore;
2-
use super::{Bucket, Entries, IndexMap, Slice};
1+
use super::{Bucket, Entries, ExtractCore, IndexMap, IndexMapCore, Slice};
32

43
use alloc::vec::{self, Vec};
54
use core::fmt;
@@ -772,3 +771,56 @@ where
772771
.finish()
773772
}
774773
}
774+
775+
/// An extracting iterator for `IndexMap`.
776+
///
777+
/// This `struct` is created by [`IndexMap::extract_if()`].
778+
/// See its documentation for more.
779+
pub struct ExtractIf<'a, K, V, F>
780+
where
781+
F: FnMut(&K, &mut V) -> bool,
782+
{
783+
inner: ExtractCore<'a, K, V>,
784+
pred: F,
785+
}
786+
787+
impl<K, V, F> ExtractIf<'_, K, V, F>
788+
where
789+
F: FnMut(&K, &mut V) -> bool,
790+
{
791+
pub(super) fn new(core: &mut IndexMapCore<K, V>, pred: F) -> ExtractIf<'_, K, V, F> {
792+
ExtractIf {
793+
inner: core.extract(),
794+
pred,
795+
}
796+
}
797+
}
798+
799+
impl<K, V, F> Iterator for ExtractIf<'_, K, V, F>
800+
where
801+
F: FnMut(&K, &mut V) -> bool,
802+
{
803+
type Item = (K, V);
804+
805+
fn next(&mut self) -> Option<Self::Item> {
806+
self.inner
807+
.extract_if(|bucket| {
808+
let (key, value) = bucket.ref_mut();
809+
(self.pred)(key, value)
810+
})
811+
.map(Bucket::key_value)
812+
}
813+
814+
fn size_hint(&self) -> (usize, Option<usize>) {
815+
(0, Some(self.inner.remaining()))
816+
}
817+
}
818+
819+
impl<'a, K, V, F> fmt::Debug for ExtractIf<'a, K, V, F>
820+
where
821+
F: FnMut(&K, &mut V) -> bool,
822+
{
823+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
824+
f.debug_struct("ExtractIf").finish_non_exhaustive()
825+
}
826+
}

src/set.rs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ mod slice;
88
mod tests;
99

1010
pub use self::iter::{
11-
Difference, Drain, Intersection, IntoIter, Iter, Splice, SymmetricDifference, Union,
11+
Difference, Drain, ExtractIf, Intersection, IntoIter, Iter, Splice, SymmetricDifference, Union,
1212
};
1313
pub use self::mutable::MutableValues;
1414
pub use self::slice::Slice;
@@ -257,6 +257,41 @@ impl<T, S> IndexSet<T, S> {
257257
Drain::new(self.map.core.drain(range))
258258
}
259259

260+
/// Creates an iterator which uses a closure to determine if a value should be removed.
261+
///
262+
/// If the closure returns true, then the value is removed and yielded.
263+
/// If the closure returns false, the value will remain in the list and will not be yielded
264+
/// by the iterator.
265+
///
266+
/// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating
267+
/// or the iteration short-circuits, then the remaining elements will be retained.
268+
/// Use [`retain`] with a negated predicate if you do not need the returned iterator.
269+
///
270+
/// [`retain`]: IndexSet::retain
271+
///
272+
/// # Examples
273+
///
274+
/// Splitting a set into even and odd values, reusing the original set:
275+
///
276+
/// ```
277+
/// use indexmap::IndexSet;
278+
///
279+
/// let mut set: IndexSet<i32> = (0..8).collect();
280+
/// let extracted: IndexSet<i32> = set.extract_if(|v| v % 2 == 0).collect();
281+
///
282+
/// let evens = extracted.into_iter().collect::<Vec<_>>();
283+
/// let odds = set.into_iter().collect::<Vec<_>>();
284+
///
285+
/// assert_eq!(evens, vec![0, 2, 4, 6]);
286+
/// assert_eq!(odds, vec![1, 3, 5, 7]);
287+
/// ```
288+
pub fn extract_if<F>(&mut self, pred: F) -> ExtractIf<'_, T, F>
289+
where
290+
F: FnMut(&T) -> bool,
291+
{
292+
ExtractIf::new(&mut self.map.core, pred)
293+
}
294+
260295
/// Splits the collection into two at the given index.
261296
///
262297
/// Returns a newly allocated set containing the elements in the range

src/set/iter.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use crate::map::{ExtractCore, IndexMapCore};
2+
13
use super::{Bucket, Entries, IndexSet, Slice};
24

35
use alloc::vec::{self, Vec};
@@ -624,3 +626,53 @@ impl<I: fmt::Debug> fmt::Debug for UnitValue<I> {
624626
fmt::Debug::fmt(&self.0, f)
625627
}
626628
}
629+
630+
/// An extracting iterator for `IndexSet`.
631+
///
632+
/// This `struct` is created by [`IndexSet::extract_if()`].
633+
/// See its documentation for more.
634+
pub struct ExtractIf<'a, T, F>
635+
where
636+
F: FnMut(&T) -> bool,
637+
{
638+
inner: ExtractCore<'a, T, ()>,
639+
pred: F,
640+
}
641+
642+
impl<T, F> ExtractIf<'_, T, F>
643+
where
644+
F: FnMut(&T) -> bool,
645+
{
646+
pub(super) fn new(core: &mut IndexMapCore<T, ()>, pred: F) -> ExtractIf<'_, T, F> {
647+
ExtractIf {
648+
inner: core.extract(),
649+
pred,
650+
}
651+
}
652+
}
653+
654+
impl<T, F> Iterator for ExtractIf<'_, T, F>
655+
where
656+
F: FnMut(&T) -> bool,
657+
{
658+
type Item = T;
659+
660+
fn next(&mut self) -> Option<Self::Item> {
661+
self.inner
662+
.extract_if(|bucket| (self.pred)(bucket.key_ref()))
663+
.map(Bucket::key)
664+
}
665+
666+
fn size_hint(&self) -> (usize, Option<usize>) {
667+
(0, Some(self.inner.remaining()))
668+
}
669+
}
670+
671+
impl<'a, T, F> fmt::Debug for ExtractIf<'a, T, F>
672+
where
673+
F: FnMut(&T) -> bool,
674+
{
675+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
676+
f.debug_struct("ExtractIf").finish_non_exhaustive()
677+
}
678+
}

0 commit comments

Comments
 (0)