implement simd_fmax/fmin

This commit is contained in:
Ralf Jung 2022-03-06 15:26:15 -05:00
parent 9851b743c1
commit 2f97eb68a0
2 changed files with 103 additions and 59 deletions

View File

@ -371,7 +371,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
| "simd_lt"
| "simd_le"
| "simd_gt"
| "simd_ge" => {
| "simd_ge"
| "simd_fmax"
| "simd_fmin" => {
use mir::BinOp;
let &[ref left, ref right] = check_arg_count(args)?;
@ -382,23 +384,30 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
let mir_op = match intrinsic_name {
"simd_add" => BinOp::Add,
"simd_sub" => BinOp::Sub,
"simd_mul" => BinOp::Mul,
"simd_div" => BinOp::Div,
"simd_rem" => BinOp::Rem,
"simd_shl" => BinOp::Shl,
"simd_shr" => BinOp::Shr,
"simd_and" => BinOp::BitAnd,
"simd_or" => BinOp::BitOr,
"simd_xor" => BinOp::BitXor,
"simd_eq" => BinOp::Eq,
"simd_ne" => BinOp::Ne,
"simd_lt" => BinOp::Lt,
"simd_le" => BinOp::Le,
"simd_gt" => BinOp::Gt,
"simd_ge" => BinOp::Ge,
enum Op {
MirOp(BinOp),
FMax,
FMin,
}
let which = match intrinsic_name {
"simd_add" => Op::MirOp(BinOp::Add),
"simd_sub" => Op::MirOp(BinOp::Sub),
"simd_mul" => Op::MirOp(BinOp::Mul),
"simd_div" => Op::MirOp(BinOp::Div),
"simd_rem" => Op::MirOp(BinOp::Rem),
"simd_shl" => Op::MirOp(BinOp::Shl),
"simd_shr" => Op::MirOp(BinOp::Shr),
"simd_and" => Op::MirOp(BinOp::BitAnd),
"simd_or" => Op::MirOp(BinOp::BitOr),
"simd_xor" => Op::MirOp(BinOp::BitXor),
"simd_eq" => Op::MirOp(BinOp::Eq),
"simd_ne" => Op::MirOp(BinOp::Ne),
"simd_lt" => Op::MirOp(BinOp::Lt),
"simd_le" => Op::MirOp(BinOp::Le),
"simd_gt" => Op::MirOp(BinOp::Gt),
"simd_ge" => Op::MirOp(BinOp::Ge),
"simd_fmax" => Op::FMax,
"simd_fmin" => Op::FMin,
_ => unreachable!(),
};
@ -406,26 +415,38 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
let left = this.read_immediate(&this.mplace_index(&left, i)?.into())?;
let right = this.read_immediate(&this.mplace_index(&right, i)?.into())?;
let dest = this.mplace_index(&dest, i)?;
let (val, overflowed, ty) = this.overflowing_binary_op(mir_op, &left, &right)?;
if matches!(mir_op, BinOp::Shl | BinOp::Shr) {
// Shifts have extra UB as SIMD operations that the MIR binop does not have.
// See <https://github.com/rust-lang/rust/issues/91237>.
if overflowed {
let r_val = right.to_scalar()?.to_bits(right.layout.size)?;
throw_ub_format!("overflowing shift by {} in `{}` in SIMD lane {}", r_val, intrinsic_name, i);
let val = match which {
Op::MirOp(mir_op) => {
let (val, overflowed, ty) = this.overflowing_binary_op(mir_op, &left, &right)?;
if matches!(mir_op, BinOp::Shl | BinOp::Shr) {
// Shifts have extra UB as SIMD operations that the MIR binop does not have.
// See <https://github.com/rust-lang/rust/issues/91237>.
if overflowed {
let r_val = right.to_scalar()?.to_bits(right.layout.size)?;
throw_ub_format!("overflowing shift by {} in `{}` in SIMD lane {}", r_val, intrinsic_name, i);
}
}
if matches!(mir_op, BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge) {
// Special handling for boolean-returning operations
assert_eq!(ty, this.tcx.types.bool);
let val = val.to_bool().unwrap();
bool_to_simd_element(val, dest.layout.size)
} else {
assert_ne!(ty, this.tcx.types.bool);
assert_eq!(ty, dest.layout.ty);
val
}
}
}
if matches!(mir_op, BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge) {
// Special handling for boolean-returning operations
assert_eq!(ty, this.tcx.types.bool);
let val = val.to_bool().unwrap();
let val = bool_to_simd_element(val, dest.layout.size);
this.write_scalar(val, &dest.into())?;
} else {
assert_ne!(ty, this.tcx.types.bool);
assert_eq!(ty, dest.layout.ty);
this.write_scalar(val, &dest.into())?;
}
Op::FMax => {
assert!(matches!(dest.layout.ty.kind(), ty::Float(_)));
this.max_op(&left, &right)?.to_scalar()?
}
Op::FMin => {
assert!(matches!(dest.layout.ty.kind(), ty::Float(_)));
this.min_op(&left, &right)?.to_scalar()?
}
};
this.write_scalar(val, &dest.into())?;
}
}
#[rustfmt::skip]
@ -478,24 +499,10 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
this.binary_op(mir_op, &res, &op)?
}
Op::Max => {
// if `op > res`...
if this.binary_op(BinOp::Gt, &op, &res)?.to_scalar()?.to_bool()? {
// update accumulator
op
} else {
// no change
res
}
this.max_op(&res, &op)?
}
Op::Min => {
// if `op < res`...
if this.binary_op(BinOp::Lt, &op, &res)?.to_scalar()?.to_bool()? {
// update accumulator
op
} else {
// no change
res
}
this.min_op(&res, &op)?
}
};
}
@ -1071,4 +1078,30 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
_ => bug!("`float_to_int_unchecked` called with non-int output type {:?}", dest_ty),
})
}
fn max_op(
&self,
left: &ImmTy<'tcx, Tag>,
right: &ImmTy<'tcx, Tag>,
) -> InterpResult<'tcx, ImmTy<'tcx, Tag>> {
let this = self.eval_context_ref();
Ok(if this.binary_op(BinOp::Gt, left, right)?.to_scalar()?.to_bool()? {
*left
} else {
*right
})
}
fn min_op(
&self,
left: &ImmTy<'tcx, Tag>,
right: &ImmTy<'tcx, Tag>,
) -> InterpResult<'tcx, ImmTy<'tcx, Tag>> {
let this = self.eval_context_ref();
Ok(if this.binary_op(BinOp::Lt, left, right)?.to_scalar()?.to_bool()? {
*left
} else {
*right
})
}
}

