miri: make NaN generation non-deterministic

This commit is contained in:
Ralf Jung 2023-10-08 12:03:01 +02:00
parent d087c6fae2
commit 6796c5765d
6 changed files with 385 additions and 25 deletions

View File

@ -500,6 +500,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
b: &ImmTy<'tcx, M::Provenance>,
dest: &PlaceTy<'tcx, M::Provenance>,
) -> InterpResult<'tcx> {
assert_eq!(a.layout.ty, b.layout.ty);
assert!(matches!(a.layout.ty.kind(), ty::Int(..) | ty::Uint(..)));
// Performs an exact division, resulting in undefined behavior where
// `x % y != 0` or `y == 0` or `x == T::MIN && y == -1`.
// First, check x % y != 0 (or if that computation overflows).
@ -522,7 +525,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
l: &ImmTy<'tcx, M::Provenance>,
r: &ImmTy<'tcx, M::Provenance>,
) -> InterpResult<'tcx, Scalar<M::Provenance>> {
assert_eq!(l.layout.ty, r.layout.ty);
assert!(matches!(l.layout.ty.kind(), ty::Int(..) | ty::Uint(..)));
assert!(matches!(mir_op, BinOp::Add | BinOp::Sub));
let (val, overflowed) = self.overflowing_binary_op(mir_op, l, r)?;
Ok(if overflowed {
let size = l.layout.size;

View File

@ -6,6 +6,7 @@ use std::borrow::{Borrow, Cow};
use std::fmt::Debug;
use std::hash::Hash;
use rustc_apfloat::Float;
use rustc_ast::{InlineAsmOptions, InlineAsmTemplatePiece};
use rustc_middle::mir;
use rustc_middle::ty::layout::TyAndLayout;
@ -240,6 +241,13 @@ pub trait Machine<'mir, 'tcx: 'mir>: Sized {
right: &ImmTy<'tcx, Self::Provenance>,
) -> InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)>;
/// Generate the NaN returned by a float operation, given the list of inputs.
/// (This is all inputs, not just NaN inputs!)
fn generate_nan<F: Float>(_ecx: &InterpCx<'mir, 'tcx, Self>, _inputs: &[F]) -> F {
// By default we always return the preferred NaN.
F::NAN
}
/// Called before writing the specified `local` of the `frame`.
/// Since writing a ZST is not actually accessing memory or locals, this is never invoked
/// for ZST reads.

View File

