diff --git a/rust-version b/rust-version index 109a8080b07..a769188204f 100644 --- a/rust-version +++ b/rust-version @@ -1 +1 @@ -8876ca3dd46b99fe7e6ad937f11493d37996231e +297273c45b205820a4c055082c71677197a40b55 diff --git a/src/shims/intrinsics.rs b/src/shims/intrinsics.rs index 7dc4a000d1e..897ebe4ae79 100644 --- a/src/shims/intrinsics.rs +++ b/src/shims/intrinsics.rs @@ -345,7 +345,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx bug!("simd_fabs operand is not a float") }; let op = op.to_scalar()?; - // FIXME: Using host floats. match float_ty { FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()), FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()), @@ -438,12 +437,10 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx } } Op::FMax => { - assert!(matches!(dest.layout.ty.kind(), ty::Float(_))); - this.max_op(&left, &right)?.to_scalar()? + fmax_op(&left, &right)? } Op::FMin => { - assert!(matches!(dest.layout.ty.kind(), ty::Float(_))); - this.min_op(&left, &right)?.to_scalar()? + fmin_op(&left, &right)? } }; this.write_scalar(val, &dest.into())?; @@ -499,10 +496,28 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx this.binary_op(mir_op, &res, &op)? } Op::Max => { - this.max_op(&res, &op)? + if matches!(res.layout.ty.kind(), ty::Float(_)) { + ImmTy::from_scalar(fmax_op(&res, &op)?, res.layout) + } else { + // Just boring integers, so NaNs to worry about + if this.binary_op(BinOp::Ge, &res, &op)?.to_scalar()?.to_bool()? { + res + } else { + op + } + } } Op::Min => { - this.min_op(&res, &op)? + if matches!(res.layout.ty.kind(), ty::Float(_)) { + ImmTy::from_scalar(fmin_op(&res, &op)?, res.layout) + } else { + // Just boring integers, so NaNs to worry about + if this.binary_op(BinOp::Le, &res, &op)?.to_scalar()?.to_bool()? { + res + } else { + op + } + } } }; } @@ -1078,30 +1093,36 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx _ => bug!("`float_to_int_unchecked` called with non-int output type {:?}", dest_ty), }) } - - fn max_op( - &self, - left: &ImmTy<'tcx, Tag>, - right: &ImmTy<'tcx, Tag>, - ) -> InterpResult<'tcx, ImmTy<'tcx, Tag>> { - let this = self.eval_context_ref(); - Ok(if this.binary_op(BinOp::Gt, left, right)?.to_scalar()?.to_bool()? { - *left - } else { - *right - }) - } - - fn min_op( - &self, - left: &ImmTy<'tcx, Tag>, - right: &ImmTy<'tcx, Tag>, - ) -> InterpResult<'tcx, ImmTy<'tcx, Tag>> { - let this = self.eval_context_ref(); - Ok(if this.binary_op(BinOp::Lt, left, right)?.to_scalar()?.to_bool()? { - *left - } else { - *right - }) - } +} + +fn fmax_op<'tcx>( + left: &ImmTy<'tcx, Tag>, + right: &ImmTy<'tcx, Tag>, +) -> InterpResult<'tcx, Scalar> { + assert_eq!(left.layout.ty, right.layout.ty); + let ty::Float(float_ty) = left.layout.ty.kind() else { + bug!("fmax operand is not a float") + }; + let left = left.to_scalar()?; + let right = right.to_scalar()?; + Ok(match float_ty { + FloatTy::F32 => Scalar::from_f32(left.to_f32()?.max(right.to_f32()?)), + FloatTy::F64 => Scalar::from_f64(left.to_f64()?.max(right.to_f64()?)), + }) +} + +fn fmin_op<'tcx>( + left: &ImmTy<'tcx, Tag>, + right: &ImmTy<'tcx, Tag>, +) -> InterpResult<'tcx, Scalar> { + assert_eq!(left.layout.ty, right.layout.ty); + let ty::Float(float_ty) = left.layout.ty.kind() else { + bug!("fmin operand is not a float") + }; + let left = left.to_scalar()?; + let right = right.to_scalar()?; + Ok(match float_ty { + FloatTy::F32 => Scalar::from_f32(left.to_f32()?.min(right.to_f32()?)), + FloatTy::F64 => Scalar::from_f64(left.to_f64()?.min(right.to_f64()?)), + }) } diff --git a/tests/run-pass/portable-simd.rs b/tests/run-pass/portable-simd.rs index 817d18a45d4..48297ee4e69 100644 --- a/tests/run-pass/portable-simd.rs +++ b/tests/run-pass/portable-simd.rs @@ -30,6 +30,19 @@ fn simd_ops_f32() { assert_eq!(b.horizontal_max(), 3.0); assert_eq!(a.horizontal_min(), 10.0); assert_eq!(b.horizontal_min(), -4.0); + + assert_eq!( + f32x2::from_array([0.0, f32::NAN]).max(f32x2::from_array([f32::NAN, 0.0])), + f32x2::from_array([0.0, 0.0]) + ); + assert_eq!(f32x2::from_array([0.0, f32::NAN]).horizontal_max(), 0.0); + assert_eq!(f32x2::from_array([f32::NAN, 0.0]).horizontal_max(), 0.0); + assert_eq!( + f32x2::from_array([0.0, f32::NAN]).min(f32x2::from_array([f32::NAN, 0.0])), + f32x2::from_array([0.0, 0.0]) + ); + assert_eq!(f32x2::from_array([0.0, f32::NAN]).horizontal_min(), 0.0); + assert_eq!(f32x2::from_array([f32::NAN, 0.0]).horizontal_min(), 0.0); } fn simd_ops_f64() { @@ -61,6 +74,19 @@ fn simd_ops_f64() { assert_eq!(b.horizontal_max(), 3.0); assert_eq!(a.horizontal_min(), 10.0); assert_eq!(b.horizontal_min(), -4.0); + + assert_eq!( + f64x2::from_array([0.0, f64::NAN]).max(f64x2::from_array([f64::NAN, 0.0])), + f64x2::from_array([0.0, 0.0]) + ); + assert_eq!(f64x2::from_array([0.0, f64::NAN]).horizontal_max(), 0.0); + assert_eq!(f64x2::from_array([f64::NAN, 0.0]).horizontal_max(), 0.0); + assert_eq!( + f64x2::from_array([0.0, f64::NAN]).min(f64x2::from_array([f64::NAN, 0.0])), + f64x2::from_array([0.0, 0.0]) + ); + assert_eq!(f64x2::from_array([0.0, f64::NAN]).horizontal_min(), 0.0); + assert_eq!(f64x2::from_array([f64::NAN, 0.0]).horizontal_min(), 0.0); } fn simd_ops_i32() {