Add add/sub methods that only panic with debug assertions to rustc

This mitigates the perf impact of enabling overflow checks on rustc.
The change to use overflow checks will be done in a later PR.
This commit is contained in:
Nilstrieb 2024-03-28 20:27:12 +01:00
parent c3b05c6e5b
commit 5039160c5b
6 changed files with 109 additions and 27 deletions

View File

@ -1,5 +1,8 @@
//! This is a copy of `core::hash::sip` adapted to providing 128 bit hashes. //! This is a copy of `core::hash::sip` adapted to providing 128 bit hashes.
// This code is very hot and uses lots of arithmetic, avoid overflow checks for performance.
// See https://github.com/rust-lang/rust/pull/119440#issuecomment-1874255727
use rustc_serialize::int_overflow::{DebugStrictAdd, DebugStrictSub};
use std::hash::Hasher; use std::hash::Hasher;
use std::mem::{self, MaybeUninit}; use std::mem::{self, MaybeUninit};
use std::ptr; use std::ptr;
@ -103,19 +106,19 @@ unsafe fn copy_nonoverlapping_small(src: *const u8, dst: *mut u8, count: usize)
} }
let mut i = 0; let mut i = 0;
if i + 3 < count { if i.debug_strict_add(3) < count {
ptr::copy_nonoverlapping(src.add(i), dst.add(i), 4); ptr::copy_nonoverlapping(src.add(i), dst.add(i), 4);
i += 4; i = i.debug_strict_add(4);
} }
if i + 1 < count { if i.debug_strict_add(1) < count {
ptr::copy_nonoverlapping(src.add(i), dst.add(i), 2); ptr::copy_nonoverlapping(src.add(i), dst.add(i), 2);
i += 2 i = i.debug_strict_add(2)
} }
if i < count { if i < count {
*dst.add(i) = *src.add(i); *dst.add(i) = *src.add(i);
i += 1; i = i.debug_strict_add(1);
} }
debug_assert_eq!(i, count); debug_assert_eq!(i, count);
@ -211,14 +214,14 @@ pub fn new_with_keys(key0: u64, key1: u64) -> SipHasher128 {
debug_assert!(nbuf < BUFFER_SIZE); debug_assert!(nbuf < BUFFER_SIZE);
debug_assert!(nbuf + LEN < BUFFER_WITH_SPILL_SIZE); debug_assert!(nbuf + LEN < BUFFER_WITH_SPILL_SIZE);
if nbuf + LEN < BUFFER_SIZE { if nbuf.debug_strict_add(LEN) < BUFFER_SIZE {
unsafe { unsafe {
// The memcpy call is optimized away because the size is known. // The memcpy call is optimized away because the size is known.
let dst = (self.buf.as_mut_ptr() as *mut u8).add(nbuf); let dst = (self.buf.as_mut_ptr() as *mut u8).add(nbuf);
ptr::copy_nonoverlapping(bytes.as_ptr(), dst, LEN); ptr::copy_nonoverlapping(bytes.as_ptr(), dst, LEN);
} }
self.nbuf = nbuf + LEN; self.nbuf = nbuf.debug_strict_add(LEN);
return; return;
} }
@ -265,8 +268,9 @@ pub fn new_with_keys(key0: u64, key1: u64) -> SipHasher128 {
// This function should only be called when the write fills the buffer. // This function should only be called when the write fills the buffer.
// Therefore, when LEN == 1, the new `self.nbuf` must be zero. // Therefore, when LEN == 1, the new `self.nbuf` must be zero.
// LEN is statically known, so the branch is optimized away. // LEN is statically known, so the branch is optimized away.
self.nbuf = if LEN == 1 { 0 } else { nbuf + LEN - BUFFER_SIZE }; self.nbuf =
self.processed += BUFFER_SIZE; if LEN == 1 { 0 } else { nbuf.debug_strict_add(LEN).debug_strict_sub(BUFFER_SIZE) };
self.processed = self.processed.debug_strict_add(BUFFER_SIZE);
} }
} }
@ -277,7 +281,7 @@ fn slice_write(&mut self, msg: &[u8]) {
let nbuf = self.nbuf; let nbuf = self.nbuf;
debug_assert!(nbuf < BUFFER_SIZE); debug_assert!(nbuf < BUFFER_SIZE);
if nbuf + length < BUFFER_SIZE { if nbuf.debug_strict_add(length) < BUFFER_SIZE {
unsafe { unsafe {
let dst = (self.buf.as_mut_ptr() as *mut u8).add(nbuf); let dst = (self.buf.as_mut_ptr() as *mut u8).add(nbuf);
@ -289,7 +293,7 @@ fn slice_write(&mut self, msg: &[u8]) {
} }
} }
self.nbuf = nbuf + length; self.nbuf = nbuf.debug_strict_add(length);
return; return;
} }
@ -315,7 +319,7 @@ unsafe fn slice_write_process_buffer(&mut self, msg: &[u8]) {
// This function should only be called when the write fills the buffer, // This function should only be called when the write fills the buffer,
// so we know that there is enough input to fill the current element. // so we know that there is enough input to fill the current element.
let valid_in_elem = nbuf % ELEM_SIZE; let valid_in_elem = nbuf % ELEM_SIZE;
let needed_in_elem = ELEM_SIZE - valid_in_elem; let needed_in_elem = ELEM_SIZE.debug_strict_sub(valid_in_elem);
let src = msg.as_ptr(); let src = msg.as_ptr();
let dst = (self.buf.as_mut_ptr() as *mut u8).add(nbuf); let dst = (self.buf.as_mut_ptr() as *mut u8).add(nbuf);
@ -327,7 +331,7 @@ unsafe fn slice_write_process_buffer(&mut self, msg: &[u8]) {
// ELEM_SIZE` to show the compiler that this loop's upper bound is > 0. // ELEM_SIZE` to show the compiler that this loop's upper bound is > 0.
// We know that is true, because last step ensured we have a full // We know that is true, because last step ensured we have a full
// element in the buffer. // element in the buffer.
let last = nbuf / ELEM_SIZE + 1; let last = (nbuf / ELEM_SIZE).debug_strict_add(1);
for i in 0..last { for i in 0..last {
let elem = self.buf.get_unchecked(i).assume_init().to_le(); let elem = self.buf.get_unchecked(i).assume_init().to_le();
@ -338,7 +342,7 @@ unsafe fn slice_write_process_buffer(&mut self, msg: &[u8]) {
// Process the remaining element-sized chunks of input. // Process the remaining element-sized chunks of input.
let mut processed = needed_in_elem; let mut processed = needed_in_elem;
let input_left = length - processed; let input_left = length.debug_strict_sub(processed);
let elems_left = input_left / ELEM_SIZE; let elems_left = input_left / ELEM_SIZE;
let extra_bytes_left = input_left % ELEM_SIZE; let extra_bytes_left = input_left % ELEM_SIZE;
@ -347,7 +351,7 @@ unsafe fn slice_write_process_buffer(&mut self, msg: &[u8]) {
self.state.v3 ^= elem; self.state.v3 ^= elem;
Sip13Rounds::c_rounds(&mut self.state); Sip13Rounds::c_rounds(&mut self.state);
self.state.v0 ^= elem; self.state.v0 ^= elem;
processed += ELEM_SIZE; processed = processed.debug_strict_add(ELEM_SIZE);
} }
// Copy remaining input into start of buffer. // Copy remaining input into start of buffer.
@ -356,7 +360,7 @@ unsafe fn slice_write_process_buffer(&mut self, msg: &[u8]) {
copy_nonoverlapping_small(src, dst, extra_bytes_left); copy_nonoverlapping_small(src, dst, extra_bytes_left);
self.nbuf = extra_bytes_left; self.nbuf = extra_bytes_left;
self.processed += nbuf + processed; self.processed = self.processed.debug_strict_add(nbuf.debug_strict_add(processed));
} }
} }
@ -394,7 +398,7 @@ pub fn finish128(mut self) -> (u64, u64) {
}; };
// Finalize the hash. // Finalize the hash.
let length = self.processed + self.nbuf; let length = self.processed.debug_strict_add(self.nbuf);
let b: u64 = ((length as u64 & 0xff) << 56) | elem; let b: u64 = ((length as u64 & 0xff) << 56) | elem;
state.v3 ^= b; state.v3 ^= b;

