@@ -155,6 +155,22 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
155
155
}
156
156
}
157
157
158
+ /// Whether the sizes of two sets are roughly the same order of magnitude.
159
+ ///
160
+ /// If they are, or if either set is empty, then their intersection
161
+ /// is efficiently calculated by iterating both sets jointly.
162
+ /// If they aren't, then it is more scalable to iterate over the small set
163
+ /// and find matches in the large set (except if the largest element in
164
+ /// the small set hardly surpasses the smallest element in the large set).
165
+ fn are_proportionate_for_intersection ( len1 : usize , len2 : usize ) -> bool {
166
+ let ( small, large) = if len1 <= len2 {
167
+ ( len1, len2)
168
+ } else {
169
+ ( len2, len1)
170
+ } ;
171
+ ( large >> 7 ) <= small
172
+ }
173
+
158
174
/// A lazy iterator producing elements in the intersection of `BTreeSet`s.
159
175
///
160
176
/// This `struct` is created by the [`intersection`] method on [`BTreeSet`].
@@ -165,7 +181,13 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
165
181
#[ stable( feature = "rust1" , since = "1.0.0" ) ]
166
182
pub struct Intersection < ' a , T : ' a > {
167
183
a : Peekable < Iter < ' a , T > > ,
168
- b : Peekable < Iter < ' a , T > > ,
184
+ b : IntersectionOther < ' a , T > ,
185
+ }
186
+
187
+ #[ derive( Debug ) ]
188
+ enum IntersectionOther < ' a , T > {
189
+ Stitch ( Peekable < Iter < ' a , T > > ) ,
190
+ Search ( & ' a BTreeSet < T > ) ,
169
191
}
170
192
171
193
#[ stable( feature = "collection_debug" , since = "1.17.0" ) ]
@@ -326,9 +348,21 @@ impl<T: Ord> BTreeSet<T> {
326
348
/// ```
327
349
#[ stable( feature = "rust1" , since = "1.0.0" ) ]
328
350
pub fn intersection < ' a > ( & ' a self , other : & ' a BTreeSet < T > ) -> Intersection < ' a , T > {
329
- Intersection {
330
- a : self . iter ( ) . peekable ( ) ,
331
- b : other. iter ( ) . peekable ( ) ,
351
+ if are_proportionate_for_intersection ( self . len ( ) , other. len ( ) ) {
352
+ Intersection {
353
+ a : self . iter ( ) . peekable ( ) ,
354
+ b : IntersectionOther :: Stitch ( other. iter ( ) . peekable ( ) ) ,
355
+ }
356
+ } else if self . len ( ) <= other. len ( ) {
357
+ Intersection {
358
+ a : self . iter ( ) . peekable ( ) ,
359
+ b : IntersectionOther :: Search ( & other) ,
360
+ }
361
+ } else {
362
+ Intersection {
363
+ a : other. iter ( ) . peekable ( ) ,
364
+ b : IntersectionOther :: Search ( & self ) ,
365
+ }
332
366
}
333
367
}
334
368
@@ -1069,6 +1103,14 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> {
1069
1103
#[ stable( feature = "fused" , since = "1.26.0" ) ]
1070
1104
impl < T : Ord > FusedIterator for SymmetricDifference < ' _ , T > { }
1071
1105
1106
+ impl < ' a , T > Clone for IntersectionOther < ' a , T > {
1107
+ fn clone ( & self ) -> IntersectionOther < ' a , T > {
1108
+ match self {
1109
+ IntersectionOther :: Stitch ( ref iter) => IntersectionOther :: Stitch ( iter. clone ( ) ) ,
1110
+ IntersectionOther :: Search ( set) => IntersectionOther :: Search ( set) ,
1111
+ }
1112
+ }
1113
+ }
1072
1114
#[ stable( feature = "rust1" , since = "1.0.0" ) ]
1073
1115
impl < T > Clone for Intersection < ' _ , T > {
1074
1116
fn clone ( & self ) -> Self {
@@ -1083,24 +1125,36 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> {
1083
1125
type Item = & ' a T ;
1084
1126
1085
1127
fn next ( & mut self ) -> Option < & ' a T > {
1086
- loop {
1087
- match Ord :: cmp ( self . a . peek ( ) ?, self . b . peek ( ) ?) {
1088
- Less => {
1089
- self . a . next ( ) ;
1090
- }
1091
- Equal => {
1092
- self . b . next ( ) ;
1093
- return self . a . next ( ) ;
1128
+ match self . b {
1129
+ IntersectionOther :: Stitch ( ref mut self_b) => loop {
1130
+ match Ord :: cmp ( self . a . peek ( ) ?, self_b. peek ( ) ?) {
1131
+ Less => {
1132
+ self . a . next ( ) ;
1133
+ }
1134
+ Equal => {
1135
+ self_b. next ( ) ;
1136
+ return self . a . next ( ) ;
1137
+ }
1138
+ Greater => {
1139
+ self_b. next ( ) ;
1140
+ }
1094
1141
}
1095
- Greater => {
1096
- self . b . next ( ) ;
1142
+ }
1143
+ IntersectionOther :: Search ( set) => loop {
1144
+ let e = self . a . next ( ) ?;
1145
+ if set. contains ( & e) {
1146
+ return Some ( e) ;
1097
1147
}
1098
1148
}
1099
1149
}
1100
1150
}
1101
1151
1102
1152
fn size_hint ( & self ) -> ( usize , Option < usize > ) {
1103
- ( 0 , Some ( min ( self . a . len ( ) , self . b . len ( ) ) ) )
1153
+ let b_len = match self . b {
1154
+ IntersectionOther :: Stitch ( ref iter) => iter. len ( ) ,
1155
+ IntersectionOther :: Search ( set) => set. len ( ) ,
1156
+ } ;
1157
+ ( 0 , Some ( min ( self . a . len ( ) , b_len) ) )
1104
1158
}
1105
1159
}
1106
1160
@@ -1140,3 +1194,21 @@ impl<'a, T: Ord> Iterator for Union<'a, T> {
1140
1194
1141
1195
#[ stable( feature = "fused" , since = "1.26.0" ) ]
1142
1196
impl < T : Ord > FusedIterator for Union < ' _ , T > { }
1197
+
1198
+ #[ cfg( test) ]
1199
+ mod tests {
1200
+ use super :: * ;
1201
+
1202
+ #[ test]
1203
+ fn test_are_proportionate_for_intersection ( ) {
1204
+ assert ! ( are_proportionate_for_intersection( 0 , 0 ) ) ;
1205
+ assert ! ( are_proportionate_for_intersection( 0 , 127 ) ) ;
1206
+ assert ! ( !are_proportionate_for_intersection( 0 , 128 ) ) ;
1207
+ assert ! ( are_proportionate_for_intersection( 1 , 255 ) ) ;
1208
+ assert ! ( !are_proportionate_for_intersection( 1 , 256 ) ) ;
1209
+ assert ! ( are_proportionate_for_intersection( 127 , 0 ) ) ;
1210
+ assert ! ( !are_proportionate_for_intersection( 128 , 0 ) ) ;
1211
+ assert ! ( are_proportionate_for_intersection( 255 , 1 ) ) ;
1212
+ assert ! ( !are_proportionate_for_intersection( 256 , 1 ) ) ;
1213
+ }
1214
+ }
0 commit comments