Skip to content

Commit 702127f

Browse files
committed
rollup merge of rust-lang#19296: csouth3/trieset-union
TrieSet doesn't yet have union, intersection, difference, and symmetric difference functions implemented. Luckily, TrieSet is largely similar to TreeSet, so I was able to reference the implementations of these functions in the latter, and adapt them as necessary to make them work for TrieSet. One thing that I thought was interesting is that the Iterator yielded by `iter()` for TrieSet iterates over the set's values directly rather than references to the values (whereas I think in most cases I see the Iterator given by `iter()` iterating over immutable references), so for consistency within TrieSet's interface, all of these Iterators also iterate over the values directly. Let me know if all of these should be instead iterating over references.
2 parents f40fa83 + 2a6f197 commit 702127f

File tree

1 file changed

+268
-1
lines changed

1 file changed

+268
-1
lines changed

src/libcollections/trie/set.rs

Lines changed: 268 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
// except according to those terms.
1010

1111
// FIXME(conventions): implement bounded iterators
12-
// FIXME(conventions): implement union family of fns
1312
// FIXME(conventions): implement BitOr, BitAnd, BitXor, and Sub
1413
// FIXME(conventions): replace each_reverse by making iter DoubleEnded
1514
// FIXME(conventions): implement iter_mut and into_iter
@@ -19,6 +18,7 @@ use core::prelude::*;
1918
use core::default::Default;
2019
use core::fmt;
2120
use core::fmt::Show;
21+
use core::iter::Peekable;
2222
use std::hash::Hash;
2323

2424
use trie_map::{TrieMap, Entries};
@@ -172,6 +172,106 @@ impl TrieSet {
172172
SetItems{iter: self.map.upper_bound(val)}
173173
}
174174

175+
/// Visits the values representing the difference, in ascending order.
176+
///
177+
/// # Example
178+
///
179+
/// ```
180+
/// use std::collections::TrieSet;
181+
///
182+
/// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect();
183+
/// let b: TrieSet = [3, 4, 5].iter().map(|&x| x).collect();
184+
///
185+
/// // Can be seen as `a - b`.
186+
/// for x in a.difference(&b) {
187+
/// println!("{}", x); // Print 1 then 2
188+
/// }
189+
///
190+
/// let diff1: TrieSet = a.difference(&b).collect();
191+
/// assert_eq!(diff1, [1, 2].iter().map(|&x| x).collect());
192+
///
193+
/// // Note that difference is not symmetric,
194+
/// // and `b - a` means something else:
195+
/// let diff2: TrieSet = b.difference(&a).collect();
196+
/// assert_eq!(diff2, [4, 5].iter().map(|&x| x).collect());
197+
/// ```
198+
#[unstable = "matches collection reform specification, waiting for dust to settle"]
199+
pub fn difference<'a>(&'a self, other: &'a TrieSet) -> DifferenceItems<'a> {
200+
DifferenceItems{a: self.iter().peekable(), b: other.iter().peekable()}
201+
}
202+
203+
/// Visits the values representing the symmetric difference, in ascending order.
204+
///
205+
/// # Example
206+
///
207+
/// ```
208+
/// use std::collections::TrieSet;
209+
///
210+
/// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect();
211+
/// let b: TrieSet = [3, 4, 5].iter().map(|&x| x).collect();
212+
///
213+
/// // Print 1, 2, 4, 5 in ascending order.
214+
/// for x in a.symmetric_difference(&b) {
215+
/// println!("{}", x);
216+
/// }
217+
///
218+
/// let diff1: TrieSet = a.symmetric_difference(&b).collect();
219+
/// let diff2: TrieSet = b.symmetric_difference(&a).collect();
220+
///
221+
/// assert_eq!(diff1, diff2);
222+
/// assert_eq!(diff1, [1, 2, 4, 5].iter().map(|&x| x).collect());
223+
/// ```
224+
#[unstable = "matches collection reform specification, waiting for dust to settle."]
225+
pub fn symmetric_difference<'a>(&'a self, other: &'a TrieSet) -> SymDifferenceItems<'a> {
226+
SymDifferenceItems{a: self.iter().peekable(), b: other.iter().peekable()}
227+
}
228+
229+
/// Visits the values representing the intersection, in ascending order.
230+
///
231+
/// # Example
232+
///
233+
/// ```
234+
/// use std::collections::TrieSet;
235+
///
236+
/// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect();
237+
/// let b: TrieSet = [2, 3, 4].iter().map(|&x| x).collect();
238+
///
239+
/// // Print 2, 3 in ascending order.
240+
/// for x in a.intersection(&b) {
241+
/// println!("{}", x);
242+
/// }
243+
///
244+
/// let diff: TrieSet = a.intersection(&b).collect();
245+
/// assert_eq!(diff, [2, 3].iter().map(|&x| x).collect());
246+
/// ```
247+
#[unstable = "matches collection reform specification, waiting for dust to settle"]
248+
pub fn intersection<'a>(&'a self, other: &'a TrieSet) -> IntersectionItems<'a> {
249+
IntersectionItems{a: self.iter().peekable(), b: other.iter().peekable()}
250+
}
251+
252+
/// Visits the values representing the union, in ascending order.
253+
///
254+
/// # Example
255+
///
256+
/// ```
257+
/// use std::collections::TrieSet;
258+
///
259+
/// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect();
260+
/// let b: TrieSet = [3, 4, 5].iter().map(|&x| x).collect();
261+
///
262+
/// // Print 1, 2, 3, 4, 5 in ascending order.
263+
/// for x in a.union(&b) {
264+
/// println!("{}", x);
265+
/// }
266+
///
267+
/// let diff: TrieSet = a.union(&b).collect();
268+
/// assert_eq!(diff, [1, 2, 3, 4, 5].iter().map(|&x| x).collect());
269+
/// ```
270+
#[unstable = "matches collection reform specification, waiting for dust to settle"]
271+
pub fn union<'a>(&'a self, other: &'a TrieSet) -> UnionItems<'a> {
272+
UnionItems{a: self.iter().peekable(), b: other.iter().peekable()}
273+
}
274+
175275
/// Return the number of elements in the set
176276
///
177277
/// # Example
@@ -368,6 +468,39 @@ pub struct SetItems<'a> {
368468
iter: Entries<'a, ()>
369469
}
370470

