Improve code around SGX waitqueue

Followed up of d36e390d81
See https://github.com/rust-lang/rust/pull/109732#issuecomment-1543574908
for more details.

Co-authored-by: Jethro Beekman <jethro@fortanix.com>
This commit is contained in:
Urgau 2023-05-11 10:53:16 +02:00
parent e280df556d
commit f5aede9c82

View File

@ -202,12 +202,18 @@ pub fn wait_timeout<T, F: FnOnce()>(
pub fn notify_one<T>( pub fn notify_one<T>(
mut guard: SpinMutexGuard<'_, WaitVariable<T>>, mut guard: SpinMutexGuard<'_, WaitVariable<T>>,
) -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>> { ) -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>> {
// SAFETY: lifetime of the pop() return value is limited to the map
// closure (The closure return value is 'static). The underlying
// stack frame won't be freed until after the WaitGuard created below
// is dropped.
unsafe { unsafe {
if let Some(entry) = guard.queue.inner.pop() { let tcs = guard.queue.inner.pop().map(|entry| -> Tcs {
let mut entry_guard = entry.lock(); let mut entry_guard = entry.lock();
let tcs = entry_guard.tcs;
entry_guard.wake = true; entry_guard.wake = true;
drop(entry_guard); entry_guard.tcs
});
if let Some(tcs) = tcs {
Ok(WaitGuard { mutex_guard: Some(guard), notified_tcs: NotifiedTcs::Single(tcs) }) Ok(WaitGuard { mutex_guard: Some(guard), notified_tcs: NotifiedTcs::Single(tcs) })
} else { } else {
Err(guard) Err(guard)
@ -223,6 +229,9 @@ pub fn notify_one<T>(
pub fn notify_all<T>( pub fn notify_all<T>(
mut guard: SpinMutexGuard<'_, WaitVariable<T>>, mut guard: SpinMutexGuard<'_, WaitVariable<T>>,
) -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>> { ) -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>> {
// SAFETY: lifetime of the pop() return values are limited to the
// while loop body. The underlying stack frames won't be freed until
// after the WaitGuard created below is dropped.
unsafe { unsafe {
let mut count = 0; let mut count = 0;
while let Some(entry) = guard.queue.inner.pop() { while let Some(entry) = guard.queue.inner.pop() {
@ -230,6 +239,7 @@ pub fn notify_all<T>(
let mut entry_guard = entry.lock(); let mut entry_guard = entry.lock();
entry_guard.wake = true; entry_guard.wake = true;
} }
if let Some(count) = NonZeroUsize::new(count) { if let Some(count) = NonZeroUsize::new(count) {
Ok(WaitGuard { mutex_guard: Some(guard), notified_tcs: NotifiedTcs::All { count } }) Ok(WaitGuard { mutex_guard: Some(guard), notified_tcs: NotifiedTcs::All { count } })
} else { } else {