collections::bitv: implement BitvSet directly as a Bitv

This commit is contained in:
Andrew Poelstra 2014-06-23 17:35:43 -07:00 committed by Andrew Poelstra
parent f728ad0134
commit a4c0468a21

View File

@ -15,7 +15,8 @@ use core::prelude::*;
use core::cmp;
use core::default::Default;
use core::fmt;
use core::iter::{Enumerate, Repeat, Map, Zip};
use core::iter::{Map, Zip};
use core::option;
use core::ops;
use core::slice;
use core::uint;
@ -268,6 +269,33 @@ fn die() -> ! {
fail!("Tried to do operation on bit vectors with different sizes");
}
enum WordsVariant<'a> {
NoneIter,
OneIter(option::Item<uint>),
VecIter(slice::Items<'a, uint>)
}
struct Words<'a> {
rep: WordsVariant<'a>,
offset: uint
}
impl<'a> Iterator<(uint, uint)> for Words<'a> {
/// Returns (offset, word)
fn next<'a>(&'a mut self) -> Option<(uint, uint)> {
let ret = match self.rep {
NoneIter => None,
OneIter(ref mut it) => it.next(),
VecIter(ref mut it) => it.next().map(|n| *n)
};
self.offset += 1;
match ret {
Some(n) => Some((self.offset - 1, n)),
None => None
}
}
}
impl Bitv {
#[inline]
fn do_op(&mut self, op: Op, other: &Bitv) -> bool {
@ -295,6 +323,18 @@ impl Bitv {
}
}
}
#[inline]
fn words<'a>(&'a self, start: uint) -> Words<'a> {
Words {
rep: match self.rep {
Small(_) if start > 0 => NoneIter,
Small(ref s) => OneIter(Some(s.bits).move_iter()),
Big(ref b) => VecIter(b.storage.slice_from(start).iter())
},
offset: start
}
}
}
impl Bitv {
@ -687,15 +727,8 @@ impl<'a> RandomAccessIterator<bool> for Bits<'a> {
/// It should also be noted that the amount of storage necessary for holding a
/// set of objects is proportional to the maximum of the objects when viewed
/// as a `uint`.
#[deriving(Clone)]
pub struct BitvSet {
size: uint,
// In theory this is a `Bitv` instead of always a `BigBitv`, but knowing that
// there's an array of storage makes our lives a whole lot easier when
// performing union/intersection/etc operations
bitv: BigBitv
}
#[deriving(Clone, PartialEq, Eq)]
pub struct BitvSet(Bitv);
impl Default for BitvSet {
#[inline]
@ -705,56 +738,87 @@ impl Default for BitvSet {
impl BitvSet {
/// Creates a new bit vector set with initially no contents
pub fn new() -> BitvSet {
BitvSet{ size: 0, bitv: BigBitv::new(vec!(0)) }
BitvSet(Bitv::new(0, false))
}
/// Creates a new bit vector set from the given bit vector
pub fn from_bitv(bitv: Bitv) -> BitvSet {
let mut size = 0;
bitv.ones(|_| {
size += 1;
true
});
let Bitv{rep, ..} = bitv;
match rep {
Big(b) => BitvSet{ size: size, bitv: b },
Small(SmallBitv{bits}) =>
BitvSet{ size: size, bitv: BigBitv{ storage: vec!(bits) } },
}
BitvSet(bitv)
}
/// Returns the capacity in bits for this bit vector. Inserting any
/// element less than this amount will not trigger a resizing.
pub fn capacity(&self) -> uint { self.bitv.storage.len() * uint::BITS }
pub fn capacity(&self) -> uint {
let &BitvSet(ref bitv) = self;
match bitv.rep {
Small(_) => uint::BITS,
Big(ref s) => s.storage.len() * uint::BITS
}
}
/// Consumes this set to return the underlying bit vector
pub fn unwrap(self) -> Bitv {
let cap = self.capacity();
let BitvSet{bitv, ..} = self;
return Bitv{ nbits:cap, rep: Big(bitv) };
let BitvSet(bitv) = self;
bitv
}
#[inline]
/// Grows the vector to be able to store bits with indices `[0, size - 1]`
fn grow(&mut self, size: uint) {
let &BitvSet(ref mut bitv) = self;
let small_to_big = match bitv.rep { Small(s) if size >= uint::BITS => Some(s.bits), _ => None };
if small_to_big.is_some() {
bitv.rep = Big(BigBitv { storage: vec![small_to_big.unwrap()] });
}
match bitv.rep {
Small(_) => {},
Big(ref mut b) => {
let size = (size + uint::BITS - 1) / uint::BITS;
if b.storage.len() < size {
b.storage.grow(size, &0);
}
}
};
}
#[inline]
fn other_op(&mut self, other: &BitvSet, f: |uint, uint| -> uint) {
fn nbits(mut w: uint) -> uint {
let mut bits = 0;
for _ in range(0u, uint::BITS) {
if w == 0 {
break;
// Expand the vector if necessary
self.grow(other.capacity());
// Unwrap Bitvs
let &BitvSet(ref mut self_bitv) = self;
let &BitvSet(ref other_bitv) = other;
for (i, w) in other_bitv.words(0) {
match self_bitv.rep {
Small(ref mut s) => { s.bits = f(s.bits, w); }
Big(ref mut b) => {
let old = *b.storage.get(i);
let new = f(old, w);
*b.storage.get_mut(i) = new;
*b.storage.get_mut(i) = f(*b.storage.get(i), w);
}
bits += w & 1;
w >>= 1;
}
return bits;
}
if self.capacity() < other.capacity() {
self.bitv.storage.grow(other.capacity() / uint::BITS, &0);
}
for (i, &w) in other.bitv.storage.iter().enumerate() {
let old = *self.bitv.storage.get(i);
let new = f(old, w);
*self.bitv.storage.get_mut(i) = new;
self.size += nbits(new) - nbits(old);
}
#[inline]
/// Truncate the underlying vector to the least length required
pub fn shrink_to_fit(&mut self) {
let &BitvSet(ref mut bitv) = self;
// Two steps: we borrow b as immutable to get the length...
let old_len = match bitv.rep {
Small(_) => 1,
Big(ref b) => b.storage.len()
};
// ...and as mutable to change it.
match bitv.rep {
Small(_) => {},
Big(ref mut b) => {
let n = b.storage.iter().rev().take_while(|&&n| n == 0).count();
let trunc_len = cmp::max(old_len - n, 1);
b.storage.truncate(trunc_len);
bitv.nbits = trunc_len * uint::BITS;
}
}
}
@ -818,29 +882,6 @@ impl BitvSet {
}
}
impl cmp::PartialEq for BitvSet {
fn eq(&self, other: &BitvSet) -> bool {
if self.size != other.size {
return false;
}
for (_, w1, w2) in self.commons(other) {
if w1 != w2 {
return false;
}
}
for (_, _, w) in self.outliers(other) {
if w != 0 {
return false;
}
}
return true;
}
fn ne(&self, other: &BitvSet) -> bool { !self.eq(other) }
}
impl cmp::Eq for BitvSet {}
impl fmt::Show for BitvSet {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
try!(write!(fmt, "{{"));
@ -866,19 +907,26 @@ impl<S: hash::Writer> hash::Hash<S> for BitvSet {
impl Collection for BitvSet {
#[inline]
fn len(&self) -> uint { self.size }
fn len(&self) -> uint {
let &BitvSet(ref bitv) = self;
match bitv.rep {
Small(ref s) => s.bits.count_ones(),
Big(ref b) => b.storage.iter().fold(0, |acc, &n| acc + n.count_ones())
}
}
}
impl Mutable for BitvSet {
fn clear(&mut self) {
self.bitv.each_storage(|w| { *w = 0; true });
self.size = 0;
let &BitvSet(ref mut bitv) = self;
bitv.clear();
}
}
impl Set<uint> for BitvSet {
fn contains(&self, value: &uint) -> bool {
*value < self.bitv.storage.len() * uint::BITS && self.bitv.get(*value)
let &BitvSet(ref bitv) = self;
*value < bitv.nbits && bitv.get(*value)
}
fn is_disjoint(&self, other: &BitvSet) -> bool {
@ -914,14 +962,15 @@ impl MutableSet<uint> for BitvSet {
if self.contains(&value) {
return false;
}
let nbits = self.capacity();
if value >= nbits {
let newsize = cmp::max(value, nbits * 2) / uint::BITS + 1;
assert!(newsize > self.bitv.storage.len());
self.bitv.storage.grow(newsize, &0);
if value >= self.capacity() {
let new_cap = cmp::max(value + 1, self.capacity() * 2);
self.grow(new_cap);
}
self.size += 1;
self.bitv.set(value, true);
let &BitvSet(ref mut bitv) = self;
if value >= bitv.nbits {
bitv.nbits = value + 1;
}
bitv.set(value, true);
return true;
}
@ -929,16 +978,8 @@ impl MutableSet<uint> for BitvSet {
if !self.contains(value) {
return false;
}
self.size -= 1;
self.bitv.set(*value, false);
// Attempt to truncate our storage
let mut i = self.bitv.storage.len();
while i > 1 && *self.bitv.storage.get(i - 1) == 0 {
i -= 1;
}
self.bitv.storage.truncate(i);
let &BitvSet(ref mut bitv) = self;
bitv.set(*value, false);
return true;
}
}
@ -949,12 +990,12 @@ impl BitvSet {
/// w1, w2) where the bit location is the number of bits offset so far,
/// and w1/w2 are the words coming from the two vectors self, other.
fn commons<'a>(&'a self, other: &'a BitvSet)
-> Map<'static, ((uint, &'a uint), &'a Vec<uint>), (uint, uint, uint),
Zip<Enumerate<slice::Items<'a, uint>>, Repeat<&'a Vec<uint>>>> {
let min = cmp::min(self.bitv.storage.len(), other.bitv.storage.len());
self.bitv.storage.slice(0, min).iter().enumerate()
.zip(Repeat::new(&other.bitv.storage))
.map(|((i, &w), o_store)| (i * uint::BITS, w, *o_store.get(i)))
-> Map<((uint, uint), (uint, uint)), (uint, uint, uint),
Zip<Words<'a>, Words<'a>>> {
let &BitvSet(ref self_bitv) = self;
let &BitvSet(ref other_bitv) = other;
self_bitv.words(0).zip(other_bitv.words(0))
.map(|((i, w1), (_, w2))| (i * uint::BITS, w1, w2))
}
/// Visits each word in `self` or `other` that extends beyond the other. This
@ -965,19 +1006,18 @@ impl BitvSet {
/// is true if the word comes from `self`, and `false` if it comes from
/// `other`.
fn outliers<'a>(&'a self, other: &'a BitvSet)
-> Map<'static, ((uint, &'a uint), uint), (bool, uint, uint),
Zip<Enumerate<slice::Items<'a, uint>>, Repeat<uint>>> {
let slen = self.bitv.storage.len();
let olen = other.bitv.storage.len();
-> Map<(uint, uint), (bool, uint, uint), Words<'a>> {
let slen = self.capacity() / uint::BITS;
let olen = other.capacity() / uint::BITS;
let &BitvSet(ref self_bitv) = self;
let &BitvSet(ref other_bitv) = other;
if olen < slen {
self.bitv.storage.slice_from(olen).iter().enumerate()
.zip(Repeat::new(olen))
.map(|((i, &w), min)| (true, (i + min) * uint::BITS, w))
self_bitv.words(olen)
.map(|(i, w)| (true, i * uint::BITS, w))
} else {
other.bitv.storage.slice_from(slen).iter().enumerate()
.zip(Repeat::new(slen))
.map(|((i, &w), min)| (false, (i + min) * uint::BITS, w))
other_bitv.words(slen)
.map(|(i, w)| (false, i * uint::BITS, w))
}
}
}
@ -1600,6 +1640,7 @@ mod tests {
assert!(a.insert(1000));
assert!(a.remove(&1000));
a.shrink_to_fit();
assert_eq!(a.capacity(), uint::BITS);
}