Auto merge of #3124 - eduardosm:fix-sse41-round, r=RalfJung

Fix rounding mode check in SSE4.1 round functions

Now it masks out the correct bit and adds some explanatory comments. Also extends the tests.
This commit is contained in:
bors 2023-10-17 15:24:26 +00:00
commit 2366a90d3f
2 changed files with 64 additions and 5 deletions

View File

@ -283,11 +283,20 @@ fn round_first<'tcx, F: rustc_apfloat::Float>(
assert_eq!(dest_len, left_len); assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len); assert_eq!(dest_len, right_len);
let rounding = match this.read_scalar(rounding)?.to_i32()? & !0x80 { // The fourth bit of `rounding` only affects the SSE status
0x00 => rustc_apfloat::Round::NearestTiesToEven, // register, which cannot be accessed from Miri (or from Rust,
0x01 => rustc_apfloat::Round::TowardNegative, // for that matter), so we can ignore it.
0x02 => rustc_apfloat::Round::TowardPositive, let rounding = match this.read_scalar(rounding)?.to_i32()? & !0b1000 {
0x03 => rustc_apfloat::Round::TowardZero, // When the third bit is 0, the rounding mode is determined by the
// first two bits.
0b000 => rustc_apfloat::Round::NearestTiesToEven,
0b001 => rustc_apfloat::Round::TowardNegative,
0b010 => rustc_apfloat::Round::TowardPositive,
0b011 => rustc_apfloat::Round::TowardZero,
// When the third bit is 1, the rounding mode is determined by the
// SSE status register. Since we do not support modifying it from
// Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
0b100..=0b111 => rustc_apfloat::Round::NearestTiesToEven,
rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"), rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
}; };

View File

@ -119,6 +119,31 @@ unsafe fn test_mm_round_sd() {
let r = _mm_round_sd::<_MM_FROUND_TO_NEAREST_INT>(a, b); let r = _mm_round_sd::<_MM_FROUND_TO_NEAREST_INT>(a, b);
let e = _mm_setr_pd(-2.0, 3.5); let e = _mm_setr_pd(-2.0, 3.5);
assert_eq_m128d(r, e); assert_eq_m128d(r, e);
let a = _mm_setr_pd(1.5, 3.5);
let b = _mm_setr_pd(-2.5, -4.5);
let r = _mm_round_sd::<_MM_FROUND_TO_NEG_INF>(a, b);
let e = _mm_setr_pd(-3.0, 3.5);
assert_eq_m128d(r, e);
let a = _mm_setr_pd(1.5, 3.5);
let b = _mm_setr_pd(-2.5, -4.5);
let r = _mm_round_sd::<_MM_FROUND_TO_POS_INF>(a, b);
let e = _mm_setr_pd(-2.0, 3.5);
assert_eq_m128d(r, e);
let a = _mm_setr_pd(1.5, 3.5);
let b = _mm_setr_pd(-2.5, -4.5);
let r = _mm_round_sd::<_MM_FROUND_TO_ZERO>(a, b);
let e = _mm_setr_pd(-2.0, 3.5);
assert_eq_m128d(r, e);
// Assume round-to-nearest by default
let a = _mm_setr_pd(1.5, 3.5);
let b = _mm_setr_pd(-2.5, -4.5);
let r = _mm_round_sd::<_MM_FROUND_CUR_DIRECTION>(a, b);
let e = _mm_setr_pd(-2.0, 3.5);
assert_eq_m128d(r, e);
} }
test_mm_round_sd(); test_mm_round_sd();
@ -129,6 +154,31 @@ unsafe fn test_mm_round_ss() {
let r = _mm_round_ss::<_MM_FROUND_TO_NEAREST_INT>(a, b); let r = _mm_round_ss::<_MM_FROUND_TO_NEAREST_INT>(a, b);
let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5); let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5);
assert_eq_m128(r, e); assert_eq_m128(r, e);
let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5);
let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5);
let r = _mm_round_ss::<_MM_FROUND_TO_NEG_INF>(a, b);
let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5);
assert_eq_m128(r, e);
let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5);
let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5);
let r = _mm_round_ss::<_MM_FROUND_TO_POS_INF>(a, b);
let e = _mm_setr_ps(-1.0, 3.5, 7.5, 15.5);
assert_eq_m128(r, e);
let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5);
let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5);
let r = _mm_round_ss::<_MM_FROUND_TO_ZERO>(a, b);
let e = _mm_setr_ps(-1.0, 3.5, 7.5, 15.5);
assert_eq_m128(r, e);
// Assume round-to-nearest by default
let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5);
let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5);
let r = _mm_round_ss::<_MM_FROUND_CUR_DIRECTION>(a, b);
let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5);
assert_eq_m128(r, e);
} }
test_mm_round_ss(); test_mm_round_ss();