generalize BitMatrix to be NxM and not just NxN

This commit is contained in:
Niko Matsakis 2016-08-05 20:12:53 -04:00
parent 8150494ac2
commit 9978cbc8f4
2 changed files with 69 additions and 17 deletions

View File

@ -124,32 +124,32 @@ impl FromIterator<bool> for BitVector {
}
}
/// A "bit matrix" is basically a square matrix of booleans
/// represented as one gigantic bitvector. In other words, it is as if
/// you have N bitvectors, each of length N. Note that `elements` here is `N`/
/// A "bit matrix" is basically a matrix of booleans represented as
/// one gigantic bitvector. In other words, it is as if you have
/// `rows` bitvectors, each of length `columns`.
#[derive(Clone)]
pub struct BitMatrix {
elements: usize,
columns: usize,
vector: Vec<u64>,
}
impl BitMatrix {
// Create a new `elements x elements` matrix, initially empty.
pub fn new(elements: usize) -> BitMatrix {
// Create a new `rows x columns` matrix, initially empty.
pub fn new(rows: usize, columns: usize) -> BitMatrix {
// For every element, we need one bit for every other
// element. Round up to an even number of u64s.
let u64s_per_elem = u64s(elements);
let u64s_per_row = u64s(columns);
BitMatrix {
elements: elements,
vector: vec![0; elements * u64s_per_elem],
columns: columns,
vector: vec![0; rows * u64s_per_row],
}
}
/// The range of bits for a given element.
fn range(&self, element: usize) -> (usize, usize) {
let u64s_per_elem = u64s(self.elements);
let start = element * u64s_per_elem;
(start, start + u64s_per_elem)
/// The range of bits for a given row.
fn range(&self, row: usize) -> (usize, usize) {
let u64s_per_row = u64s(self.columns);
let start = row * u64s_per_row;
(start, start + u64s_per_row)
}
pub fn add(&mut self, source: usize, target: usize) -> bool {
@ -179,7 +179,7 @@ impl BitMatrix {
pub fn intersection(&self, a: usize, b: usize) -> Vec<usize> {
let (a_start, a_end) = self.range(a);
let (b_start, b_end) = self.range(b);
let mut result = Vec::with_capacity(self.elements);
let mut result = Vec::with_capacity(self.columns);
for (base, (i, j)) in (a_start..a_end).zip(b_start..b_end).enumerate() {
let mut v = self.vector[i] & self.vector[j];
for bit in 0..64 {
@ -215,6 +215,15 @@ impl BitMatrix {
}
changed
}
pub fn iter<'a>(&'a self, row: usize) -> BitVectorIter<'a> {
let (start, end) = self.range(row);
BitVectorIter {
iter: self.vector[start..end].iter(),
current: 0,
idx: 0,
}
}
}
fn u64s(elements: usize) -> usize {
@ -300,7 +309,7 @@ fn grow() {
#[test]
fn matrix_intersection() {
let mut vec1 = BitMatrix::new(200);
let mut vec1 = BitMatrix::new(200, 200);
// (*) Elements reachable from both 2 and 65.
@ -328,3 +337,45 @@ fn matrix_intersection() {
let intersection = vec1.intersection(2, 65);
assert_eq!(intersection, &[10, 64, 160]);
}
#[test]
fn matrix_iter() {
let mut matrix = BitMatrix::new(64, 100);
matrix.add(3, 22);
matrix.add(3, 75);
matrix.add(2, 99);
matrix.add(4, 0);
matrix.merge(3, 5);
let expected = [99];
let mut iter = expected.iter();
for i in matrix.iter(2) {
let j = *iter.next().unwrap();
assert_eq!(i, j);
}
assert!(iter.next().is_none());
let expected = [22, 75];
let mut iter = expected.iter();
for i in matrix.iter(3) {
let j = *iter.next().unwrap();
assert_eq!(i, j);
}
assert!(iter.next().is_none());
let expected = [0];
let mut iter = expected.iter();
for i in matrix.iter(4) {
let j = *iter.next().unwrap();
assert_eq!(i, j);
}
assert!(iter.next().is_none());
let expected = [22, 75];
let mut iter = expected.iter();
for i in matrix.iter(5) {
let j = *iter.next().unwrap();
assert_eq!(i, j);
}
assert!(iter.next().is_none());
}

View File

@ -252,7 +252,8 @@ impl<T: Debug + PartialEq> TransitiveRelation<T> {
}
fn compute_closure(&self) -> BitMatrix {
let mut matrix = BitMatrix::new(self.elements.len());
let mut matrix = BitMatrix::new(self.elements.len(),
self.elements.len());
let mut changed = true;
while changed {
changed = false;