miri: reduce code duplication in SSE/SSE2 bin_op_* functions

This commit is contained in:
Eduardo Sánchez Muñoz 2023-09-13 20:17:59 +02:00
parent 68687e2535
commit 21418429fd
6 changed files with 194 additions and 310 deletions

View File

@ -503,16 +503,20 @@ pub fn to_target_isize(self, cx: &impl HasDataLayout) -> InterpResult<'tcx, i64>
Ok(i64::try_from(b).unwrap())
}
#[inline]
pub fn to_float<F: Float>(self) -> InterpResult<'tcx, F> {
// Going through `to_uint` to check size and truncation.
Ok(F::from_bits(self.to_uint(Size::from_bits(F::BITS))?))
}
#[inline]
pub fn to_f32(self) -> InterpResult<'tcx, Single> {
// Going through `u32` to check size and truncation.
Ok(Single::from_bits(self.to_u32()?.into()))
self.to_float()
}
#[inline]
pub fn to_f64(self) -> InterpResult<'tcx, Double> {
// Going through `u64` to check size and truncation.
Ok(Double::from_bits(self.to_u64()?.into()))
self.to_float()
}
}

View File

@ -1152,3 +1152,20 @@ pub fn get_local_crates(tcx: TyCtxt<'_>) -> Vec<CrateNum> {
pub fn target_os_is_unix(target_os: &str) -> bool {
matches!(target_os, "linux" | "macos" | "freebsd" | "android")
}
pub(crate) fn bool_to_simd_element(b: bool, size: Size) -> Scalar<Provenance> {
// SIMD uses all-1 as pattern for "true". In two's complement,
// -1 has all its bits set to one and `from_int` will truncate or
// sign-extend it to `size` as required.
let val = if b { -1 } else { 0 };
Scalar::from_int(val, size)
}
pub(crate) fn simd_element_to_bool(elem: ImmTy<'_, Provenance>) -> InterpResult<'_, bool> {
let val = elem.to_scalar().to_int(elem.layout.size)?;
Ok(match val {
0 => false,
-1 => true,
_ => throw_ub_format!("each element of a SIMD mask must be all-0-bits or all-1-bits"),
})
}

View File

@ -1,10 +1,10 @@
use rustc_apfloat::{Float, Round};
use rustc_middle::ty::layout::{HasParamEnv, LayoutOf};
use rustc_middle::{mir, ty, ty::FloatTy};
use rustc_target::abi::{Endian, HasDataLayout, Size};
use rustc_target::abi::{Endian, HasDataLayout};
use crate::*;
use helpers::check_arg_count;
use helpers::{bool_to_simd_element, check_arg_count, simd_element_to_bool};
impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
@ -612,21 +612,6 @@ enum Op {
}
}
fn bool_to_simd_element(b: bool, size: Size) -> Scalar<Provenance> {
// SIMD uses all-1 as pattern for "true"
let val = if b { -1 } else { 0 };
Scalar::from_int(val, size)
}
fn simd_element_to_bool(elem: ImmTy<'_, Provenance>) -> InterpResult<'_, bool> {
let val = elem.to_scalar().to_int(elem.layout.size)?;
Ok(match val {
0 => false,
-1 => true,
_ => throw_ub_format!("each element of a SIMD mask must be all-0-bits or all-1-bits"),
})
}
fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {
assert!(idx < vec_len);
match endianness {

View File

@ -1,4 +1,8 @@
use crate::InterpResult;
use rustc_middle::mir;
use rustc_target::abi::Size;
use crate::*;
use helpers::bool_to_simd_element;
pub(super) mod sse;
pub(super) mod sse2;
@ -43,3 +47,155 @@ fn from_intrinsic_imm(imm: i8, intrinsic: &str) -> InterpResult<'_, Self> {
}
}
}
#[derive(Copy, Clone)]
enum FloatBinOp {
/// Arithmetic operation
Arith(mir::BinOp),
/// Comparison
Cmp(FloatCmpOp),
/// Minimum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/minss>
/// <https://www.felixcloutier.com/x86/minps>
/// <https://www.felixcloutier.com/x86/minsd>
/// <https://www.felixcloutier.com/x86/minpd>
Min,
/// Maximum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/maxss>
/// <https://www.felixcloutier.com/x86/maxps>
/// <https://www.felixcloutier.com/x86/maxsd>
/// <https://www.felixcloutier.com/x86/maxpd>
Max,
}
/// Performs `which` scalar operation on `left` and `right` and returns
/// the result.
fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
this: &crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &ImmTy<'tcx, Provenance>,
right: &ImmTy<'tcx, Provenance>,
) -> InterpResult<'tcx, Scalar<Provenance>> {
match which {
FloatBinOp::Arith(which) => {
let (res, _overflow, _ty) = this.overflowing_binary_op(which, left, right)?;
Ok(res)
}
FloatBinOp::Cmp(which) => {
let left = left.to_scalar().to_float::<F>()?;
let right = right.to_scalar().to_float::<F>()?;
// FIXME: Make sure that these operations match the semantics
// of cmpps/cmpss/cmppd/cmpsd
let res = match which {
FloatCmpOp::Eq => left == right,
FloatCmpOp::Lt => left < right,
FloatCmpOp::Le => left <= right,
FloatCmpOp::Unord => left.is_nan() || right.is_nan(),
FloatCmpOp::Neq => left != right,
FloatCmpOp::Nlt => !(left < right),
FloatCmpOp::Nle => !(left <= right),
FloatCmpOp::Ord => !left.is_nan() && !right.is_nan(),
};
Ok(bool_to_simd_element(res, Size::from_bits(F::BITS)))
}
FloatBinOp::Min => {
let left_scalar = left.to_scalar();
let left = left_scalar.to_float::<F>()?;
let right_scalar = right.to_scalar();
let right = right_scalar.to_float::<F>()?;
// SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
// is true when `x` is either +0 or -0.
if (left == F::ZERO && right == F::ZERO)
|| left.is_nan()
|| right.is_nan()
|| left >= right
{
Ok(right_scalar)
} else {
Ok(left_scalar)
}
}
FloatBinOp::Max => {
let left_scalar = left.to_scalar();
let left = left_scalar.to_float::<F>()?;
let right_scalar = right.to_scalar();
let right = right_scalar.to_float::<F>()?;
// SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
// is true when `x` is either +0 or -0.
if (left == F::ZERO && right == F::ZERO)
|| left.is_nan()
|| right.is_nan()
|| left <= right
{
Ok(right_scalar)
} else {
Ok(left_scalar)
}
}
}
}
/// Performs `which` operation on the first component of `left` and `right`
/// and copies the other components from `left`. The result is stored in `dest`.
fn bin_op_simd_float_first<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
let res0 = bin_op_float::<F>(
this,
which,
&this.read_immediate(&this.project_index(&left, 0)?)?,
&this.read_immediate(&this.project_index(&right, 0)?)?,
)?;
this.write_scalar(res0, &this.project_index(&dest, 0)?)?;
for i in 1..dest_len {
this.copy_op(
&this.project_index(&left, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}
Ok(())
}
/// Performs `which` operation on each component of `left` and
/// `right`, storing the result is stored in `dest`.
fn bin_op_simd_float_all<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
for i in 0..dest_len {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let right = this.read_immediate(&this.project_index(&right, i)?)?;
let dest = this.project_index(&dest, i)?;
let res = bin_op_float::<F>(this, which, &left, &right)?;
this.write_scalar(res, &dest)?;
}
Ok(())
}

