diff --git a/src/tools/miri/src/concurrency/thread.rs b/src/tools/miri/src/concurrency/thread.rs index 3c5a6786bd1..5e6fcbde69a 100644 --- a/src/tools/miri/src/concurrency/thread.rs +++ b/src/tools/miri/src/concurrency/thread.rs @@ -32,10 +32,12 @@ pub enum SchedulingAction { /// Timeout callbacks can be created by synchronization primitives to tell the /// scheduler that they should be called once some period of time passes. -pub trait TimeoutCallback<'mir, 'tcx>: VisitMachineValues + 'tcx { +pub trait MachineCallback<'mir, 'tcx>: VisitMachineValues { fn call(&self, ecx: &mut InterpCx<'mir, 'tcx, MiriMachine<'mir, 'tcx>>) -> InterpResult<'tcx>; } +type TimeoutCallback<'mir, 'tcx> = Box + 'tcx>; + /// A thread identifier. #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)] pub struct ThreadId(u32); @@ -252,7 +254,7 @@ struct TimeoutCallbackInfo<'mir, 'tcx> { /// The callback should be called no earlier than this time. call_time: Time, /// The called function. - callback: Box>, + callback: TimeoutCallback<'mir, 'tcx>, } impl<'mir, 'tcx> std::fmt::Debug for TimeoutCallbackInfo<'mir, 'tcx> { @@ -303,10 +305,10 @@ impl VisitMachineValues for ThreadManager<'_, '_> { let ThreadManager { threads, thread_local_alloc_ids, + timeout_callbacks, active_thread: _, yield_active_thread: _, sync: _, - timeout_callbacks: _, } = self; for thread in threads { @@ -315,8 +317,9 @@ impl VisitMachineValues for ThreadManager<'_, '_> { for ptr in thread_local_alloc_ids.borrow().values().copied() { visit.visit(ptr); } - // FIXME: Do we need to do something for TimeoutCallback? That's a Box, not sure what - // to do. + for callback in timeout_callbacks.values() { + callback.callback.visit_machine_values(visit); + } } } @@ -542,7 +545,7 @@ impl<'mir, 'tcx: 'mir> ThreadManager<'mir, 'tcx> { &mut self, thread: ThreadId, call_time: Time, - callback: Box>, + callback: TimeoutCallback<'mir, 'tcx>, ) { self.timeout_callbacks .try_insert(thread, TimeoutCallbackInfo { call_time, callback }) @@ -558,7 +561,7 @@ impl<'mir, 'tcx: 'mir> ThreadManager<'mir, 'tcx> { fn get_ready_callback( &mut self, clock: &Clock, - ) -> Option<(ThreadId, Box>)> { + ) -> Option<(ThreadId, TimeoutCallback<'mir, 'tcx>)> { // We iterate over all threads in the order of their indices because // this allows us to have a deterministic scheduler. for thread in self.threads.indices() { @@ -931,7 +934,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { &mut self, thread: ThreadId, call_time: Time, - callback: Box>, + callback: TimeoutCallback<'mir, 'tcx>, ) { let this = self.eval_context_mut(); if !this.machine.communicate() && matches!(call_time, Time::RealTime(..)) { diff --git a/src/tools/miri/src/shims/time.rs b/src/tools/miri/src/shims/time.rs index 43792024e2c..05eff3dfd59 100644 --- a/src/tools/miri/src/shims/time.rs +++ b/src/tools/miri/src/shims/time.rs @@ -1,6 +1,6 @@ use std::time::{Duration, SystemTime}; -use crate::concurrency::thread::TimeoutCallback; +use crate::concurrency::thread::MachineCallback; use crate::*; /// Returns the time elapsed between the provided time and the unix epoch as a `Duration`. @@ -257,7 +257,7 @@ impl VisitMachineValues for Callback { fn visit_machine_values(&self, _visit: &mut ProvenanceVisitor) {} } -impl<'mir, 'tcx: 'mir> TimeoutCallback<'mir, 'tcx> for Callback { +impl<'mir, 'tcx: 'mir> MachineCallback<'mir, 'tcx> for Callback { fn call(&self, ecx: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> { ecx.unblock_thread(self.active_thread); Ok(()) diff --git a/src/tools/miri/src/shims/unix/linux/sync.rs b/src/tools/miri/src/shims/unix/linux/sync.rs index 784fa12d18a..bbfb1c34db7 100644 --- a/src/tools/miri/src/shims/unix/linux/sync.rs +++ b/src/tools/miri/src/shims/unix/linux/sync.rs @@ -1,4 +1,4 @@ -use crate::concurrency::thread::{Time, TimeoutCallback}; +use crate::concurrency::thread::{MachineCallback, Time}; use crate::*; use rustc_target::abi::{Align, Size}; use std::time::SystemTime; @@ -268,7 +268,7 @@ impl<'tcx> VisitMachineValues for Callback<'tcx> { } } -impl<'mir, 'tcx: 'mir> TimeoutCallback<'mir, 'tcx> for Callback<'tcx> { +impl<'mir, 'tcx: 'mir> MachineCallback<'mir, 'tcx> for Callback<'tcx> { fn call(&self, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> { this.unblock_thread(self.thread); this.futex_remove_waiter(self.addr_usize, self.thread); diff --git a/src/tools/miri/src/shims/unix/sync.rs b/src/tools/miri/src/shims/unix/sync.rs index cdb3cdc4b9a..72b71ada8e0 100644 --- a/src/tools/miri/src/shims/unix/sync.rs +++ b/src/tools/miri/src/shims/unix/sync.rs @@ -3,7 +3,7 @@ use std::time::SystemTime; use rustc_hir::LangItem; use rustc_middle::ty::{layout::TyAndLayout, query::TyCtxtAt, Ty}; -use crate::concurrency::thread::{Time, TimeoutCallback}; +use crate::concurrency::thread::{MachineCallback, Time}; use crate::*; // pthread_mutexattr_t is either 4 or 8 bytes, depending on the platform. @@ -901,7 +901,7 @@ impl<'tcx> VisitMachineValues for Callback<'tcx> { } } -impl<'mir, 'tcx: 'mir> TimeoutCallback<'mir, 'tcx> for Callback<'tcx> { +impl<'mir, 'tcx: 'mir> MachineCallback<'mir, 'tcx> for Callback<'tcx> { fn call(&self, ecx: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> { // We are not waiting for the condvar any more, wait for the // mutex instead.