use early return for race_detecting() logic

This commit is contained in:
Ralf Jung 2024-09-15 13:25:26 +02:00
parent a888905226
commit 339f68bd6c

View File

@ -1048,7 +1048,9 @@ pub fn read<'tcx>(
) -> InterpResult<'tcx> { ) -> InterpResult<'tcx> {
let current_span = machine.current_span(); let current_span = machine.current_span();
let global = machine.data_race.as_ref().unwrap(); let global = machine.data_race.as_ref().unwrap();
if global.race_detecting() { if !global.race_detecting() {
return Ok(());
}
let (index, mut thread_clocks) = global.active_thread_state_mut(&machine.threads); let (index, mut thread_clocks) = global.active_thread_state_mut(&machine.threads);
let mut alloc_ranges = self.alloc_ranges.borrow_mut(); let mut alloc_ranges = self.alloc_ranges.borrow_mut();
for (mem_clocks_range, mem_clocks) in for (mem_clocks_range, mem_clocks) in
@ -1071,9 +1073,6 @@ pub fn read<'tcx>(
} }
} }
Ok(()) Ok(())
} else {
Ok(())
}
} }
/// Detect data-races for an unsynchronized write operation. It will not perform /// Detect data-races for an unsynchronized write operation. It will not perform
@ -1091,17 +1090,16 @@ pub fn write<'tcx>(
) -> InterpResult<'tcx> { ) -> InterpResult<'tcx> {
let current_span = machine.current_span(); let current_span = machine.current_span();
let global = machine.data_race.as_mut().unwrap(); let global = machine.data_race.as_mut().unwrap();
if global.race_detecting() { if !global.race_detecting() {
return Ok(());
}
let (index, mut thread_clocks) = global.active_thread_state_mut(&machine.threads); let (index, mut thread_clocks) = global.active_thread_state_mut(&machine.threads);
for (mem_clocks_range, mem_clocks) in for (mem_clocks_range, mem_clocks) in
self.alloc_ranges.get_mut().iter_mut(access_range.start, access_range.size) self.alloc_ranges.get_mut().iter_mut(access_range.start, access_range.size)
{ {
if let Err(DataRace) = mem_clocks.write_race_detect( if let Err(DataRace) =
&mut thread_clocks, mem_clocks.write_race_detect(&mut thread_clocks, index, write_type, current_span)
index, {
write_type,
current_span,
) {
drop(thread_clocks); drop(thread_clocks);
// Report data-race // Report data-race
return Self::report_data_race( return Self::report_data_race(
@ -1116,9 +1114,6 @@ pub fn write<'tcx>(
} }
} }
Ok(()) Ok(())
} else {
Ok(())
}
} }
} }
@ -1149,7 +1144,9 @@ impl FrameState {
pub fn local_write(&self, local: mir::Local, storage_live: bool, machine: &MiriMachine<'_>) { pub fn local_write(&self, local: mir::Local, storage_live: bool, machine: &MiriMachine<'_>) {
let current_span = machine.current_span(); let current_span = machine.current_span();
let global = machine.data_race.as_ref().unwrap(); let global = machine.data_race.as_ref().unwrap();
if global.race_detecting() { if !global.race_detecting() {
return;
}
let (index, mut thread_clocks) = global.active_thread_state_mut(&machine.threads); let (index, mut thread_clocks) = global.active_thread_state_mut(&machine.threads);
// This should do the same things as `MemoryCellClocks::write_race_detect`. // This should do the same things as `MemoryCellClocks::write_race_detect`.
if !current_span.is_dummy() { if !current_span.is_dummy() {
@ -1173,12 +1170,13 @@ pub fn local_write(&self, local: mir::Local, storage_live: bool, machine: &MiriM
clocks.write_type = NaWriteType::Write; clocks.write_type = NaWriteType::Write;
} }
} }
}
pub fn local_read(&self, local: mir::Local, machine: &MiriMachine<'_>) { pub fn local_read(&self, local: mir::Local, machine: &MiriMachine<'_>) {
let current_span = machine.current_span(); let current_span = machine.current_span();
let global = machine.data_race.as_ref().unwrap(); let global = machine.data_race.as_ref().unwrap();
if global.race_detecting() { if !global.race_detecting() {
return;
}
let (index, mut thread_clocks) = global.active_thread_state_mut(&machine.threads); let (index, mut thread_clocks) = global.active_thread_state_mut(&machine.threads);
// This should do the same things as `MemoryCellClocks::read_race_detect`. // This should do the same things as `MemoryCellClocks::read_race_detect`.
if !current_span.is_dummy() { if !current_span.is_dummy() {
@ -1191,7 +1189,6 @@ pub fn local_read(&self, local: mir::Local, machine: &MiriMachine<'_>) {
let clocks = clocks.entry(local).or_insert_with(Default::default); let clocks = clocks.entry(local).or_insert_with(Default::default);
clocks.read = thread_clocks.clock[index]; clocks.read = thread_clocks.clock[index];
} }
}
pub fn local_moved_to_memory( pub fn local_moved_to_memory(
&self, &self,
@ -1200,7 +1197,9 @@ pub fn local_moved_to_memory(
machine: &MiriMachine<'_>, machine: &MiriMachine<'_>,
) { ) {
let global = machine.data_race.as_ref().unwrap(); let global = machine.data_race.as_ref().unwrap();
if global.race_detecting() { if !global.race_detecting() {
return;
}
let (index, _thread_clocks) = global.active_thread_state_mut(&machine.threads); let (index, _thread_clocks) = global.active_thread_state_mut(&machine.threads);
// Get the time the last write actually happened. This can fail to exist if // Get the time the last write actually happened. This can fail to exist if
// `race_detecting` was false when the write occurred, in that case we can backdate this // `race_detecting` was false when the write occurred, in that case we can backdate this
@ -1217,7 +1216,6 @@ pub fn local_moved_to_memory(
} }
} }
} }
}
impl<'tcx> EvalContextPrivExt<'tcx> for MiriInterpCx<'tcx> {} impl<'tcx> EvalContextPrivExt<'tcx> for MiriInterpCx<'tcx> {}
trait EvalContextPrivExt<'tcx>: MiriInterpCxExt<'tcx> { trait EvalContextPrivExt<'tcx>: MiriInterpCxExt<'tcx> {
@ -1403,8 +1401,10 @@ fn validate_atomic_op<A: Debug + Copy>(
) -> InterpResult<'tcx> { ) -> InterpResult<'tcx> {
let this = self.eval_context_ref(); let this = self.eval_context_ref();
assert!(access.is_atomic()); assert!(access.is_atomic());
if let Some(data_race) = &this.machine.data_race { let Some(data_race) = &this.machine.data_race else { return Ok(()) };
if data_race.race_detecting() { if !data_race.race_detecting() {
return Ok(());
}
let size = place.layout.size; let size = place.layout.size;
let (alloc_id, base_offset, _prov) = this.ptr_get_alloc_id(place.ptr(), 0)?; let (alloc_id, base_offset, _prov) = this.ptr_get_alloc_id(place.ptr(), 0)?;
// Load and log the atomic operation. // Load and log the atomic operation.
@ -1427,8 +1427,7 @@ fn validate_atomic_op<A: Debug + Copy>(
for (mem_clocks_range, mem_clocks) in for (mem_clocks_range, mem_clocks) in
alloc_meta.alloc_ranges.borrow_mut().iter_mut(base_offset, size) alloc_meta.alloc_ranges.borrow_mut().iter_mut(base_offset, size)
{ {
if let Err(DataRace) = op(mem_clocks, &mut thread_clocks, index, atomic) if let Err(DataRace) = op(mem_clocks, &mut thread_clocks, index, atomic) {
{
mem::drop(thread_clocks); mem::drop(thread_clocks);
return VClockAlloc::report_data_race( return VClockAlloc::report_data_race(
data_race, data_race,
@ -1453,9 +1452,7 @@ fn validate_atomic_op<A: Debug + Copy>(
// Log changes to atomic memory. // Log changes to atomic memory.
if tracing::enabled!(tracing::Level::TRACE) { if tracing::enabled!(tracing::Level::TRACE) {
for (_offset, mem_clocks) in for (_offset, mem_clocks) in alloc_meta.alloc_ranges.borrow().iter(base_offset, size) {
alloc_meta.alloc_ranges.borrow().iter(base_offset, size)
{
trace!( trace!(
"Updated atomic memory({:?}, size={}) to {:#?}", "Updated atomic memory({:?}, size={}) to {:#?}",
place.ptr(), place.ptr(),
@ -1464,8 +1461,7 @@ fn validate_atomic_op<A: Debug + Copy>(
); );
} }
} }
}
}
Ok(()) Ok(())
} }
} }