diff --git a/src/libcollections/bitv.rs b/src/libcollections/bitv.rs index 234393ce561..59ad188b530 100644 --- a/src/libcollections/bitv.rs +++ b/src/libcollections/bitv.rs @@ -15,7 +15,8 @@ use core::prelude::*; use core::cmp; use core::default::Default; use core::fmt; -use core::iter::{Enumerate, Repeat, Map, Zip}; +use core::iter::{Map, Zip}; +use core::option; use core::ops; use core::slice; use core::uint; @@ -268,6 +269,33 @@ fn die() -> ! { fail!("Tried to do operation on bit vectors with different sizes"); } +enum WordsVariant<'a> { + NoneIter, + OneIter(option::Item), + VecIter(slice::Items<'a, uint>) +} + +struct Words<'a> { + rep: WordsVariant<'a>, + offset: uint +} + +impl<'a> Iterator<(uint, uint)> for Words<'a> { + /// Returns (offset, word) + fn next<'a>(&'a mut self) -> Option<(uint, uint)> { + let ret = match self.rep { + NoneIter => None, + OneIter(ref mut it) => it.next(), + VecIter(ref mut it) => it.next().map(|n| *n) + }; + self.offset += 1; + match ret { + Some(n) => Some((self.offset - 1, n)), + None => None + } + } +} + impl Bitv { #[inline] fn do_op(&mut self, op: Op, other: &Bitv) -> bool { @@ -295,6 +323,18 @@ impl Bitv { } } } + + #[inline] + fn words<'a>(&'a self, start: uint) -> Words<'a> { + Words { + rep: match self.rep { + Small(_) if start > 0 => NoneIter, + Small(ref s) => OneIter(Some(s.bits).move_iter()), + Big(ref b) => VecIter(b.storage.slice_from(start).iter()) + }, + offset: start + } + } } impl Bitv { @@ -687,15 +727,8 @@ impl<'a> RandomAccessIterator for Bits<'a> { /// It should also be noted that the amount of storage necessary for holding a /// set of objects is proportional to the maximum of the objects when viewed /// as a `uint`. -#[deriving(Clone)] -pub struct BitvSet { - size: uint, - - // In theory this is a `Bitv` instead of always a `BigBitv`, but knowing that - // there's an array of storage makes our lives a whole lot easier when - // performing union/intersection/etc operations - bitv: BigBitv -} +#[deriving(Clone, PartialEq, Eq)] +pub struct BitvSet(Bitv); impl Default for BitvSet { #[inline] @@ -705,56 +738,87 @@ impl Default for BitvSet { impl BitvSet { /// Creates a new bit vector set with initially no contents pub fn new() -> BitvSet { - BitvSet{ size: 0, bitv: BigBitv::new(vec!(0)) } + BitvSet(Bitv::new(0, false)) } /// Creates a new bit vector set from the given bit vector pub fn from_bitv(bitv: Bitv) -> BitvSet { - let mut size = 0; - bitv.ones(|_| { - size += 1; - true - }); - let Bitv{rep, ..} = bitv; - match rep { - Big(b) => BitvSet{ size: size, bitv: b }, - Small(SmallBitv{bits}) => - BitvSet{ size: size, bitv: BigBitv{ storage: vec!(bits) } }, - } + BitvSet(bitv) } /// Returns the capacity in bits for this bit vector. Inserting any /// element less than this amount will not trigger a resizing. - pub fn capacity(&self) -> uint { self.bitv.storage.len() * uint::BITS } + pub fn capacity(&self) -> uint { + let &BitvSet(ref bitv) = self; + match bitv.rep { + Small(_) => uint::BITS, + Big(ref s) => s.storage.len() * uint::BITS + } + } /// Consumes this set to return the underlying bit vector pub fn unwrap(self) -> Bitv { - let cap = self.capacity(); - let BitvSet{bitv, ..} = self; - return Bitv{ nbits:cap, rep: Big(bitv) }; + let BitvSet(bitv) = self; + bitv + } + + #[inline] + /// Grows the vector to be able to store bits with indices `[0, size - 1]` + fn grow(&mut self, size: uint) { + let &BitvSet(ref mut bitv) = self; + let small_to_big = match bitv.rep { Small(s) if size >= uint::BITS => Some(s.bits), _ => None }; + if small_to_big.is_some() { + bitv.rep = Big(BigBitv { storage: vec![small_to_big.unwrap()] }); + } + match bitv.rep { + Small(_) => {}, + Big(ref mut b) => { + let size = (size + uint::BITS - 1) / uint::BITS; + if b.storage.len() < size { + b.storage.grow(size, &0); + } + } + }; } #[inline] fn other_op(&mut self, other: &BitvSet, f: |uint, uint| -> uint) { - fn nbits(mut w: uint) -> uint { - let mut bits = 0; - for _ in range(0u, uint::BITS) { - if w == 0 { - break; + // Expand the vector if necessary + self.grow(other.capacity()); + // Unwrap Bitvs + let &BitvSet(ref mut self_bitv) = self; + let &BitvSet(ref other_bitv) = other; + for (i, w) in other_bitv.words(0) { + match self_bitv.rep { + Small(ref mut s) => { s.bits = f(s.bits, w); } + Big(ref mut b) => { + let old = *b.storage.get(i); + let new = f(old, w); + *b.storage.get_mut(i) = new; + *b.storage.get_mut(i) = f(*b.storage.get(i), w); } - bits += w & 1; - w >>= 1; } - return bits; } - if self.capacity() < other.capacity() { - self.bitv.storage.grow(other.capacity() / uint::BITS, &0); - } - for (i, &w) in other.bitv.storage.iter().enumerate() { - let old = *self.bitv.storage.get(i); - let new = f(old, w); - *self.bitv.storage.get_mut(i) = new; - self.size += nbits(new) - nbits(old); + } + + #[inline] + /// Truncate the underlying vector to the least length required + pub fn shrink_to_fit(&mut self) { + let &BitvSet(ref mut bitv) = self; + // Two steps: we borrow b as immutable to get the length... + let old_len = match bitv.rep { + Small(_) => 1, + Big(ref b) => b.storage.len() + }; + // ...and as mutable to change it. + match bitv.rep { + Small(_) => {}, + Big(ref mut b) => { + let n = b.storage.iter().rev().take_while(|&&n| n == 0).count(); + let trunc_len = cmp::max(old_len - n, 1); + b.storage.truncate(trunc_len); + bitv.nbits = trunc_len * uint::BITS; + } } } @@ -818,29 +882,6 @@ impl BitvSet { } } -impl cmp::PartialEq for BitvSet { - fn eq(&self, other: &BitvSet) -> bool { - if self.size != other.size { - return false; - } - for (_, w1, w2) in self.commons(other) { - if w1 != w2 { - return false; - } - } - for (_, _, w) in self.outliers(other) { - if w != 0 { - return false; - } - } - return true; - } - - fn ne(&self, other: &BitvSet) -> bool { !self.eq(other) } -} - -impl cmp::Eq for BitvSet {} - impl fmt::Show for BitvSet { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { try!(write!(fmt, "{{")); @@ -866,19 +907,26 @@ impl hash::Hash for BitvSet { impl Collection for BitvSet { #[inline] - fn len(&self) -> uint { self.size } + fn len(&self) -> uint { + let &BitvSet(ref bitv) = self; + match bitv.rep { + Small(ref s) => s.bits.count_ones(), + Big(ref b) => b.storage.iter().fold(0, |acc, &n| acc + n.count_ones()) + } + } } impl Mutable for BitvSet { fn clear(&mut self) { - self.bitv.each_storage(|w| { *w = 0; true }); - self.size = 0; + let &BitvSet(ref mut bitv) = self; + bitv.clear(); } } impl Set for BitvSet { fn contains(&self, value: &uint) -> bool { - *value < self.bitv.storage.len() * uint::BITS && self.bitv.get(*value) + let &BitvSet(ref bitv) = self; + *value < bitv.nbits && bitv.get(*value) } fn is_disjoint(&self, other: &BitvSet) -> bool { @@ -914,14 +962,15 @@ impl MutableSet for BitvSet { if self.contains(&value) { return false; } - let nbits = self.capacity(); - if value >= nbits { - let newsize = cmp::max(value, nbits * 2) / uint::BITS + 1; - assert!(newsize > self.bitv.storage.len()); - self.bitv.storage.grow(newsize, &0); + if value >= self.capacity() { + let new_cap = cmp::max(value + 1, self.capacity() * 2); + self.grow(new_cap); } - self.size += 1; - self.bitv.set(value, true); + let &BitvSet(ref mut bitv) = self; + if value >= bitv.nbits { + bitv.nbits = value + 1; + } + bitv.set(value, true); return true; } @@ -929,16 +978,8 @@ impl MutableSet for BitvSet { if !self.contains(value) { return false; } - self.size -= 1; - self.bitv.set(*value, false); - - // Attempt to truncate our storage - let mut i = self.bitv.storage.len(); - while i > 1 && *self.bitv.storage.get(i - 1) == 0 { - i -= 1; - } - self.bitv.storage.truncate(i); - + let &BitvSet(ref mut bitv) = self; + bitv.set(*value, false); return true; } } @@ -949,12 +990,12 @@ impl BitvSet { /// w1, w2) where the bit location is the number of bits offset so far, /// and w1/w2 are the words coming from the two vectors self, other. fn commons<'a>(&'a self, other: &'a BitvSet) - -> Map<'static, ((uint, &'a uint), &'a Vec), (uint, uint, uint), - Zip>, Repeat<&'a Vec>>> { - let min = cmp::min(self.bitv.storage.len(), other.bitv.storage.len()); - self.bitv.storage.slice(0, min).iter().enumerate() - .zip(Repeat::new(&other.bitv.storage)) - .map(|((i, &w), o_store)| (i * uint::BITS, w, *o_store.get(i))) + -> Map<((uint, uint), (uint, uint)), (uint, uint, uint), + Zip, Words<'a>>> { + let &BitvSet(ref self_bitv) = self; + let &BitvSet(ref other_bitv) = other; + self_bitv.words(0).zip(other_bitv.words(0)) + .map(|((i, w1), (_, w2))| (i * uint::BITS, w1, w2)) } /// Visits each word in `self` or `other` that extends beyond the other. This @@ -965,19 +1006,18 @@ impl BitvSet { /// is true if the word comes from `self`, and `false` if it comes from /// `other`. fn outliers<'a>(&'a self, other: &'a BitvSet) - -> Map<'static, ((uint, &'a uint), uint), (bool, uint, uint), - Zip>, Repeat>> { - let slen = self.bitv.storage.len(); - let olen = other.bitv.storage.len(); + -> Map<(uint, uint), (bool, uint, uint), Words<'a>> { + let slen = self.capacity() / uint::BITS; + let olen = other.capacity() / uint::BITS; + let &BitvSet(ref self_bitv) = self; + let &BitvSet(ref other_bitv) = other; if olen < slen { - self.bitv.storage.slice_from(olen).iter().enumerate() - .zip(Repeat::new(olen)) - .map(|((i, &w), min)| (true, (i + min) * uint::BITS, w)) + self_bitv.words(olen) + .map(|(i, w)| (true, i * uint::BITS, w)) } else { - other.bitv.storage.slice_from(slen).iter().enumerate() - .zip(Repeat::new(slen)) - .map(|((i, &w), min)| (false, (i + min) * uint::BITS, w)) + other_bitv.words(slen) + .map(|(i, w)| (false, i * uint::BITS, w)) } } } @@ -1600,6 +1640,7 @@ mod tests { assert!(a.insert(1000)); assert!(a.remove(&1000)); + a.shrink_to_fit(); assert_eq!(a.capacity(), uint::BITS); }