From e84a3833077825893b5bf0fb5abf07bdbd58d988 Mon Sep 17 00:00:00 2001 From: Tobias Bucher Date: Sat, 13 Dec 2014 02:13:06 +0100 Subject: [PATCH] Add a new invariant to `Bitv` The length of the underlying vector must now be exactly as long as it needs to be. --- src/libcollections/bit.rs | 152 +++++++++++++++++++++----------------- 1 file changed, 83 insertions(+), 69 deletions(-) diff --git a/src/libcollections/bit.rs b/src/libcollections/bit.rs index 5cf5183d25d..3f33d85ba56 100644 --- a/src/libcollections/bit.rs +++ b/src/libcollections/bit.rs @@ -8,18 +8,25 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -// FIXME(Gankro): Bitv and BitvSet are very tightly coupled. Ideally (for maintenance), -// they should be in separate files/modules, with BitvSet only using Bitv's public API. +// FIXME(Gankro): Bitv and BitvSet are very tightly coupled. Ideally (for +// maintenance), they should be in separate files/modules, with BitvSet only +// using Bitv's public API. This will be hard for performance though, because +// `Bitv` will not want to leak its internal representation while its internal +// representation as `u32`s must be assumed for best performance. -// First rule of Bitv club: almost everything can actually overflow because we're working with -// bits and not bytes. -// -// Second rule of Bitv club: the last "block" of bits may be partially used. We must ensure that -// those unused bits are zeroed out, as other methods will assume this is the case. It may be -// the case that this isn't a great design, but having "undefined" bits is headache-inducing. -// -// Third rule of Bitv club: BitvSet is fairly tightly coupled to Bitv's implementation details. -// Make sure any changes to Bitv are properly addressed in BitvSet. +// FIXME(tbu-): `Bitv`'s methods shouldn't be `union`, `intersection`, but +// rather `or` and `and`. + +// (1) Be careful, most things can overflow here because the amount of bits in +// memory can overflow `uint`. +// (2) Make sure that the underlying vector has no excess length: +// E. g. `nbits == 16`, `storage.len() == 2` would be excess length, +// because the last word isn't used at all. This is important because some +// methods rely on it (for *CORRECTNESS*). +// (3) Make sure that the unused bits in the last word are zeroed out, again +// other methods rely on it for *CORRECTNESS*. +// (4) `BitvSet` is tightly coupled with `Bitv`, so any changes you make in +// `Bitv` will need to be reflected in `BitvSet`. //! Collections implemented with bit vectors. //! @@ -82,10 +89,10 @@ use core::iter::{Cloned, Chain, Enumerate, Repeat, Skip, Take}; use core::iter; use core::num::Int; use core::slice::{Items, MutItems}; -use core::{u32, uint}; -use std::hash; +use core::{u8, u32, uint}; -use vec::Vec; +use hash; +use Vec; type Blocks<'a> = Cloned>; type MutBlocks<'a> = MutItems<'a, u32>; @@ -181,17 +188,15 @@ fn mask_for_bits(bits: uint) -> u32 { } impl Bitv { - /// Applies the given operation to the blocks of self and other, and sets self to - /// be the result. + /// Applies the given operation to the blocks of self and other, and sets + /// self to be the result. This relies on the caller not to corrupt the + /// last word. #[inline] fn process(&mut self, other: &Bitv, mut op: F) -> bool where F: FnMut(u32, u32) -> u32 { - let len = other.storage.len(); - assert_eq!(self.storage.len(), len); + assert_eq!(self.len(), other.len()); + // This could theoretically be a `debug_assert!`. + assert_eq!(self.storage.len(), other.storage.len()); let mut changed = false; - // Notice: `a` is *not* masked here, which is fine as long as - // `op` is a bitwise operation, since any bits that should've - // been masked were fine to change anyway. `b` is masked to - // make sure its unmasked bits do not cause damage. for (a, b) in self.blocks_mut().zip(other.blocks()) { let w = op(*a, b); if *a != w { @@ -204,21 +209,20 @@ impl Bitv { /// Iterator over mutable refs to the underlying blocks of data. fn blocks_mut(&mut self) -> MutBlocks { - let blocks = blocks_for_bits(self.len()); - self.storage.slice_to_mut(blocks).iter_mut() + // (2) + self.storage.iter_mut() } /// Iterator over the underlying blocks of data fn blocks(&self) -> Blocks { - let blocks = blocks_for_bits(self.len()); - self.storage[..blocks].iter().cloned() + // (2) + self.storage.iter().cloned() } - /// An operation might screw up the unused bits in the last block of the Bitv. - /// It's assumed to be all 0's. This fixes it up. + /// An operation might screw up the unused bits in the last block of the + /// `Bitv`. As per (3), it's assumed to be all 0s. This method fixes it up. fn fix_last_block(&mut self) { - let len = self.len(); - let extra_bits = len % u32::BITS; + let extra_bits = self.len() % u32::BITS; if extra_bits > 0 { let mask = (1 << extra_bits) - 1; let storage_len = self.storage.len(); @@ -259,7 +263,6 @@ impl Bitv { storage: Vec::from_elem(nblocks, if bit { !0u32 } else { 0u32 }), nbits: nbits }; - bitv.fix_last_block(); bitv } @@ -295,15 +298,33 @@ impl Bitv { /// false, false, true, false])); /// ``` pub fn from_bytes(bytes: &[u8]) -> Bitv { - Bitv::from_fn(bytes.len() * 8, |i| { - let b = bytes[i / 8] as u32; - let offset = i % 8; - b >> (7 - offset) & 1 == 1 - }) + let len = bytes.len().checked_mul(u8::BITS).expect("capacity overflow"); + let mut bitv = Bitv::with_capacity(len); + let complete_words = bytes.len() / 4; + let extra_bytes = bytes.len() % 4; + + for i in range(0, complete_words) { + bitv.storage.push( + (bytes[i * 4 + 0] as u32 << 0) | + (bytes[i * 4 + 1] as u32 << 8) | + (bytes[i * 4 + 2] as u32 << 16) | + (bytes[i * 4 + 3] as u32 << 24) + ); + } + + if extra_bytes > 0 { + let mut last_word = 0u32; + for (i, &byte) in bytes[complete_words*4..].iter().enumerate() { + last_word |= byte as u32 << (i * 8); + } + bitv.storage.push(last_word); + } + + bitv } - /// Creates a `Bitv` of the specified length where the value at each - /// index is `f(index)`. + /// Creates a `Bitv` of the specified length where the value at each index + /// is `f(index)`. /// /// # Examples /// @@ -339,7 +360,9 @@ impl Bitv { #[inline] #[unstable = "panic semantics are likely to change in the future"] pub fn get(&self, i: uint) -> Option { - assert!(i < self.nbits); + if i >= self.nbits { + return None; + } let w = i / u32::BITS; let b = i % u32::BITS; self.storage.get(w).map(|&block| @@ -548,7 +571,7 @@ impl Bitv { #[inline] #[unstable = "matches collection reform specification, waiting for dust to settle"] pub fn iter<'a>(&'a self) -> Bits<'a> { - Bits {bitv: self, next_idx: 0, end_idx: self.nbits} + Bits { bitv: self, next_idx: 0, end_idx: self.nbits } } /// Returns `true` if all bits are 0. @@ -608,7 +631,7 @@ impl Bitv { /// assert_eq!(bv.to_bytes(), vec!(0b00100000, 0b10000000)); /// ``` pub fn to_bytes(&self) -> Vec { - fn bit (bitv: &Bitv, byte: uint, bit: uint) -> u8 { + fn bit(bitv: &Bitv, byte: uint, bit: uint) -> u8 { let offset = byte * 8 + bit; if offset >= bitv.nbits { 0 @@ -634,7 +657,7 @@ impl Bitv { /// Deprecated: Use `iter().collect()`. #[deprecated = "Use `iter().collect()`"] pub fn to_bools(&self) -> Vec { - Vec::from_fn(self.nbits, |i| self[i]) + self.iter().collect() } /// Compares a `Bitv` to a slice of `bool`s. @@ -656,12 +679,7 @@ impl Bitv { /// ``` pub fn eq_vec(&self, v: &[bool]) -> bool { assert_eq!(self.nbits, v.len()); - let mut i = 0; - while i < self.nbits { - if self[i] != v[i] { return false; } - i = i + 1; - } - true + iter::order::eq(self.iter(), v.iter().cloned()) } /// Shortens a `Bitv`, dropping excess elements. @@ -682,6 +700,7 @@ impl Bitv { pub fn truncate(&mut self, len: uint) { if len < self.len() { self.nbits = len; + // This fixes (2). self.storage.truncate(blocks_for_bits(len)); self.fix_last_block(); } @@ -707,13 +726,9 @@ impl Bitv { #[unstable = "matches collection reform specification, waiting for dust to settle"] pub fn reserve(&mut self, additional: uint) { let desired_cap = self.len().checked_add(additional).expect("capacity overflow"); - match self.storage.len().checked_mul(u32::BITS) { - None => {} // Vec has more initialized capacity than we can ever use - Some(initialized_cap) => { - if desired_cap > initialized_cap { - self.storage.reserve(blocks_for_bits(desired_cap - initialized_cap)); - } - } + let storage_len = self.storage.len(); + if desired_cap > self.capacity() { + self.storage.reserve(blocks_for_bits(desired_cap) - storage_len); } } @@ -741,13 +756,9 @@ impl Bitv { #[unstable = "matches collection reform specification, waiting for dust to settle"] pub fn reserve_exact(&mut self, additional: uint) { let desired_cap = self.len().checked_add(additional).expect("capacity overflow"); - match self.storage.len().checked_mul(u32::BITS) { - None => {} // Vec has more initialized capacity than we can ever use - Some(initialized_cap) => { - if desired_cap > initialized_cap { - self.storage.reserve_exact(blocks_for_bits(desired_cap - initialized_cap)); - } - } + let storage_len = self.storage.len(); + if desired_cap > self.capacity() { + self.storage.reserve_exact(blocks_for_bits(desired_cap) - storage_len); } } @@ -801,8 +812,7 @@ impl Bitv { if value { self.storage[old_last_word] |= !mask; } else { - // Extra bits are already supposed to be zero by invariant, but play it safe... - self.storage[old_last_word] &= mask; + // Extra bits are already zero by invariant. } } @@ -843,9 +853,13 @@ impl Bitv { } else { let i = self.nbits - 1; let ret = self[i]; - // Second rule of Bitv Club + // (3) self.set(i, false); self.nbits = i; + if self.nbits % u32::BITS == 0 { + // (2) + self.storage.pop(); + } Some(ret) } } @@ -864,11 +878,11 @@ impl Bitv { /// ``` #[unstable = "matches collection reform specification, waiting for dust to settle"] pub fn push(&mut self, elem: bool) { - let insert_pos = self.nbits; - self.nbits = self.nbits.checked_add(1).expect("Capacity overflow"); - if self.storage.len().checked_mul(u32::BITS).unwrap_or(uint::MAX) < self.nbits { + if self.nbits % u32::BITS == 0 { self.storage.push(0); } + let insert_pos = self.nbits; + self.nbits = self.nbits.checked_add(1).expect("Capacity overflow"); self.set(insert_pos, elem); } @@ -958,7 +972,7 @@ impl Ord for Bitv { impl fmt::Show for Bitv { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { for bit in self.iter() { - try!(write!(fmt, "{}", if bit { 1u } else { 0u })); + try!(write!(fmt, "{}", if bit { 1u32 } else { 0u32 })); } Ok(()) }