471+
/// An iterator producing elements in the set difference (in-order).
472+
pub struct DifferenceItems<'a> {
473+
a: Peekable<uint, SetItems<'a>>,
474+
b: Peekable<uint, SetItems<'a>>,
475+
}
476+
477+
/// An iterator producing elements in the set symmetric difference (in-order).
478+
pub struct SymDifferenceItems<'a> {
479+
a: Peekable<uint, SetItems<'a>>,
480+
b: Peekable<uint, SetItems<'a>>,
481+
}
482+
483+
/// An iterator producing elements in the set intersection (in-order).
484+
pub struct IntersectionItems<'a> {
485+
a: Peekable<uint, SetItems<'a>>,
486+
b: Peekable<uint, SetItems<'a>>,
487+
}
488+
489+
/// An iterator producing elements in the set union (in-order).
490+
pub struct UnionItems<'a> {
491+
a: Peekable<uint, SetItems<'a>>,
492+
b: Peekable<uint, SetItems<'a>>,
493+
}
494+
495+
/// Compare `x` and `y`, but return `short` if x is None and `long` if y is None
496+
fn cmp_opt(x: Option<&uint>, y: Option<&uint>, short: Ordering, long: Ordering) -> Ordering {
497+
match (x, y) {
498+
(None , _ ) => short,
499+
(_ , None ) => long,
500+
(Some(x1), Some(y1)) => x1.cmp(y1),
501+
}
502+
}
503+
371504
impl<'a> Iterator<uint> for SetItems<'a> {
372505
fn next(&mut self) -> Option<uint> {
373506
self.iter.next().map(|(key, _)| key)
@@ -378,6 +511,60 @@ impl<'a> Iterator<uint> for SetItems<'a> {
378511
}
379512
}
380513

514+
impl<'a> Iterator<uint> for DifferenceItems<'a> {
515+
fn next(&mut self) -> Option<uint> {
516+
loop {
517+
match cmp_opt(self.a.peek(), self.b.peek(), Less, Less) {
518+
Less => return self.a.next(),
519+
Equal => { self.a.next(); self.b.next(); }
520+
Greater => { self.b.next(); }
521+
}
522+
}
523+
}
524+
}
525+
526+
impl<'a> Iterator<uint> for SymDifferenceItems<'a> {
527+
fn next(&mut self) -> Option<uint> {
528+
loop {
529+
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
530+
Less => return self.a.next(),
531+
Equal => { self.a.next(); self.b.next(); }
532+
Greater => return self.b.next(),
533+
}
534+
}
535+
}
536+
}
537+
538+
impl<'a> Iterator<uint> for IntersectionItems<'a> {
539+
fn next(&mut self) -> Option<uint> {
540+
loop {
541+
let o_cmp = match (self.a.peek(), self.b.peek()) {
542+
(None , _ ) => None,
543+
(_ , None ) => None,
544+
(Some(a1), Some(b1)) => Some(a1.cmp(b1)),
545+
};
546+
match o_cmp {
547+
None => return None,
548+
Some(Less) => { self.a.next(); }
549+
Some(Equal) => { self.b.next(); return self.a.next() }
550+
Some(Greater) => { self.b.next(); }
551+
}
552+
}
553+
}
554+
}
555+
556+
impl<'a> Iterator<uint> for UnionItems<'a> {
557+
fn next(&mut self) -> Option<uint> {
558+
loop {
559+
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
560+
Less => return self.a.next(),
561+
Equal => { self.b.next(); return self.a.next() }
562+
Greater => return self.b.next(),
563+
}
564+
}
565+
}
566+
}
567+
381568
#[cfg(test)]
382569
mod test {
383570
use std::prelude::*;
@@ -471,4 +658,84 @@ mod test {
471658
assert!(b > a && b >= a);
472659
assert!(a < b && a <= b);
473660
}
661+
662+
fn check(a: &[uint],
663+
b: &[uint],
664+
expected: &[uint],
665+
f: |&TrieSet, &TrieSet, f: |uint| -> bool| -> bool) {
666+
let mut set_a = TrieSet::new();
667+
let mut set_b = TrieSet::new();
668+
669+
for x in a.iter() { assert!(set_a.insert(*x)) }
670+
for y in b.iter() { assert!(set_b.insert(*y)) }
671+
672+
let mut i = 0;
673+
f(&set_a, &set_b, |x| {
674+
assert_eq!(x, expected[i]);
675+
i += 1;
676+
true
677+
});
678+
assert_eq!(i, expected.len());
679+
}
680+
681+
#[test]
682+
fn test_intersection() {
683+
fn check_intersection(a: &[uint], b: &[uint], expected: &[uint]) {
684+
check(a, b, expected, |x, y, f| x.intersection(y).all(f))
685+
}
686+
687+
check_intersection(&[], &[], &[]);
688+
check_intersection(&[1, 2, 3], &[], &[]);
689+
check_intersection(&[], &[1, 2, 3], &[]);
690+
check_intersection(&[2], &[1, 2, 3], &[2]);
691+
check_intersection(&[1, 2, 3], &[2], &[2]);
692+
check_intersection(&[11, 1, 3, 77, 103, 5],
693+
&[2, 11, 77, 5, 3],
694+
&[3, 5, 11, 77]);
695+
}
696+
697+
#[test]
698+
fn test_difference() {
699+
fn check_difference(a: &[uint], b: &[uint], expected: &[uint]) {
700+
check(a, b, expected, |x, y, f| x.difference(y).all(f))
701+
}
702+
703+
check_difference(&[], &[], &[]);
704+
check_difference(&[1, 12], &[], &[1, 12]);
705+
check_difference(&[], &[1, 2, 3, 9], &[]);
706+
check_difference(&[1, 3, 5, 9, 11],
707+
&[3, 9],
708+
&[1, 5, 11]);
709+
check_difference(&[11, 22, 33, 40, 42],
710+
&[14, 23, 34, 38, 39, 50],
711+
&[11, 22, 33, 40, 42]);
712+
}
713+
714+
#[test]
715+
fn test_symmetric_difference() {
716+
fn check_symmetric_difference(a: &[uint], b: &[uint], expected: &[uint]) {
717+
check(a, b, expected, |x, y, f| x.symmetric_difference(y).all(f))
718+
}
719+
720+
check_symmetric_difference(&[], &[], &[]);
721+
check_symmetric_difference(&[1, 2, 3], &[2], &[1, 3]);
722+
check_symmetric_difference(&[2], &[1, 2, 3], &[1, 3]);
723+
check_symmetric_difference(&[1, 3, 5, 9, 11],
724+
&[3, 9, 14, 22],
725+
&[1, 5, 11, 14, 22]);
726+
}
727+
728+
#[test]
729+
fn test_union() {
730+
fn check_union(a: &[uint], b: &[uint], expected: &[uint]) {
731+
check(a, b, expected, |x, y, f| x.union(y).all(f))
732+
}
733+
734+
check_union(&[], &[], &[]);
735+
check_union(&[1, 2, 3], &[2], &[1, 2, 3]);
736+
check_union(&[2], &[1, 2, 3], &[1, 2, 3]);
737+
check_union(&[1, 3, 5, 9, 11, 16, 19, 24],
738+
&[1, 5, 9, 13, 19],
739+
&[1, 3, 5, 9, 11, 13, 16, 19, 24]);
740+
}
474741
}

0 commit comments

Comments
 (0)