Implement RFC 3151: Scoped threads.

This commit is contained in:
Mara Bos 2022-01-04 14:51:39 +01:00
parent a45b3ac183
commit 0e24ad537b
2 changed files with 206 additions and 30 deletions

View File

@ -180,6 +180,12 @@
#[macro_use]
mod local;
#[unstable(feature = "scoped_threads", issue = "none")]
mod scoped;
#[unstable(feature = "scoped_threads", issue = "none")]
pub use scoped::{scope, Scope, ScopedJoinHandle};
#[stable(feature = "rust1", since = "1.0.0")]
pub use self::local::{AccessError, LocalKey};
@ -446,6 +452,20 @@ pub unsafe fn spawn_unchecked<'a, F, T>(self, f: F) -> io::Result<JoinHandle<T>>
F: FnOnce() -> T,
F: Send + 'a,
T: Send + 'a,
{
Ok(JoinHandle(unsafe { self.spawn_unchecked_(f, None) }?))
}
unsafe fn spawn_unchecked_<'a, 'scope, F, T>(
self,
f: F,
scope_data: Option<&'scope scoped::ScopeData>,
) -> io::Result<JoinInner<'scope, T>>
where
F: FnOnce() -> T,
F: Send + 'a,
T: Send + 'a,
'scope: 'a,
{
let Builder { name, stack_size } = self;
@ -456,7 +476,8 @@ pub unsafe fn spawn_unchecked<'a, F, T>(self, f: F) -> io::Result<JoinHandle<T>>
}));
let their_thread = my_thread.clone();
let my_packet: Arc<UnsafeCell<Option<Result<T>>>> = Arc::new(UnsafeCell::new(None));
let my_packet: Arc<Packet<'scope, T>> =
Arc::new(Packet { scope: scope_data, result: UnsafeCell::new(None) });
let their_packet = my_packet.clone();
let output_capture = crate::io::set_output_capture(None);
@ -480,10 +501,14 @@ pub unsafe fn spawn_unchecked<'a, F, T>(self, f: F) -> io::Result<JoinHandle<T>>
// closure (it is an Arc<...>) and `my_packet` will be stored in the
// same `JoinInner` as this closure meaning the mutation will be
// safe (not modify it and affect a value far away).
unsafe { *their_packet.get() = Some(try_result) };
unsafe { *their_packet.result.get() = Some(try_result) };
};
Ok(JoinHandle(JoinInner {
if let Some(scope_data) = scope_data {
scope_data.increment_n_running_threads();
}
Ok(JoinInner {
// SAFETY:
//
// `imp::Thread::new` takes a closure with a `'static` lifetime, since it's passed
@ -506,8 +531,8 @@ pub unsafe fn spawn_unchecked<'a, F, T>(self, f: F) -> io::Result<JoinHandle<T>>
)?
},
thread: my_thread,
packet: Packet(my_packet),
}))
packet: my_packet,
})
}
}
@ -1239,34 +1264,53 @@ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
#[stable(feature = "rust1", since = "1.0.0")]
pub type Result<T> = crate::result::Result<T, Box<dyn Any + Send + 'static>>;
// This packet is used to communicate the return value between the spawned thread
// and the rest of the program. Memory is shared through the `Arc` within and there's
// no need for a mutex here because synchronization happens with `join()` (the
// caller will never read this packet until the thread has exited).
// This packet is used to communicate the return value between the spawned
// thread and the rest of the program. It is shared through an `Arc` and
// there's no need for a mutex here because synchronization happens with `join()`
// (the caller will never read this packet until the thread has exited).
//
// This packet itself is then stored into a `JoinInner` which in turns is placed
// in `JoinHandle` and `JoinGuard`. Due to the usage of `UnsafeCell` we need to
// manually worry about impls like Send and Sync. The type `T` should
// already always be Send (otherwise the thread could not have been created) and
// this type is inherently Sync because no methods take &self. Regardless,
// however, we add inheriting impls for Send/Sync to this type to ensure it's
// Send/Sync and that future modifications will still appropriately classify it.
struct Packet<T>(Arc<UnsafeCell<Option<Result<T>>>>);
unsafe impl<T: Send> Send for Packet<T> {}
unsafe impl<T: Sync> Sync for Packet<T> {}
/// Inner representation for JoinHandle
struct JoinInner<T> {
native: imp::Thread,
thread: Thread,
packet: Packet<T>,
// An Arc to the packet is stored into a `JoinInner` which in turns is placed
// in `JoinHandle`. Due to the usage of `UnsafeCell` we need to manually worry
// about impls like Send and Sync. The type `T` should already always be Send
// (otherwise the thread could not have been created) and this type is
// inherently Sync because no methods take &self. Regardless, however, we add
// inheriting impls for Send/Sync to this type to ensure it's Send/Sync and
// that future modifications will still appropriately classify it.
struct Packet<'scope, T> {
scope: Option<&'scope scoped::ScopeData>,
result: UnsafeCell<Option<Result<T>>>,
}
impl<T> JoinInner<T> {
unsafe impl<'scope, T: Send> Send for Packet<'scope, T> {}
unsafe impl<'scope, T: Sync> Sync for Packet<'scope, T> {}
impl<'scope, T> Drop for Packet<'scope, T> {
fn drop(&mut self) {
if let Some(scope) = self.scope {
// If this packet was for a thread that ran in a scope, the thread
// panicked, and nobody consumed the panic payload, we put the
// panic payload in the scope so it can re-throw it, if it didn't
// already capture any panic yet.
if let Some(Err(e)) = self.result.get_mut().take() {
scope.panic_payload.lock().unwrap().get_or_insert(e);
}
// Book-keeping so the scope knows when it's done.
scope.decrement_n_running_threads();
}
}
}
/// Inner representation for JoinHandle
struct JoinInner<'scope, T> {
native: imp::Thread,
thread: Thread,
packet: Arc<Packet<'scope, T>>,
}
impl<'scope, T> JoinInner<'scope, T> {
fn join(mut self) -> Result<T> {
self.native.join();
Arc::get_mut(&mut self.packet.0).unwrap().get_mut().take().unwrap()
Arc::get_mut(&mut self.packet).unwrap().result.get_mut().take().unwrap()
}
}
@ -1333,7 +1377,7 @@ fn join(mut self) -> Result<T> {
/// [`thread::Builder::spawn`]: Builder::spawn
/// [`thread::spawn`]: spawn
#[stable(feature = "rust1", since = "1.0.0")]
pub struct JoinHandle<T>(JoinInner<T>);
pub struct JoinHandle<T>(JoinInner<'static, T>);
#[stable(feature = "joinhandle_impl_send_sync", since = "1.29.0")]
unsafe impl<T> Send for JoinHandle<T> {}
@ -1407,7 +1451,7 @@ pub fn join(self) -> Result<T> {
/// function has returned, but before the thread itself has stopped running.
#[unstable(feature = "thread_is_running", issue = "90470")]
pub fn is_running(&self) -> bool {
Arc::strong_count(&self.0.packet.0) > 1
Arc::strong_count(&self.0.packet) > 1
}
}

View File

@ -0,0 +1,132 @@
use super::{current, park, Builder, JoinInner, Result, Thread};
use crate::any::Any;
use crate::fmt;
use crate::io;
use crate::marker::PhantomData;
use crate::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
use crate::sync::atomic::{AtomicUsize, Ordering};
use crate::sync::Mutex;
/// TODO: documentation
pub struct Scope<'env> {
data: ScopeData,
env: PhantomData<&'env ()>,
}
/// TODO: documentation
pub struct ScopedJoinHandle<'scope, T>(JoinInner<'scope, T>);
pub(super) struct ScopeData {
n_running_threads: AtomicUsize,
main_thread: Thread,
pub(super) panic_payload: Mutex<Option<Box<dyn Any + Send>>>,
}
impl ScopeData {
pub(super) fn increment_n_running_threads(&self) {
// We check for 'overflow' with usize::MAX / 2, to make sure there's no
// chance it overflows to 0, which would result in unsoundness.
if self.n_running_threads.fetch_add(1, Ordering::Relaxed) == usize::MAX / 2 {
// This can only reasonably happen by mem::forget()'ing many many ScopedJoinHandles.
self.decrement_n_running_threads();
panic!("too many running threads in thread scope");
}
}
pub(super) fn decrement_n_running_threads(&self) {
if self.n_running_threads.fetch_sub(1, Ordering::Release) == 1 {
self.main_thread.unpark();
}
}
}
/// TODO: documentation
pub fn scope<'env, F, T>(f: F) -> T
where
F: FnOnce(&Scope<'env>) -> T,
{
let mut scope = Scope {
data: ScopeData {
n_running_threads: AtomicUsize::new(0),
main_thread: current(),
panic_payload: Mutex::new(None),
},
env: PhantomData,
};
// Run `f`, but catch panics so we can make sure to wait for all the threads to join.
let result = catch_unwind(AssertUnwindSafe(|| f(&scope)));
// Wait until all the threads are finished.
while scope.data.n_running_threads.load(Ordering::Acquire) != 0 {
park();
}
// Throw any panic from `f` or from any panicked thread, or the return value of `f` otherwise.
match result {
Err(e) => {
// `f` itself panicked.
resume_unwind(e);
}
Ok(result) => {
if let Some(panic_payload) = scope.data.panic_payload.get_mut().unwrap().take() {
// A thread panicked.
resume_unwind(panic_payload);
} else {
// Nothing panicked.
result
}
}
}
}
impl<'env> Scope<'env> {
/// TODO: documentation
pub fn spawn<'scope, F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
where
F: FnOnce(&Scope<'env>) -> T + Send + 'env,
T: Send + 'env,
{
Builder::new().spawn_scoped(self, f).expect("failed to spawn thread")
}
}
impl Builder {
fn spawn_scoped<'scope, 'env, F, T>(
self,
scope: &'scope Scope<'env>,
f: F,
) -> io::Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce(&Scope<'env>) -> T + Send + 'env,
T: Send + 'env,
{
Ok(ScopedJoinHandle(unsafe { self.spawn_unchecked_(|| f(scope), Some(&scope.data)) }?))
}
}
impl<'scope, T> ScopedJoinHandle<'scope, T> {
/// TODO
pub fn join(self) -> Result<T> {
self.0.join()
}
/// TODO
pub fn thread(&self) -> &Thread {
&self.0.thread
}
}
impl<'env> fmt::Debug for Scope<'env> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Scope")
.field("n_running_threads", &self.data.n_running_threads.load(Ordering::Relaxed))
.field("panic_payload", &self.data.panic_payload)
.finish_non_exhaustive()
}
}
impl<'scope, T> fmt::Debug for ScopedJoinHandle<'scope, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ScopedJoinHandle").finish_non_exhaustive()
}
}