diff --git a/src/libextra/sync.rs b/src/libextra/sync.rs index 044e5e9e509..c4277bddac7 100644 --- a/src/libextra/sync.rs +++ b/src/libextra/sync.rs @@ -27,6 +27,8 @@ use std::unstable::finally::Finally; use std::util; use std::util::NonCopyable; +use arc::MutexArc; + /**************************************************************************** * Internals ****************************************************************************/ @@ -682,6 +684,67 @@ impl<'a> RWLockReadMode<'a> { pub fn read(&self, blk: || -> U) -> U { blk() } } +/// A barrier enables multiple tasks to synchronize the beginning +/// of some computation. +/// ```rust +/// use extra::sync::Barrier; +/// +/// let barrier = Barrier::new(10); +/// 10.times(|| { +/// let c = barrier.clone(); +/// // The same messages will be printed together. +/// // You will NOT see any interleaving. +/// do spawn { +/// println!("before wait"); +/// c.wait(); +/// println!("after wait"); +/// } +/// }); +/// ``` +#[deriving(Clone)] +pub struct Barrier { + priv arc: MutexArc, + priv num_tasks: uint, +} + +// The inner state of a double barrier +struct BarrierState { + priv count: uint, + priv generation_id: uint, +} + +impl Barrier { + /// Create a new barrier that can block a given number of tasks. + pub fn new(num_tasks: uint) -> Barrier { + Barrier { + arc: MutexArc::new(BarrierState { + count: 0, + generation_id: 0, + }), + num_tasks: num_tasks, + } + } + + /// Block the current task until a certain number of tasks is waiting. + pub fn wait(&self) { + self.arc.access_cond(|state, cond| { + let local_gen = state.generation_id; + state.count += 1; + if state.count < self.num_tasks { + // We need a while loop to guard against spurious wakeups. + // http://en.wikipedia.org/wiki/Spurious_wakeup + while local_gen == state.generation_id && state.count < self.num_tasks { + cond.wait(); + } + } else { + state.count = 0; + state.generation_id += 1; + cond.broadcast(); + } + }); + } +} + /**************************************************************************** * Tests ****************************************************************************/ @@ -693,6 +756,7 @@ mod tests { use std::cast; use std::result; use std::task; + use std::comm::{SharedChan, Empty}; /************************************************************************ * Semaphore tests @@ -1315,4 +1379,35 @@ mod tests { }) }) } + + /************************************************************************ + * Barrier tests + ************************************************************************/ + #[test] + fn test_barrier() { + let barrier = Barrier::new(10); + let (port, chan) = SharedChan::new(); + + 9.times(|| { + let c = barrier.clone(); + let chan = chan.clone(); + do spawn { + c.wait(); + chan.send(true); + } + }); + + // At this point, all spawned tasks should be blocked, + // so we shouldn't get anything from the port + assert!(match port.try_recv() { + Empty => true, + _ => false, + }); + + barrier.wait(); + // Now, the barrier is cleared and we should get data. + 9.times(|| { + port.recv(); + }); + } }