View File

@ -0,0 +1,65 @@
// This would belong to `rustc_data_structures`, but `rustc_serialize` needs it too.
/// Addition, but only overflow checked when `cfg(debug_assertions)` is set
/// instead of respecting `-Coverflow-checks`.
///
/// This exists for performance reasons, as we ship rustc with overflow checks.
/// While overflow checks are perf neutral in almost all of the compiler, there
/// are a few particularly hot areas where we don't want overflow checks in our
/// dist builds. Overflow is still a bug there, so we want overflow check for
/// builds with debug assertions.
///
/// That's a long way to say that this should be used in areas where overflow
/// is a bug but overflow checking is too slow.
pub trait DebugStrictAdd {
/// See [`DebugStrictAdd`].
fn debug_strict_add(self, other: Self) -> Self;
}
macro_rules! impl_debug_strict_add {
($( $ty:ty )*) => {
$(
impl DebugStrictAdd for $ty {
fn debug_strict_add(self, other: Self) -> Self {
if cfg!(debug_assertions) {
self + other
} else {
self.wrapping_add(other)
}
}
}
)*
};
}
/// See [`DebugStrictAdd`].
pub trait DebugStrictSub {
/// See [`DebugStrictAdd`].
fn debug_strict_sub(self, other: Self) -> Self;
}
macro_rules! impl_debug_strict_sub {
($( $ty:ty )*) => {
$(
impl DebugStrictSub for $ty {
fn debug_strict_sub(self, other: Self) -> Self {
if cfg!(debug_assertions) {
self - other
} else {
self.wrapping_sub(other)
}
}
}
)*
};
}
impl_debug_strict_add! {
u8 u16 u32 u64 u128 usize
i8 i16 i32 i64 i128 isize
}
impl_debug_strict_sub! {
u8 u16 u32 u64 u128 usize
i8 i16 i32 i64 i128 isize
}

