diff --git a/src/tools/miri/src/shims/intrinsics/simd.rs b/src/tools/miri/src/shims/intrinsics/simd.rs index 200f37efa27..70f90aac2c2 100644 --- a/src/tools/miri/src/shims/intrinsics/simd.rs +++ b/src/tools/miri/src/shims/intrinsics/simd.rs @@ -32,28 +32,21 @@ fn emulate_simd_intrinsic( assert_eq!(dest_len, op_len); - #[derive(Copy, Clone)] - enum HostFloatOp { - Ceil, - Floor, - Round, - Trunc, - Sqrt, - } #[derive(Copy, Clone)] enum Op { MirOp(mir::UnOp), Abs, - HostOp(HostFloatOp), + Sqrt, + Round(rustc_apfloat::Round), } let which = match intrinsic_name { "neg" => Op::MirOp(mir::UnOp::Neg), "fabs" => Op::Abs, - "ceil" => Op::HostOp(HostFloatOp::Ceil), - "floor" => Op::HostOp(HostFloatOp::Floor), - "round" => Op::HostOp(HostFloatOp::Round), - "trunc" => Op::HostOp(HostFloatOp::Trunc), - "fsqrt" => Op::HostOp(HostFloatOp::Sqrt), + "fsqrt" => Op::Sqrt, + "ceil" => Op::Round(rustc_apfloat::Round::TowardPositive), + "floor" => Op::Round(rustc_apfloat::Round::TowardNegative), + "round" => Op::Round(rustc_apfloat::Round::NearestTiesToAway), + "trunc" => Op::Round(rustc_apfloat::Round::TowardZero), _ => unreachable!(), }; @@ -73,7 +66,7 @@ enum Op { FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()), } } - Op::HostOp(host_op) => { + Op::Sqrt => { let ty::Float(float_ty) = op.layout.ty.kind() else { span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name) }; @@ -81,28 +74,32 @@ enum Op { match float_ty { FloatTy::F32 => { let f = f32::from_bits(op.to_scalar().to_u32()?); - let res = match host_op { - HostFloatOp::Ceil => f.ceil(), - HostFloatOp::Floor => f.floor(), - HostFloatOp::Round => f.round(), - HostFloatOp::Trunc => f.trunc(), - HostFloatOp::Sqrt => f.sqrt(), - }; + let res = f.sqrt(); Scalar::from_u32(res.to_bits()) } FloatTy::F64 => { let f = f64::from_bits(op.to_scalar().to_u64()?); - let res = match host_op { - HostFloatOp::Ceil => f.ceil(), - HostFloatOp::Floor => f.floor(), - HostFloatOp::Round => f.round(), - HostFloatOp::Trunc => f.trunc(), - HostFloatOp::Sqrt => f.sqrt(), - }; + let res = f.sqrt(); Scalar::from_u64(res.to_bits()) } } - + } + Op::Round(rounding) => { + let ty::Float(float_ty) = op.layout.ty.kind() else { + span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name) + }; + match float_ty { + FloatTy::F32 => { + let f = op.to_scalar().to_f32()?; + let res = f.round_to_integral(rounding).value; + Scalar::from_f32(res) + } + FloatTy::F64 => { + let f = op.to_scalar().to_f64()?; + let res = f.round_to_integral(rounding).value; + Scalar::from_f64(res) + } + } } }; this.write_scalar(val, &dest)?;