From a698b81ebfe02a614f9b41e68c6b604597a81229 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Wed, 2 Jul 2014 08:24:55 -0700 Subject: [PATCH] collections::bitv: ensure correct masking behaviour The internal masking behaviour for `Bitv` is now defined as: - Any entirely words in self.storage must be all zeroes. - Any partially used words may have anything at all in their unused bits. This means: - When decreasing self.nbits, care must be taken that any no-longer-used words are zeroed out. - When increasing self.nbits, care must be taken that any newly-unmasked bits are set to their correct values. - When reading words, care should be taken that the values of unused bits are not used. (Preferably, use `Bitv::mask_words` which zeroes them out for you.) The old behaviour was that every unused bit was always set to zero. The problem with this is that unused bits are almost never read, so forgetting to do this will result in very subtle and hard-to-track down bugs. This way the responsibility for masking falls on the places which might cause unused bits to be read: for now, this is only `Bitv::mask_words` and `BitvSet::insert`. --- src/libcollections/bitv.rs | 170 +++++++++++++++++++------------------ 1 file changed, 87 insertions(+), 83 deletions(-) diff --git a/src/libcollections/bitv.rs b/src/libcollections/bitv.rs index 20d7c3ef2cf..b480b88b4d4 100644 --- a/src/libcollections/bitv.rs +++ b/src/libcollections/bitv.rs @@ -24,22 +24,6 @@ use std::hash; use {Collection, Mutable, Set, MutableSet}; use vec::Vec; -/** - * A mask that has a 1 for each defined bit in the n'th element of a `Bitv`, - * assuming n bits. - */ -#[inline] -fn big_mask(nbits: uint, elem: uint) -> uint { - let rmd = nbits % uint::BITS; - let nelems = (nbits + uint::BITS - 1) / uint::BITS; - - if elem < nelems - 1 || rmd == 0 { - !0 - } else { - (1 << rmd) - 1 - } -} - /// The bitvector type /// /// # Example @@ -75,35 +59,47 @@ pub struct Bitv { nbits: uint } -struct Words<'a> { +struct MaskWords<'a> { iter: slice::Items<'a, uint>, + next_word: Option<&'a uint>, + last_word_mask: uint, offset: uint } -impl<'a> Iterator<(uint, uint)> for Words<'a> { +impl<'a> Iterator<(uint, uint)> for MaskWords<'a> { /// Returns (offset, word) fn next<'a>(&'a mut self) -> Option<(uint, uint)> { - let ret = self.iter.next().map(|&n| (self.offset, n)); - self.offset += 1; - ret + let ret = self.next_word; + match ret { + Some(&w) => { + self.next_word = self.iter.next(); + self.offset += 1; + // The last word may need to be masked + if self.next_word.is_none() { + Some((self.offset - 1, w & self.last_word_mask)) + } else { + Some((self.offset - 1, w)) + } + }, + None => None + } } } impl Bitv { #[inline] - fn process(&mut self, other: &Bitv, nbits: uint, - op: |uint, uint| -> uint) -> bool { + fn process(&mut self, other: &Bitv, op: |uint, uint| -> uint) -> bool { let len = other.storage.len(); assert_eq!(self.storage.len(), len); let mut changed = false; - for (i, (a, b)) in self.storage.mut_iter() - .zip(other.storage.iter()) - .enumerate() { - let mask = big_mask(nbits, i); - let w0 = *a & mask; - let w1 = *b & mask; - let w = op(w0, w1) & mask; - if w0 != w { + // Notice: `a` is *not* masked here, which is fine as long as + // `op` is a bitwise operation, since any bits that should've + // been masked were fine to change anyway. `b` is masked to + // make sure its unmasked bits do not cause damage. + for (a, (_, b)) in self.storage.mut_iter() + .zip(other.mask_words(0)) { + let w = op(*a, b); + if *a != w { changed = true; *a = w; } @@ -112,10 +108,20 @@ impl Bitv { } #[inline] - #[inline] - fn words<'a>(&'a self, start: uint) -> Words<'a> { - Words { - iter: self.storage.slice_from(start).iter(), + fn mask_words<'a>(&'a self, mut start: uint) -> MaskWords<'a> { + if start > self.storage.len() { + start = self.storage.len(); + } + let mut iter = self.storage.slice_from(start).iter(); + MaskWords { + next_word: iter.next(), + iter: iter, + last_word_mask: { + let rem = self.nbits % uint::BITS; + if rem > 0 { + (1 << rem) - 1 + } else { !0 } + }, offset: start } } @@ -124,15 +130,8 @@ impl Bitv { /// to `init`. pub fn new(nbits: uint, init: bool) -> Bitv { Bitv { - storage: { - let nelems = (nbits + uint::BITS - 1) / uint::BITS; - let mut v = Vec::from_elem(nelems, if init { !0u } else { 0u }); - // Zero out any remainder bits - if nbits % uint::BITS > 0 { - *v.get_mut(nelems - 1) &= (1 << nbits % uint::BITS) - 1; - } - v - }, + storage: Vec::from_elem((nbits + uint::BITS - 1) / uint::BITS, + if init { !0u } else { 0u }), nbits: nbits } } @@ -145,8 +144,7 @@ impl Bitv { */ #[inline] pub fn union(&mut self, other: &Bitv) -> bool { - let nbits = self.nbits; - self.process(other, nbits, |w1, w2| w1 | w2) + self.process(other, |w1, w2| w1 | w2) } /** @@ -157,8 +155,7 @@ impl Bitv { */ #[inline] pub fn intersect(&mut self, other: &Bitv) -> bool { - let nbits = self.nbits; - self.process(other, nbits, |w1, w2| w1 & w2) + self.process(other, |w1, w2| w1 & w2) } /** @@ -169,8 +166,7 @@ impl Bitv { */ #[inline] pub fn assign(&mut self, other: &Bitv) -> bool { - let nbits = self.nbits; - self.process(other, nbits, |_, w| w) + self.process(other, |_, w| w) } /// Retrieve the value at index `i` @@ -227,20 +223,18 @@ impl Bitv { */ #[inline] pub fn difference(&mut self, other: &Bitv) -> bool { - let nbits = self.nbits; - self.process(other, nbits, |w1, w2| w1 & !w2) + self.process(other, |w1, w2| w1 & !w2) } /// Returns `true` if all bits are 1 #[inline] pub fn all(&self) -> bool { - for (i, &elem) in self.storage.iter().enumerate() { - let mask = big_mask(self.nbits, i); - if elem & mask != mask { - return false; - } - } - true + let mut last_word = !0u; + // Check that every word but the last is all-ones... + self.mask_words(0).all(|(_, elem)| + { let tmp = last_word; last_word = elem; tmp == !0u }) && + // ...and that the last word is ones as far as it needs to be + (last_word == ((1 << self.nbits % uint::BITS) - 1) || last_word == !0u) } /// Returns an iterator over the elements of the vector in order. @@ -265,13 +259,7 @@ impl Bitv { /// Returns `true` if all bits are 0 pub fn none(&self) -> bool { - for (i, &elem) in self.storage.iter().enumerate() { - let mask = big_mask(self.nbits, i); - if elem & mask != 0 { - return false; - } - } - true + self.mask_words(0).all(|(_, w)| w == 0) } #[inline] @@ -397,8 +385,8 @@ impl fmt::Show for Bitv { impl hash::Hash for Bitv { fn hash(&self, state: &mut S) { self.nbits.hash(state); - for (i, elem) in self.storage.iter().enumerate() { - (elem & big_mask(self.nbits, i)).hash(state); + for (_, elem) in self.mask_words(0) { + elem.hash(state); } } } @@ -409,13 +397,7 @@ impl cmp::PartialEq for Bitv { if self.nbits != other.nbits { return false; } - for (i, (&w1, &w2)) in self.storage.iter().zip(other.storage.iter()).enumerate() { - let mask = big_mask(self.nbits, i); - if w1 & mask != w2 & mask { - return false; - } - } - true + self.mask_words(0).zip(other.mask_words(0)).all(|((_, w1), (_, w2))| w1 == w2) } } @@ -546,7 +528,7 @@ impl BitvSet { // Unwrap Bitvs let &BitvSet(ref mut self_bitv) = self; let &BitvSet(ref other_bitv) = other; - for (i, w) in other_bitv.words(0) { + for (i, w) in other_bitv.mask_words(0) { let old = *self_bitv.storage.get(i); let new = f(old, w); *self_bitv.storage.get_mut(i) = new; @@ -563,7 +545,7 @@ impl BitvSet { let n = bitv.storage.iter().rev().take_while(|&&n| n == 0).count(); // Truncate let trunc_len = cmp::max(old_len - n, 1); - bitv.storage.truncate(cmp::max(old_len - n, 1)); + bitv.storage.truncate(trunc_len); bitv.nbits = trunc_len * uint::BITS; } @@ -710,6 +692,12 @@ impl MutableSet for BitvSet { } let &BitvSet(ref mut bitv) = self; if value >= bitv.nbits { + // If we are increasing nbits, make sure we mask out any previously-unconsidered bits + let old_rem = bitv.nbits % uint::BITS; + if old_rem != 0 { + let old_last_word = (bitv.nbits + uint::BITS - 1) / uint::BITS - 1; + *bitv.storage.get_mut(old_last_word) &= (1 << old_rem) - 1; + } bitv.nbits = value + 1; } bitv.set(value, true); @@ -733,10 +721,10 @@ impl BitvSet { /// and w1/w2 are the words coming from the two vectors self, other. fn commons<'a>(&'a self, other: &'a BitvSet) -> Map<((uint, uint), (uint, uint)), (uint, uint, uint), - Zip, Words<'a>>> { + Zip, MaskWords<'a>>> { let &BitvSet(ref self_bitv) = self; let &BitvSet(ref other_bitv) = other; - self_bitv.words(0).zip(other_bitv.words(0)) + self_bitv.mask_words(0).zip(other_bitv.mask_words(0)) .map(|((i, w1), (_, w2))| (i * uint::BITS, w1, w2)) } @@ -748,17 +736,17 @@ impl BitvSet { /// is true if the word comes from `self`, and `false` if it comes from /// `other`. fn outliers<'a>(&'a self, other: &'a BitvSet) - -> Map<(uint, uint), (bool, uint, uint), Words<'a>> { + -> Map<(uint, uint), (bool, uint, uint), MaskWords<'a>> { let slen = self.capacity() / uint::BITS; let olen = other.capacity() / uint::BITS; let &BitvSet(ref self_bitv) = self; let &BitvSet(ref other_bitv) = other; if olen < slen { - self_bitv.words(olen) + self_bitv.mask_words(olen) .map(|(i, w)| (true, i * uint::BITS, w)) } else { - other_bitv.words(slen) + other_bitv.mask_words(slen) .map(|(i, w)| (false, i * uint::BITS, w)) } } @@ -1250,16 +1238,32 @@ mod tests { }); } + #[test] + fn test_bitv_masking() { + let b = Bitv::new(140, true); + let mut bs = BitvSet::from_bitv(b); + assert!(bs.contains(&139)); + assert!(!bs.contains(&140)); + assert!(bs.insert(150)); + assert!(!bs.contains(&140)); + assert!(!bs.contains(&149)); + assert!(bs.contains(&150)); + assert!(!bs.contains(&151)); + } + #[test] fn test_bitv_set_basic() { let mut b = BitvSet::new(); assert!(b.insert(3)); assert!(!b.insert(3)); assert!(b.contains(&3)); + assert!(b.insert(4)); + assert!(!b.insert(4)); + assert!(b.contains(&3)); assert!(b.insert(400)); assert!(!b.insert(400)); assert!(b.contains(&400)); - assert_eq!(b.len(), 2); + assert_eq!(b.len(), 3); } #[test]