View File

@ -1,6 +1,10 @@
use crate::opaque::MemDecoder; use crate::opaque::MemDecoder;
use crate::serialize::Decoder; use crate::serialize::Decoder;
// This code is very hot and uses lots of arithmetic, avoid overflow checks for performance.
// See https://github.com/rust-lang/rust/pull/119440#issuecomment-1874255727
use crate::int_overflow::DebugStrictAdd;
/// Returns the length of the longest LEB128 encoding for `T`, assuming `T` is an integer type /// Returns the length of the longest LEB128 encoding for `T`, assuming `T` is an integer type
pub const fn max_leb128_len<T>() -> usize { pub const fn max_leb128_len<T>() -> usize {
// The longest LEB128 encoding for an integer uses 7 bits per byte. // The longest LEB128 encoding for an integer uses 7 bits per byte.
@ -24,7 +28,7 @@ macro_rules! impl_write_unsigned_leb128 {
*out.get_unchecked_mut(i) = value as u8; *out.get_unchecked_mut(i) = value as u8;
} }
i += 1; i = i.debug_strict_add(1);
break; break;
} else { } else {
unsafe { unsafe {
@ -32,7 +36,7 @@ macro_rules! impl_write_unsigned_leb128 {
} }
value >>= 7; value >>= 7;
i += 1; i = i.debug_strict_add(1);
} }
} }
@ -69,7 +73,7 @@ pub fn $fn_name(decoder: &mut MemDecoder<'_>) -> $int_ty {
} else { } else {
result |= ((byte & 0x7F) as $int_ty) << shift; result |= ((byte & 0x7F) as $int_ty) << shift;
} }
shift += 7; shift = shift.debug_strict_add(7);
} }
} }
}; };
@ -101,7 +105,7 @@ macro_rules! impl_write_signed_leb128 {
*out.get_unchecked_mut(i) = byte; *out.get_unchecked_mut(i) = byte;
} }
i += 1; i = i.debug_strict_add(1);
if !more { if !more {
break; break;
@ -130,7 +134,7 @@ pub fn $fn_name(decoder: &mut MemDecoder<'_>) -> $int_ty {
loop { loop {
byte = decoder.read_u8(); byte = decoder.read_u8();
result |= <$int_ty>::from(byte & 0x7F) << shift; result |= <$int_ty>::from(byte & 0x7F) << shift;
shift += 7; shift = shift.debug_strict_add(7);
if (byte & 0x80) == 0 { if (byte & 0x80) == 0 {
break; break;

View File

@ -23,5 +23,6 @@
mod serialize; mod serialize;
pub mod int_overflow;
pub mod leb128; pub mod leb128;
pub mod opaque; pub mod opaque;

View File

@ -7,6 +7,10 @@
use std::path::Path; use std::path::Path;
use std::path::PathBuf; use std::path::PathBuf;
// This code is very hot and uses lots of arithmetic, avoid overflow checks for performance.
// See https://github.com/rust-lang/rust/pull/119440#issuecomment-1874255727
use crate::int_overflow::DebugStrictAdd;
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Encoder // Encoder
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -65,7 +69,7 @@ pub fn position(&self) -> usize {
// Tracking position this way instead of having a `self.position` field // Tracking position this way instead of having a `self.position` field
// means that we only need to update `self.buffered` on a write call, // means that we only need to update `self.buffered` on a write call,
// as opposed to updating `self.position` and `self.buffered`. // as opposed to updating `self.position` and `self.buffered`.
self.flushed + self.buffered self.flushed.debug_strict_add(self.buffered)
} }
#[cold] #[cold]
@ -119,7 +123,7 @@ fn write_all(&mut self, buf: &[u8]) {
} }
if let Some(dest) = self.buffer_empty().get_mut(..buf.len()) { if let Some(dest) = self.buffer_empty().get_mut(..buf.len()) {
dest.copy_from_slice(buf); dest.copy_from_slice(buf);
self.buffered += buf.len(); self.buffered = self.buffered.debug_strict_add(buf.len());
} else { } else {
self.write_all_cold_path(buf); self.write_all_cold_path(buf);
} }
@ -158,7 +162,7 @@ fn write_all(&mut self, buf: &[u8]) {
if written > N { if written > N {
Self::panic_invalid_write::<N>(written); Self::panic_invalid_write::<N>(written);
} }
self.buffered += written; self.buffered = self.buffered.debug_strict_add(written);
} }
#[cold] #[cold]

View File

@ -5,6 +5,10 @@
use rustc_data_structures::fx::FxIndexSet; use rustc_data_structures::fx::FxIndexSet;
// This code is very hot and uses lots of arithmetic, avoid overflow checks for performance.
// See https://github.com/rust-lang/rust/pull/119440#issuecomment-1874255727
use rustc_serialize::int_overflow::DebugStrictAdd;
/// A compressed span. /// A compressed span.
/// ///
/// [`SpanData`] is 16 bytes, which is too big to stick everywhere. `Span` only /// [`SpanData`] is 16 bytes, which is too big to stick everywhere. `Span` only
@ -166,7 +170,7 @@ pub fn data_untracked(self) -> SpanData {
debug_assert!(len <= MAX_LEN); debug_assert!(len <= MAX_LEN);
SpanData { SpanData {
lo: BytePos(self.lo_or_index), lo: BytePos(self.lo_or_index),
hi: BytePos(self.lo_or_index + len), hi: BytePos(self.lo_or_index.debug_strict_add(len)),
ctxt: SyntaxContext::from_u32(self.ctxt_or_parent_or_marker as u32), ctxt: SyntaxContext::from_u32(self.ctxt_or_parent_or_marker as u32),
parent: None, parent: None,
} }
@ -179,7 +183,7 @@ pub fn data_untracked(self) -> SpanData {
}; };
SpanData { SpanData {
lo: BytePos(self.lo_or_index), lo: BytePos(self.lo_or_index),
hi: BytePos(self.lo_or_index + len), hi: BytePos(self.lo_or_index.debug_strict_add(len)),
ctxt: SyntaxContext::root(), ctxt: SyntaxContext::root(),
parent: Some(parent), parent: Some(parent),
} }