Implement basic support for concurrency (Linux only).

This commit is contained in:
Vytautas Astrauskas 2020-03-16 16:48:44 -07:00
parent e06df3a881
commit 82f17ab917
9 changed files with 568 additions and 73 deletions

View File

@ -205,7 +205,8 @@ pub fn eval_main<'tcx>(tcx: TyCtxt<'tcx>, main_id: DefId, config: MiriConfig) ->
// Perform the main execution.
let res: InterpResult<'_, i64> = (|| {
// Main loop.
while ecx.step()? {
while ecx.schedule()? {
assert!(ecx.step()?);
ecx.process_diagnostics();
}
// Read the return code pointer *before* we run TLS destructors, to assert

View File

@ -12,6 +12,7 @@ extern crate rustc_ast;
#[macro_use] extern crate rustc_middle;
extern crate rustc_data_structures;
extern crate rustc_hir;
extern crate rustc_index;
extern crate rustc_mir;
extern crate rustc_span;
extern crate rustc_target;
@ -26,6 +27,7 @@ mod operator;
mod range_map;
mod shims;
mod stacked_borrows;
mod threads;
// Make all those symbols available in the same place as our own.
pub use rustc_mir::interpret::*;
@ -60,6 +62,7 @@ pub use crate::range_map::RangeMap;
pub use crate::stacked_borrows::{
EvalContextExt as StackedBorEvalContextExt, Item, Permission, PtrId, Stack, Stacks, Tag,
};
pub use crate::threads::EvalContextExt as ThreadsEvalContextExt;
/// Insert rustc arguments at the beginning of the argument list that Miri wants to be
/// set per default, for maximal validation power.

View File

@ -26,6 +26,8 @@ use rustc_target::abi::{LayoutOf, Size};
use crate::*;
pub use crate::threads::{ThreadId, ThreadSet, ThreadLocalStorage};
// Some global facts about the emulated machine.
pub const PAGE_SIZE: u64 = 4 * 1024; // FIXME: adjust to target architecture
pub const STACK_ADDR: u64 = 32 * PAGE_SIZE; // not really about the "stack", but where we start assigning integer addresses to allocations
@ -107,6 +109,7 @@ pub struct AllocExtra {
pub struct MemoryExtra {
pub stacked_borrows: Option<stacked_borrows::MemoryExtra>,
pub intptrcast: intptrcast::MemoryExtra,
pub tls: ThreadLocalStorage,
/// Mapping extern static names to their canonical allocation.
extern_statics: FxHashMap<Symbol, AllocId>,
@ -143,6 +146,7 @@ impl MemoryExtra {
rng: RefCell::new(rng),
tracked_alloc_id,
check_alignment,
tls: Default::default(),
}
}
@ -251,8 +255,8 @@ pub struct Evaluator<'mir, 'tcx> {
/// The "time anchor" for this machine's monotone clock (for `Instant` simulation).
pub(crate) time_anchor: Instant,
/// The call stack.
pub(crate) stack: Vec<Frame<'mir, 'tcx, Tag, FrameData<'tcx>>>,
/// The set of threads.
pub(crate) threads: ThreadSet<'mir, 'tcx>,
/// Precomputed `TyLayout`s for primitive data types that are commonly used inside Miri.
pub(crate) layouts: PrimitiveLayouts<'tcx>,
@ -282,7 +286,7 @@ impl<'mir, 'tcx> Evaluator<'mir, 'tcx> {
panic_payload: None,
time_anchor: Instant::now(),
layouts,
stack: Vec::default(),
threads: Default::default(),
}
}
}
@ -326,6 +330,19 @@ impl<'mir, 'tcx> Machine<'mir, 'tcx> for Evaluator<'mir, 'tcx> {
memory_extra.check_alignment
}
#[inline(always)]
fn stack<'a>(
ecx: &'a InterpCx<'mir, 'tcx, Self>
) -> &'a [Frame<'mir, 'tcx, Self::PointerTag, Self::FrameExtra>] {
ecx.active_thread_stack()
}
fn stack_mut<'a>(
ecx: &'a mut InterpCx<'mir, 'tcx, Self>
) -> &'a mut Vec<Frame<'mir, 'tcx, Self::PointerTag, Self::FrameExtra>> {
ecx.active_thread_stack_mut()
}
#[inline(always)]
fn enforce_validity(ecx: &InterpCx<'mir, 'tcx, Self>) -> bool {
ecx.machine.validate
@ -418,29 +435,39 @@ impl<'mir, 'tcx> Machine<'mir, 'tcx> for Evaluator<'mir, 'tcx> {
fn canonical_alloc_id(mem: &Memory<'mir, 'tcx, Self>, id: AllocId) -> AllocId {
let tcx = mem.tcx;
// Figure out if this is an extern static, and if yes, which one.
let def_id = match tcx.alloc_map.lock().get(id) {
Some(GlobalAlloc::Static(def_id)) if tcx.is_foreign_item(def_id) => def_id,
let alloc = tcx.alloc_map.lock().get(id);
match alloc {
Some(GlobalAlloc::Static(def_id)) if tcx.is_foreign_item(def_id) => {
// Figure out if this is an extern static, and if yes, which one.
let attrs = tcx.get_attrs(def_id);
let link_name = match attr::first_attr_value_str_by_name(&attrs, sym::link_name) {
Some(name) => name,
None => tcx.item_name(def_id),
};
// Check if we know this one.
if let Some(canonical_id) = mem.extra.extern_statics.get(&link_name) {
trace!("canonical_alloc_id: {:?} ({}) -> {:?}", id, link_name, canonical_id);
*canonical_id
} else {
// Return original id; `Memory::get_static_alloc` will throw an error.
id
}
},
Some(GlobalAlloc::Static(def_id)) if tcx.has_attr(def_id, sym::thread_local) => {
// We have a thread local, so we need to get a unique allocation id for it.
mem.extra.tls.get_or_register_allocation(*tcx, id)
},
_ => {
// No need to canonicalize anything.
return id;
id
}
};
let attrs = tcx.get_attrs(def_id);
let link_name = match attr::first_attr_value_str_by_name(&attrs, sym::link_name) {
Some(name) => name,
None => tcx.item_name(def_id),
};
// Check if we know this one.
if let Some(canonical_id) = mem.extra.extern_statics.get(&link_name) {
trace!("canonical_alloc_id: {:?} ({}) -> {:?}", id, link_name, canonical_id);
*canonical_id
} else {
// Return original id; `Memory::get_static_alloc` will throw an error.
id
}
}
fn resolve_thread_local_allocation_id(extra: &Self::MemoryExtra, id: AllocId) -> AllocId {
extra.tls.resolve_allocation(id)
}
fn init_allocation_extra<'b>(
memory_extra: &MemoryExtra,
id: AllocId,

View File

@ -6,6 +6,7 @@ use std::convert::TryFrom;
use log::trace;
use crate::*;
use rustc_index::vec::Idx;
use rustc_middle::mir;
use rustc_target::abi::{Align, LayoutOf, Size};
@ -221,13 +222,15 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
}
"pthread_getspecific" => {
let key = this.force_bits(this.read_scalar(args[0])?.not_undef()?, args[0].layout.size)?;
let ptr = this.machine.tls.load_tls(key, this)?;
let active_thread = this.get_active_thread()?;
let ptr = this.machine.tls.load_tls(key, active_thread, this)?;
this.write_scalar(ptr, dest)?;
}
"pthread_setspecific" => {
let key = this.force_bits(this.read_scalar(args[0])?.not_undef()?, args[0].layout.size)?;
let active_thread = this.get_active_thread()?;
let new_ptr = this.read_scalar(args[1])?.not_undef()?;
this.machine.tls.store_tls(key, this.test_null(new_ptr)?)?;
this.machine.tls.store_tls(key, active_thread, this.test_null(new_ptr)?)?;
// Return success (`0`).
this.write_null(dest)?;
@ -291,11 +294,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
this.write_scalar(Scalar::from_i32(result), dest)?;
}
// Better error for attempts to create a thread
"pthread_create" => {
throw_unsup_format!("Miri does not support threading");
}
// Miscellaneous
"isatty" => {
let _fd = this.read_scalar(args[0])?.to_i32()?;
@ -316,7 +314,94 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
this.write_null(dest)?;
}
// Incomplete shims that we "stub out" just to get pre-main initialization code to work.
// Threading
"pthread_create" => {
assert_eq!(args.len(), 4);
let func = args[2];
let fn_ptr = this.read_scalar(func)?.not_undef()?;
let fn_val = this.memory.get_fn(fn_ptr)?;
let instance = match fn_val {
rustc_mir::interpret::FnVal::Instance(instance) => instance,
_ => unreachable!(),
};
let thread_info_place = this.deref_operand(args[0])?;
let thread_info_type = args[0].layout.ty
.builtin_deref(true)
.ok_or_else(|| err_ub_format!(
"wrong signature used for `pthread_create`: first argument must be a raw pointer."
))?
.ty;
let thread_info_layout = this.layout_of(thread_info_type)?;
let func_arg = match *args[3] {
rustc_mir::interpret::Operand::Immediate(immediate) => immediate,
_ => unreachable!(),
};
let func_args = [func_arg];
let ret_place =
this.allocate(this.layout_of(this.tcx.types.usize)?, MiriMemoryKind::Machine.into());
let new_thread_id = this.create_thread()?;
let old_thread_id = this.set_active_thread(new_thread_id)?;
this.call_function(
instance,
&func_args[..],
Some(ret_place.into()),
StackPopCleanup::None { cleanup: true },
)?;
this.set_active_thread(old_thread_id)?;
this.write_scalar(
Scalar::from_uint(new_thread_id.index() as u128, thread_info_layout.size),
thread_info_place.into(),
)?;
// Return success (`0`).
this.write_null(dest)?;
}
"pthread_join" => {
assert_eq!(args.len(), 2);
assert!(
this.is_null(this.read_scalar(args[1])?.not_undef()?)?,
"Miri supports pthread_join only with retval==NULL"
);
let thread = this.read_scalar(args[0])?.not_undef()?.to_machine_usize(this)?;
this.join_thread(thread.into())?;
// Return success (`0`).
this.write_null(dest)?;
}
"pthread_detach" => {
let thread = this.read_scalar(args[0])?.not_undef()?.to_machine_usize(this)?;
this.detach_thread(thread.into())?;
// Return success (`0`).
this.write_null(dest)?;
}
"pthread_attr_getguardsize" => {
assert_eq!(args.len(), 2);
let guard_size = this.deref_operand(args[1])?;
let guard_size_type = args[1].layout.ty
.builtin_deref(true)
.ok_or_else(|| err_ub_format!(
"wrong signature used for `pthread_attr_getguardsize`: first argument must be a raw pointer."
))?
.ty;
let guard_size_layout = this.layout_of(guard_size_type)?;
this.write_scalar(Scalar::from_uint(crate::PAGE_SIZE, guard_size_layout.size), guard_size.into())?;
// Return success (`0`).
this.write_null(dest)?;
}
"prctl" => {
let option = this.read_scalar(args[0])?.not_undef()?.to_i32()?;
assert_eq!(option, 0xf, "Miri supports only PR_SET_NAME");
// Return success (`0`).
this.write_null(dest)?;
}
// Incomplete shims that we "stub out" just to get pre-main initialziation code to work.
// These shims are enabled only when the caller is in the standard library.
| "pthread_attr_init"
| "pthread_attr_destroy"

View File

@ -144,13 +144,15 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
}
"TlsGetValue" => {
let key = u128::from(this.read_scalar(args[0])?.to_u32()?);
let ptr = this.machine.tls.load_tls(key, this)?;
let active_thread = this.get_active_thread()?;
let ptr = this.machine.tls.load_tls(key, active_thread, this)?;
this.write_scalar(ptr, dest)?;
}
"TlsSetValue" => {
let key = u128::from(this.read_scalar(args[0])?.to_u32()?);
let active_thread = this.get_active_thread()?;
let new_ptr = this.read_scalar(args[1])?.not_undef()?;
this.machine.tls.store_tls(key, this.test_null(new_ptr)?)?;
this.machine.tls.store_tls(key, active_thread, this.test_null(new_ptr)?)?;
// Return success (`1`).
this.write_scalar(Scalar::from_i32(1), dest)?;

View File

@ -1,22 +1,24 @@
//! Implement thread-local storage.
use std::collections::BTreeMap;
use std::collections::btree_map::Entry;
use log::trace;
use rustc_middle::ty;
use rustc_target::abi::{Size, HasDataLayout};
use crate::{HelpersEvalContextExt, InterpResult, MPlaceTy, Scalar, StackPopCleanup, Tag};
use crate::{HelpersEvalContextExt, ThreadsEvalContextExt, InterpResult, MPlaceTy, Scalar, StackPopCleanup, Tag};
use crate::machine::ThreadId;
pub type TlsKey = u128;
#[derive(Copy, Clone, Debug)]
#[derive(Clone, Debug)]
pub struct TlsEntry<'tcx> {
/// The data for this key. None is used to represent NULL.
/// (We normalize this early to avoid having to do a NULL-ptr-test each time we access the data.)
/// Will eventually become a map from thread IDs to `Scalar`s, if we ever support more than one thread.
data: Option<Scalar<Tag>>,
data: BTreeMap<ThreadId, Scalar<Tag>>,
dtor: Option<ty::Instance<'tcx>>,
}
@ -52,7 +54,7 @@ impl<'tcx> TlsData<'tcx> {
pub fn create_tls_key(&mut self, dtor: Option<ty::Instance<'tcx>>, max_size: Size) -> InterpResult<'tcx, TlsKey> {
let new_key = self.next_key;
self.next_key += 1;
self.keys.insert(new_key, TlsEntry { data: None, dtor }).unwrap_none();
self.keys.insert(new_key, TlsEntry { data: Default::default(), dtor }).unwrap_none();
trace!("New TLS key allocated: {} with dtor {:?}", new_key, dtor);
if max_size.bits() < 128 && new_key >= (1u128 << max_size.bits() as u128) {
@ -74,22 +76,34 @@ impl<'tcx> TlsData<'tcx> {
pub fn load_tls(
&self,
key: TlsKey,
thread_id: ThreadId,
cx: &impl HasDataLayout,
) -> InterpResult<'tcx, Scalar<Tag>> {
match self.keys.get(&key) {
Some(&TlsEntry { data, .. }) => {
trace!("TLS key {} loaded: {:?}", key, data);
Ok(data.unwrap_or_else(|| Scalar::null_ptr(cx).into()))
Some(TlsEntry { data, .. }) => {
let value = data.get(&thread_id).cloned();
trace!("TLS key {} for thread {:?} loaded: {:?}", key, thread_id, value);
Ok(value.unwrap_or_else(|| Scalar::null_ptr(cx).into()))
}
None => throw_ub_format!("loading from a non-existing TLS key: {}", key),
}
}
pub fn store_tls(&mut self, key: TlsKey, new_data: Option<Scalar<Tag>>) -> InterpResult<'tcx> {
pub fn store_tls(
&mut self,
key: TlsKey, thread_id: ThreadId, new_data: Option<Scalar<Tag>>) -> InterpResult<'tcx> {
match self.keys.get_mut(&key) {
Some(TlsEntry { data, .. }) => {
trace!("TLS key {} stored: {:?}", key, new_data);
*data = new_data;
match new_data {
Some(ptr) => {
trace!("TLS key {} for thread {:?} stored: {:?}", key, thread_id, ptr);
data.insert(thread_id, ptr);
}
None => {
trace!("TLS key {} for thread {:?} removed", key, thread_id);
data.remove(&thread_id);
}
}
Ok(())
}
None => throw_ub_format!("storing to a non-existing TLS key: {}", key),
@ -131,7 +145,8 @@ impl<'tcx> TlsData<'tcx> {
fn fetch_tls_dtor(
&mut self,
key: Option<TlsKey>,
) -> Option<(ty::Instance<'tcx>, Scalar<Tag>, TlsKey)> {
thread_id: ThreadId,
) -> Option<(ty::Instance<'tcx>, ThreadId, Scalar<Tag>, TlsKey)> {
use std::collections::Bound::*;
let thread_local = &mut self.keys;
@ -142,12 +157,15 @@ impl<'tcx> TlsData<'tcx> {
for (&key, TlsEntry { data, dtor }) in
thread_local.range_mut((start, Unbounded))
{
if let Some(data_scalar) = *data {
if let Some(dtor) = dtor {
let ret = Some((*dtor, data_scalar, key));
*data = None;
return ret;
match data.entry(thread_id) {
Entry::Occupied(entry) => {
let (thread_id, data_scalar) = entry.remove_entry();
if let Some(dtor) = dtor {
let ret = Some((dtor, thread_id, data_scalar, key));
return ret;
}
}
Entry::Vacant(_) => {}
}
}
None
@ -156,6 +174,7 @@ impl<'tcx> TlsData<'tcx> {
impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriEvalContext<'mir, 'tcx> {}
pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx> {
/// Run TLS destructors for the currently active thread.
fn run_tls_dtors(&mut self) -> InterpResult<'tcx> {
let this = self.eval_context_mut();
assert!(!this.machine.tls.dtors_running, "running TLS dtors twice");
@ -204,28 +223,31 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
}
// Now run the "keyed" destructors.
let mut dtor = this.machine.tls.fetch_tls_dtor(None);
while let Some((instance, ptr, key)) = dtor {
trace!("Running TLS dtor {:?} on {:?}", instance, ptr);
assert!(!this.is_null(ptr).unwrap(), "data can't be NULL when dtor is called!");
for thread_id in this.get_all_thread_ids() {
this.set_active_thread(thread_id)?;
let mut dtor = this.machine.tls.fetch_tls_dtor(None, thread_id);
while let Some((instance, thread_id, ptr, key)) = dtor {
trace!("Running TLS dtor {:?} on {:?} at {:?}", instance, ptr, thread_id);
assert!(!this.is_null(ptr).unwrap(), "Data can't be NULL when dtor is called!");
let ret_place = MPlaceTy::dangling(this.machine.layouts.unit, this).into();
this.call_function(
instance,
&[ptr.into()],
Some(ret_place),
StackPopCleanup::None { cleanup: true },
)?;
let ret_place = MPlaceTy::dangling(this.layout_of(this.tcx.mk_unit())?, this).into();
this.call_function(
instance,
&[ptr.into()],
Some(ret_place),
StackPopCleanup::None { cleanup: true },
)?;
// step until out of stackframes
this.run()?;
// step until out of stackframes
this.run()?;
// Fetch next dtor after `key`.
dtor = match this.machine.tls.fetch_tls_dtor(Some(key)) {
dtor @ Some(_) => dtor,
// We ran each dtor once, start over from the beginning.
None => this.machine.tls.fetch_tls_dtor(None),
};
// Fetch next dtor after `key`.
dtor = match this.machine.tls.fetch_tls_dtor(Some(key), thread_id) {
dtor @ Some(_) => dtor,
// We ran each dtor once, start over from the beginning.
None => this.machine.tls.fetch_tls_dtor(None, thread_id),
};
}
}
Ok(())
}

303
src/threads.rs Normal file
View File

@ -0,0 +1,303 @@
//! Implements threads.
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use log::trace;
use rustc_middle::ty;
use rustc_data_structures::fx::FxHashMap;
use rustc_index::vec::{Idx, IndexVec};
use crate::*;
/// A thread identifier.
#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
pub struct ThreadId(usize);
impl Idx for ThreadId {
fn new(idx: usize) -> Self {
ThreadId(idx)
}
fn index(self) -> usize {
self.0
}
}
impl From<u64> for ThreadId {
fn from(id: u64) -> Self {
Self(id as usize)
}
}
/// The state of a thread.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum ThreadState {
/// The thread is enabled and can be executed.
Enabled,
/// The thread tried to join the specified thread and is blocked until that
/// thread terminates.
Blocked(ThreadId),
/// The thread has terminated its execution (we do not delete terminated
/// threads.)
Terminated,
}
/// A thread.
pub struct Thread<'mir, 'tcx> {
state: ThreadState,
/// The virtual call stack.
stack: Vec<Frame<'mir, 'tcx, Tag, FrameData<'tcx>>>,
/// Is the thread detached?
///
/// A thread is detached if its join handle was destroyed and no other
/// thread can join it.
detached: bool,
}
impl<'mir, 'tcx> Thread<'mir, 'tcx> {
/// Check if the thread terminated. If yes, change the state to terminated
/// and return `true`.
fn check_terminated(&mut self) -> bool {
if self.state == ThreadState::Enabled {
if self.stack.is_empty() {
self.state = ThreadState::Terminated;
return true;
}
}
false
}
}
impl<'mir, 'tcx> std::fmt::Debug for Thread<'mir, 'tcx> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.state)
}
}
impl<'mir, 'tcx> Default for Thread<'mir, 'tcx> {
fn default() -> Self {
Self { state: ThreadState::Enabled, stack: Vec::new(), detached: false }
}
}
/// A set of threads.
#[derive(Debug)]
pub struct ThreadSet<'mir, 'tcx> {
/// Identifier of the currently active thread.
active_thread: ThreadId,
/// Threads used in the program.
///
/// Note that this vector also contains terminated threads.
threads: IndexVec<ThreadId, Thread<'mir, 'tcx>>,
/// List of threads that just terminated. TODO: Cleanup.
terminated_threads: Vec<ThreadId>,
}
impl<'mir, 'tcx> Default for ThreadSet<'mir, 'tcx> {
fn default() -> Self {
let mut threads = IndexVec::new();
threads.push(Default::default());
Self {
active_thread: ThreadId::new(0),
threads: threads,
terminated_threads: Default::default(),
}
}
}
impl<'mir, 'tcx: 'mir> ThreadSet<'mir, 'tcx> {
/// Borrow the stack of the active thread.
fn active_thread_stack(&self) -> &[Frame<'mir, 'tcx, Tag, FrameData<'tcx>>] {
&self.threads[self.active_thread].stack
}
/// Mutably borrow the stack of the active thread.
fn active_thread_stack_mut(&mut self) -> &mut Vec<Frame<'mir, 'tcx, Tag, FrameData<'tcx>>> {
&mut self.threads[self.active_thread].stack
}
/// Create a new thread and returns its id.
fn create_thread(&mut self) -> ThreadId {
let new_thread_id = ThreadId::new(self.threads.len());
self.threads.push(Default::default());
new_thread_id
}
/// Set an active thread and return the id of the thread that was active before.
fn set_active_thread(&mut self, id: ThreadId) -> ThreadId {
let active_thread_id = self.active_thread;
self.active_thread = id;
assert!(self.active_thread.index() < self.threads.len());
active_thread_id
}
/// Get the id of the currently active thread.
fn get_active_thread(&self) -> ThreadId {
self.active_thread
}
/// Mark the thread as detached, which means that no other thread will try
/// to join it and the thread is responsible for cleaning up.
fn detach_thread(&mut self, id: ThreadId) {
self.threads[id].detached = true;
}
/// Mark that the active thread tries to join the thread with `joined_thread_id`.
fn join_thread(&mut self, joined_thread_id: ThreadId) {
assert!(!self.threads[joined_thread_id].detached, "Bug: trying to join a detached thread.");
assert_ne!(joined_thread_id, self.active_thread, "Bug: trying to join itself");
assert!(
self.threads
.iter()
.all(|thread| thread.state != ThreadState::Blocked(joined_thread_id)),
"Bug: multiple threads try to join the same thread."
);
if self.threads[joined_thread_id].state != ThreadState::Terminated {
// The joined thread is still running, we need to wait for it.
self.threads[self.active_thread].state = ThreadState::Blocked(joined_thread_id);
trace!(
"{:?} blocked on {:?} when trying to join",
self.active_thread,
joined_thread_id
);
}
}
/// Get ids of all threads ever allocated.
fn get_all_thread_ids(&mut self) -> Vec<ThreadId> {
(0..self.threads.len()).map(ThreadId::new).collect()
}
/// Decide which thread to run next.
///
/// Returns `false` if all threads terminated.
fn schedule(&mut self) -> InterpResult<'tcx, bool> {
if self.threads[self.active_thread].check_terminated() {
// Check if we need to unblock any threads.
for (i, thread) in self.threads.iter_enumerated_mut() {
if thread.state == ThreadState::Blocked(self.active_thread) {
trace!("unblocking {:?} because {:?} terminated", i, self.active_thread);
thread.state = ThreadState::Enabled;
}
}
}
if self.threads[self.active_thread].state == ThreadState::Enabled {
return Ok(true);
}
if let Some(enabled_thread) =
self.threads.iter().position(|thread| thread.state == ThreadState::Enabled)
{
self.active_thread = ThreadId::new(enabled_thread);
return Ok(true);
}
if self.threads.iter().all(|thread| thread.state == ThreadState::Terminated) {
Ok(false)
} else {
throw_machine_stop!(TerminationInfo::Abort(Some(format!("execution deadlocked"))))
}
}
}
/// In Rust, a thread local variable is just a specially marked static. To
/// ensure a property that each memory allocation has a globally unique
/// allocation identifier, we create a fresh allocation id for each thread. This
/// data structure keeps the track of the created allocation identifiers and
/// their relation to the original static allocations.
#[derive(Clone, Debug, Default)]
pub struct ThreadLocalStorage {
/// A map from a thread local allocation identifier to the static from which
/// it was created.
thread_local_origin: RefCell<FxHashMap<AllocId, AllocId>>,
/// A map from a thread local static and thread id to the unique thread
/// local allocation.
thread_local_allocations: RefCell<FxHashMap<(AllocId, ThreadId), AllocId>>,
/// The currently active thread.
active_thread: Option<ThreadId>,
}
impl ThreadLocalStorage {
/// For static allocation identifier `original_id` get a thread local
/// allocation identifier. If it is not allocated yet, allocate.
pub fn get_or_register_allocation(&self, tcx: ty::TyCtxt<'_>, original_id: AllocId) -> AllocId {
match self
.thread_local_allocations
.borrow_mut()
.entry((original_id, self.active_thread.unwrap()))
{
Entry::Occupied(entry) => *entry.get(),
Entry::Vacant(entry) => {
let fresh_id = tcx.alloc_map.lock().reserve();
entry.insert(fresh_id);
self.thread_local_origin.borrow_mut().insert(fresh_id, original_id);
trace!(
"get_or_register_allocation(original_id={:?}) -> {:?}",
original_id,
fresh_id
);
fresh_id
}
}
}
/// For thread local allocation identifier `alloc_id`, retrieve the original
/// static allocation identifier from which it was created.
pub fn resolve_allocation(&self, alloc_id: AllocId) -> AllocId {
trace!("resolve_allocation(alloc_id: {:?})", alloc_id);
if let Some(original_id) = self.thread_local_origin.borrow().get(&alloc_id) {
trace!("resolve_allocation(alloc_id: {:?}) -> {:?}", alloc_id, original_id);
*original_id
} else {
alloc_id
}
}
/// Set which thread is currently active.
fn set_active_thread(&mut self, active_thread: ThreadId) {
self.active_thread = Some(active_thread);
}
}
impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriEvalContext<'mir, 'tcx> {}
pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx> {
fn create_thread(&mut self) -> InterpResult<'tcx, ThreadId> {
let this = self.eval_context_mut();
Ok(this.machine.threads.create_thread())
}
fn detach_thread(&mut self, thread_id: ThreadId) -> InterpResult<'tcx> {
let this = self.eval_context_mut();
this.machine.threads.detach_thread(thread_id);
Ok(())
}
fn join_thread(&mut self, joined_thread_id: ThreadId) -> InterpResult<'tcx> {
let this = self.eval_context_mut();
this.machine.threads.join_thread(joined_thread_id);
Ok(())
}
fn set_active_thread(&mut self, thread_id: ThreadId) -> InterpResult<'tcx, ThreadId> {
let this = self.eval_context_mut();
this.memory.extra.tls.set_active_thread(thread_id);
Ok(this.machine.threads.set_active_thread(thread_id))
}
fn get_active_thread(&self) -> InterpResult<'tcx, ThreadId> {
let this = self.eval_context_ref();
Ok(this.machine.threads.get_active_thread())
}
fn active_thread_stack(&self) -> &[Frame<'mir, 'tcx, Tag, FrameData<'tcx>>] {
let this = self.eval_context_ref();
this.machine.threads.active_thread_stack()
}
fn active_thread_stack_mut(&mut self) -> &mut Vec<Frame<'mir, 'tcx, Tag, FrameData<'tcx>>> {
let this = self.eval_context_mut();
this.machine.threads.active_thread_stack_mut()
}
fn get_all_thread_ids(&mut self) -> Vec<ThreadId> {
let this = self.eval_context_mut();
this.machine.threads.get_all_thread_ids()
}
/// Decide which thread to run next.
///
/// Returns `false` if all threads terminated.
fn schedule(&mut self) -> InterpResult<'tcx, bool> {
let this = self.eval_context_mut();
// Find the next thread to run.
if this.machine.threads.schedule()? {
let active_thread = this.machine.threads.get_active_thread();
this.memory.extra.tls.set_active_thread(active_thread);
Ok(true)
} else {
Ok(false)
}
}
}