@ -113,6 +113,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
) -> (ImmTy<'tcx, M::Provenance>, bool) {
use rustc_middle::mir::BinOp::*;
// Performs appropriate non-deterministic adjustments of NaN results.
let adjust_nan = |f: F| -> F {
if f.is_nan() { M::generate_nan(self, &[l, r]) } else { f }
};
let val = match bin_op {
Eq => ImmTy::from_bool(l == r, *self.tcx),
Ne => ImmTy::from_bool(l != r, *self.tcx),
@ -120,11 +125,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
Le => ImmTy::from_bool(l <= r, *self.tcx),
Gt => ImmTy::from_bool(l > r, *self.tcx),
Ge => ImmTy::from_bool(l >= r, *self.tcx),
Add => ImmTy::from_scalar((l + r).value.into(), layout),
Sub => ImmTy::from_scalar((l - r).value.into(), layout),
Mul => ImmTy::from_scalar((l * r).value.into(), layout),
Div => ImmTy::from_scalar((l / r).value.into(), layout),
Rem => ImmTy::from_scalar((l % r).value.into(), layout),
Add => ImmTy::from_scalar(adjust_nan((l + r).value).into(), layout),
Sub => ImmTy::from_scalar(adjust_nan((l - r).value).into(), layout),
Mul => ImmTy::from_scalar(adjust_nan((l * r).value).into(), layout),
Div => ImmTy::from_scalar(adjust_nan((l / r).value).into(), layout),
Rem => ImmTy::from_scalar(adjust_nan((l % r).value).into(), layout),
_ => span_bug!(self.cur_span(), "invalid float op: `{:?}`", bin_op),
};
(val, false)

View File

@ -1001,6 +1001,11 @@ impl<'mir, 'tcx> Machine<'mir, 'tcx> for MiriMachine<'mir, 'tcx> {
ecx.binary_ptr_op(bin_op, left, right)
}
#[inline(always)]
fn generate_nan<F: rustc_apfloat::Float>(ecx: &InterpCx<'mir, 'tcx, Self>, inputs: &[F]) -> F {
ecx.generate_nan(inputs)
}
fn thread_local_static_base_pointer(
ecx: &mut MiriInterpCx<'mir, 'tcx>,
def_id: DefId,

View File

@ -1,20 +1,16 @@
use std::iter;
use log::trace;
use rand::{seq::IteratorRandom, Rng};
use rustc_apfloat::Float;
use rustc_middle::mir;
use rustc_target::abi::Size;
use crate::*;
pub trait EvalContextExt<'tcx> {
fn binary_ptr_op(
&self,
bin_op: mir::BinOp,
left: &ImmTy<'tcx, Provenance>,
right: &ImmTy<'tcx, Provenance>,
) -> InterpResult<'tcx, (ImmTy<'tcx, Provenance>, bool)>;
}
impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> {
impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
fn binary_ptr_op(
&self,
bin_op: mir::BinOp,
@ -23,12 +19,13 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> {
) -> InterpResult<'tcx, (ImmTy<'tcx, Provenance>, bool)> {
use rustc_middle::mir::BinOp::*;
let this = self.eval_context_ref();
trace!("ptr_op: {:?} {:?} {:?}", *left, bin_op, *right);
Ok(match bin_op {
Eq | Ne | Lt | Le | Gt | Ge => {
assert_eq!(left.layout.abi, right.layout.abi); // types an differ, e.g. fn ptrs with different `for`
let size = self.pointer_size();
let size = this.pointer_size();
// Just compare the bits. ScalarPairs are compared lexicographically.
// We thus always compare pairs and simply fill scalars up with 0.
let left = match **left {
@ -50,7 +47,7 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> {
Ge => left >= right,
_ => bug!(),
};
(ImmTy::from_bool(res, *self.tcx), false)
(ImmTy::from_bool(res, *this.tcx), false)
}
// Some more operations are possible with atomics.
@ -58,26 +55,49 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> {
Add | Sub | BitOr | BitAnd | BitXor => {
assert!(left.layout.ty.is_unsafe_ptr());
assert!(right.layout.ty.is_unsafe_ptr());
let ptr = left.to_scalar().to_pointer(self)?;
let ptr = left.to_scalar().to_pointer(this)?;
// We do the actual operation with usize-typed scalars.
let left = ImmTy::from_uint(ptr.addr().bytes(), self.machine.layouts.usize);
let left = ImmTy::from_uint(ptr.addr().bytes(), this.machine.layouts.usize);
let right = ImmTy::from_uint(
right.to_scalar().to_target_usize(self)?,
self.machine.layouts.usize,
right.to_scalar().to_target_usize(this)?,
this.machine.layouts.usize,
);
let (result, overflowing) = self.overflowing_binary_op(bin_op, &left, &right)?;
let (result, overflowing) = this.overflowing_binary_op(bin_op, &left, &right)?;
// Construct a new pointer with the provenance of `ptr` (the LHS).
let result_ptr = Pointer::new(
ptr.provenance,
Size::from_bytes(result.to_scalar().to_target_usize(self)?),
Size::from_bytes(result.to_scalar().to_target_usize(this)?),
);
(
ImmTy::from_scalar(Scalar::from_maybe_pointer(result_ptr, self), left.layout),
ImmTy::from_scalar(Scalar::from_maybe_pointer(result_ptr, this), left.layout),
overflowing,
)
}
_ => span_bug!(self.cur_span(), "Invalid operator on pointers: {:?}", bin_op),
_ => span_bug!(this.cur_span(), "Invalid operator on pointers: {:?}", bin_op),
})
}
fn generate_nan<F: Float>(&self, inputs: &[F]) -> F {
let this = self.eval_context_ref();
let mut rand = this.machine.rng.borrow_mut();
// Assemble an iterator of possible NaNs: preferred, unchanged propagation, quieting propagation.
let preferred_nan = F::qnan(Some(0));
let nans = iter::once(preferred_nan)
.chain(inputs.iter().filter(|f| f.is_nan()).copied())
.chain(inputs.iter().filter(|f| f.is_signaling()).map(|f| {
// Make it quiet, by setting the bit. We assume that `preferred_nan`
// only has bits set that all quiet NaNs need to have set.
F::from_bits(f.to_bits() | preferred_nan.to_bits())
}));
// Pick one of the NaNs.
let nan = nans.choose(&mut *rand).unwrap();
// Non-deterministically flip the sign.
if rand.gen() {
// This will properly flip even for NaN.
-nan
} else {
nan
}
}
}

View File

@ -0,0 +1,316 @@
use std::collections::HashSet;
use std::fmt;
use std::hash::Hash;
use std::hint::black_box;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum Sign {
Neg = 1,
Pos = 0,
}
use Sign::*;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum NaNKind {
Quiet = 1,
Signaling = 0,
}
use NaNKind::*;
#[track_caller]
fn check_all_outcomes<T: Eq + Hash + fmt::Display>(expected: HashSet<T>, generate: impl Fn() -> T) {
let mut seen = HashSet::new();
// Let's give it 8x as many tries as we are expecting values.
let tries = expected.len() * 8;
for _ in 0..tries {
let val = generate();
assert!(expected.contains(&val), "got an unexpected value: {val}");
seen.insert(val);
}
// Let's see if we saw them all.
for val in expected {
if !seen.contains(&val) {
panic!("did not get value that should be possible: {val}");
}
}
}
// -- f32 support
#[repr(C)]
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
struct F32(u32);
impl From<f32> for F32 {
fn from(x: f32) -> Self {
F32(x.to_bits())
}
}
/// Returns a value that is `ones` many 1-bits.
fn u32_ones(ones: u32) -> u32 {
assert!(ones <= 32);
if ones == 0 {
// `>>` by 32 doesn't actually shift. So inconsistent :(
return 0;
}
u32::MAX >> (32 - ones)
}
const F32_SIGN_BIT: u32 = 32 - 1; // position of the sign bit
const F32_EXP: u32 = 8; // 8 bits of exponent
const F32_MANTISSA: u32 = F32_SIGN_BIT - F32_EXP;
const F32_NAN_PAYLOAD: u32 = F32_MANTISSA - 1;
impl fmt::Display for F32 {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// Alaways show raw bits.
write!(f, "0x{:08x} ", self.0)?;
// Also show nice version.
let val = self.0;
let sign = val >> F32_SIGN_BIT;
let val = val & u32_ones(F32_SIGN_BIT); // mask away sign
let exp = val >> F32_MANTISSA;
let mantissa = val & u32_ones(F32_MANTISSA);
if exp == u32_ones(F32_EXP) {
// A NaN! Special printing.
let sign = if sign != 0 { Neg } else { Pos };
let quiet = if (mantissa >> F32_NAN_PAYLOAD) != 0 { Quiet } else { Signaling };
let payload = mantissa & u32_ones(F32_NAN_PAYLOAD);
write!(f, "(NaN: {:?}, {:?}, payload = {:#x})", sign, quiet, payload)
} else {
// Normal float value.
write!(f, "({})", f32::from_bits(self.0))
}
}
}
impl F32 {
fn nan(sign: Sign, kind: NaNKind, payload: u32) -> Self {
// Either the quiet bit must be set of the payload must be non-0;
// otherwise this is not a NaN but an infinity.
assert!(kind == Quiet || payload != 0);
// Payload must fit in 22 bits.
assert!(payload < (1 << F32_NAN_PAYLOAD));
// Concatenate the bits (with a 22bit payload).
// Pattern: [negative] ++ [1]^8 ++ [quiet] ++ [payload]
let val = ((sign as u32) << F32_SIGN_BIT)
| (u32_ones(F32_EXP) << F32_MANTISSA)
| ((kind as u32) << F32_NAN_PAYLOAD)
| payload;
// Sanity check.
assert!(f32::from_bits(val).is_nan());
// Done!
F32(val)
}
fn as_f32(self) -> f32 {
black_box(f32::from_bits(self.0))
}
}
// -- f64 support
#[repr(C)]
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
struct F64(u64);
impl From<f64> for F64 {
fn from(x: f64) -> Self {
F64(x.to_bits())
}
}
/// Returns a value that is `ones` many 1-bits.
fn u64_ones(ones: u32) -> u64 {
assert!(ones <= 64);
if ones == 0 {
// `>>` by 32 doesn't actually shift. So inconsistent :(
return 0;
}
u64::MAX >> (64 - ones)
}
const F64_SIGN_BIT: u32 = 64 - 1; // position of the sign bit
const F64_EXP: u32 = 11; // 11 bits of exponent
const F64_MANTISSA: u32 = F64_SIGN_BIT - F64_EXP;
const F64_NAN_PAYLOAD: u32 = F64_MANTISSA - 1;
impl fmt::Display for F64 {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// Alaways show raw bits.
write!(f, "0x{:08x} ", self.0)?;
// Also show nice version.
let val = self.0;
let sign = val >> F64_SIGN_BIT;
let val = val & u64_ones(F64_SIGN_BIT); // mask away sign
let exp = val >> F64_MANTISSA;
let mantissa = val & u64_ones(F64_MANTISSA);
if exp == u64_ones(F64_EXP) {
// A NaN! Special printing.
let sign = if sign != 0 { Neg } else { Pos };
let quiet = if (mantissa >> F64_NAN_PAYLOAD) != 0 { Quiet } else { Signaling };
let payload = mantissa & u64_ones(F64_NAN_PAYLOAD);
write!(f, "(NaN: {:?}, {:?}, payload = {:#x})", sign, quiet, payload)
} else {
// Normal float value.
write!(f, "({})", f64::from_bits(self.0))
}
}
}
impl F64 {
fn nan(sign: Sign, kind: NaNKind, payload: u64) -> Self {
// Either the quiet bit must be set of the payload must be non-0;
// otherwise this is not a NaN but an infinity.
assert!(kind == Quiet || payload != 0);
// Payload must fit in 52 bits.
assert!(payload < (1 << F64_NAN_PAYLOAD));
// Concatenate the bits (with a 52bit payload).
// Pattern: [negative] ++ [1]^11 ++ [quiet] ++ [payload]
let val = ((sign as u64) << F64_SIGN_BIT)
| (u64_ones(F64_EXP) << F64_MANTISSA)
| ((kind as u64) << F64_NAN_PAYLOAD)
| payload;
// Sanity check.
assert!(f64::from_bits(val).is_nan());
// Done!
F64(val)
}
fn as_f64(self) -> f64 {
black_box(f64::from_bits(self.0))
}
}
// -- actual tests
fn test_f32() {
// Freshly generated NaNs can have either sign.
check_all_outcomes(
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
|| F32::from(0.0 / black_box(0.0)),
);
// When there are NaN inputs, their payload can be propagated, with any sign.
let all1_payload = u32_ones(22);
let all1 = F32::nan(Pos, Quiet, all1_payload).as_f32();
check_all_outcomes(
HashSet::from_iter([
F32::nan(Pos, Quiet, 0),
F32::nan(Neg, Quiet, 0),
F32::nan(Pos, Quiet, all1_payload),
F32::nan(Neg, Quiet, all1_payload),
]),
|| F32::from(0.0 + all1),
);
// When there are two NaN inputs, the output can be either one, or the preferred NaN.
let just1 = F32::nan(Neg, Quiet, 1).as_f32();
check_all_outcomes(
HashSet::from_iter([
F32::nan(Pos, Quiet, 0),
F32::nan(Neg, Quiet, 0),
F32::nan(Pos, Quiet, 1),
F32::nan(Neg, Quiet, 1),
F32::nan(Pos, Quiet, all1_payload),
F32::nan(Neg, Quiet, all1_payload),
]),
|| F32::from(just1 - all1),
);
// When there are *signaling* NaN inputs, they might be quieted or not.
let all1_snan = F32::nan(Pos, Signaling, all1_payload).as_f32();
check_all_outcomes(
HashSet::from_iter([
F32::nan(Pos, Quiet, 0),
F32::nan(Neg, Quiet, 0),
F32::nan(Pos, Quiet, all1_payload),
F32::nan(Neg, Quiet, all1_payload),
F32::nan(Pos, Signaling, all1_payload),
F32::nan(Neg, Signaling, all1_payload),
]),
|| F32::from(0.0 * all1_snan),
);
// Mix signaling and non-signaling NaN.
check_all_outcomes(
HashSet::from_iter([
F32::nan(Pos, Quiet, 0),
F32::nan(Neg, Quiet, 0),
F32::nan(Pos, Quiet, 1),
F32::nan(Neg, Quiet, 1),
F32::nan(Pos, Quiet, all1_payload),
F32::nan(Neg, Quiet, all1_payload),
F32::nan(Pos, Signaling, all1_payload),
F32::nan(Neg, Signaling, all1_payload),
]),
|| F32::from(just1 % all1_snan),
);
}
fn test_f64() {
// Freshly generated NaNs can have either sign.
check_all_outcomes(
HashSet::from_iter([F64::nan(Pos, Quiet, 0), F64::nan(Neg, Quiet, 0)]),
|| F64::from(0.0 / black_box(0.0)),
);
// When there are NaN inputs, their payload can be propagated, with any sign.
let all1_payload = u64_ones(51);
let all1 = F64::nan(Pos, Quiet, all1_payload).as_f64();
check_all_outcomes(
HashSet::from_iter([
F64::nan(Pos, Quiet, 0),
F64::nan(Neg, Quiet, 0),
F64::nan(Pos, Quiet, all1_payload),
F64::nan(Neg, Quiet, all1_payload),
]),
|| F64::from(0.0 + all1),
);
// When there are two NaN inputs, the output can be either one, or the preferred NaN.
let just1 = F64::nan(Neg, Quiet, 1).as_f64();
check_all_outcomes(
HashSet::from_iter([
F64::nan(Pos, Quiet, 0),
F64::nan(Neg, Quiet, 0),
F64::nan(Pos, Quiet, 1),
F64::nan(Neg, Quiet, 1),
F64::nan(Pos, Quiet, all1_payload),
F64::nan(Neg, Quiet, all1_payload),
]),
|| F64::from(just1 - all1),
);
// When there are *signaling* NaN inputs, they might be quieted or not.
let all1_snan = F64::nan(Pos, Signaling, all1_payload).as_f64();
check_all_outcomes(
HashSet::from_iter([
F64::nan(Pos, Quiet, 0),
F64::nan(Neg, Quiet, 0),
F64::nan(Pos, Quiet, all1_payload),
F64::nan(Neg, Quiet, all1_payload),
F64::nan(Pos, Signaling, all1_payload),
F64::nan(Neg, Signaling, all1_payload),
]),
|| F64::from(0.0 * all1_snan),
);
// Mix signaling and non-signaling NaN.
check_all_outcomes(
HashSet::from_iter([
F64::nan(Pos, Quiet, 0),
F64::nan(Neg, Quiet, 0),
F64::nan(Pos, Quiet, 1),
F64::nan(Neg, Quiet, 1),
F64::nan(Pos, Quiet, all1_payload),
F64::nan(Neg, Quiet, all1_payload),
F64::nan(Pos, Signaling, all1_payload),
F64::nan(Neg, Signaling, all1_payload),
]),
|| F64::from(just1 % all1_snan),
);
}
fn main() {
// Check our constants against std, just to be sure.
// We add 1 since our numbers are the number of bits stored
// to represent the value, and std has the precision of the value,
// which is one more due to the implicit leading 1.
assert_eq!(F32_MANTISSA + 1, f32::MANTISSA_DIGITS);
assert_eq!(F64_MANTISSA + 1, f64::MANTISSA_DIGITS);
test_f32();
test_f64();
}