collections::bitv: ensure correct masking behaviour

The internal masking behaviour for `Bitv` is now defined as:
  - Any entirely words in self.storage must be all zeroes.
  - Any partially used words may have anything at all in their
    unused bits.

This means:
  - When decreasing self.nbits, care must be taken that any
    no-longer-used words are zeroed out.

  - When increasing self.nbits, care must be taken that any
    newly-unmasked bits are set to their correct values.

  - When reading words, care should be taken that the values of
    unused bits are not used. (Preferably, use `Bitv::mask_words`
    which zeroes them out for you.)

The old behaviour was that every unused bit was always set to
zero. The problem with this is that unused bits are almost never
read, so forgetting to do this will result in very subtle and
hard-to-track down bugs. This way the responsibility for masking
falls on the places which might cause unused bits to be read: for
now, this is only `Bitv::mask_words` and `BitvSet::insert`.
This commit is contained in:
Andrew Poelstra 2014-07-02 08:24:55 -07:00
parent 2d23319e33
commit a698b81ebf

@ -24,22 +24,6 @@ use std::hash;
use {Collection, Mutable, Set, MutableSet};
use vec::Vec;
/**
* A mask that has a 1 for each defined bit in the n'th element of a `Bitv`,
* assuming n bits.
*/
#[inline]
fn big_mask(nbits: uint, elem: uint) -> uint {
let rmd = nbits % uint::BITS;
let nelems = (nbits + uint::BITS - 1) / uint::BITS;
if elem < nelems - 1 || rmd == 0 {
!0
} else {
(1 << rmd) - 1
}
}
/// The bitvector type
///
/// # Example
@ -75,35 +59,47 @@ pub struct Bitv {
nbits: uint
}
struct Words<'a> {
struct MaskWords<'a> {
iter: slice::Items<'a, uint>,
next_word: Option<&'a uint>,
last_word_mask: uint,
offset: uint
}
impl<'a> Iterator<(uint, uint)> for Words<'a> {
impl<'a> Iterator<(uint, uint)> for MaskWords<'a> {
/// Returns (offset, word)
fn next<'a>(&'a mut self) -> Option<(uint, uint)> {
let ret = self.iter.next().map(|&n| (self.offset, n));
self.offset += 1;
ret
let ret = self.next_word;
match ret {
Some(&w) => {
self.next_word = self.iter.next();
self.offset += 1;
// The last word may need to be masked
if self.next_word.is_none() {
Some((self.offset - 1, w & self.last_word_mask))
} else {
Some((self.offset - 1, w))
}
},
None => None
}
}
}
impl Bitv {
#[inline]
fn process(&mut self, other: &Bitv, nbits: uint,
op: |uint, uint| -> uint) -> bool {
fn process(&mut self, other: &Bitv, op: |uint, uint| -> uint) -> bool {
let len = other.storage.len();
assert_eq!(self.storage.len(), len);
let mut changed = false;
for (i, (a, b)) in self.storage.mut_iter()
.zip(other.storage.iter())
.enumerate() {
let mask = big_mask(nbits, i);
let w0 = *a & mask;
let w1 = *b & mask;
let w = op(w0, w1) & mask;
if w0 != w {
// Notice: `a` is *not* masked here, which is fine as long as
// `op` is a bitwise operation, since any bits that should've
// been masked were fine to change anyway. `b` is masked to
// make sure its unmasked bits do not cause damage.
for (a, (_, b)) in self.storage.mut_iter()
.zip(other.mask_words(0)) {
let w = op(*a, b);
if *a != w {
changed = true;
*a = w;
}
@ -112,10 +108,20 @@ impl Bitv {
}
#[inline]
#[inline]
fn words<'a>(&'a self, start: uint) -> Words<'a> {
Words {
iter: self.storage.slice_from(start).iter(),
fn mask_words<'a>(&'a self, mut start: uint) -> MaskWords<'a> {
if start > self.storage.len() {
start = self.storage.len();
}
let mut iter = self.storage.slice_from(start).iter();
MaskWords {
next_word: iter.next(),
iter: iter,
last_word_mask: {
let rem = self.nbits % uint::BITS;
if rem > 0 {
(1 << rem) - 1
} else { !0 }
},
offset: start
}
}
@ -124,15 +130,8 @@ impl Bitv {
/// to `init`.
pub fn new(nbits: uint, init: bool) -> Bitv {
Bitv {
storage: {
let nelems = (nbits + uint::BITS - 1) / uint::BITS;
let mut v = Vec::from_elem(nelems, if init { !0u } else { 0u });
// Zero out any remainder bits
if nbits % uint::BITS > 0 {
*v.get_mut(nelems - 1) &= (1 << nbits % uint::BITS) - 1;
}
v
},
storage: Vec::from_elem((nbits + uint::BITS - 1) / uint::BITS,
if init { !0u } else { 0u }),
nbits: nbits
}
}
@ -145,8 +144,7 @@ impl Bitv {
*/
#[inline]
pub fn union(&mut self, other: &Bitv) -> bool {
let nbits = self.nbits;
self.process(other, nbits, |w1, w2| w1 | w2)
self.process(other, |w1, w2| w1 | w2)
}
/**
@ -157,8 +155,7 @@ impl Bitv {
*/
#[inline]
pub fn intersect(&mut self, other: &Bitv) -> bool {
let nbits = self.nbits;
self.process(other, nbits, |w1, w2| w1 & w2)
self.process(other, |w1, w2| w1 & w2)
}
/**
@ -169,8 +166,7 @@ impl Bitv {
*/
#[inline]
pub fn assign(&mut self, other: &Bitv) -> bool {
let nbits = self.nbits;
self.process(other, nbits, |_, w| w)
self.process(other, |_, w| w)
}
/// Retrieve the value at index `i`
@ -227,20 +223,18 @@ impl Bitv {
*/
#[inline]
pub fn difference(&mut self, other: &Bitv) -> bool {
let nbits = self.nbits;
self.process(other, nbits, |w1, w2| w1 & !w2)
self.process(other, |w1, w2| w1 & !w2)
}
/// Returns `true` if all bits are 1
#[inline]
pub fn all(&self) -> bool {
for (i, &elem) in self.storage.iter().enumerate() {
let mask = big_mask(self.nbits, i);
if elem & mask != mask {
return false;
}
}
true
let mut last_word = !0u;
// Check that every word but the last is all-ones...
self.mask_words(0).all(|(_, elem)|
{ let tmp = last_word; last_word = elem; tmp == !0u }) &&
// ...and that the last word is ones as far as it needs to be
(last_word == ((1 << self.nbits % uint::BITS) - 1) || last_word == !0u)
}
/// Returns an iterator over the elements of the vector in order.
@ -265,13 +259,7 @@ impl Bitv {
/// Returns `true` if all bits are 0
pub fn none(&self) -> bool {
for (i, &elem) in self.storage.iter().enumerate() {
let mask = big_mask(self.nbits, i);
if elem & mask != 0 {
return false;
}
}
true
self.mask_words(0).all(|(_, w)| w == 0)
}
#[inline]
@ -397,8 +385,8 @@ impl fmt::Show for Bitv {
impl<S: hash::Writer> hash::Hash<S> for Bitv {
fn hash(&self, state: &mut S) {
self.nbits.hash(state);
for (i, elem) in self.storage.iter().enumerate() {
(elem & big_mask(self.nbits, i)).hash(state);
for (_, elem) in self.mask_words(0) {
elem.hash(state);
}
}
}
@ -409,13 +397,7 @@ impl cmp::PartialEq for Bitv {
if self.nbits != other.nbits {
return false;
}
for (i, (&w1, &w2)) in self.storage.iter().zip(other.storage.iter()).enumerate() {
let mask = big_mask(self.nbits, i);
if w1 & mask != w2 & mask {
return false;
}
}
true
self.mask_words(0).zip(other.mask_words(0)).all(|((_, w1), (_, w2))| w1 == w2)
}
}
@ -546,7 +528,7 @@ impl BitvSet {
// Unwrap Bitvs
let &BitvSet(ref mut self_bitv) = self;
let &BitvSet(ref other_bitv) = other;
for (i, w) in other_bitv.words(0) {
for (i, w) in other_bitv.mask_words(0) {
let old = *self_bitv.storage.get(i);
let new = f(old, w);
*self_bitv.storage.get_mut(i) = new;
@ -563,7 +545,7 @@ impl BitvSet {
let n = bitv.storage.iter().rev().take_while(|&&n| n == 0).count();
// Truncate
let trunc_len = cmp::max(old_len - n, 1);
bitv.storage.truncate(cmp::max(old_len - n, 1));
bitv.storage.truncate(trunc_len);
bitv.nbits = trunc_len * uint::BITS;
}
@ -710,6 +692,12 @@ impl MutableSet<uint> for BitvSet {
}
let &BitvSet(ref mut bitv) = self;
if value >= bitv.nbits {
// If we are increasing nbits, make sure we mask out any previously-unconsidered bits
let old_rem = bitv.nbits % uint::BITS;
if old_rem != 0 {
let old_last_word = (bitv.nbits + uint::BITS - 1) / uint::BITS - 1;
*bitv.storage.get_mut(old_last_word) &= (1 << old_rem) - 1;
}
bitv.nbits = value + 1;
}
bitv.set(value, true);
@ -733,10 +721,10 @@ impl BitvSet {
/// and w1/w2 are the words coming from the two vectors self, other.
fn commons<'a>(&'a self, other: &'a BitvSet)
-> Map<((uint, uint), (uint, uint)), (uint, uint, uint),
Zip<Words<'a>, Words<'a>>> {
Zip<MaskWords<'a>, MaskWords<'a>>> {
let &BitvSet(ref self_bitv) = self;
let &BitvSet(ref other_bitv) = other;
self_bitv.words(0).zip(other_bitv.words(0))
self_bitv.mask_words(0).zip(other_bitv.mask_words(0))
.map(|((i, w1), (_, w2))| (i * uint::BITS, w1, w2))
}
@ -748,17 +736,17 @@ impl BitvSet {
/// is true if the word comes from `self`, and `false` if it comes from
/// `other`.
fn outliers<'a>(&'a self, other: &'a BitvSet)
-> Map<(uint, uint), (bool, uint, uint), Words<'a>> {
-> Map<(uint, uint), (bool, uint, uint), MaskWords<'a>> {
let slen = self.capacity() / uint::BITS;
let olen = other.capacity() / uint::BITS;
let &BitvSet(ref self_bitv) = self;
let &BitvSet(ref other_bitv) = other;
if olen < slen {
self_bitv.words(olen)
self_bitv.mask_words(olen)
.map(|(i, w)| (true, i * uint::BITS, w))
} else {
other_bitv.words(slen)
other_bitv.mask_words(slen)
.map(|(i, w)| (false, i * uint::BITS, w))
}
}
@ -1250,16 +1238,32 @@ mod tests {
});
}
#[test]
fn test_bitv_masking() {
let b = Bitv::new(140, true);
let mut bs = BitvSet::from_bitv(b);
assert!(bs.contains(&139));
assert!(!bs.contains(&140));
assert!(bs.insert(150));
assert!(!bs.contains(&140));
assert!(!bs.contains(&149));
assert!(bs.contains(&150));
assert!(!bs.contains(&151));
}
#[test]
fn test_bitv_set_basic() {
let mut b = BitvSet::new();
assert!(b.insert(3));
assert!(!b.insert(3));
assert!(b.contains(&3));
assert!(b.insert(4));
assert!(!b.insert(4));
assert!(b.contains(&3));
assert!(b.insert(400));
assert!(!b.insert(400));
assert!(b.contains(&400));
assert_eq!(b.len(), 2);
assert_eq!(b.len(), 3);
}
#[test]