View File

@ -1,7 +0,0 @@
use std::thread;
// error-pattern: Miri does not support threading
fn main() {
thread::spawn(|| {});
}

View File

@ -0,0 +1,59 @@
use std::thread;
fn create_and_detach() {
thread::spawn(|| ());
}
fn create_and_join() {
thread::spawn(|| ()).join().unwrap();
}
fn create_and_get_result() {
let nine = thread::spawn(|| 5 + 4).join().unwrap();
assert_eq!(nine, 9);
}
fn create_and_leak_result() {
thread::spawn(|| 7);
}
fn create_nested_and_detach() {
thread::spawn(|| {
thread::spawn(|| ());
});
}
fn create_nested_and_join() {
let handle = thread::spawn(|| thread::spawn(|| ()));
let handle_nested = handle.join().unwrap();
handle_nested.join().unwrap();
}
fn create_move_in() {
let x = String::from("Hello!");
thread::spawn(move || {
assert_eq!(x.len(), 6);
})
.join()
.unwrap();
}
fn create_move_out() {
let result = thread::spawn(|| {
String::from("Hello!")
})
.join()
.unwrap();
assert_eq!(result.len(), 6);
}
fn main() {
create_and_detach();
create_and_join();
create_and_get_result();
create_and_leak_result();
create_nested_and_detach();
create_nested_and_join();
create_move_in();
create_move_out();
}