diff --git a/src/libcollections/bitv.rs b/src/libcollections/bitv.rs index 1b3c6e148cd..e89d66578b0 100644 --- a/src/libcollections/bitv.rs +++ b/src/libcollections/bitv.rs @@ -66,7 +66,7 @@ use core::prelude::*; use core::cmp; use core::default::Default; use core::fmt; -use core::iter::Take; +use core::iter::{Chain, Enumerate, Repeat, Skip, Take}; use core::iter; use core::slice; use core::uint; @@ -75,25 +75,22 @@ use std::hash; use {Mutable, Set, MutableSet, MutableSeq}; use vec::Vec; +type MatchWords<'a> = Chain, Skip>>>>; // Take two BitV's, and return iterators of their words, where the shorter one // has been padded with 0's -macro_rules! match_words( - ($a_expr:expr, $b_expr:expr) => ({ - let a = $a_expr; - let b = $b_expr; - let a_len = a.storage.len(); - let b_len = b.storage.len(); +fn match_words <'a,'b>(a: &'a Bitv, b: &'b Bitv) -> (MatchWords<'a>, MatchWords<'b>) { + let a_len = a.storage.len(); + let b_len = b.storage.len(); - // have to uselessly pretend to pad the longer one for type matching - if a_len < b_len { - (a.mask_words(0).chain(iter::Repeat::new(0u).enumerate().take(b_len).skip(a_len)), - b.mask_words(0).chain(iter::Repeat::new(0u).enumerate().take(0).skip(0))) - } else { - (a.mask_words(0).chain(iter::Repeat::new(0u).enumerate().take(0).skip(0)), - b.mask_words(0).chain(iter::Repeat::new(0u).enumerate().take(a_len).skip(b_len))) - } - }) -) + // have to uselessly pretend to pad the longer one for type matching + if a_len < b_len { + (a.mask_words(0).chain(Repeat::new(0u).enumerate().take(b_len).skip(a_len)), + b.mask_words(0).chain(Repeat::new(0u).enumerate().take(0).skip(0))) + } else { + (a.mask_words(0).chain(Repeat::new(0u).enumerate().take(0).skip(0)), + b.mask_words(0).chain(Repeat::new(0u).enumerate().take(a_len).skip(b_len))) + } +} static TRUE: bool = true; static FALSE: bool = false; @@ -1014,7 +1011,7 @@ impl Extendable for BitvSet { impl PartialOrd for BitvSet { #[inline] fn partial_cmp(&self, other: &BitvSet) -> Option { - let (a_iter, b_iter) = match_words!(self.get_ref(), other.get_ref()); + let (a_iter, b_iter) = match_words(self.get_ref(), other.get_ref()); iter::order::partial_cmp(a_iter, b_iter) } } @@ -1022,7 +1019,7 @@ impl PartialOrd for BitvSet { impl Ord for BitvSet { #[inline] fn cmp(&self, other: &BitvSet) -> Ordering { - let (a_iter, b_iter) = match_words!(self.get_ref(), other.get_ref()); + let (a_iter, b_iter) = match_words(self.get_ref(), other.get_ref()); iter::order::cmp(a_iter, b_iter) } } @@ -1030,7 +1027,7 @@ impl Ord for BitvSet { impl cmp::PartialEq for BitvSet { #[inline] fn eq(&self, other: &BitvSet) -> bool { - let (a_iter, b_iter) = match_words!(self.get_ref(), other.get_ref()); + let (a_iter, b_iter) = match_words(self.get_ref(), other.get_ref()); iter::order::eq(a_iter, b_iter) } } @@ -1191,10 +1188,10 @@ impl BitvSet { self_bitv.reserve(other_bitv.capacity()); // virtually pad other with 0's for equal lengths - let self_len = self_bitv.storage.len(); - let other_len = other_bitv.storage.len(); - let mut other_words = other_bitv.mask_words(0) - .chain(iter::Repeat::new(0u).enumerate().take(self_len).skip(other_len)); + let mut other_words = { + let (_, result) = match_words(self_bitv, other_bitv); + result + }; // Apply values found in other for (i, w) in other_words { @@ -1524,7 +1521,7 @@ impl Set for BitvSet { #[inline] fn is_disjoint(&self, other: &BitvSet) -> bool { - self.intersection(other).count() > 0 + self.intersection(other).next().is_none() } #[inline] @@ -2266,6 +2263,24 @@ mod tests { assert!(set1.is_subset(&set2)); // { 2 } { 2, 4 } } + #[test] + fn test_bitv_set_is_disjoint() { + let a = BitvSet::from_bitv(from_bytes([0b10100010])); + let b = BitvSet::from_bitv(from_bytes([0b01000000])); + let c = BitvSet::new(); + let d = BitvSet::from_bitv(from_bytes([0b00110000])); + + assert!(!a.is_disjoint(&d)); + assert!(!d.is_disjoint(&a)); + + assert!(a.is_disjoint(&b)) + assert!(a.is_disjoint(&c)) + assert!(b.is_disjoint(&a)) + assert!(b.is_disjoint(&c)) + assert!(c.is_disjoint(&a)) + assert!(c.is_disjoint(&b)) + } + #[test] fn test_bitv_set_intersect_with() { // Explicitly 0'ed bits