From 2a6f197bf4358f6ed9211777e59553128caa459b Mon Sep 17 00:00:00 2001 From: Chase Southwood Date: Tue, 25 Nov 2014 02:15:28 -0600 Subject: [PATCH] Implement union, intersection, and difference functions for TrieSet. --- src/libcollections/trie/set.rs | 269 ++++++++++++++++++++++++++++++++- 1 file changed, 268 insertions(+), 1 deletion(-) diff --git a/src/libcollections/trie/set.rs b/src/libcollections/trie/set.rs index f40c0db5edf..dd884b6ee41 100644 --- a/src/libcollections/trie/set.rs +++ b/src/libcollections/trie/set.rs @@ -9,7 +9,6 @@ // except according to those terms. // FIXME(conventions): implement bounded iterators -// FIXME(conventions): implement union family of fns // FIXME(conventions): implement BitOr, BitAnd, BitXor, and Sub // FIXME(conventions): replace each_reverse by making iter DoubleEnded // FIXME(conventions): implement iter_mut and into_iter @@ -19,6 +18,7 @@ use core::prelude::*; use core::default::Default; use core::fmt; use core::fmt::Show; +use core::iter::Peekable; use std::hash::Hash; use trie_map::{TrieMap, Entries}; @@ -172,6 +172,106 @@ impl TrieSet { SetItems{iter: self.map.upper_bound(val)} } + /// Visits the values representing the difference, in ascending order. + /// + /// # Example + /// + /// ``` + /// use std::collections::TrieSet; + /// + /// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect(); + /// let b: TrieSet = [3, 4, 5].iter().map(|&x| x).collect(); + /// + /// // Can be seen as `a - b`. + /// for x in a.difference(&b) { + /// println!("{}", x); // Print 1 then 2 + /// } + /// + /// let diff1: TrieSet = a.difference(&b).collect(); + /// assert_eq!(diff1, [1, 2].iter().map(|&x| x).collect()); + /// + /// // Note that difference is not symmetric, + /// // and `b - a` means something else: + /// let diff2: TrieSet = b.difference(&a).collect(); + /// assert_eq!(diff2, [4, 5].iter().map(|&x| x).collect()); + /// ``` + #[unstable = "matches collection reform specification, waiting for dust to settle"] + pub fn difference<'a>(&'a self, other: &'a TrieSet) -> DifferenceItems<'a> { + DifferenceItems{a: self.iter().peekable(), b: other.iter().peekable()} + } + + /// Visits the values representing the symmetric difference, in ascending order. + /// + /// # Example + /// + /// ``` + /// use std::collections::TrieSet; + /// + /// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect(); + /// let b: TrieSet = [3, 4, 5].iter().map(|&x| x).collect(); + /// + /// // Print 1, 2, 4, 5 in ascending order. + /// for x in a.symmetric_difference(&b) { + /// println!("{}", x); + /// } + /// + /// let diff1: TrieSet = a.symmetric_difference(&b).collect(); + /// let diff2: TrieSet = b.symmetric_difference(&a).collect(); + /// + /// assert_eq!(diff1, diff2); + /// assert_eq!(diff1, [1, 2, 4, 5].iter().map(|&x| x).collect()); + /// ``` + #[unstable = "matches collection reform specification, waiting for dust to settle."] + pub fn symmetric_difference<'a>(&'a self, other: &'a TrieSet) -> SymDifferenceItems<'a> { + SymDifferenceItems{a: self.iter().peekable(), b: other.iter().peekable()} + } + + /// Visits the values representing the intersection, in ascending order. + /// + /// # Example + /// + /// ``` + /// use std::collections::TrieSet; + /// + /// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect(); + /// let b: TrieSet = [2, 3, 4].iter().map(|&x| x).collect(); + /// + /// // Print 2, 3 in ascending order. + /// for x in a.intersection(&b) { + /// println!("{}", x); + /// } + /// + /// let diff: TrieSet = a.intersection(&b).collect(); + /// assert_eq!(diff, [2, 3].iter().map(|&x| x).collect()); + /// ``` + #[unstable = "matches collection reform specification, waiting for dust to settle"] + pub fn intersection<'a>(&'a self, other: &'a TrieSet) -> IntersectionItems<'a> { + IntersectionItems{a: self.iter().peekable(), b: other.iter().peekable()} + } + + /// Visits the values representing the union, in ascending order. + /// + /// # Example + /// + /// ``` + /// use std::collections::TrieSet; + /// + /// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect(); + /// let b: TrieSet = [3, 4, 5].iter().map(|&x| x).collect(); + /// + /// // Print 1, 2, 3, 4, 5 in ascending order. + /// for x in a.union(&b) { + /// println!("{}", x); + /// } + /// + /// let diff: TrieSet = a.union(&b).collect(); + /// assert_eq!(diff, [1, 2, 3, 4, 5].iter().map(|&x| x).collect()); + /// ``` + #[unstable = "matches collection reform specification, waiting for dust to settle"] + pub fn union<'a>(&'a self, other: &'a TrieSet) -> UnionItems<'a> { + UnionItems{a: self.iter().peekable(), b: other.iter().peekable()} + } + /// Return the number of elements in the set /// /// # Example @@ -368,6 +468,39 @@ pub struct SetItems<'a> { iter: Entries<'a, ()> } +/// An iterator producing elements in the set difference (in-order). +pub struct DifferenceItems<'a> { + a: Peekable>, + b: Peekable>, +} + +/// An iterator producing elements in the set symmetric difference (in-order). +pub struct SymDifferenceItems<'a> { + a: Peekable>, + b: Peekable>, +} + +/// An iterator producing elements in the set intersection (in-order). +pub struct IntersectionItems<'a> { + a: Peekable>, + b: Peekable>, +} + +/// An iterator producing elements in the set union (in-order). +pub struct UnionItems<'a> { + a: Peekable>, + b: Peekable>, +} + +/// Compare `x` and `y`, but return `short` if x is None and `long` if y is None +fn cmp_opt(x: Option<&uint>, y: Option<&uint>, short: Ordering, long: Ordering) -> Ordering { + match (x, y) { + (None , _ ) => short, + (_ , None ) => long, + (Some(x1), Some(y1)) => x1.cmp(y1), + } +} + impl<'a> Iterator for SetItems<'a> { fn next(&mut self) -> Option { self.iter.next().map(|(key, _)| key) @@ -378,6 +511,60 @@ impl<'a> Iterator for SetItems<'a> { } } +impl<'a> Iterator for DifferenceItems<'a> { + fn next(&mut self) -> Option { + loop { + match cmp_opt(self.a.peek(), self.b.peek(), Less, Less) { + Less => return self.a.next(), + Equal => { self.a.next(); self.b.next(); } + Greater => { self.b.next(); } + } + } + } +} + +impl<'a> Iterator for SymDifferenceItems<'a> { + fn next(&mut self) -> Option { + loop { + match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) { + Less => return self.a.next(), + Equal => { self.a.next(); self.b.next(); } + Greater => return self.b.next(), + } + } + } +} + +impl<'a> Iterator for IntersectionItems<'a> { + fn next(&mut self) -> Option { + loop { + let o_cmp = match (self.a.peek(), self.b.peek()) { + (None , _ ) => None, + (_ , None ) => None, + (Some(a1), Some(b1)) => Some(a1.cmp(b1)), + }; + match o_cmp { + None => return None, + Some(Less) => { self.a.next(); } + Some(Equal) => { self.b.next(); return self.a.next() } + Some(Greater) => { self.b.next(); } + } + } + } +} + +impl<'a> Iterator for UnionItems<'a> { + fn next(&mut self) -> Option { + loop { + match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) { + Less => return self.a.next(), + Equal => { self.b.next(); return self.a.next() } + Greater => return self.b.next(), + } + } + } +} + #[cfg(test)] mod test { use std::prelude::*; @@ -471,4 +658,84 @@ mod test { assert!(b > a && b >= a); assert!(a < b && a <= b); } + + fn check(a: &[uint], + b: &[uint], + expected: &[uint], + f: |&TrieSet, &TrieSet, f: |uint| -> bool| -> bool) { + let mut set_a = TrieSet::new(); + let mut set_b = TrieSet::new(); + + for x in a.iter() { assert!(set_a.insert(*x)) } + for y in b.iter() { assert!(set_b.insert(*y)) } + + let mut i = 0; + f(&set_a, &set_b, |x| { + assert_eq!(x, expected[i]); + i += 1; + true + }); + assert_eq!(i, expected.len()); + } + + #[test] + fn test_intersection() { + fn check_intersection(a: &[uint], b: &[uint], expected: &[uint]) { + check(a, b, expected, |x, y, f| x.intersection(y).all(f)) + } + + check_intersection(&[], &[], &[]); + check_intersection(&[1, 2, 3], &[], &[]); + check_intersection(&[], &[1, 2, 3], &[]); + check_intersection(&[2], &[1, 2, 3], &[2]); + check_intersection(&[1, 2, 3], &[2], &[2]); + check_intersection(&[11, 1, 3, 77, 103, 5], + &[2, 11, 77, 5, 3], + &[3, 5, 11, 77]); + } + + #[test] + fn test_difference() { + fn check_difference(a: &[uint], b: &[uint], expected: &[uint]) { + check(a, b, expected, |x, y, f| x.difference(y).all(f)) + } + + check_difference(&[], &[], &[]); + check_difference(&[1, 12], &[], &[1, 12]); + check_difference(&[], &[1, 2, 3, 9], &[]); + check_difference(&[1, 3, 5, 9, 11], + &[3, 9], + &[1, 5, 11]); + check_difference(&[11, 22, 33, 40, 42], + &[14, 23, 34, 38, 39, 50], + &[11, 22, 33, 40, 42]); + } + + #[test] + fn test_symmetric_difference() { + fn check_symmetric_difference(a: &[uint], b: &[uint], expected: &[uint]) { + check(a, b, expected, |x, y, f| x.symmetric_difference(y).all(f)) + } + + check_symmetric_difference(&[], &[], &[]); + check_symmetric_difference(&[1, 2, 3], &[2], &[1, 3]); + check_symmetric_difference(&[2], &[1, 2, 3], &[1, 3]); + check_symmetric_difference(&[1, 3, 5, 9, 11], + &[3, 9, 14, 22], + &[1, 5, 11, 14, 22]); + } + + #[test] + fn test_union() { + fn check_union(a: &[uint], b: &[uint], expected: &[uint]) { + check(a, b, expected, |x, y, f| x.union(y).all(f)) + } + + check_union(&[], &[], &[]); + check_union(&[1, 2, 3], &[2], &[1, 2, 3]); + check_union(&[2], &[1, 2, 3], &[1, 2, 3]); + check_union(&[1, 3, 5, 9, 11, 16, 19, 24], + &[1, 5, 9, 13, 19], + &[1, 3, 5, 9, 11, 13, 16, 19, 24]); + } }