diff --git a/src/libcollections/bitv.rs b/src/libcollections/bitv.rs index 20d7c3ef2cf..b480b88b4d4 100644 --- a/src/libcollections/bitv.rs +++ b/src/libcollections/bitv.rs @@ -24,22 +24,6 @@ use std::hash; use {Collection, Mutable, Set, MutableSet}; use vec::Vec; -/** - * A mask that has a 1 for each defined bit in the n'th element of a `Bitv`, - * assuming n bits. - */ -#[inline] -fn big_mask(nbits: uint, elem: uint) -> uint { - let rmd = nbits % uint::BITS; - let nelems = (nbits + uint::BITS - 1) / uint::BITS; - - if elem < nelems - 1 || rmd == 0 { - !0 - } else { - (1 << rmd) - 1 - } -} - /// The bitvector type /// /// # Example @@ -75,35 +59,47 @@ pub struct Bitv { nbits: uint } -struct Words<'a> { +struct MaskWords<'a> { iter: slice::Items<'a, uint>, + next_word: Option<&'a uint>, + last_word_mask: uint, offset: uint } -impl<'a> Iterator<(uint, uint)> for Words<'a> { +impl<'a> Iterator<(uint, uint)> for MaskWords<'a> { /// Returns (offset, word) fn next<'a>(&'a mut self) -> Option<(uint, uint)> { - let ret = self.iter.next().map(|&n| (self.offset, n)); - self.offset += 1; - ret + let ret = self.next_word; + match ret { + Some(&w) => { + self.next_word = self.iter.next(); + self.offset += 1; + // The last word may need to be masked + if self.next_word.is_none() { + Some((self.offset - 1, w & self.last_word_mask)) + } else { + Some((self.offset - 1, w)) + } + }, + None => None + } } } impl Bitv { #[inline] - fn process(&mut self, other: &Bitv, nbits: uint, - op: |uint, uint| -> uint) -> bool { + fn process(&mut self, other: &Bitv, op: |uint, uint| -> uint) -> bool { let len = other.storage.len(); assert_eq!(self.storage.len(), len); let mut changed = false; - for (i, (a, b)) in self.storage.mut_iter() - .zip(other.storage.iter()) - .enumerate() { - let mask = big_mask(nbits, i); - let w0 = *a & mask; - let w1 = *b & mask; - let w = op(w0, w1) & mask; - if w0 != w { + // Notice: `a` is *not* masked here, which is fine as long as + // `op` is a bitwise operation, since any bits that should've + // been masked were fine to change anyway. `b` is masked to + // make sure its unmasked bits do not cause damage. + for (a, (_, b)) in self.storage.mut_iter() + .zip(other.mask_words(0)) { + let w = op(*a, b); + if *a != w { changed = true; *a = w; } @@ -112,10 +108,20 @@ impl Bitv { } #[inline] - #[inline] - fn words<'a>(&'a self, start: uint) -> Words<'a> { - Words { - iter: self.storage.slice_from(start).iter(), + fn mask_words<'a>(&'a self, mut start: uint) -> MaskWords<'a> { + if start > self.storage.len() { + start = self.storage.len(); + } + let mut iter = self.storage.slice_from(start).iter(); + MaskWords { + next_word: iter.next(), + iter: iter, + last_word_mask: { + let rem = self.nbits % uint::BITS; + if rem > 0 { + (1 << rem) - 1 + } else { !0 } + }, offset: start } } @@ -124,15 +130,8 @@ impl Bitv { /// to `init`. pub fn new(nbits: uint, init: bool) -> Bitv { Bitv { - storage: { - let nelems = (nbits + uint::BITS - 1) / uint::BITS; - let mut v = Vec::from_elem(nelems, if init { !0u } else { 0u }); - // Zero out any remainder bits - if nbits % uint::BITS > 0 { - *v.get_mut(nelems - 1) &= (1 << nbits % uint::BITS) - 1; - } - v - }, + storage: Vec::from_elem((nbits + uint::BITS - 1) / uint::BITS, + if init { !0u } else { 0u }), nbits: nbits } } @@ -145,8 +144,7 @@ impl Bitv { */ #[inline] pub fn union(&mut self, other: &Bitv) -> bool { - let nbits = self.nbits; - self.process(other, nbits, |w1, w2| w1 | w2) + self.process(other, |w1, w2| w1 | w2) } /** @@ -157,8 +155,7 @@ impl Bitv { */ #[inline] pub fn intersect(&mut self, other: &Bitv) -> bool { - let nbits = self.nbits; - self.process(other, nbits, |w1, w2| w1 & w2) + self.process(other, |w1, w2| w1 & w2) } /** @@ -169,8 +166,7 @@ impl Bitv { */ #[inline] pub fn assign(&mut self, other: &Bitv) -> bool { - let nbits = self.nbits; - self.process(other, nbits, |_, w| w) + self.process(other, |_, w| w) } /// Retrieve the value at index `i` @@ -227,20 +223,18 @@ impl Bitv { */ #[inline] pub fn difference(&mut self, other: &Bitv) -> bool { - let nbits = self.nbits; - self.process(other, nbits, |w1, w2| w1 & !w2) + self.process(other, |w1, w2| w1 & !w2) } /// Returns `true` if all bits are 1 #[inline] pub fn all(&self) -> bool { - for (i, &elem) in self.storage.iter().enumerate() { - let mask = big_mask(self.nbits, i); - if elem & mask != mask { - return false; - } - } - true + let mut last_word = !0u; + // Check that every word but the last is all-ones... + self.mask_words(0).all(|(_, elem)| + { let tmp = last_word; last_word = elem; tmp == !0u }) && + // ...and that the last word is ones as far as it needs to be + (last_word == ((1 << self.nbits % uint::BITS) - 1) || last_word == !0u) } /// Returns an iterator over the elements of the vector in order. @@ -265,13 +259,7 @@ impl Bitv { /// Returns `true` if all bits are 0 pub fn none(&self) -> bool { - for (i, &elem) in self.storage.iter().enumerate() { - let mask = big_mask(self.nbits, i); - if elem & mask != 0 { - return false; - } - } - true + self.mask_words(0).all(|(_, w)| w == 0) } #[inline] @@ -397,8 +385,8 @@ impl fmt::Show for Bitv { impl hash::Hash for Bitv { fn hash(&self, state: &mut S) { self.nbits.hash(state); - for (i, elem) in self.storage.iter().enumerate() { - (elem & big_mask(self.nbits, i)).hash(state); + for (_, elem) in self.mask_words(0) { + elem.hash(state); } } } @@ -409,13 +397,7 @@ impl cmp::PartialEq for Bitv { if self.nbits != other.nbits { return false; } - for (i, (&w1, &w2)) in self.storage.iter().zip(other.storage.iter()).enumerate() { - let mask = big_mask(self.nbits, i); - if w1 & mask != w2 & mask { - return false; - } - } - true + self.mask_words(0).zip(other.mask_words(0)).all(|((_, w1), (_, w2))| w1 == w2) } } @@ -546,7 +528,7 @@ impl BitvSet { // Unwrap Bitvs let &BitvSet(ref mut self_bitv) = self; let &BitvSet(ref other_bitv) = other; - for (i, w) in other_bitv.words(0) { + for (i, w) in other_bitv.mask_words(0) { let old = *self_bitv.storage.get(i); let new = f(old, w); *self_bitv.storage.get_mut(i) = new; @@ -563,7 +545,7 @@ impl BitvSet { let n = bitv.storage.iter().rev().take_while(|&&n| n == 0).count(); // Truncate let trunc_len = cmp::max(old_len - n, 1); - bitv.storage.truncate(cmp::max(old_len - n, 1)); + bitv.storage.truncate(trunc_len); bitv.nbits = trunc_len * uint::BITS; } @@ -710,6 +692,12 @@ impl MutableSet for BitvSet { } let &BitvSet(ref mut bitv) = self; if value >= bitv.nbits { + // If we are increasing nbits, make sure we mask out any previously-unconsidered bits + let old_rem = bitv.nbits % uint::BITS; + if old_rem != 0 { + let old_last_word = (bitv.nbits + uint::BITS - 1) / uint::BITS - 1; + *bitv.storage.get_mut(old_last_word) &= (1 << old_rem) - 1; + } bitv.nbits = value + 1; } bitv.set(value, true); @@ -733,10 +721,10 @@ impl BitvSet { /// and w1/w2 are the words coming from the two vectors self, other. fn commons<'a>(&'a self, other: &'a BitvSet) -> Map<((uint, uint), (uint, uint)), (uint, uint, uint), - Zip, Words<'a>>> { + Zip, MaskWords<'a>>> { let &BitvSet(ref self_bitv) = self; let &BitvSet(ref other_bitv) = other; - self_bitv.words(0).zip(other_bitv.words(0)) + self_bitv.mask_words(0).zip(other_bitv.mask_words(0)) .map(|((i, w1), (_, w2))| (i * uint::BITS, w1, w2)) } @@ -748,17 +736,17 @@ impl BitvSet { /// is true if the word comes from `self`, and `false` if it comes from /// `other`. fn outliers<'a>(&'a self, other: &'a BitvSet) - -> Map<(uint, uint), (bool, uint, uint), Words<'a>> { + -> Map<(uint, uint), (bool, uint, uint), MaskWords<'a>> { let slen = self.capacity() / uint::BITS; let olen = other.capacity() / uint::BITS; let &BitvSet(ref self_bitv) = self; let &BitvSet(ref other_bitv) = other; if olen < slen { - self_bitv.words(olen) + self_bitv.mask_words(olen) .map(|(i, w)| (true, i * uint::BITS, w)) } else { - other_bitv.words(slen) + other_bitv.mask_words(slen) .map(|(i, w)| (false, i * uint::BITS, w)) } } @@ -1250,16 +1238,32 @@ mod tests { }); } + #[test] + fn test_bitv_masking() { + let b = Bitv::new(140, true); + let mut bs = BitvSet::from_bitv(b); + assert!(bs.contains(&139)); + assert!(!bs.contains(&140)); + assert!(bs.insert(150)); + assert!(!bs.contains(&140)); + assert!(!bs.contains(&149)); + assert!(bs.contains(&150)); + assert!(!bs.contains(&151)); + } + #[test] fn test_bitv_set_basic() { let mut b = BitvSet::new(); assert!(b.insert(3)); assert!(!b.insert(3)); assert!(b.contains(&3)); + assert!(b.insert(4)); + assert!(!b.insert(4)); + assert!(b.contains(&3)); assert!(b.insert(400)); assert!(!b.insert(400)); assert!(b.contains(&400)); - assert_eq!(b.len(), 2); + assert_eq!(b.len(), 3); } #[test]