rust/library/std/src/sync/barrier.rs
Yuki Okushi fb9232b453
Rollup merge of #87440 - twetzel59:fix-barrier-no-op, r=yaahc
Remove unnecessary condition in Barrier::wait()

This is my first pull request for Rust, so feel free to call me out if anything is amiss.

After some examination, I realized that the second condition of the "spurious-wakeup-handler" loop in ``std::sync::Barrier::wait()`` should always evaluate to ``true``, making it redundant in the ``&&`` expression.

Here is the affected function before the fix:
```rust
#[stable(feature = "rust1", since = "1.0.0")]
pub fn wait(&self) -> BarrierWaitResult {
    let mut lock = self.lock.lock().unwrap();
    let local_gen = lock.generation_id;
    lock.count += 1;
    if lock.count < self.num_threads {
        // We need a while loop to guard against spurious wakeups.
        // https://en.wikipedia.org/wiki/Spurious_wakeup
        while local_gen == lock.generation_id && lock.count < self.num_threads { // fixme
            lock = self.cvar.wait(lock).unwrap();
        }
        BarrierWaitResult(false)
    } else {
        lock.count = 0;
        lock.generation_id = lock.generation_id.wrapping_add(1);
        self.cvar.notify_all();
        BarrierWaitResult(true)
    }
}
```

At first glance, it seems that the check that ``lock.count < self.num_threads`` would be necessary in order for a thread A to detect when another thread B has caused the barrier to reach its thread count, making thread B the "leader".

However, the control flow implicitly results in an invariant that makes observing ``!(lock.count < self.num_threads)``, i.e. ``lock.count >= self.num_threads`` impossible from thread A.

When thread B, which will be the leader, calls ``.wait()`` on this shared instance of the ``Barrier``, it locks the mutex in the first line and saves the ``MutexGuard`` in the ``lock`` variable. It then increments the value of ``lock.count``. However, it then proceeds to check if ``lock.count < self.num_threads``. Since it is the leader, it is the case that (after the increment of ``lock.count``), the lock count is *equal* to the number of threads. Thus, the second branch is immediately taken and ``lock.count`` is zeroed. Additionally, the generation ID is incremented (with wrap). Then, the condition variable is signalled. But, the other threads are waiting at the line ``lock = self.cvar.wait(lock).unwrap();``, so they cannot resume until thread B's call to ``Barrier::wait()`` returns, which drops the ``MutexGuard`` acquired in the first ``let`` statement and unlocks the mutex.

The order of events is thus:
1. A thread A calls `.wait()`
2. `.wait()` acquires the mutex, increments `lock.count`, and takes the first branch
3. Thread A enters the ``while`` loop since the generation ID has not changed and the count is less than the number of threads for the ``Barrier``
3. Spurious wakeups occur, but both conditions hold, so the thread A waits on the condition variable
4. This process repeats for N - 2 additional times for non-leader threads A'
5. *Meanwhile*, Thread B calls ``Barrier::wait()`` on the same barrier that threads A, A', A'', etc. are waiting on. The thread count reaches the number of threads for the ``Barrier``, so all threads should now proceed, with B being the leader. B acquires the mutex and increments the value ``lock.count`` only to find that it is not less than ``self.num_threads``. Thus, it immediately clamps ``self.num_threads`` back down to 0 and increments the generation. Then, it signals the condvar to tell the A (prime) threads that they may continue.
6. The A, A', A''... threads wake up and attempt to re-acquire the ``lock`` as per the internal operation of a condition variable. When each A has exclusive access to the mutex, it finds that ``lock.generation_id`` no longer matches ``local_generation`` **and the ``&&`` expression short-circuits -- and even if it were to evaluate it, ``self.count`` is definitely less than ``self.num_threads`` because it has been reset to ``0`` by thread B *before* B dropped its ``MutexGuard``**.

Therefore, it my understanding that it would be impossible for the non-leader threads to ever see the second boolean expression evaluate to anything other than ``true``. This PR simply removes that condition.

Any input would be appreciated. Sorry if this is terribly verbose. I'm new to the Rust community and concurrency can be hard to explain in words. Thanks!
2021-10-21 14:11:02 +09:00

175 lines
5.0 KiB
Rust

