From 2f97eb68a0f77d3829151bc57855d42535465a6d Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Sun, 6 Mar 2022 15:26:15 -0500 Subject: [PATCH] implement simd_fmax/fmin --- src/shims/intrinsics.rs | 139 ++++++++++++++++++++------------ tests/run-pass/portable-simd.rs | 23 ++++-- 2 files changed, 103 insertions(+), 59 deletions(-) diff --git a/src/shims/intrinsics.rs b/src/shims/intrinsics.rs index 6f168536980..7dc4a000d1e 100644 --- a/src/shims/intrinsics.rs +++ b/src/shims/intrinsics.rs @@ -371,7 +371,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx | "simd_lt" | "simd_le" | "simd_gt" - | "simd_ge" => { + | "simd_ge" + | "simd_fmax" + | "simd_fmin" => { use mir::BinOp; let &[ref left, ref right] = check_arg_count(args)?; @@ -382,23 +384,30 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx assert_eq!(dest_len, left_len); assert_eq!(dest_len, right_len); - let mir_op = match intrinsic_name { - "simd_add" => BinOp::Add, - "simd_sub" => BinOp::Sub, - "simd_mul" => BinOp::Mul, - "simd_div" => BinOp::Div, - "simd_rem" => BinOp::Rem, - "simd_shl" => BinOp::Shl, - "simd_shr" => BinOp::Shr, - "simd_and" => BinOp::BitAnd, - "simd_or" => BinOp::BitOr, - "simd_xor" => BinOp::BitXor, - "simd_eq" => BinOp::Eq, - "simd_ne" => BinOp::Ne, - "simd_lt" => BinOp::Lt, - "simd_le" => BinOp::Le, - "simd_gt" => BinOp::Gt, - "simd_ge" => BinOp::Ge, + enum Op { + MirOp(BinOp), + FMax, + FMin, + } + let which = match intrinsic_name { + "simd_add" => Op::MirOp(BinOp::Add), + "simd_sub" => Op::MirOp(BinOp::Sub), + "simd_mul" => Op::MirOp(BinOp::Mul), + "simd_div" => Op::MirOp(BinOp::Div), + "simd_rem" => Op::MirOp(BinOp::Rem), + "simd_shl" => Op::MirOp(BinOp::Shl), + "simd_shr" => Op::MirOp(BinOp::Shr), + "simd_and" => Op::MirOp(BinOp::BitAnd), + "simd_or" => Op::MirOp(BinOp::BitOr), + "simd_xor" => Op::MirOp(BinOp::BitXor), + "simd_eq" => Op::MirOp(BinOp::Eq), + "simd_ne" => Op::MirOp(BinOp::Ne), + "simd_lt" => Op::MirOp(BinOp::Lt), + "simd_le" => Op::MirOp(BinOp::Le), + "simd_gt" => Op::MirOp(BinOp::Gt), + "simd_ge" => Op::MirOp(BinOp::Ge), + "simd_fmax" => Op::FMax, + "simd_fmin" => Op::FMin, _ => unreachable!(), }; @@ -406,26 +415,38 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx let left = this.read_immediate(&this.mplace_index(&left, i)?.into())?; let right = this.read_immediate(&this.mplace_index(&right, i)?.into())?; let dest = this.mplace_index(&dest, i)?; - let (val, overflowed, ty) = this.overflowing_binary_op(mir_op, &left, &right)?; - if matches!(mir_op, BinOp::Shl | BinOp::Shr) { - // Shifts have extra UB as SIMD operations that the MIR binop does not have. - // See . - if overflowed { - let r_val = right.to_scalar()?.to_bits(right.layout.size)?; - throw_ub_format!("overflowing shift by {} in `{}` in SIMD lane {}", r_val, intrinsic_name, i); + let val = match which { + Op::MirOp(mir_op) => { + let (val, overflowed, ty) = this.overflowing_binary_op(mir_op, &left, &right)?; + if matches!(mir_op, BinOp::Shl | BinOp::Shr) { + // Shifts have extra UB as SIMD operations that the MIR binop does not have. + // See . + if overflowed { + let r_val = right.to_scalar()?.to_bits(right.layout.size)?; + throw_ub_format!("overflowing shift by {} in `{}` in SIMD lane {}", r_val, intrinsic_name, i); + } + } + if matches!(mir_op, BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge) { + // Special handling for boolean-returning operations + assert_eq!(ty, this.tcx.types.bool); + let val = val.to_bool().unwrap(); + bool_to_simd_element(val, dest.layout.size) + } else { + assert_ne!(ty, this.tcx.types.bool); + assert_eq!(ty, dest.layout.ty); + val + } } - } - if matches!(mir_op, BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge) { - // Special handling for boolean-returning operations - assert_eq!(ty, this.tcx.types.bool); - let val = val.to_bool().unwrap(); - let val = bool_to_simd_element(val, dest.layout.size); - this.write_scalar(val, &dest.into())?; - } else { - assert_ne!(ty, this.tcx.types.bool); - assert_eq!(ty, dest.layout.ty); - this.write_scalar(val, &dest.into())?; - } + Op::FMax => { + assert!(matches!(dest.layout.ty.kind(), ty::Float(_))); + this.max_op(&left, &right)?.to_scalar()? + } + Op::FMin => { + assert!(matches!(dest.layout.ty.kind(), ty::Float(_))); + this.min_op(&left, &right)?.to_scalar()? + } + }; + this.write_scalar(val, &dest.into())?; } } #[rustfmt::skip] @@ -478,24 +499,10 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx this.binary_op(mir_op, &res, &op)? } Op::Max => { - // if `op > res`... - if this.binary_op(BinOp::Gt, &op, &res)?.to_scalar()?.to_bool()? { - // update accumulator - op - } else { - // no change - res - } + this.max_op(&res, &op)? } Op::Min => { - // if `op < res`... - if this.binary_op(BinOp::Lt, &op, &res)?.to_scalar()?.to_bool()? { - // update accumulator - op - } else { - // no change - res - } + this.min_op(&res, &op)? } }; } @@ -1071,4 +1078,30 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx _ => bug!("`float_to_int_unchecked` called with non-int output type {:?}", dest_ty), }) } + + fn max_op( + &self, + left: &ImmTy<'tcx, Tag>, + right: &ImmTy<'tcx, Tag>, + ) -> InterpResult<'tcx, ImmTy<'tcx, Tag>> { + let this = self.eval_context_ref(); + Ok(if this.binary_op(BinOp::Gt, left, right)?.to_scalar()?.to_bool()? { + *left + } else { + *right + }) + } + + fn min_op( + &self, + left: &ImmTy<'tcx, Tag>, + right: &ImmTy<'tcx, Tag>, + ) -> InterpResult<'tcx, ImmTy<'tcx, Tag>> { + let this = self.eval_context_ref(); + Ok(if this.binary_op(BinOp::Lt, left, right)?.to_scalar()?.to_bool()? { + *left + } else { + *right + }) + } } diff --git a/tests/run-pass/portable-simd.rs b/tests/run-pass/portable-simd.rs index ccedf61a381..817d18a45d4 100644 --- a/tests/run-pass/portable-simd.rs +++ b/tests/run-pass/portable-simd.rs @@ -12,6 +12,8 @@ fn simd_ops_f32() { assert_eq!(a / f32x4::splat(2.0), f32x4::splat(5.0)); assert_eq!(a % b, f32x4::from_array([0.0, 0.0, 1.0, 2.0])); assert_eq!(b.abs(), f32x4::from_array([1.0, 2.0, 3.0, 4.0])); + assert_eq!(a.max(b * f32x4::splat(4.0)), f32x4::from_array([10.0, 10.0, 12.0, 10.0])); + assert_eq!(a.min(b * f32x4::splat(4.0)), f32x4::from_array([4.0, 8.0, 10.0, -16.0])); assert_eq!(a.lanes_eq(f32x4::splat(5.0) * b), Mask::from_array([false, true, false, false])); assert_eq!(a.lanes_ne(f32x4::splat(5.0) * b), Mask::from_array([true, false, true, true])); @@ -41,6 +43,8 @@ fn simd_ops_f64() { assert_eq!(a / f64x4::splat(2.0), f64x4::splat(5.0)); assert_eq!(a % b, f64x4::from_array([0.0, 0.0, 1.0, 2.0])); assert_eq!(b.abs(), f64x4::from_array([1.0, 2.0, 3.0, 4.0])); + assert_eq!(a.max(b * f64x4::splat(4.0)), f64x4::from_array([10.0, 10.0, 12.0, 10.0])); + assert_eq!(a.min(b * f64x4::splat(4.0)), f64x4::from_array([4.0, 8.0, 10.0, -16.0])); assert_eq!(a.lanes_eq(f64x4::splat(5.0) * b), Mask::from_array([false, true, false, false])); assert_eq!(a.lanes_ne(f64x4::splat(5.0) * b), Mask::from_array([true, false, true, true])); @@ -71,6 +75,12 @@ fn simd_ops_i32() { assert_eq!(i32x2::splat(i32::MIN) / i32x2::splat(-1), i32x2::splat(i32::MIN)); assert_eq!(a % b, i32x4::from_array([0, 0, 1, 2])); assert_eq!(i32x2::splat(i32::MIN) % i32x2::splat(-1), i32x2::splat(0)); + assert_eq!(b.abs(), i32x4::from_array([1, 2, 3, 4])); + // FIXME not a per-lane method (https://github.com/rust-lang/rust/issues/94682) + // assert_eq!(a.max(b * i32x4::splat(4)), i32x4::from_array([10, 10, 12, 10])); + // assert_eq!(a.min(b * i32x4::splat(4)), i32x4::from_array([4, 8, 10, -16])); + + assert_eq!(!b, i32x4::from_array([!1, !2, !3, !-4])); assert_eq!(b << i32x4::splat(2), i32x4::from_array([4, 8, 12, -16])); assert_eq!(b >> i32x4::splat(1), i32x4::from_array([0, 1, 1, -2])); assert_eq!(b & i32x4::splat(2), i32x4::from_array([0, 2, 2, 0])); @@ -84,12 +94,6 @@ fn simd_ops_i32() { assert_eq!(a.lanes_ge(i32x4::splat(5) * b), Mask::from_array([true, true, false, true])); assert_eq!(a.lanes_gt(i32x4::splat(5) * b), Mask::from_array([true, false, false, true])); - assert_eq!(a.horizontal_and(), 10); - assert_eq!(b.horizontal_and(), 0); - assert_eq!(a.horizontal_or(), 10); - assert_eq!(b.horizontal_or(), -1); - assert_eq!(a.horizontal_xor(), 0); - assert_eq!(b.horizontal_xor(), -4); assert_eq!(a.horizontal_sum(), 40); assert_eq!(b.horizontal_sum(), 2); assert_eq!(a.horizontal_product(), 100 * 100); @@ -98,6 +102,13 @@ fn simd_ops_i32() { assert_eq!(b.horizontal_max(), 3); assert_eq!(a.horizontal_min(), 10); assert_eq!(b.horizontal_min(), -4); + + assert_eq!(a.horizontal_and(), 10); + assert_eq!(b.horizontal_and(), 0); + assert_eq!(a.horizontal_or(), 10); + assert_eq!(b.horizontal_or(), -1); + assert_eq!(a.horizontal_xor(), 0); + assert_eq!(b.horizontal_xor(), -4); } fn simd_mask() {