diff --git a/src/tools/miri/src/concurrency/thread.rs b/src/tools/miri/src/concurrency/thread.rs index 3432f10f7a9..ac5dcbf0f4f 100644 --- a/src/tools/miri/src/concurrency/thread.rs +++ b/src/tools/miri/src/concurrency/thread.rs @@ -870,6 +870,7 @@ fn active_thread_stack_mut( this.machine.threads.active_thread_stack_mut() } + /// Set the name of the current thread. The buffer must not include the null terminator. #[inline] fn set_thread_name(&mut self, thread: ThreadId, new_thread_name: Vec) { let this = self.eval_context_mut(); diff --git a/src/tools/miri/src/helpers.rs b/src/tools/miri/src/helpers.rs index 4bc38d2dc36..f98727186c4 100644 --- a/src/tools/miri/src/helpers.rs +++ b/src/tools/miri/src/helpers.rs @@ -1,6 +1,7 @@ pub mod convert; use std::cmp; +use std::iter; use std::mem; use std::num::NonZeroUsize; use std::time::Duration; @@ -735,6 +736,7 @@ fn read_timespec( }) } + /// Read a sequence of bytes until the first null terminator. fn read_c_str<'a>(&'a self, ptr: Pointer>) -> InterpResult<'tcx, &'a [u8]> where 'tcx: 'a, @@ -761,6 +763,30 @@ fn read_c_str<'a>(&'a self, ptr: Pointer>) -> InterpResult<'t this.read_bytes_ptr_strip_provenance(ptr, len) } + /// Helper function to write a sequence of bytes with an added null-terminator, which is what + /// the Unix APIs usually handle. This function returns `Ok((false, length))` without trying + /// to write if `size` is not large enough to fit the contents of `c_str` plus a null + /// terminator. It returns `Ok((true, length))` if the writing process was successful. The + /// string length returned does include the null terminator. + fn write_c_str( + &mut self, + c_str: &[u8], + ptr: Pointer>, + size: u64, + ) -> InterpResult<'tcx, (bool, u64)> { + // If `size` is smaller or equal than `bytes.len()`, writing `bytes` plus the required null + // terminator to memory using the `ptr` pointer would cause an out-of-bounds access. + let string_length = u64::try_from(c_str.len()).unwrap(); + let string_length = string_length.checked_add(1).unwrap(); + if size < string_length { + return Ok((false, string_length)); + } + self.eval_context_mut() + .write_bytes_ptr(ptr, c_str.iter().copied().chain(iter::once(0u8)))?; + Ok((true, string_length)) + } + + /// Read a sequence of u16 until the first null terminator. fn read_wide_str(&self, mut ptr: Pointer>) -> InterpResult<'tcx, Vec> { let this = self.eval_context_ref(); let size2 = Size::from_bytes(2); @@ -783,6 +809,39 @@ fn read_wide_str(&self, mut ptr: Pointer>) -> InterpResult<'t Ok(wchars) } + /// Helper function to write a sequence of u16 with an added 0x0000-terminator, which is what + /// the Windows APIs usually handle. This function returns `Ok((false, length))` without trying + /// to write if `size` is not large enough to fit the contents of `os_string` plus a null + /// terminator. It returns `Ok((true, length))` if the writing process was successful. The + /// string length returned does include the null terminator. Length is measured in units of + /// `u16.` + fn write_wide_str( + &mut self, + wide_str: &[u16], + ptr: Pointer>, + size: u64, + ) -> InterpResult<'tcx, (bool, u64)> { + // If `size` is smaller or equal than `bytes.len()`, writing `bytes` plus the required + // 0x0000 terminator to memory would cause an out-of-bounds access. + let string_length = u64::try_from(wide_str.len()).unwrap(); + let string_length = string_length.checked_add(1).unwrap(); + if size < string_length { + return Ok((false, string_length)); + } + + // Store the UTF-16 string. + let size2 = Size::from_bytes(2); + let this = self.eval_context_mut(); + let mut alloc = this + .get_ptr_alloc_mut(ptr, size2 * string_length, Align::from_bytes(2).unwrap())? + .unwrap(); // not a ZST, so we will get a result + for (offset, wchar) in wide_str.iter().copied().chain(iter::once(0x0000)).enumerate() { + let offset = u64::try_from(offset).unwrap(); + alloc.write_scalar(alloc_range(size2 * offset, size2), Scalar::from_u16(wchar))?; + } + Ok((true, string_length)) + } + /// Check that the ABI is what we expect. fn check_abi<'a>(&self, abi: Abi, exp_abi: Abi) -> InterpResult<'a, ()> { if self.eval_context_ref().machine.enforce_abi && abi != exp_abi { diff --git a/src/tools/miri/src/shims/os_str.rs b/src/tools/miri/src/shims/os_str.rs index 407dab970ad..99b3605c601 100644 --- a/src/tools/miri/src/shims/os_str.rs +++ b/src/tools/miri/src/shims/os_str.rs @@ -1,6 +1,5 @@ use std::borrow::Cow; use std::ffi::{OsStr, OsString}; -use std::iter; use std::path::{Path, PathBuf}; #[cfg(unix)] @@ -9,7 +8,6 @@ use std::os::windows::ffi::{OsStrExt, OsStringExt}; use rustc_middle::ty::layout::LayoutOf; -use rustc_target::abi::{Align, Size}; use crate::*; @@ -100,16 +98,7 @@ fn write_os_str_to_c_str( size: u64, ) -> InterpResult<'tcx, (bool, u64)> { let bytes = os_str_to_bytes(os_str)?; - // If `size` is smaller or equal than `bytes.len()`, writing `bytes` plus the required null - // terminator to memory using the `ptr` pointer would cause an out-of-bounds access. - let string_length = u64::try_from(bytes.len()).unwrap(); - let string_length = string_length.checked_add(1).unwrap(); - if size < string_length { - return Ok((false, string_length)); - } - self.eval_context_mut() - .write_bytes_ptr(ptr, bytes.iter().copied().chain(iter::once(0u8)))?; - Ok((true, string_length)) + self.eval_context_mut().write_c_str(bytes, ptr, size) } /// Helper function to write an OsStr as a 0x0000-terminated u16-sequence, which is what @@ -140,25 +129,7 @@ fn os_str_to_u16vec<'tcx>(os_str: &OsStr) -> InterpResult<'tcx, Vec> { } let u16_vec = os_str_to_u16vec(os_str)?; - // If `size` is smaller or equal than `bytes.len()`, writing `bytes` plus the required - // 0x0000 terminator to memory would cause an out-of-bounds access. - let string_length = u64::try_from(u16_vec.len()).unwrap(); - let string_length = string_length.checked_add(1).unwrap(); - if size < string_length { - return Ok((false, string_length)); - } - - // Store the UTF-16 string. - let size2 = Size::from_bytes(2); - let this = self.eval_context_mut(); - let mut alloc = this - .get_ptr_alloc_mut(ptr, size2 * string_length, Align::from_bytes(2).unwrap())? - .unwrap(); // not a ZST, so we will get a result - for (offset, wchar) in u16_vec.into_iter().chain(iter::once(0x0000)).enumerate() { - let offset = u64::try_from(offset).unwrap(); - alloc.write_scalar(alloc_range(size2 * offset, size2), Scalar::from_u16(wchar))?; - } - Ok((true, string_length)) + self.eval_context_mut().write_wide_str(&u16_vec, ptr, size) } /// Allocate enough memory to store the given `OsStr` as a null-terminated sequence of bytes. diff --git a/src/tools/miri/src/shims/unix/linux/foreign_items.rs b/src/tools/miri/src/shims/unix/linux/foreign_items.rs index dd382fff029..c004e2292a9 100644 --- a/src/tools/miri/src/shims/unix/linux/foreign_items.rs +++ b/src/tools/miri/src/shims/unix/linux/foreign_items.rs @@ -72,6 +72,16 @@ fn emulate_foreign_item_by_name( this.pthread_setname_np(this.read_scalar(thread)?, this.read_scalar(name)?)?; this.write_scalar(res, dest)?; } + "pthread_getname_np" => { + let [thread, name, len] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + let res = this.pthread_getname_np( + this.read_scalar(thread)?, + this.read_scalar(name)?, + this.read_scalar(len)?, + )?; + this.write_scalar(res, dest)?; + } // Dynamically invoked syscalls "syscall" => { diff --git a/src/tools/miri/src/shims/unix/macos/foreign_items.rs b/src/tools/miri/src/shims/unix/macos/foreign_items.rs index 38d791fba98..0e931023e6c 100644 --- a/src/tools/miri/src/shims/unix/macos/foreign_items.rs +++ b/src/tools/miri/src/shims/unix/macos/foreign_items.rs @@ -178,6 +178,16 @@ fn emulate_foreign_item_by_name( let thread = this.pthread_self()?; this.pthread_setname_np(thread, this.read_scalar(name)?)?; } + "pthread_getname_np" => { + let [thread, name, len] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + let res = this.pthread_getname_np( + this.read_scalar(thread)?, + this.read_scalar(name)?, + this.read_scalar(len)?, + )?; + this.write_scalar(res, dest)?; + } // Incomplete shims that we "stub out" just to get pre-main initialization code to work. // These shims are enabled only when the caller is in the standard library. diff --git a/src/tools/miri/src/shims/unix/thread.rs b/src/tools/miri/src/shims/unix/thread.rs index 59474d8d10a..4320ecd389e 100644 --- a/src/tools/miri/src/shims/unix/thread.rs +++ b/src/tools/miri/src/shims/unix/thread.rs @@ -78,11 +78,35 @@ fn pthread_setname_np( let name = name.to_pointer(this)?; let name = this.read_c_str(name)?.to_owned(); + + if name.len() > 15 { + // Thread names are limited to 16 characaters, including the null terminator. + return this.eval_libc("ERANGE"); + } + this.set_thread_name(thread, name); Ok(Scalar::from_u32(0)) } + fn pthread_getname_np( + &mut self, + thread: Scalar, + name_out: Scalar, + len: Scalar, + ) -> InterpResult<'tcx, Scalar> { + let this = self.eval_context_mut(); + + let thread = ThreadId::try_from(thread.to_machine_usize(this)?).unwrap(); + let name_out = name_out.to_pointer(this)?; + let len = len.to_machine_usize(this)?; + + let name = this.get_thread_name(thread).to_owned(); + let (success, _written) = this.write_c_str(&name, name_out, len)?; + + if success { Ok(Scalar::from_u32(0)) } else { this.eval_libc("ERANGE") } + } + fn sched_yield(&mut self) -> InterpResult<'tcx, i32> { let this = self.eval_context_mut(); diff --git a/src/tools/miri/tests/pass-dep/shims/pthreads.rs b/src/tools/miri/tests/pass-dep/shims/pthreads.rs index d062eda7e90..bbddca74754 100644 --- a/src/tools/miri/tests/pass-dep/shims/pthreads.rs +++ b/src/tools/miri/tests/pass-dep/shims/pthreads.rs @@ -1,10 +1,14 @@ //@ignore-target-windows: No libc on Windows +#![feature(cstr_from_bytes_until_nul)] +use std::ffi::CStr; +use std::thread; fn main() { test_mutex_libc_init_recursive(); test_mutex_libc_init_normal(); test_mutex_libc_init_errorcheck(); test_rwlock_libc_static_initializer(); + test_named_thread_truncation(); #[cfg(any(target_os = "linux"))] test_mutex_libc_static_initializer_recursive(); @@ -125,3 +129,24 @@ fn test_rwlock_libc_static_initializer() { assert_eq!(libc::pthread_rwlock_destroy(rw.get()), 0); } } + +fn test_named_thread_truncation() { + let long_name = std::iter::once("test_named_thread_truncation") + .chain(std::iter::repeat(" yada").take(100)) + .collect::(); + + let result = thread::Builder::new().name(long_name.clone()).spawn(move || { + // Rust remembers the full thread name itself. + assert_eq!(thread::current().name(), Some(long_name.as_str())); + + // But the system is limited -- make sure we successfully set a truncation. + let mut buf = vec![0u8; long_name.len() + 1]; + unsafe { + libc::pthread_getname_np(libc::pthread_self(), buf.as_mut_ptr().cast(), buf.len()); + } + let cstr = CStr::from_bytes_until_nul(&buf).unwrap(); + assert!(cstr.to_bytes().len() >= 15); // POSIX seems to promise at least 15 chars + assert!(long_name.as_bytes().starts_with(cstr.to_bytes())); + }); + result.unwrap().join().unwrap(); +}