#[cfg(test)]
mod tests;
use crate::fmt;
use crate::sync::{Condvar, Mutex};
/// A barrier enables multiple threads to synchronize the beginning
/// of some computation.
///
/// # Examples
///
/// ```
/// use std::sync::{Arc, Barrier};
/// use std::thread;
///
/// let mut handles = Vec::with_capacity(10);
/// let barrier = Arc::new(Barrier::new(10));
/// for _ in 0..10 {
/// let c = Arc::clone(&barrier);
/// // The same messages will be printed together.
/// // You will NOT see any interleaving.
/// handles.push(thread::spawn(move|| {
/// println!("before wait");
/// c.wait();
/// println!("after wait");
/// }));
/// }
/// // Wait for other threads to finish.
/// for handle in handles {
/// handle.join().unwrap();
/// }
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub struct Barrier {
lock: Mutex<BarrierState>,
cvar: Condvar,
num_threads: usize,
}
// The inner state of a double barrier
struct BarrierState {
count: usize,
generation_id: usize,
}
/// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all threads
/// in the [`Barrier`] have rendezvoused.
///
/// # Examples
///
/// ```
/// use std::sync::Barrier;
///
/// let barrier = Barrier::new(1);
/// let barrier_wait_result = barrier.wait();
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub struct BarrierWaitResult(bool);
#[stable(feature = "std_debug", since = "1.16.0")]
impl fmt::Debug for Barrier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Barrier").finish_non_exhaustive()
}
}
impl Barrier {
/// Creates a new barrier that can block a given number of threads.
///
/// A barrier will block `n`-1 threads which call [`wait()`] and then wake
/// up all threads at once when the `n`th thread calls [`wait()`].
///
/// [`wait()`]: Barrier::wait
///
/// # Examples
///
/// ```
/// use std::sync::Barrier;
///
/// let barrier = Barrier::new(10);
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
#[must_use]
pub fn new(n: usize) -> Barrier {
Barrier {
lock: Mutex::new(BarrierState { count: 0, generation_id: 0 }),
cvar: Condvar::new(),
num_threads: n,
}
}
/// Blocks the current thread until all threads have rendezvoused here.
///
/// Barriers are re-usable after all threads have rendezvoused once, and can
/// be used continuously.
///
/// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that
/// returns `true` from [`BarrierWaitResult::is_leader()`] when returning
/// from this function, and all other threads will receive a result that
/// will return `false` from [`BarrierWaitResult::is_leader()`].
///
/// # Examples
///
/// ```
/// use std::sync::{Arc, Barrier};
/// use std::thread;
///
/// let mut handles = Vec::with_capacity(10);
/// let barrier = Arc::new(Barrier::new(10));
/// for _ in 0..10 {
/// let c = Arc::clone(&barrier);
/// // The same messages will be printed together.
/// // You will NOT see any interleaving.
/// handles.push(thread::spawn(move|| {
/// println!("before wait");
/// c.wait();
/// println!("after wait");
/// }));
/// }
/// // Wait for other threads to finish.
/// for handle in handles {
/// handle.join().unwrap();
/// }
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn wait(&self) -> BarrierWaitResult {
let mut lock = self.lock.lock().unwrap();
let local_gen = lock.generation_id;
lock.count += 1;
if lock.count < self.num_threads {
// We need a while loop to guard against spurious wakeups.
// https://en.wikipedia.org/wiki/Spurious_wakeup
while local_gen == lock.generation_id {
lock = self.cvar.wait(lock).unwrap();
}
BarrierWaitResult(false)
} else {
lock.count = 0;
lock.generation_id = lock.generation_id.wrapping_add(1);
self.cvar.notify_all();
BarrierWaitResult(true)
}
}
}
#[stable(feature = "std_debug", since = "1.16.0")]
impl fmt::Debug for BarrierWaitResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BarrierWaitResult").field("is_leader", &self.is_leader()).finish()
}
}
impl BarrierWaitResult {
/// Returns `true` if this thread is the "leader thread" for the call to
/// [`Barrier::wait()`].
///
/// Only one thread will have `true` returned from their result, all other
/// threads will have `false` returned.
///
/// # Examples
///
/// ```
/// use std::sync::Barrier;
///
/// let barrier = Barrier::new(1);
/// let barrier_wait_result = barrier.wait();
/// println!("{:?}", barrier_wait_result.is_leader());
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
#[must_use]
pub fn is_leader(&self) -> bool {
self.0
}
}