View File

@ -5,7 +5,7 @@
use rand::Rng as _;
use super::FloatCmpOp;
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp};
use crate::*;
use shims::foreign_items::EmulateByNameResult;
@ -45,7 +45,7 @@ fn emulate_x86_sse_intrinsic(
_ => unreachable!(),
};
bin_op_ss(this, which, left, right, dest)?;
bin_op_simd_float_first::<Single>(this, which, left, right, dest)?;
}
// Used to implement _mm_min_ps and _mm_max_ps functions.
// Note that the semantics are a bit different from Rust simd_min
@ -62,7 +62,7 @@ fn emulate_x86_sse_intrinsic(
_ => unreachable!(),
};
bin_op_ps(this, which, left, right, dest)?;
bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
}
// Used to implement _mm_{sqrt,rcp,rsqrt}_ss functions.
// Performs the operations on the first component of `op` and
@ -106,7 +106,7 @@ fn emulate_x86_sse_intrinsic(
"llvm.x86.sse.cmp.ss",
)?);
bin_op_ss(this, which, left, right, dest)?;
bin_op_simd_float_first::<Single>(this, which, left, right, dest)?;
}
// Used to implement the _mm_cmp_ps function.
// Performs a comparison operation on each component of `left`
@ -121,7 +121,7 @@ fn emulate_x86_sse_intrinsic(
"llvm.x86.sse.cmp.ps",
)?);
bin_op_ps(this, which, left, right, dest)?;
bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
}
// Used to implement _mm_{,u}comi{eq,lt,le,gt,ge,neq}_ss functions.
// Compares the first component of `left` and `right` and returns
@ -281,148 +281,6 @@ fn emulate_x86_sse_intrinsic(
}
}
#[derive(Copy, Clone)]
enum FloatBinOp {
/// Arithmetic operation
Arith(mir::BinOp),
/// Comparison
Cmp(FloatCmpOp),
/// Minimum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/minss>
/// <https://www.felixcloutier.com/x86/minps>
Min,
/// Maximum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/maxss>
/// <https://www.felixcloutier.com/x86/maxps>
Max,
}
/// Performs `which` scalar operation on `left` and `right` and returns
/// the result.
fn bin_op_f32<'tcx>(
this: &crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &ImmTy<'tcx, Provenance>,
right: &ImmTy<'tcx, Provenance>,
) -> InterpResult<'tcx, Scalar<Provenance>> {
match which {
FloatBinOp::Arith(which) => {
let (res, _, _) = this.overflowing_binary_op(which, left, right)?;
Ok(res)
}
FloatBinOp::Cmp(which) => {
let left = left.to_scalar().to_f32()?;
let right = right.to_scalar().to_f32()?;
// FIXME: Make sure that these operations match the semantics of cmpps
let res = match which {
FloatCmpOp::Eq => left == right,
FloatCmpOp::Lt => left < right,
FloatCmpOp::Le => left <= right,
FloatCmpOp::Unord => left.is_nan() || right.is_nan(),
FloatCmpOp::Neq => left != right,
FloatCmpOp::Nlt => !(left < right),
FloatCmpOp::Nle => !(left <= right),
FloatCmpOp::Ord => !left.is_nan() && !right.is_nan(),
};
Ok(Scalar::from_u32(if res { u32::MAX } else { 0 }))
}
FloatBinOp::Min => {
let left = left.to_scalar().to_f32()?;
let right = right.to_scalar().to_f32()?;
// SSE semantics to handle zero and NaN. Note that `x == Single::ZERO`
// is true when `x` is either +0 or -0.
if (left == Single::ZERO && right == Single::ZERO)
|| left.is_nan()
|| right.is_nan()
|| left >= right
{
Ok(Scalar::from_f32(right))
} else {
Ok(Scalar::from_f32(left))
}
}
FloatBinOp::Max => {
let left = left.to_scalar().to_f32()?;
let right = right.to_scalar().to_f32()?;
// SSE semantics to handle zero and NaN. Note that `x == Single::ZERO`
// is true when `x` is either +0 or -0.
if (left == Single::ZERO && right == Single::ZERO)
|| left.is_nan()
|| right.is_nan()
|| left <= right
{
Ok(Scalar::from_f32(right))
} else {
Ok(Scalar::from_f32(left))
}
}
}
}
/// Performs `which` operation on the first component of `left` and `right`
/// and copies the other components from `left`. The result is stored in `dest`.
fn bin_op_ss<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
let res0 = bin_op_f32(
this,
which,
&this.read_immediate(&this.project_index(&left, 0)?)?,
&this.read_immediate(&this.project_index(&right, 0)?)?,
)?;
this.write_scalar(res0, &this.project_index(&dest, 0)?)?;
for i in 1..dest_len {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let dest = this.project_index(&dest, i)?;
this.write_immediate(*left, &dest)?;
}
Ok(())
}
/// Performs `which` operation on each component of `left` and
/// `right`, storing the result is stored in `dest`.
fn bin_op_ps<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
for i in 0..dest_len {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let right = this.read_immediate(&this.project_index(&right, i)?)?;
let dest = this.project_index(&dest, i)?;
let res = bin_op_f32(this, which, &left, &right)?;
this.write_scalar(res, &dest)?;
}
Ok(())
}
#[derive(Copy, Clone)]
enum FloatUnaryOp {
/// sqrt(x)

View File

@ -7,7 +7,7 @@
use rustc_span::Symbol;
use rustc_target::spec::abi::Abi;
use super::FloatCmpOp;
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp};
use crate::*;
use shims::foreign_items::EmulateByNameResult;
@ -513,7 +513,7 @@ enum ShiftOp {
_ => unreachable!(),
};
bin_op_sd(this, which, left, right, dest)?;
bin_op_simd_float_first::<Double>(this, which, left, right, dest)?;
}
// Used to implement _mm_min_pd and _mm_max_pd functions.
// Note that the semantics are a bit different from Rust simd_min
@ -530,7 +530,7 @@ enum ShiftOp {
_ => unreachable!(),
};
bin_op_pd(this, which, left, right, dest)?;
bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
}
// Used to implement _mm_sqrt_sd functions.
// Performs the operations on the first component of `op` and
@ -589,7 +589,7 @@ enum ShiftOp {
"llvm.x86.sse2.cmp.sd",
)?);
bin_op_sd(this, which, left, right, dest)?;
bin_op_simd_float_first::<Double>(this, which, left, right, dest)?;
}
// Used to implement the _mm_cmp*_pd functions.
// Performs a comparison operation on each component of `left`
@ -604,7 +604,7 @@ enum ShiftOp {
"llvm.x86.sse2.cmp.pd",
)?);
bin_op_pd(this, which, left, right, dest)?;
bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
}
// Used to implement _mm_{,u}comi{eq,lt,le,gt,ge,neq}_sd functions.
// Compares the first component of `left` and `right` and returns
@ -840,139 +840,3 @@ fn extract_first_u64<'tcx>(
// Get the first u64 from the array
this.read_scalar(&this.project_index(&op, 0)?)?.to_u64()
}
#[derive(Copy, Clone)]
enum FloatBinOp {
/// Comparison
Cmp(FloatCmpOp),
/// Minimum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/minsd>
/// <https://www.felixcloutier.com/x86/minpd>
Min,
/// Maximum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/maxsd>
/// <https://www.felixcloutier.com/x86/maxpd>
Max,
}
/// Performs `which` scalar operation on `left` and `right` and returns
/// the result.
// FIXME make this generic over apfloat type to reduce code duplicaton with bin_op_f32
fn bin_op_f64<'tcx>(
which: FloatBinOp,
left: &ImmTy<'tcx, Provenance>,
right: &ImmTy<'tcx, Provenance>,
) -> InterpResult<'tcx, Scalar<Provenance>> {
match which {
FloatBinOp::Cmp(which) => {
let left = left.to_scalar().to_f64()?;
let right = right.to_scalar().to_f64()?;
// FIXME: Make sure that these operations match the semantics of cmppd
let res = match which {
FloatCmpOp::Eq => left == right,
FloatCmpOp::Lt => left < right,
FloatCmpOp::Le => left <= right,
FloatCmpOp::Unord => left.is_nan() || right.is_nan(),
FloatCmpOp::Neq => left != right,
FloatCmpOp::Nlt => !(left < right),
FloatCmpOp::Nle => !(left <= right),
FloatCmpOp::Ord => !left.is_nan() && !right.is_nan(),
};
Ok(Scalar::from_u64(if res { u64::MAX } else { 0 }))
}
FloatBinOp::Min => {
let left = left.to_scalar().to_f64()?;
let right = right.to_scalar().to_f64()?;
// SSE semantics to handle zero and NaN. Note that `x == Single::ZERO`
// is true when `x` is either +0 or -0.
if (left == Double::ZERO && right == Double::ZERO)
|| left.is_nan()
|| right.is_nan()
|| left >= right
{
Ok(Scalar::from_f64(right))
} else {
Ok(Scalar::from_f64(left))
}
}
FloatBinOp::Max => {
let left = left.to_scalar().to_f64()?;
let right = right.to_scalar().to_f64()?;
// SSE semantics to handle zero and NaN. Note that `x == Single::ZERO`
// is true when `x` is either +0 or -0.
if (left == Double::ZERO && right == Double::ZERO)
|| left.is_nan()
|| right.is_nan()
|| left <= right
{
Ok(Scalar::from_f64(right))
} else {
Ok(Scalar::from_f64(left))
}
}
}
}
/// Performs `which` operation on the first component of `left` and `right`
/// and copies the other components from `left`. The result is stored in `dest`.
fn bin_op_sd<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
let res0 = bin_op_f64(
which,
&this.read_immediate(&this.project_index(&left, 0)?)?,
&this.read_immediate(&this.project_index(&right, 0)?)?,
)?;
this.write_scalar(res0, &this.project_index(&dest, 0)?)?;
for i in 1..dest_len {
this.copy_op(
&this.project_index(&left, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}
Ok(())
}
/// Performs `which` operation on each component of `left` and
/// `right`, storing the result is stored in `dest`.
fn bin_op_pd<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
for i in 0..dest_len {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let right = this.read_immediate(&this.project_index(&right, i)?)?;
let dest = this.project_index(&dest, i)?;
let res = bin_op_f64(which, &left, &right)?;
this.write_scalar(res, &dest)?;
}
Ok(())
}