From a1d94d43ac876975e15b5069da1f9b7a81e29733 Mon Sep 17 00:00:00 2001 From: DrMeepster <19316085+DrMeepster@users.noreply.github.com> Date: Sat, 29 Oct 2022 22:58:34 -0700 Subject: [PATCH] impl condvars for windows --- src/tools/miri/src/concurrency/sync.rs | 14 +- src/tools/miri/src/shims/unix/sync.rs | 14 +- .../miri/src/shims/windows/foreign_items.rs | 19 +++ src/tools/miri/src/shims/windows/sync.rs | 146 ++++++++++++++++++ src/tools/miri/tests/pass/concurrency/sync.rs | 20 +-- .../tests/pass/concurrency/sync_nopreempt.rs | 1 - .../miri/tests/pass/panic/concurrent-panic.rs | 1 - 7 files changed, 185 insertions(+), 30 deletions(-) diff --git a/src/tools/miri/src/concurrency/sync.rs b/src/tools/miri/src/concurrency/sync.rs index e76610e7302..48f9e605276 100644 --- a/src/tools/miri/src/concurrency/sync.rs +++ b/src/tools/miri/src/concurrency/sync.rs @@ -121,8 +121,10 @@ declare_id!(CondvarId); struct CondvarWaiter { /// The thread that is waiting on this variable. thread: ThreadId, - /// The mutex on which the thread is waiting. - mutex: MutexId, + /// The mutex or rwlock on which the thread is waiting. + lock: u32, + /// If the lock is shared or exclusive + shared: bool, } /// The conditional variable state. @@ -569,16 +571,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { } /// Mark that the thread is waiting on the conditional variable. - fn condvar_wait(&mut self, id: CondvarId, thread: ThreadId, mutex: MutexId) { + fn condvar_wait(&mut self, id: CondvarId, thread: ThreadId, lock: u32, shared: bool) { let this = self.eval_context_mut(); let waiters = &mut this.machine.threads.sync.condvars[id].waiters; assert!(waiters.iter().all(|waiter| waiter.thread != thread), "thread is already waiting"); - waiters.push_back(CondvarWaiter { thread, mutex }); + waiters.push_back(CondvarWaiter { thread, lock, shared }); } /// Wake up some thread (if there is any) sleeping on the conditional /// variable. - fn condvar_signal(&mut self, id: CondvarId) -> Option<(ThreadId, MutexId)> { + fn condvar_signal(&mut self, id: CondvarId) -> Option<(ThreadId, u32, bool)> { let this = self.eval_context_mut(); let current_thread = this.get_active_thread(); let condvar = &mut this.machine.threads.sync.condvars[id]; @@ -592,7 +594,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { if let Some(data_race) = data_race { data_race.validate_lock_acquire(&condvar.data_race, waiter.thread); } - (waiter.thread, waiter.mutex) + (waiter.thread, waiter.lock, waiter.shared) }) } diff --git a/src/tools/miri/src/shims/unix/sync.rs b/src/tools/miri/src/shims/unix/sync.rs index fcb00692079..d24e1a56bd5 100644 --- a/src/tools/miri/src/shims/unix/sync.rs +++ b/src/tools/miri/src/shims/unix/sync.rs @@ -696,8 +696,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { fn pthread_cond_signal(&mut self, cond_op: &OpTy<'tcx, Provenance>) -> InterpResult<'tcx, i32> { let this = self.eval_context_mut(); let id = this.condvar_get_or_create_id(cond_op, CONDVAR_ID_OFFSET)?; - if let Some((thread, mutex)) = this.condvar_signal(id) { - post_cond_signal(this, thread, mutex)?; + if let Some((thread, mutex, shared)) = this.condvar_signal(id) { + assert!(!shared); + post_cond_signal(this, thread, MutexId::from_u32(mutex))?; } Ok(0) @@ -710,8 +711,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { let this = self.eval_context_mut(); let id = this.condvar_get_or_create_id(cond_op, CONDVAR_ID_OFFSET)?; - while let Some((thread, mutex)) = this.condvar_signal(id) { - post_cond_signal(this, thread, mutex)?; + while let Some((thread, mutex, shared)) = this.condvar_signal(id) { + assert!(!shared); + post_cond_signal(this, thread, MutexId::from_u32(mutex))?; } Ok(0) @@ -729,7 +731,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { let active_thread = this.get_active_thread(); release_cond_mutex_and_block(this, active_thread, mutex_id)?; - this.condvar_wait(id, active_thread, mutex_id); + this.condvar_wait(id, active_thread, mutex_id.to_u32(), false); Ok(0) } @@ -768,7 +770,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { }; release_cond_mutex_and_block(this, active_thread, mutex_id)?; - this.condvar_wait(id, active_thread, mutex_id); + this.condvar_wait(id, active_thread, mutex_id.to_u32(), false); // We return success for now and override it in the timeout callback. this.write_scalar(Scalar::from_i32(0), dest)?; diff --git a/src/tools/miri/src/shims/windows/foreign_items.rs b/src/tools/miri/src/shims/windows/foreign_items.rs index 2a34a3a47bb..e16749c986b 100644 --- a/src/tools/miri/src/shims/windows/foreign_items.rs +++ b/src/tools/miri/src/shims/windows/foreign_items.rs @@ -273,6 +273,25 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { let result = this.InitOnceComplete(ptr, flags, context)?; this.write_scalar(result, dest)?; } + "SleepConditionVariableSRW" => { + let [condvar, lock, timeout, flags] = + this.check_shim(abi, Abi::System { unwind: false }, link_name, args)?; + + let result = this.SleepConditionVariableSRW(condvar, lock, timeout, flags, dest)?; + this.write_scalar(result, dest)?; + } + "WakeConditionVariable" => { + let [condvar] = + this.check_shim(abi, Abi::System { unwind: false }, link_name, args)?; + + this.WakeConditionVariable(condvar)?; + } + "WakeAllConditionVariable" => { + let [condvar] = + this.check_shim(abi, Abi::System { unwind: false }, link_name, args)?; + + this.WakeAllConditionVariable(condvar)?; + } // Dynamic symbol loading "GetProcAddress" => { diff --git a/src/tools/miri/src/shims/windows/sync.rs b/src/tools/miri/src/shims/windows/sync.rs index 098804626f2..2eab1794c4f 100644 --- a/src/tools/miri/src/shims/windows/sync.rs +++ b/src/tools/miri/src/shims/windows/sync.rs @@ -8,6 +8,38 @@ use crate::*; const SRWLOCK_ID_OFFSET: u64 = 0; const INIT_ONCE_ID_OFFSET: u64 = 0; +const CONDVAR_ID_OFFSET: u64 = 0; + +impl<'mir, 'tcx> EvalContextExtPriv<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {} +pub trait EvalContextExtPriv<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { + /// Try to reacquire the lock associated with the condition variable after we + /// were signaled. + fn reacquire_cond_lock( + &mut self, + thread: ThreadId, + lock: RwLockId, + shared: bool, + ) -> InterpResult<'tcx> { + let this = self.eval_context_mut(); + this.unblock_thread(thread); + + if shared { + if this.rwlock_is_locked(lock) { + this.rwlock_enqueue_and_block_reader(lock, thread); + } else { + this.rwlock_reader_lock(lock, thread); + } + } else { + if this.rwlock_is_write_locked(lock) { + this.rwlock_enqueue_and_block_writer(lock, thread); + } else { + this.rwlock_writer_lock(lock, thread); + } + } + + Ok(()) + } +} impl<'mir, 'tcx> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {} #[allow(non_snake_case)] @@ -327,4 +359,118 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { Ok(()) } + + fn SleepConditionVariableSRW( + &mut self, + condvar_op: &OpTy<'tcx, Provenance>, + lock_op: &OpTy<'tcx, Provenance>, + timeout_op: &OpTy<'tcx, Provenance>, + flags_op: &OpTy<'tcx, Provenance>, + dest: &PlaceTy<'tcx, Provenance>, + ) -> InterpResult<'tcx, Scalar> { + let this = self.eval_context_mut(); + + let condvar_id = this.condvar_get_or_create_id(condvar_op, CONDVAR_ID_OFFSET)?; + let lock_id = this.rwlock_get_or_create_id(lock_op, SRWLOCK_ID_OFFSET)?; + let timeout_ms = this.read_scalar(timeout_op)?.to_u32()?; + let flags = this.read_scalar(flags_op)?.to_u32()?; + + let timeout_time = if timeout_ms == this.eval_windows("c", "INFINITE")?.to_u32()? { + None + } else { + let duration = Duration::from_millis(timeout_ms.into()); + Some(this.machine.clock.now().checked_add(duration).unwrap()) + }; + + let shared_mode = 0x1; // CONDITION_VARIABLE_LOCKMODE_SHARED is not in std + let shared = flags == shared_mode; + + let active_thread = this.get_active_thread(); + + let was_locked = if shared { + this.rwlock_reader_unlock(lock_id, active_thread) + } else { + this.rwlock_writer_unlock(lock_id, active_thread) + }; + + if !was_locked { + throw_ub_format!( + "calling SleepConditionVariableSRW with an SRWLock that is not locked by the current thread" + ); + } + + this.block_thread(active_thread); + this.condvar_wait(condvar_id, active_thread, lock_id.to_u32(), shared); + + if let Some(timeout_time) = timeout_time { + struct Callback<'tcx> { + thread: ThreadId, + condvar_id: CondvarId, + lock_id: RwLockId, + shared: bool, + dest: PlaceTy<'tcx, Provenance>, + } + + impl<'tcx> VisitTags for Callback<'tcx> { + fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) { + let Callback { thread: _, condvar_id: _, lock_id: _, shared: _, dest } = self; + dest.visit_tags(visit); + } + } + + impl<'mir, 'tcx: 'mir> MachineCallback<'mir, 'tcx> for Callback<'tcx> { + fn call(&self, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> { + this.reacquire_cond_lock(self.thread, self.lock_id, self.shared)?; + + this.condvar_remove_waiter(self.condvar_id, self.thread); + + let error_timeout = this.eval_windows("c", "ERROR_TIMEOUT")?; + this.set_last_error(error_timeout)?; + this.write_scalar(this.eval_windows("c", "FALSE")?, &self.dest)?; + Ok(()) + } + } + + this.register_timeout_callback( + active_thread, + Time::Monotonic(timeout_time), + Box::new(Callback { + thread: active_thread, + condvar_id, + lock_id, + shared, + dest: dest.clone(), + }), + ); + } + + this.eval_windows("c", "TRUE") + } + + fn WakeConditionVariable(&mut self, condvar_op: &OpTy<'tcx, Provenance>) -> InterpResult<'tcx> { + let this = self.eval_context_mut(); + let condvar_id = this.condvar_get_or_create_id(condvar_op, CONDVAR_ID_OFFSET)?; + + if let Some((thread, lock, shared)) = this.condvar_signal(condvar_id) { + this.reacquire_cond_lock(thread, RwLockId::from_u32(lock), shared)?; + this.unregister_timeout_callback_if_exists(thread); + } + + Ok(()) + } + + fn WakeAllConditionVariable( + &mut self, + condvar_op: &OpTy<'tcx, Provenance>, + ) -> InterpResult<'tcx> { + let this = self.eval_context_mut(); + let condvar_id = this.condvar_get_or_create_id(condvar_op, CONDVAR_ID_OFFSET)?; + + while let Some((thread, lock, shared)) = this.condvar_signal(condvar_id) { + this.reacquire_cond_lock(thread, RwLockId::from_u32(lock), shared)?; + this.unregister_timeout_callback_if_exists(thread); + } + + Ok(()) + } } diff --git a/src/tools/miri/tests/pass/concurrency/sync.rs b/src/tools/miri/tests/pass/concurrency/sync.rs index b1518a49fbb..19ea6c130bd 100644 --- a/src/tools/miri/tests/pass/concurrency/sync.rs +++ b/src/tools/miri/tests/pass/concurrency/sync.rs @@ -230,20 +230,8 @@ fn main() { check_once(); park_timeout(); park_unpark(); - - if !cfg!(windows) { - // ignore-target-windows: Condvars on Windows are not supported yet - check_barriers(); - check_conditional_variables_notify_one(); - check_conditional_variables_timed_wait_timeout(); - check_conditional_variables_timed_wait_notimeout(); - } else { - // We need to fake the same output... - for _ in 0..10 { - println!("before wait"); - } - for _ in 0..10 { - println!("after wait"); - } - } + check_barriers(); + check_conditional_variables_notify_one(); + check_conditional_variables_timed_wait_timeout(); + check_conditional_variables_timed_wait_notimeout(); } diff --git a/src/tools/miri/tests/pass/concurrency/sync_nopreempt.rs b/src/tools/miri/tests/pass/concurrency/sync_nopreempt.rs index 55206f4bfc5..c6cff038f81 100644 --- a/src/tools/miri/tests/pass/concurrency/sync_nopreempt.rs +++ b/src/tools/miri/tests/pass/concurrency/sync_nopreempt.rs @@ -1,4 +1,3 @@ -//@ignore-target-windows: Condvars on Windows are not supported yet. // We are making scheduler assumptions here. //@compile-flags: -Zmiri-strict-provenance -Zmiri-preemption-rate=0 diff --git a/src/tools/miri/tests/pass/panic/concurrent-panic.rs b/src/tools/miri/tests/pass/panic/concurrent-panic.rs index 342269c6acb..776bc2057f3 100644 --- a/src/tools/miri/tests/pass/panic/concurrent-panic.rs +++ b/src/tools/miri/tests/pass/panic/concurrent-panic.rs @@ -1,4 +1,3 @@ -//@ignore-target-windows: Condvars on Windows are not supported yet. // We are making scheduler assumptions here. //@compile-flags: -Zmiri-preemption-rate=0