Skip to content

Fix BitvSet union_with and symmetric_difference_with #17558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 9, 2014
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 131 additions & 57 deletions src/libcollections/bitv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,20 @@ impl Bitv {
/// }
/// ```
pub fn with_capacity(nbits: uint, init: bool) -> Bitv {
Bitv {
let mut bitv = Bitv {
storage: Vec::from_elem((nbits + uint::BITS - 1) / uint::BITS,
if init { !0u } else { 0u }),
nbits: nbits
};

// Zero out any unused bits in the highest word if necessary
let used_bits = bitv.nbits % uint::BITS;
if init && used_bits != 0 {
let largest_used_word = (bitv.nbits + uint::BITS - 1) / uint::BITS - 1;
*bitv.storage.get_mut(largest_used_word) &= (1 << used_bits) - 1;
}

bitv
}

/// Retrieves the value at index `i`.
Expand Down Expand Up @@ -629,9 +638,9 @@ impl Bitv {
/// ```
pub fn reserve(&mut self, size: uint) {
let old_size = self.storage.len();
let size = (size + uint::BITS - 1) / uint::BITS;
if old_size < size {
self.storage.grow(size - old_size, 0);
let new_size = (size + uint::BITS - 1) / uint::BITS;
if old_size < new_size {
self.storage.grow(new_size - old_size, 0);
}
}

Expand Down Expand Up @@ -686,8 +695,15 @@ impl Bitv {
}
// Allocate new words, if needed
if new_nwords > self.storage.len() {
let to_add = new_nwords - self.storage.len();
self.storage.grow(to_add, full_value);
let to_add = new_nwords - self.storage.len();
self.storage.grow(to_add, full_value);

// Zero out and unused bits in the new tail word
if value {
let tail_word = new_nwords - 1;
let used_bits = new_nbits % uint::BITS;
*self.storage.get_mut(tail_word) &= (1 << used_bits) - 1;
}
}
// Adjust internal bit count
self.nbits = new_nbits;
Expand Down Expand Up @@ -970,9 +986,8 @@ impl<'a> RandomAccessIterator<bool> for Bits<'a> {
/// }
///
/// // Can convert back to a `Bitv`
/// let bv: Bitv = s.unwrap();
/// assert!(bv.eq_vec([true, true, false, true,
/// false, false, false, false]));
/// let bv: Bitv = s.into_bitv();
/// assert!(bv.get(3));
/// ```
#[deriving(Clone)]
pub struct BitvSet(Bitv);
Expand All @@ -993,7 +1008,8 @@ impl FromIterator<bool> for BitvSet {
impl Extendable<bool> for BitvSet {
#[inline]
fn extend<I: Iterator<bool>>(&mut self, iterator: I) {
self.get_mut_ref().extend(iterator);
let &BitvSet(ref mut self_bitv) = self;
self_bitv.extend(iterator);
}
}

Expand Down Expand Up @@ -1049,7 +1065,8 @@ impl BitvSet {
/// ```
#[inline]
pub fn with_capacity(nbits: uint) -> BitvSet {
BitvSet(Bitv::with_capacity(nbits, false))
let bitv = Bitv::with_capacity(nbits, false);
BitvSet::from_bitv(bitv)
}

/// Creates a new bit vector set from the given bit vector.
Expand All @@ -1068,7 +1085,9 @@ impl BitvSet {
/// }
/// ```
#[inline]
pub fn from_bitv(bitv: Bitv) -> BitvSet {
pub fn from_bitv(mut bitv: Bitv) -> BitvSet {
// Mark every bit as valid
bitv.nbits = bitv.capacity();
BitvSet(bitv)
}

Expand Down Expand Up @@ -1102,7 +1121,10 @@ impl BitvSet {
/// ```
pub fn reserve(&mut self, size: uint) {
let &BitvSet(ref mut bitv) = self;
bitv.reserve(size)
bitv.reserve(size);
if bitv.nbits < size {
bitv.nbits = bitv.capacity();
}
}

/// Consumes this set to return the underlying bit vector.
Expand All @@ -1116,11 +1138,12 @@ impl BitvSet {
/// s.insert(0);
/// s.insert(3);
///
/// let bv = s.unwrap();
/// assert!(bv.eq_vec([true, false, false, true]));
/// let bv = s.into_bitv();
/// assert!(bv.get(0));
/// assert!(bv.get(3));
/// ```
#[inline]
pub fn unwrap(self) -> Bitv {
pub fn into_bitv(self) -> Bitv {
let BitvSet(bitv) = self;
bitv
}
Expand All @@ -1144,38 +1167,15 @@ impl BitvSet {
bitv
}

/// Returns a mutable reference to the underlying bit vector.
///
/// # Example
///
/// ```
/// use std::collections::BitvSet;
///
/// let mut s = BitvSet::new();
/// s.insert(0);
/// assert_eq!(s.contains(&0), true);
/// {
/// // Will free the set during bv's lifetime
/// let bv = s.get_mut_ref();
/// bv.set(0, false);
/// }
/// assert_eq!(s.contains(&0), false);
/// ```
#[inline]
pub fn get_mut_ref<'a>(&'a mut self) -> &'a mut Bitv {
let &BitvSet(ref mut bitv) = self;
bitv
}

#[inline]
fn other_op(&mut self, other: &BitvSet, f: |uint, uint| -> uint) {
// Expand the vector if necessary
self.reserve(other.capacity());

// Unwrap Bitvs
let &BitvSet(ref mut self_bitv) = self;
let &BitvSet(ref other_bitv) = other;

// Expand the vector if necessary
self_bitv.reserve(other_bitv.capacity());

// virtually pad other with 0's for equal lengths
let mut other_words = {
let (_, result) = match_words(self_bitv, other_bitv);
Expand Down Expand Up @@ -1376,9 +1376,10 @@ impl BitvSet {
///
/// let mut a = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let b = BitvSet::from_bitv(bitv::from_bytes([b]));
/// let res = BitvSet::from_bitv(bitv::from_bytes([res]));
///
/// a.union_with(&b);
/// assert_eq!(a.unwrap(), bitv::from_bytes([res]));
/// assert_eq!(a, res);
/// ```
#[inline]
pub fn union_with(&mut self, other: &BitvSet) {
Expand All @@ -1399,9 +1400,10 @@ impl BitvSet {
///
/// let mut a = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let b = BitvSet::from_bitv(bitv::from_bytes([b]));
/// let res = BitvSet::from_bitv(bitv::from_bytes([res]));
///
/// a.intersect_with(&b);
/// assert_eq!(a.unwrap(), bitv::from_bytes([res]));
/// assert_eq!(a, res);
/// ```
#[inline]
pub fn intersect_with(&mut self, other: &BitvSet) {
Expand All @@ -1424,15 +1426,17 @@ impl BitvSet {
///
/// let mut bva = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let bvb = BitvSet::from_bitv(bitv::from_bytes([b]));
/// let bva_b = BitvSet::from_bitv(bitv::from_bytes([a_b]));
/// let bvb_a = BitvSet::from_bitv(bitv::from_bytes([b_a]));
///
/// bva.difference_with(&bvb);
/// assert_eq!(bva.unwrap(), bitv::from_bytes([a_b]));
/// assert_eq!(bva, bva_b);
///
/// let bva = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let mut bvb = BitvSet::from_bitv(bitv::from_bytes([b]));
///
/// bvb.difference_with(&bva);
/// assert_eq!(bvb.unwrap(), bitv::from_bytes([b_a]));
/// assert_eq!(bvb, bvb_a);
/// ```
#[inline]
pub fn difference_with(&mut self, other: &BitvSet) {
Expand All @@ -1454,9 +1458,10 @@ impl BitvSet {
///
/// let mut a = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let b = BitvSet::from_bitv(bitv::from_bytes([b]));
/// let res = BitvSet::from_bitv(bitv::from_bytes([res]));
///
/// a.symmetric_difference_with(&b);
/// assert_eq!(a.unwrap(), bitv::from_bytes([res]));
/// assert_eq!(a, res);
/// ```
#[inline]
pub fn symmetric_difference_with(&mut self, other: &BitvSet) {
Expand Down Expand Up @@ -1538,20 +1543,14 @@ impl MutableSet<uint> for BitvSet {
if self.contains(&value) {
return false;
}

// Ensure we have enough space to hold the new element
if value >= self.capacity() {
let new_cap = cmp::max(value + 1, self.capacity() * 2);
self.reserve(new_cap);
}

let &BitvSet(ref mut bitv) = self;
if value >= bitv.nbits {
// If we are increasing nbits, make sure we mask out any previously-unconsidered bits
let old_rem = bitv.nbits % uint::BITS;
if old_rem != 0 {
let old_last_word = (bitv.nbits + uint::BITS - 1) / uint::BITS - 1;
*bitv.storage.get_mut(old_last_word) &= (1 << old_rem) - 1;
}
bitv.nbits = value + 1;
}
bitv.set(value, true);
return true;
}
Expand Down Expand Up @@ -2225,14 +2224,15 @@ mod tests {
assert!(a.insert(160));
assert!(a.insert(19));
assert!(a.insert(24));
assert!(a.insert(200));

assert!(b.insert(1));
assert!(b.insert(5));
assert!(b.insert(9));
assert!(b.insert(13));
assert!(b.insert(19));

let expected = [1, 3, 5, 9, 11, 13, 19, 24, 160];
let expected = [1, 3, 5, 9, 11, 13, 19, 24, 160, 200];
let actual = a.union(&b).collect::<Vec<uint>>();
assert_eq!(actual.as_slice(), expected.as_slice());
}
Expand Down Expand Up @@ -2281,6 +2281,27 @@ mod tests {
assert!(c.is_disjoint(&b))
}

#[test]
fn test_bitv_set_union_with() {
//a should grow to include larger elements
let mut a = BitvSet::new();
a.insert(0);
let mut b = BitvSet::new();
b.insert(5);
let expected = BitvSet::from_bitv(from_bytes([0b10000100]));
a.union_with(&b);
assert_eq!(a, expected);

// Standard
let mut a = BitvSet::from_bitv(from_bytes([0b10100010]));
let mut b = BitvSet::from_bitv(from_bytes([0b01100010]));
let c = a.clone();
a.union_with(&b);
b.union_with(&c);
assert_eq!(a.len(), 4);
assert_eq!(b.len(), 4);
}

#[test]
fn test_bitv_set_intersect_with() {
// Explicitly 0'ed bits
Expand Down Expand Up @@ -2311,6 +2332,59 @@ mod tests {
assert_eq!(b.len(), 2);
}

#[test]
fn test_bitv_set_difference_with() {
// Explicitly 0'ed bits
let mut a = BitvSet::from_bitv(from_bytes([0b00000000]));
let b = BitvSet::from_bitv(from_bytes([0b10100010]));
a.difference_with(&b);
assert!(a.is_empty());

// Uninitialized bits should behave like 0's
let mut a = BitvSet::new();
let b = BitvSet::from_bitv(from_bytes([0b11111111]));
a.difference_with(&b);
assert!(a.is_empty());

// Standard
let mut a = BitvSet::from_bitv(from_bytes([0b10100010]));
let mut b = BitvSet::from_bitv(from_bytes([0b01100010]));
let c = a.clone();
a.difference_with(&b);
b.difference_with(&c);
assert_eq!(a.len(), 1);
assert_eq!(b.len(), 1);
}

#[test]
fn test_bitv_set_symmetric_difference_with() {
//a should grow to include larger elements
let mut a = BitvSet::new();
a.insert(0);
a.insert(1);
let mut b = BitvSet::new();
b.insert(1);
b.insert(5);
let expected = BitvSet::from_bitv(from_bytes([0b10000100]));
a.symmetric_difference_with(&b);
assert_eq!(a, expected);

let mut a = BitvSet::from_bitv(from_bytes([0b10100010]));
let b = BitvSet::new();
let c = a.clone();
a.symmetric_difference_with(&b);
assert_eq!(a, c);

// Standard
let mut a = BitvSet::from_bitv(from_bytes([0b11100010]));
let mut b = BitvSet::from_bitv(from_bytes([0b01101010]));
let c = a.clone();
a.symmetric_difference_with(&b);
b.symmetric_difference_with(&c);
assert_eq!(a.len(), 2);
assert_eq!(b.len(), 2);
}

#[test]
fn test_bitv_set_eq() {
let a = BitvSet::from_bitv(from_bytes([0b10100010]));
Expand Down