fix handling of NaNs in simd max/min

This commit is contained in:
Ralf Jung 2022-03-06 16:54:51 -05:00
parent 2f97eb68a0
commit b87a9c90e1
3 changed files with 81 additions and 34 deletions

View File

@ -1 +1 @@
8876ca3dd46b99fe7e6ad937f11493d37996231e
297273c45b205820a4c055082c71677197a40b55

View File

@ -345,7 +345,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
bug!("simd_fabs operand is not a float")
};
let op = op.to_scalar()?;
// FIXME: Using host floats.
match float_ty {
FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()),
FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
@ -438,12 +437,10 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
}
}
Op::FMax => {
assert!(matches!(dest.layout.ty.kind(), ty::Float(_)));
this.max_op(&left, &right)?.to_scalar()?
fmax_op(&left, &right)?
}
Op::FMin => {
assert!(matches!(dest.layout.ty.kind(), ty::Float(_)));
this.min_op(&left, &right)?.to_scalar()?
fmin_op(&left, &right)?
}
};
this.write_scalar(val, &dest.into())?;
@ -499,10 +496,28 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
this.binary_op(mir_op, &res, &op)?
}
Op::Max => {
this.max_op(&res, &op)?
if matches!(res.layout.ty.kind(), ty::Float(_)) {
ImmTy::from_scalar(fmax_op(&res, &op)?, res.layout)
} else {
// Just boring integers, so NaNs to worry about
if this.binary_op(BinOp::Ge, &res, &op)?.to_scalar()?.to_bool()? {
res
} else {
op
}
}
}
Op::Min => {
this.min_op(&res, &op)?
if matches!(res.layout.ty.kind(), ty::Float(_)) {
ImmTy::from_scalar(fmin_op(&res, &op)?, res.layout)
} else {
// Just boring integers, so NaNs to worry about
if this.binary_op(BinOp::Le, &res, &op)?.to_scalar()?.to_bool()? {
res
} else {
op
}
}
}
};
}
@ -1078,30 +1093,36 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
_ => bug!("`float_to_int_unchecked` called with non-int output type {:?}", dest_ty),
})
}
fn max_op(
&self,
left: &ImmTy<'tcx, Tag>,
right: &ImmTy<'tcx, Tag>,
) -> InterpResult<'tcx, ImmTy<'tcx, Tag>> {
let this = self.eval_context_ref();
Ok(if this.binary_op(BinOp::Gt, left, right)?.to_scalar()?.to_bool()? {
*left
} else {
*right
})
}
fn min_op(
&self,
left: &ImmTy<'tcx, Tag>,
right: &ImmTy<'tcx, Tag>,
) -> InterpResult<'tcx, ImmTy<'tcx, Tag>> {
let this = self.eval_context_ref();
Ok(if this.binary_op(BinOp::Lt, left, right)?.to_scalar()?.to_bool()? {
*left
} else {
*right
})
}
}
fn fmax_op<'tcx>(
left: &ImmTy<'tcx, Tag>,
right: &ImmTy<'tcx, Tag>,
) -> InterpResult<'tcx, Scalar<Tag>> {
assert_eq!(left.layout.ty, right.layout.ty);
let ty::Float(float_ty) = left.layout.ty.kind() else {
bug!("fmax operand is not a float")
};
let left = left.to_scalar()?;
let right = right.to_scalar()?;
Ok(match float_ty {
FloatTy::F32 => Scalar::from_f32(left.to_f32()?.max(right.to_f32()?)),
FloatTy::F64 => Scalar::from_f64(left.to_f64()?.max(right.to_f64()?)),
})
}
fn fmin_op<'tcx>(
left: &ImmTy<'tcx, Tag>,
right: &ImmTy<'tcx, Tag>,
) -> InterpResult<'tcx, Scalar<Tag>> {
assert_eq!(left.layout.ty, right.layout.ty);
let ty::Float(float_ty) = left.layout.ty.kind() else {
bug!("fmin operand is not a float")
};
let left = left.to_scalar()?;
let right = right.to_scalar()?;
Ok(match float_ty {
FloatTy::F32 => Scalar::from_f32(left.to_f32()?.min(right.to_f32()?)),
FloatTy::F64 => Scalar::from_f64(left.to_f64()?.min(right.to_f64()?)),
})
}

View File

@ -30,6 +30,19 @@ fn simd_ops_f32() {
assert_eq!(b.horizontal_max(), 3.0);
assert_eq!(a.horizontal_min(), 10.0);
assert_eq!(b.horizontal_min(), -4.0);
assert_eq!(
f32x2::from_array([0.0, f32::NAN]).max(f32x2::from_array([f32::NAN, 0.0])),
f32x2::from_array([0.0, 0.0])
);
assert_eq!(f32x2::from_array([0.0, f32::NAN]).horizontal_max(), 0.0);
assert_eq!(f32x2::from_array([f32::NAN, 0.0]).horizontal_max(), 0.0);
assert_eq!(
f32x2::from_array([0.0, f32::NAN]).min(f32x2::from_array([f32::NAN, 0.0])),
f32x2::from_array([0.0, 0.0])
);
assert_eq!(f32x2::from_array([0.0, f32::NAN]).horizontal_min(), 0.0);
assert_eq!(f32x2::from_array([f32::NAN, 0.0]).horizontal_min(), 0.0);
}
fn simd_ops_f64() {
@ -61,6 +74,19 @@ fn simd_ops_f64() {
assert_eq!(b.horizontal_max(), 3.0);
assert_eq!(a.horizontal_min(), 10.0);
assert_eq!(b.horizontal_min(), -4.0);
assert_eq!(
f64x2::from_array([0.0, f64::NAN]).max(f64x2::from_array([f64::NAN, 0.0])),
f64x2::from_array([0.0, 0.0])
);
assert_eq!(f64x2::from_array([0.0, f64::NAN]).horizontal_max(), 0.0);
assert_eq!(f64x2::from_array([f64::NAN, 0.0]).horizontal_max(), 0.0);
assert_eq!(
f64x2::from_array([0.0, f64::NAN]).min(f64x2::from_array([f64::NAN, 0.0])),
f64x2::from_array([0.0, 0.0])
);
assert_eq!(f64x2::from_array([0.0, f64::NAN]).horizontal_min(), 0.0);
assert_eq!(f64x2::from_array([f64::NAN, 0.0]).horizontal_min(), 0.0);
}
fn simd_ops_i32() {