Skip to content

Commit a379504

Browse files
committed
improve worst-case performance of BTreeSet intersection
1 parent 913ad6d commit a379504

File tree

3 files changed

+175
-72
lines changed

3 files changed

+175
-72
lines changed

src/liballoc/benches/btree/set.rs

+104-46
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,35 @@
11
use std::collections::BTreeSet;
2+
use std::collections::btree_set::Intersection;
23

34
use rand::{thread_rng, Rng};
45
use test::{black_box, Bencher};
56

6-
fn random(n1: u32, n2: u32) -> [BTreeSet<usize>; 2] {
7+
fn random(n1: usize, n2: usize) -> [BTreeSet<usize>; 2] {
78
let mut rng = thread_rng();
8-
let mut set1 = BTreeSet::new();
9-
let mut set2 = BTreeSet::new();
10-
for _ in 0..n1 {
11-
let i = rng.gen::<usize>();
12-
set1.insert(i);
13-
}
14-
for _ in 0..n2 {
15-
let i = rng.gen::<usize>();
16-
set2.insert(i);
9+
let mut sets = [BTreeSet::new(), BTreeSet::new()];
10+
for i in 0..2 {
11+
while sets[i].len() < [n1, n2][i] {
12+
sets[i].insert(rng.gen());
13+
}
1714
}
18-
[set1, set2]
15+
assert_eq!(sets[0].len(), n1);
16+
assert_eq!(sets[1].len(), n2);
17+
sets
1918
}
2019

21-
fn staggered(n1: u32, n2: u32) -> [BTreeSet<u32>; 2] {
22-
let mut even = BTreeSet::new();
23-
let mut odd = BTreeSet::new();
24-
for i in 0..n1 {
25-
even.insert(i * 2);
26-
}
27-
for i in 0..n2 {
28-
odd.insert(i * 2 + 1);
20+
fn stagger(n1: usize, factor: usize) -> [BTreeSet<u32>; 2] {
21+
let n2 = n1 * factor;
22+
let mut sets = [BTreeSet::new(), BTreeSet::new()];
23+
for i in 0..(n1 + n2) {
24+
let b = i % (factor + 1) != 0;
25+
sets[b as usize].insert(i as u32);
2926
}
30-
[even, odd]
27+
assert_eq!(sets[0].len(), n1);
28+
assert_eq!(sets[1].len(), n2);
29+
sets
3130
}
3231

33-
fn neg_vs_pos(n1: u32, n2: u32) -> [BTreeSet<i32>; 2] {
32+
fn neg_vs_pos(n1: usize, n2: usize) -> [BTreeSet<i32>; 2] {
3433
let mut neg = BTreeSet::new();
3534
let mut pos = BTreeSet::new();
3635
for i in -(n1 as i32)..=-1 {
@@ -39,22 +38,38 @@ fn neg_vs_pos(n1: u32, n2: u32) -> [BTreeSet<i32>; 2] {
3938
for i in 1..=(n2 as i32) {
4039
pos.insert(i);
4140
}
41+
assert_eq!(neg.len(), n1);
42+
assert_eq!(pos.len(), n2);
4243
[neg, pos]
4344
}
4445

45-
fn pos_vs_neg(n1: u32, n2: u32) -> [BTreeSet<i32>; 2] {
46-
let mut neg = BTreeSet::new();
47-
let mut pos = BTreeSet::new();
48-
for i in -(n1 as i32)..=-1 {
49-
neg.insert(i);
46+
fn pos_vs_neg(n1: usize, n2: usize) -> [BTreeSet<i32>; 2] {
47+
let mut sets = neg_vs_pos(n2, n1);
48+
sets.reverse();
49+
assert_eq!(sets[0].len(), n1);
50+
assert_eq!(sets[1].len(), n2);
51+
sets
52+
}
53+
54+
fn intersection_search<T>(sets: &[BTreeSet<T>; 2]) -> Intersection<T>
55+
where T: std::cmp::Ord
56+
{
57+
Intersection::Search {
58+
a_iter: sets[0].iter(),
59+
b_set: &sets[1],
5060
}
51-
for i in 1..=(n2 as i32) {
52-
pos.insert(i);
61+
}
62+
63+
fn intersection_stitch<T>(sets: &[BTreeSet<T>; 2]) -> Intersection<T>
64+
where T: std::cmp::Ord
65+
{
66+
Intersection::Stitch {
67+
a_iter: sets[0].iter(),
68+
b_iter: sets[1].iter(),
5369
}
54-
[pos, neg]
5570
}
5671

57-
macro_rules! set_intersection_bench {
72+
macro_rules! intersection_bench {
5873
($name: ident, $sets: expr) => {
5974
#[bench]
6075
pub fn $name(b: &mut Bencher) {
@@ -68,21 +83,64 @@ macro_rules! set_intersection_bench {
6883
})
6984
}
7085
};
86+
($name: ident, $sets: expr, $intersection_kind: ident) => {
87+
#[bench]
88+
pub fn $name(b: &mut Bencher) {
89+
// setup
90+
let sets = $sets;
91+
assert!(sets[0].len() >= 1);
92+
assert!(sets[1].len() >= sets[0].len());
93+
94+
// measure
95+
b.iter(|| {
96+
let x = $intersection_kind(&sets).count();
97+
black_box(x);
98+
})
99+
}
100+
};
71101
}
72102

73-
set_intersection_bench! {intersect_random_100, random(100, 100)}
74-
set_intersection_bench! {intersect_random_10k, random(10_000, 10_000)}
75-
set_intersection_bench! {intersect_random_10_vs_10k, random(10, 10_000)}
76-
set_intersection_bench! {intersect_random_10k_vs_10, random(10_000, 10)}
77-
set_intersection_bench! {intersect_staggered_100, staggered(100, 100)}
78-
set_intersection_bench! {intersect_staggered_10k, staggered(10_000, 10_000)}
79-
set_intersection_bench! {intersect_staggered_10_vs_10k, staggered(10, 10_000)}
80-
set_intersection_bench! {intersect_staggered_10k_vs_10, staggered(10_000, 10)}
81-
set_intersection_bench! {intersect_neg_vs_pos_100, neg_vs_pos(100, 100)}
82-
set_intersection_bench! {intersect_neg_vs_pos_10k, neg_vs_pos(10_000, 10_000)}
83-
set_intersection_bench! {intersect_neg_vs_pos_10_vs_10k,neg_vs_pos(10, 10_000)}
84-
set_intersection_bench! {intersect_neg_vs_pos_10k_vs_10,neg_vs_pos(10_000, 10)}
85-
set_intersection_bench! {intersect_pos_vs_neg_100, pos_vs_neg(100, 100)}
86-
set_intersection_bench! {intersect_pos_vs_neg_10k, pos_vs_neg(10_000, 10_000)}
87-
set_intersection_bench! {intersect_pos_vs_neg_10_vs_10k,pos_vs_neg(10, 10_000)}
88-
set_intersection_bench! {intersect_pos_vs_neg_10k_vs_10,pos_vs_neg(10_000, 10)}
103+
intersection_bench! {intersect_100_neg_vs_100_pos, neg_vs_pos(100, 100)}
104+
intersection_bench! {intersect_100_neg_vs_10k_pos, neg_vs_pos(100, 10_000)}
105+
intersection_bench! {intersect_100_pos_vs_100_neg, pos_vs_neg(100, 100)}
106+
intersection_bench! {intersect_100_pos_vs_10k_neg, pos_vs_neg(100, 10_000)}
107+
intersection_bench! {intersect_10k_neg_vs_100_pos, neg_vs_pos(10_000, 100)}
108+
intersection_bench! {intersect_10k_neg_vs_10k_pos, neg_vs_pos(10_000, 10_000)}
109+
intersection_bench! {intersect_10k_pos_vs_100_neg, pos_vs_neg(10_000, 100)}
110+
intersection_bench! {intersect_10k_pos_vs_10k_neg, pos_vs_neg(10_000, 10_000)}
111+
intersection_bench! {intersect_random_100_vs_100_actual,random(100, 100)}
112+
intersection_bench! {intersect_random_100_vs_100_search,random(100, 100), intersection_search}
113+
intersection_bench! {intersect_random_100_vs_100_stitch,random(100, 100), intersection_stitch}
114+
intersection_bench! {intersect_random_100_vs_10k_actual,random(100, 10_000)}
115+
intersection_bench! {intersect_random_100_vs_10k_search,random(100, 10_000), intersection_search}
116+
intersection_bench! {intersect_random_100_vs_10k_stitch,random(100, 10_000), intersection_stitch}
117+
intersection_bench! {intersect_random_10k_vs_10k_actual,random(10_000, 10_000)}
118+
intersection_bench! {intersect_random_10k_vs_10k_search,random(10_000, 10_000), intersection_search}
119+
intersection_bench! {intersect_random_10k_vs_10k_stitch,random(10_000, 10_000), intersection_stitch}
120+
intersection_bench! {intersect_stagger_100_actual, stagger(100, 1)}
121+
intersection_bench! {intersect_stagger_100_search, stagger(100, 1), intersection_search}
122+
intersection_bench! {intersect_stagger_100_stitch, stagger(100, 1), intersection_stitch}
123+
intersection_bench! {intersect_stagger_10k_actual, stagger(10_000, 1)}
124+
intersection_bench! {intersect_stagger_10k_search, stagger(10_000, 1), intersection_search}
125+
intersection_bench! {intersect_stagger_10k_stitch, stagger(10_000, 1), intersection_stitch}
126+
intersection_bench! {intersect_stagger_1_actual, stagger(1, 1)}
127+
intersection_bench! {intersect_stagger_1_search, stagger(1, 1), intersection_search}
128+
intersection_bench! {intersect_stagger_1_stitch, stagger(1, 1), intersection_stitch}
129+
intersection_bench! {intersect_stagger_diff1_actual, stagger(100, 1 << 1)}
130+
intersection_bench! {intersect_stagger_diff1_search, stagger(100, 1 << 1), intersection_search}
131+
intersection_bench! {intersect_stagger_diff1_stitch, stagger(100, 1 << 1), intersection_stitch}
132+
intersection_bench! {intersect_stagger_diff2_actual, stagger(100, 1 << 2)}
133+
intersection_bench! {intersect_stagger_diff2_search, stagger(100, 1 << 2), intersection_search}
134+
intersection_bench! {intersect_stagger_diff2_stitch, stagger(100, 1 << 2), intersection_stitch}
135+
intersection_bench! {intersect_stagger_diff3_actual, stagger(100, 1 << 3)}
136+
intersection_bench! {intersect_stagger_diff3_search, stagger(100, 1 << 3), intersection_search}
137+
intersection_bench! {intersect_stagger_diff3_stitch, stagger(100, 1 << 3), intersection_stitch}
138+
intersection_bench! {intersect_stagger_diff4_actual, stagger(100, 1 << 4)}
139+
intersection_bench! {intersect_stagger_diff4_search, stagger(100, 1 << 4), intersection_search}
140+
intersection_bench! {intersect_stagger_diff4_stitch, stagger(100, 1 << 4), intersection_stitch}
141+
intersection_bench! {intersect_stagger_diff5_actual, stagger(100, 1 << 5)}
142+
intersection_bench! {intersect_stagger_diff5_search, stagger(100, 1 << 5), intersection_search}
143+
intersection_bench! {intersect_stagger_diff5_stitch, stagger(100, 1 << 5), intersection_stitch}
144+
intersection_bench! {intersect_stagger_diff6_actual, stagger(100, 1 << 6)}
145+
intersection_bench! {intersect_stagger_diff6_search, stagger(100, 1 << 6), intersection_search}
146+
intersection_bench! {intersect_stagger_diff6_stitch, stagger(100, 1 << 6), intersection_stitch}

src/liballoc/benches/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![feature(repr_simd)]
22
#![feature(test)]
3+
#![feature(benches_btree_set)]
34

45
extern crate test;
56

src/liballoc/collections/btree/set.rs

+70-26
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
use core::borrow::Borrow;
55
use core::cmp::Ordering::{self, Less, Greater, Equal};
6-
use core::cmp::{min, max};
6+
use core::cmp::max;
77
use core::fmt::{self, Debug};
88
use core::iter::{Peekable, FromIterator, FusedIterator};
99
use core::ops::{BitOr, BitAnd, BitXor, Sub, RangeBounds};
@@ -163,18 +163,34 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
163163
/// [`BTreeSet`]: struct.BTreeSet.html
164164
/// [`intersection`]: struct.BTreeSet.html#method.intersection
165165
#[stable(feature = "rust1", since = "1.0.0")]
166-
pub struct Intersection<'a, T: 'a> {
167-
a: Peekable<Iter<'a, T>>,
168-
b: Peekable<Iter<'a, T>>,
166+
pub enum Intersection<'a, T: 'a> {
167+
#[doc(hidden)]
168+
#[unstable(feature = "benches_btree_set", reason = "benchmarks for pull #58577", issue = "0")]
169+
Stitch {
170+
a_iter: Iter<'a, T>, // for size_hint, should be the smaller of the sets
171+
b_iter: Iter<'a, T>,
172+
},
173+
#[doc(hidden)]
174+
#[unstable(feature = "benches_btree_set", reason = "benchmarks for pull #58577", issue = "0")]
175+
Search {
176+
a_iter: Iter<'a, T>, // for size_hint, should be the smaller of the sets
177+
b_set: &'a BTreeSet<T>,
178+
},
169179
}
170180

171181
#[stable(feature = "collection_debug", since = "1.17.0")]
172182
impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
173183
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174-
f.debug_tuple("Intersection")
175-
.field(&self.a)
176-
.field(&self.b)
177-
.finish()
184+
match self {
185+
Intersection::Stitch { a_iter, b_iter } => f
186+
.debug_tuple("Intersection")
187+
.field(&a_iter)
188+
.field(&b_iter)
189+
.finish(),
190+
Intersection::Search { a_iter, b_set: _ } => {
191+
f.debug_tuple("Intersection").field(&a_iter).finish()
192+
}
193+
}
178194
}
179195
}
180196

@@ -326,9 +342,22 @@ impl<T: Ord> BTreeSet<T> {
326342
/// ```
327343
#[stable(feature = "rust1", since = "1.0.0")]
328344
pub fn intersection<'a>(&'a self, other: &'a BTreeSet<T>) -> Intersection<'a, T> {
329-
Intersection {
330-
a: self.iter().peekable(),
331-
b: other.iter().peekable(),
345+
let (a_set, b_set) = if self.len() <= other.len() {
346+
(self, other)
347+
} else {
348+
(other, self)
349+
};
350+
if a_set.len() > b_set.len() / 16 {
351+
Intersection::Stitch {
352+
a_iter: a_set.iter(),
353+
b_iter: b_set.iter(),
354+
}
355+
} else {
356+
// Iterate small set only and find matches in large set.
357+
Intersection::Search {
358+
a_iter: a_set.iter(),
359+
b_set,
360+
}
332361
}
333362
}
334363

@@ -1072,9 +1101,15 @@ impl<T: Ord> FusedIterator for SymmetricDifference<'_, T> {}
10721101
#[stable(feature = "rust1", since = "1.0.0")]
10731102
impl<T> Clone for Intersection<'_, T> {
10741103
fn clone(&self) -> Self {
1075-
Intersection {
1076-
a: self.a.clone(),
1077-
b: self.b.clone(),
1104+
match self {
1105+
Intersection::Stitch { a_iter, b_iter } => Intersection::Stitch {
1106+
a_iter: a_iter.clone(),
1107+
b_iter: b_iter.clone(),
1108+
},
1109+
Intersection::Search { a_iter, b_set } => Intersection::Search {
1110+
a_iter: a_iter.clone(),
1111+
b_set,
1112+
},
10781113
}
10791114
}
10801115
}
@@ -1083,24 +1118,33 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> {
10831118
type Item = &'a T;
10841119

10851120
fn next(&mut self) -> Option<&'a T> {
1086-
loop {
1087-
match Ord::cmp(self.a.peek()?, self.b.peek()?) {
1088-
Less => {
1089-
self.a.next();
1090-
}
1091-
Equal => {
1092-
self.b.next();
1093-
return self.a.next();
1094-
}
1095-
Greater => {
1096-
self.b.next();
1121+
match self {
1122+
Intersection::Stitch { a_iter, b_iter } => {
1123+
let mut a_next = a_iter.next()?;
1124+
let mut b_next = b_iter.next()?;
1125+
loop {
1126+
match Ord::cmp(a_next, b_next) {
1127+
Less => a_next = a_iter.next()?,
1128+
Greater => b_next = b_iter.next()?,
1129+
Equal => return Some(a_next),
1130+
}
10971131
}
10981132
}
1133+
Intersection::Search { a_iter, b_set } => loop {
1134+
let a_next = a_iter.next()?;
1135+
if b_set.contains(&a_next) {
1136+
return Some(a_next);
1137+
}
1138+
},
10991139
}
11001140
}
11011141

11021142
fn size_hint(&self) -> (usize, Option<usize>) {
1103-
(0, Some(min(self.a.len(), self.b.len())))
1143+
let max_size = match self {
1144+
Intersection::Stitch { a_iter, .. } => a_iter.len(),
1145+
Intersection::Search { a_iter, .. } => a_iter.len(),
1146+
};
1147+
(0, Some(max_size))
11041148
}
11051149
}
11061150

0 commit comments

Comments
 (0)