View File

@ -12,6 +12,8 @@ fn simd_ops_f32() {
assert_eq!(a / f32x4::splat(2.0), f32x4::splat(5.0));
assert_eq!(a % b, f32x4::from_array([0.0, 0.0, 1.0, 2.0]));
assert_eq!(b.abs(), f32x4::from_array([1.0, 2.0, 3.0, 4.0]));
assert_eq!(a.max(b * f32x4::splat(4.0)), f32x4::from_array([10.0, 10.0, 12.0, 10.0]));
assert_eq!(a.min(b * f32x4::splat(4.0)), f32x4::from_array([4.0, 8.0, 10.0, -16.0]));
assert_eq!(a.lanes_eq(f32x4::splat(5.0) * b), Mask::from_array([false, true, false, false]));
assert_eq!(a.lanes_ne(f32x4::splat(5.0) * b), Mask::from_array([true, false, true, true]));
@ -41,6 +43,8 @@ fn simd_ops_f64() {
assert_eq!(a / f64x4::splat(2.0), f64x4::splat(5.0));
assert_eq!(a % b, f64x4::from_array([0.0, 0.0, 1.0, 2.0]));
assert_eq!(b.abs(), f64x4::from_array([1.0, 2.0, 3.0, 4.0]));
assert_eq!(a.max(b * f64x4::splat(4.0)), f64x4::from_array([10.0, 10.0, 12.0, 10.0]));
assert_eq!(a.min(b * f64x4::splat(4.0)), f64x4::from_array([4.0, 8.0, 10.0, -16.0]));
assert_eq!(a.lanes_eq(f64x4::splat(5.0) * b), Mask::from_array([false, true, false, false]));
assert_eq!(a.lanes_ne(f64x4::splat(5.0) * b), Mask::from_array([true, false, true, true]));
@ -71,6 +75,12 @@ fn simd_ops_i32() {
assert_eq!(i32x2::splat(i32::MIN) / i32x2::splat(-1), i32x2::splat(i32::MIN));
assert_eq!(a % b, i32x4::from_array([0, 0, 1, 2]));
assert_eq!(i32x2::splat(i32::MIN) % i32x2::splat(-1), i32x2::splat(0));
assert_eq!(b.abs(), i32x4::from_array([1, 2, 3, 4]));
// FIXME not a per-lane method (https://github.com/rust-lang/rust/issues/94682)
// assert_eq!(a.max(b * i32x4::splat(4)), i32x4::from_array([10, 10, 12, 10]));
// assert_eq!(a.min(b * i32x4::splat(4)), i32x4::from_array([4, 8, 10, -16]));
assert_eq!(!b, i32x4::from_array([!1, !2, !3, !-4]));
assert_eq!(b << i32x4::splat(2), i32x4::from_array([4, 8, 12, -16]));
assert_eq!(b >> i32x4::splat(1), i32x4::from_array([0, 1, 1, -2]));
assert_eq!(b & i32x4::splat(2), i32x4::from_array([0, 2, 2, 0]));
@ -84,12 +94,6 @@ fn simd_ops_i32() {
assert_eq!(a.lanes_ge(i32x4::splat(5) * b), Mask::from_array([true, true, false, true]));
assert_eq!(a.lanes_gt(i32x4::splat(5) * b), Mask::from_array([true, false, false, true]));
assert_eq!(a.horizontal_and(), 10);
assert_eq!(b.horizontal_and(), 0);
assert_eq!(a.horizontal_or(), 10);
assert_eq!(b.horizontal_or(), -1);
assert_eq!(a.horizontal_xor(), 0);
assert_eq!(b.horizontal_xor(), -4);
assert_eq!(a.horizontal_sum(), 40);
assert_eq!(b.horizontal_sum(), 2);
assert_eq!(a.horizontal_product(), 100 * 100);
@ -98,6 +102,13 @@ fn simd_ops_i32() {
assert_eq!(b.horizontal_max(), 3);
assert_eq!(a.horizontal_min(), 10);
assert_eq!(b.horizontal_min(), -4);
assert_eq!(a.horizontal_and(), 10);
assert_eq!(b.horizontal_and(), 0);
assert_eq!(a.horizontal_or(), 10);
assert_eq!(b.horizontal_or(), -1);
assert_eq!(a.horizontal_xor(), 0);
assert_eq!(b.horizontal_xor(), -4);
}
fn simd_mask() {