Skip to content

Commit ef1faf4

Browse files
authored
Rollup merge of rust-lang#58577 - ssomers:btreeset_intersection, r=KodrAus
improve worst-case performance of BTreeSet intersection Major performance boost when comparing tiny and huge sets. Probably with controversial changes and I sure have questions: - Names and places of functions and types - How many comments to write where - Why does rustc tell me to `ref mut` and `ref` matches on the iterator, while the book says ref is old school. - (Why) do I have to write out the clone like that (`#[derive(Clone)]` doesn't work) - Am I allowed to use `#[derive(Debug)]` there at all? - I'd like to test function `are_proportionate_for_intersection` in test_intersection (or another test case next to it) itself, but I think the private function is inaccessible there. liballoc has other test cases not in the tests directory. PS I don't list these questions to start a discussion here, just to inspire reviewers and to remember myself.
2 parents 52fd337 + 4fd0cc1 commit ef1faf4

File tree

4 files changed

+189
-15
lines changed

4 files changed

+189
-15
lines changed

src/liballoc/benches/btree/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
mod map;
2+
mod set;

src/liballoc/benches/btree/set.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
use std::collections::BTreeSet;
2+
3+
use rand::{thread_rng, Rng};
4+
use test::{black_box, Bencher};
5+
6+
fn random(n1: u32, n2: u32) -> [BTreeSet<usize>; 2] {
7+
let mut rng = thread_rng();
8+
let mut set1 = BTreeSet::new();
9+
let mut set2 = BTreeSet::new();
10+
for _ in 0..n1 {
11+
let i = rng.gen::<usize>();
12+
set1.insert(i);
13+
}
14+
for _ in 0..n2 {
15+
let i = rng.gen::<usize>();
16+
set2.insert(i);
17+
}
18+
[set1, set2]
19+
}
20+
21+
fn staggered(n1: u32, n2: u32) -> [BTreeSet<u32>; 2] {
22+
let mut even = BTreeSet::new();
23+
let mut odd = BTreeSet::new();
24+
for i in 0..n1 {
25+
even.insert(i * 2);
26+
}
27+
for i in 0..n2 {
28+
odd.insert(i * 2 + 1);
29+
}
30+
[even, odd]
31+
}
32+
33+
fn neg_vs_pos(n1: u32, n2: u32) -> [BTreeSet<i32>; 2] {
34+
let mut neg = BTreeSet::new();
35+
let mut pos = BTreeSet::new();
36+
for i in -(n1 as i32)..=-1 {
37+
neg.insert(i);
38+
}
39+
for i in 1..=(n2 as i32) {
40+
pos.insert(i);
41+
}
42+
[neg, pos]
43+
}
44+
45+
fn pos_vs_neg(n1: u32, n2: u32) -> [BTreeSet<i32>; 2] {
46+
let mut neg = BTreeSet::new();
47+
let mut pos = BTreeSet::new();
48+
for i in -(n1 as i32)..=-1 {
49+
neg.insert(i);
50+
}
51+
for i in 1..=(n2 as i32) {
52+
pos.insert(i);
53+
}
54+
[pos, neg]
55+
}
56+
57+
macro_rules! set_intersection_bench {
58+
($name: ident, $sets: expr) => {
59+
#[bench]
60+
pub fn $name(b: &mut Bencher) {
61+
// setup
62+
let sets = $sets;
63+
64+
// measure
65+
b.iter(|| {
66+
let x = sets[0].intersection(&sets[1]).count();
67+
black_box(x);
68+
})
69+
}
70+
};
71+
}
72+
73+
set_intersection_bench! {intersect_random_100, random(100, 100)}
74+
set_intersection_bench! {intersect_random_10k, random(10_000, 10_000)}
75+
set_intersection_bench! {intersect_random_10_vs_10k, random(10, 10_000)}
76+
set_intersection_bench! {intersect_random_10k_vs_10, random(10_000, 10)}
77+
set_intersection_bench! {intersect_staggered_100, staggered(100, 100)}
78+
set_intersection_bench! {intersect_staggered_10k, staggered(10_000, 10_000)}
79+
set_intersection_bench! {intersect_staggered_10_vs_10k, staggered(10, 10_000)}
80+
set_intersection_bench! {intersect_staggered_10k_vs_10, staggered(10_000, 10)}
81+
set_intersection_bench! {intersect_neg_vs_pos_100, neg_vs_pos(100, 100)}
82+
set_intersection_bench! {intersect_neg_vs_pos_10k, neg_vs_pos(10_000, 10_000)}
83+
set_intersection_bench! {intersect_neg_vs_pos_10_vs_10k,neg_vs_pos(10, 10_000)}
84+
set_intersection_bench! {intersect_neg_vs_pos_10k_vs_10,neg_vs_pos(10_000, 10)}
85+
set_intersection_bench! {intersect_pos_vs_neg_100, pos_vs_neg(100, 100)}
86+
set_intersection_bench! {intersect_pos_vs_neg_10k, pos_vs_neg(10_000, 10_000)}
87+
set_intersection_bench! {intersect_pos_vs_neg_10_vs_10k,pos_vs_neg(10, 10_000)}
88+
set_intersection_bench! {intersect_pos_vs_neg_10k_vs_10,pos_vs_neg(10_000, 10)}

src/liballoc/collections/btree/set.rs

Lines changed: 87 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,22 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
155155
}
156156
}
157157

