From a8aa303cf0107afaabaf551a6c7c00835244e70f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eduardo=20S=C3=A1nchez=20Mu=C3=B1oz?= Date: Mon, 16 Oct 2023 18:40:22 +0200 Subject: [PATCH] Fix rounding mode check in SSE4.1 round functions Now it masks out the correct bit and adds some explanatory comments. Also extends the tests. --- src/tools/miri/src/shims/x86/sse41.rs | 19 +++++-- .../miri/tests/pass/intrinsics-x86-sse41.rs | 50 +++++++++++++++++++ 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/src/tools/miri/src/shims/x86/sse41.rs b/src/tools/miri/src/shims/x86/sse41.rs index 1c8100ecc65..cfa06ded6e6 100644 --- a/src/tools/miri/src/shims/x86/sse41.rs +++ b/src/tools/miri/src/shims/x86/sse41.rs @@ -283,11 +283,20 @@ fn round_first<'tcx, F: rustc_apfloat::Float>( assert_eq!(dest_len, left_len); assert_eq!(dest_len, right_len); - let rounding = match this.read_scalar(rounding)?.to_i32()? & !0x80 { - 0x00 => rustc_apfloat::Round::NearestTiesToEven, - 0x01 => rustc_apfloat::Round::TowardNegative, - 0x02 => rustc_apfloat::Round::TowardPositive, - 0x03 => rustc_apfloat::Round::TowardZero, + // The fourth bit of `rounding` only affects the SSE status + // register, which cannot be accessed from Miri (or from Rust, + // for that matter), so we can ignore it. + let rounding = match this.read_scalar(rounding)?.to_i32()? & !0b1000 { + // When the third bit is 0, the rounding mode is determined by the + // first two bits. + 0b000 => rustc_apfloat::Round::NearestTiesToEven, + 0b001 => rustc_apfloat::Round::TowardNegative, + 0b010 => rustc_apfloat::Round::TowardPositive, + 0b011 => rustc_apfloat::Round::TowardZero, + // When the third bit is 1, the rounding mode is determined by the + // SSE status register. Since we do not support modifying it from + // Miri (or Rust), we assume it to be at its default mode (round-to-nearest). + 0b100..=0b111 => rustc_apfloat::Round::NearestTiesToEven, rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"), }; diff --git a/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs b/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs index 01b915f1810..d5489ffaf4b 100644 --- a/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs +++ b/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs @@ -119,6 +119,31 @@ unsafe fn test_mm_round_sd() { let r = _mm_round_sd::<_MM_FROUND_TO_NEAREST_INT>(a, b); let e = _mm_setr_pd(-2.0, 3.5); assert_eq_m128d(r, e); + + let a = _mm_setr_pd(1.5, 3.5); + let b = _mm_setr_pd(-2.5, -4.5); + let r = _mm_round_sd::<_MM_FROUND_TO_NEG_INF>(a, b); + let e = _mm_setr_pd(-3.0, 3.5); + assert_eq_m128d(r, e); + + let a = _mm_setr_pd(1.5, 3.5); + let b = _mm_setr_pd(-2.5, -4.5); + let r = _mm_round_sd::<_MM_FROUND_TO_POS_INF>(a, b); + let e = _mm_setr_pd(-2.0, 3.5); + assert_eq_m128d(r, e); + + let a = _mm_setr_pd(1.5, 3.5); + let b = _mm_setr_pd(-2.5, -4.5); + let r = _mm_round_sd::<_MM_FROUND_TO_ZERO>(a, b); + let e = _mm_setr_pd(-2.0, 3.5); + assert_eq_m128d(r, e); + + // Assume round-to-nearest by default + let a = _mm_setr_pd(1.5, 3.5); + let b = _mm_setr_pd(-2.5, -4.5); + let r = _mm_round_sd::<_MM_FROUND_CUR_DIRECTION>(a, b); + let e = _mm_setr_pd(-2.0, 3.5); + assert_eq_m128d(r, e); } test_mm_round_sd(); @@ -129,6 +154,31 @@ unsafe fn test_mm_round_ss() { let r = _mm_round_ss::<_MM_FROUND_TO_NEAREST_INT>(a, b); let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5); assert_eq_m128(r, e); + + let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5); + let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5); + let r = _mm_round_ss::<_MM_FROUND_TO_NEG_INF>(a, b); + let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5); + assert_eq_m128(r, e); + + let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5); + let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5); + let r = _mm_round_ss::<_MM_FROUND_TO_POS_INF>(a, b); + let e = _mm_setr_ps(-1.0, 3.5, 7.5, 15.5); + assert_eq_m128(r, e); + + let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5); + let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5); + let r = _mm_round_ss::<_MM_FROUND_TO_ZERO>(a, b); + let e = _mm_setr_ps(-1.0, 3.5, 7.5, 15.5); + assert_eq_m128(r, e); + + // Assume round-to-nearest by default + let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5); + let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5); + let r = _mm_round_ss::<_MM_FROUND_CUR_DIRECTION>(a, b); + let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5); + assert_eq_m128(r, e); } test_mm_round_ss();