fix weak memory bug in TLS on Windows

This commit is contained in:
Ralf Jung 2024-04-23 08:22:29 +02:00
parent c8d19a92aa
commit d5d714bb34

View File

@ -141,9 +141,15 @@ unsafe fn init(&'static self) -> Key {
panic!("out of TLS indexes"); panic!("out of TLS indexes");
} }
self.key.store(key + 1, Release);
register_dtor(self); register_dtor(self);
// Release-storing the key needs to be the last thing we do.
// This is because in `fn key()`, other threads will do an acquire load of the key,
// and if that sees this write then it will entirely bypass the `InitOnce`. We thus
// need to establish synchronization through `key`. In particular that acquire load
// must happen-after the register_dtor above, to ensure the dtor actually runs!
self.key.store(key + 1, Release);
let r = c::InitOnceComplete(self.once.get(), 0, ptr::null_mut()); let r = c::InitOnceComplete(self.once.get(), 0, ptr::null_mut());
debug_assert_eq!(r, c::TRUE); debug_assert_eq!(r, c::TRUE);
@ -313,8 +319,22 @@ unsafe fn run_dtors() {
// Use acquire ordering to observe key initialization. // Use acquire ordering to observe key initialization.
let mut cur = DTORS.load(Acquire); let mut cur = DTORS.load(Acquire);
while !cur.is_null() { while !cur.is_null() {
let key = (*cur).key.load(Relaxed) - 1; let pre_key = (*cur).key.load(Acquire);
let dtor = (*cur).dtor.unwrap(); let dtor = (*cur).dtor.unwrap();
cur = (*cur).next.load(Relaxed);
// In StaticKey::init, we register the dtor before setting `key`.
// So if one thread's `run_dtors` races with another thread executing `init` on the same
// `StaticKey`, we can encounter a key of 0 here. That means this key was never
// initialized in this thread so we can safely skip it.
if pre_key == 0 {
continue;
}
// If this is non-zero, then via the `Acquire` load above we synchronized with
// everything relevant for this key. (It's not clear that this is needed, since the
// release-acquire pair on DTORS also establishes synchronization, but better safe than
// sorry.)
let key = pre_key - 1;
let ptr = c::TlsGetValue(key); let ptr = c::TlsGetValue(key);
if !ptr.is_null() { if !ptr.is_null() {
@ -322,8 +342,6 @@ unsafe fn run_dtors() {
dtor(ptr as *mut _); dtor(ptr as *mut _);
any_run = true; any_run = true;
} }
cur = (*cur).next.load(Relaxed);
} }
if !any_run { if !any_run {