158+
/// Whether the sizes of two sets are roughly the same order of magnitude.
159+
///
160+
/// If they are, or if either set is empty, then their intersection
161+
/// is efficiently calculated by iterating both sets jointly.
162+
/// If they aren't, then it is more scalable to iterate over the small set
163+
/// and find matches in the large set (except if the largest element in
164+
/// the small set hardly surpasses the smallest element in the large set).
165+
fn are_proportionate_for_intersection(len1: usize, len2: usize) -> bool {
166+
let (small, large) = if len1 <= len2 {
167+
(len1, len2)
168+
} else {
169+
(len2, len1)
170+
};
171+
(large >> 7) <= small
172+
}
173+
158174
/// A lazy iterator producing elements in the intersection of `BTreeSet`s.
159175
///
160176
/// This `struct` is created by the [`intersection`] method on [`BTreeSet`].
@@ -165,7 +181,13 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
165181
#[stable(feature = "rust1", since = "1.0.0")]
166182
pub struct Intersection<'a, T: 'a> {
167183
a: Peekable<Iter<'a, T>>,
168-
b: Peekable<Iter<'a, T>>,
184+
b: IntersectionOther<'a, T>,
185+
}
186+
187+
#[derive(Debug)]
188+
enum IntersectionOther<'a, T> {
189+
Stitch(Peekable<Iter<'a, T>>),
190+
Search(&'a BTreeSet<T>),
169191
}
170192

171193
#[stable(feature = "collection_debug", since = "1.17.0")]
@@ -326,9 +348,21 @@ impl<T: Ord> BTreeSet<T> {
326348
/// ```
327349
#[stable(feature = "rust1", since = "1.0.0")]
328350
pub fn intersection<'a>(&'a self, other: &'a BTreeSet<T>) -> Intersection<'a, T> {
329-
Intersection {
330-
a: self.iter().peekable(),
331-
b: other.iter().peekable(),
351+
if are_proportionate_for_intersection(self.len(), other.len()) {
352+
Intersection {
353+
a: self.iter().peekable(),
354+
b: IntersectionOther::Stitch(other.iter().peekable()),
355+
}
356+
} else if self.len() <= other.len() {
357+
Intersection {
358+
a: self.iter().peekable(),
359+
b: IntersectionOther::Search(&other),
360+
}
361+
} else {
362+
Intersection {
363+
a: other.iter().peekable(),
364+
b: IntersectionOther::Search(&self),
365+
}
332366
}
333367
}
334368

@@ -1069,6 +1103,14 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> {
10691103
#[stable(feature = "fused", since = "1.26.0")]
10701104
impl<T: Ord> FusedIterator for SymmetricDifference<'_, T> {}
10711105

1106+
impl<'a, T> Clone for IntersectionOther<'a, T> {
1107+
fn clone(&self) -> IntersectionOther<'a, T> {
1108+
match self {
1109+
IntersectionOther::Stitch(ref iter) => IntersectionOther::Stitch(iter.clone()),
1110+
IntersectionOther::Search(set) => IntersectionOther::Search(set),
1111+
}
1112+
}
1113+
}
10721114
#[stable(feature = "rust1", since = "1.0.0")]
10731115
impl<T> Clone for Intersection<'_, T> {
10741116
fn clone(&self) -> Self {
@@ -1083,24 +1125,36 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> {
10831125
type Item = &'a T;
10841126

10851127
fn next(&mut self) -> Option<&'a T> {
1086-
loop {
1087-
match Ord::cmp(self.a.peek()?, self.b.peek()?) {
1088-
Less => {
1089-
self.a.next();
1090-
}
1091-
Equal => {
1092-
self.b.next();
1093-
return self.a.next();
1128+
match self.b {
1129+
IntersectionOther::Stitch(ref mut self_b) => loop {
1130+
match Ord::cmp(self.a.peek()?, self_b.peek()?) {
1131+
Less => {
1132+
self.a.next();
1133+
}
1134+
Equal => {
1135+
self_b.next();
1136+
return self.a.next();
1137+
}
1138+
Greater => {
1139+
self_b.next();
1140+
}
10941141
}
1095-
Greater => {
1096-
self.b.next();
1142+
}
1143+
IntersectionOther::Search(set) => loop {
1144+
let e = self.a.next()?;
1145+
if set.contains(&e) {
1146+
return Some(e);
10971147
}
10981148
}
10991149
}
11001150
}
11011151

11021152
fn size_hint(&self) -> (usize, Option<usize>) {
1103-
(0, Some(min(self.a.len(), self.b.len())))
1153+
let b_len = match self.b {
1154+
IntersectionOther::Stitch(ref iter) => iter.len(),
1155+
IntersectionOther::Search(set) => set.len(),
1156+
};
1157+
(0, Some(min(self.a.len(), b_len)))
11041158
}
11051159
}
11061160

@@ -1140,3 +1194,21 @@ impl<'a, T: Ord> Iterator for Union<'a, T> {
11401194

11411195
#[stable(feature = "fused", since = "1.26.0")]
11421196
impl<T: Ord> FusedIterator for Union<'_, T> {}
1197+
1198+
#[cfg(test)]
1199+
mod tests {
1200+
use super::*;
1201+
1202+
#[test]
1203+
fn test_are_proportionate_for_intersection() {
1204+
assert!(are_proportionate_for_intersection(0, 0));
1205+
assert!(are_proportionate_for_intersection(0, 127));
1206+
assert!(!are_proportionate_for_intersection(0, 128));
1207+
assert!(are_proportionate_for_intersection(1, 255));
1208+
assert!(!are_proportionate_for_intersection(1, 256));
1209+
assert!(are_proportionate_for_intersection(127, 0));
1210+
assert!(!are_proportionate_for_intersection(128, 0));
1211+
assert!(are_proportionate_for_intersection(255, 1));
1212+
assert!(!are_proportionate_for_intersection(256, 1));
1213+
}
1214+
}

src/liballoc/tests/btree/set.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,19 @@ fn test_intersection() {
6969
check_intersection(&[11, 1, 3, 77, 103, 5, -5],
7070
&[2, 11, 77, -9, -42, 5, 3],
7171
&[3, 5, 11, 77]);
72+
73+
let mut large = [0i32; 512];
74+
for i in 0..512 {
75+
large[i] = i as i32
76+
}
77+
check_intersection(&large[..], &[], &[]);
78+
check_intersection(&large[..], &[-1], &[]);
79+
check_intersection(&large[..], &[42], &[42]);
80+
check_intersection(&large[..], &[4, 2], &[2, 4]);
81+
check_intersection(&[], &large[..], &[]);
82+
check_intersection(&[-1], &large[..], &[]);
83+
check_intersection(&[42], &large[..], &[42]);
84+
check_intersection(&[4, 2], &large[..], &[2, 4]);
7285
}
7386

7487
#[test]

0 commit comments

Comments
 (0)