implement missing SIMD comparison operators, simd_xor, and simd_reduce_all

This commit is contained in:
Ralf Jung 2022-03-05 13:30:16 -05:00
parent a715171534
commit 90207a5484
2 changed files with 80 additions and 25 deletions

View File

@ -356,7 +356,15 @@ fn call_intrinsic(
| "simd_shr" | "simd_shr"
| "simd_and" | "simd_and"
| "simd_or" | "simd_or"
| "simd_eq" => { | "simd_xor"
| "simd_eq"
| "simd_ne"
| "simd_lt"
| "simd_le"
| "simd_gt"
| "simd_ge" => {
use mir::BinOp;
let &[ref left, ref right] = check_arg_count(args)?; let &[ref left, ref right] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?; let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?; let (right, right_len) = this.operand_to_simd(right)?;
@ -366,16 +374,22 @@ fn call_intrinsic(
assert_eq!(dest_len, right_len); assert_eq!(dest_len, right_len);
let op = match intrinsic_name { let op = match intrinsic_name {
"simd_add" => mir::BinOp::Add, "simd_add" => BinOp::Add,
"simd_sub" => mir::BinOp::Sub, "simd_sub" => BinOp::Sub,
"simd_mul" => mir::BinOp::Mul, "simd_mul" => BinOp::Mul,
"simd_div" => mir::BinOp::Div, "simd_div" => BinOp::Div,
"simd_rem" => mir::BinOp::Rem, "simd_rem" => BinOp::Rem,
"simd_shl" => mir::BinOp::Shl, "simd_shl" => BinOp::Shl,
"simd_shr" => mir::BinOp::Shr, "simd_shr" => BinOp::Shr,
"simd_and" => mir::BinOp::BitAnd, "simd_and" => BinOp::BitAnd,
"simd_or" => mir::BinOp::BitOr, "simd_or" => BinOp::BitOr,
"simd_eq" => mir::BinOp::Eq, "simd_xor" => BinOp::BitXor,
"simd_eq" => BinOp::Eq,
"simd_ne" => BinOp::Ne,
"simd_lt" => BinOp::Lt,
"simd_le" => BinOp::Le,
"simd_gt" => BinOp::Gt,
"simd_ge" => BinOp::Ge,
_ => unreachable!(), _ => unreachable!(),
}; };
@ -384,7 +398,7 @@ fn call_intrinsic(
let right = this.read_immediate(&this.mplace_index(&right, i)?.into())?; let right = this.read_immediate(&this.mplace_index(&right, i)?.into())?;
let dest = this.mplace_index(&dest, i)?; let dest = this.mplace_index(&dest, i)?;
let (val, overflowed, ty) = this.overflowing_binary_op(op, &left, &right)?; let (val, overflowed, ty) = this.overflowing_binary_op(op, &left, &right)?;
if matches!(op, mir::BinOp::Shl | mir::BinOp::Shr) { if matches!(op, BinOp::Shl | BinOp::Shr) {
// Shifts have extra UB as SIMD operations that the MIR binop does not have. // Shifts have extra UB as SIMD operations that the MIR binop does not have.
// See <https://github.com/rust-lang/rust/issues/91237>. // See <https://github.com/rust-lang/rust/issues/91237>.
if overflowed { if overflowed {
@ -392,27 +406,38 @@ fn call_intrinsic(
throw_ub_format!("overflowing shift by {} in `{}` in SIMD lane {}", r_val, intrinsic_name, i); throw_ub_format!("overflowing shift by {} in `{}` in SIMD lane {}", r_val, intrinsic_name, i);
} }
} }
if matches!(op, mir::BinOp::Eq) { if matches!(op, BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge) {
// Special handling for boolean-returning operations // Special handling for boolean-returning operations
assert_eq!(ty, this.tcx.types.bool); assert_eq!(ty, this.tcx.types.bool);
let val = val.to_bool().unwrap(); let val = val.to_bool().unwrap();
let val = bool_to_simd_element(val, dest.layout.size); let val = bool_to_simd_element(val, dest.layout.size);
this.write_scalar(val, &dest.into())?; this.write_scalar(val, &dest.into())?;
} else { } else {
assert_ne!(ty, this.tcx.types.bool);
assert_eq!(ty, dest.layout.ty); assert_eq!(ty, dest.layout.ty);
this.write_scalar(val, &dest.into())?; this.write_scalar(val, &dest.into())?;
} }
} }
} }
"simd_reduce_any" => { "simd_reduce_any" | "simd_reduce_all" => {
let &[ref op] = check_arg_count(args)?; let &[ref op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?; let (op, op_len) = this.operand_to_simd(op)?;
let mut res = false; // the neutral element // the neutral element
let mut res = match intrinsic_name {
"simd_reduce_any" => false,
"simd_reduce_all" => true,
_ => bug!(),
};
for i in 0..op_len { for i in 0..op_len {
let op = this.read_immediate(&this.mplace_index(&op, i)?.into())?; let op = this.read_immediate(&this.mplace_index(&op, i)?.into())?;
let val = simd_element_to_bool(op)?; let val = simd_element_to_bool(op)?;
res = res | val; res = match intrinsic_name {
"simd_reduce_any" => res | val,
"simd_reduce_all" => res & val,
_ => bug!(),
};
} }
this.write_scalar(Scalar::from_bool(res), dest)?; this.write_scalar(Scalar::from_bool(res), dest)?;

View File

@ -12,6 +12,14 @@ fn simd_ops_f32() {
assert_eq!(a / f32x4::splat(2.0), f32x4::splat(5.0)); assert_eq!(a / f32x4::splat(2.0), f32x4::splat(5.0));
assert_eq!(a % b, f32x4::from_array([0.0, 0.0, 1.0, 2.0])); assert_eq!(a % b, f32x4::from_array([0.0, 0.0, 1.0, 2.0]));
assert_eq!(b.abs(), f32x4::from_array([1.0, 2.0, 3.0, 4.0])); assert_eq!(b.abs(), f32x4::from_array([1.0, 2.0, 3.0, 4.0]));
// FIXME use Mask::from_array once simd_cast is implemented.
assert_eq!(a.lanes_eq(f32x4::splat(5.0)*b), Mask::from_int(i32x4::from_array([0, -1, 0, 0])));
assert_eq!(a.lanes_ne(f32x4::splat(5.0)*b), Mask::from_int(i32x4::from_array([-1, 0, -1, -1])));
assert_eq!(a.lanes_le(f32x4::splat(5.0)*b), Mask::from_int(i32x4::from_array([0, -1, -1, 0])));
assert_eq!(a.lanes_lt(f32x4::splat(5.0)*b), Mask::from_int(i32x4::from_array([0, 0, -1, 0])));
assert_eq!(a.lanes_ge(f32x4::splat(5.0)*b), Mask::from_int(i32x4::from_array([-1, -1, 0, -1])));
assert_eq!(a.lanes_gt(f32x4::splat(5.0)*b), Mask::from_int(i32x4::from_array([-1, 0, 0, -1])));
} }
fn simd_ops_f64() { fn simd_ops_f64() {
@ -25,30 +33,48 @@ fn simd_ops_f64() {
assert_eq!(a / f64x4::splat(2.0), f64x4::splat(5.0)); assert_eq!(a / f64x4::splat(2.0), f64x4::splat(5.0));
assert_eq!(a % b, f64x4::from_array([0.0, 0.0, 1.0, 2.0])); assert_eq!(a % b, f64x4::from_array([0.0, 0.0, 1.0, 2.0]));
assert_eq!(b.abs(), f64x4::from_array([1.0, 2.0, 3.0, 4.0])); assert_eq!(b.abs(), f64x4::from_array([1.0, 2.0, 3.0, 4.0]));
// FIXME use Mask::from_array once simd_cast is implemented.
assert_eq!(a.lanes_eq(f64x4::splat(5.0)*b), Mask::from_int(i64x4::from_array([0, -1, 0, 0])));
assert_eq!(a.lanes_ne(f64x4::splat(5.0)*b), Mask::from_int(i64x4::from_array([-1, 0, -1, -1])));
assert_eq!(a.lanes_le(f64x4::splat(5.0)*b), Mask::from_int(i64x4::from_array([0, -1, -1, 0])));
assert_eq!(a.lanes_lt(f64x4::splat(5.0)*b), Mask::from_int(i64x4::from_array([0, 0, -1, 0])));
assert_eq!(a.lanes_ge(f64x4::splat(5.0)*b), Mask::from_int(i64x4::from_array([-1, -1, 0, -1])));
assert_eq!(a.lanes_gt(f64x4::splat(5.0)*b), Mask::from_int(i64x4::from_array([-1, 0, 0, -1])));
} }
fn simd_ops_i32() { fn simd_ops_i32() {
let a = i32x4::splat(10); let a = i32x4::splat(10);
let b = i32x4::from_array([1, 2, 3, 4]); let b = i32x4::from_array([1, 2, 3, -4]);
assert_eq!(-b, i32x4::from_array([-1, -2, -3, -4])); assert_eq!(-b, i32x4::from_array([-1, -2, -3, 4]));
assert_eq!(a + b, i32x4::from_array([11, 12, 13, 14])); assert_eq!(a + b, i32x4::from_array([11, 12, 13, 6]));
assert_eq!(a - b, i32x4::from_array([9, 8, 7, 6])); assert_eq!(a - b, i32x4::from_array([9, 8, 7, 14]));
assert_eq!(a * b, i32x4::from_array([10, 20, 30, 40])); assert_eq!(a * b, i32x4::from_array([10, 20, 30, -40]));
assert_eq!(a / b, i32x4::from_array([10, 5, 3, 2])); assert_eq!(a / b, i32x4::from_array([10, 5, 3, -2]));
assert_eq!(a / i32x4::splat(2), i32x4::splat(5)); assert_eq!(a / i32x4::splat(2), i32x4::splat(5));
assert_eq!(i32x2::splat(i32::MIN) / i32x2::splat(-1), i32x2::splat(i32::MIN)); assert_eq!(i32x2::splat(i32::MIN) / i32x2::splat(-1), i32x2::splat(i32::MIN));
assert_eq!(a % b, i32x4::from_array([0, 0, 1, 2])); assert_eq!(a % b, i32x4::from_array([0, 0, 1, 2]));
assert_eq!(i32x2::splat(i32::MIN) % i32x2::splat(-1), i32x2::splat(0)); assert_eq!(i32x2::splat(i32::MIN) % i32x2::splat(-1), i32x2::splat(0));
assert_eq!(b << i32x4::splat(2), i32x4::from_array([4, 8, 12, 16])); assert_eq!(b << i32x4::splat(2), i32x4::from_array([4, 8, 12, -16]));
assert_eq!(b >> i32x4::splat(1), i32x4::from_array([0, 1, 1, 2])); assert_eq!(b >> i32x4::splat(1), i32x4::from_array([0, 1, 1, -2]));
assert_eq!(b & i32x4::splat(2), i32x4::from_array([0, 2, 2, 0])); assert_eq!(b & i32x4::splat(2), i32x4::from_array([0, 2, 2, 0]));
assert_eq!(b | i32x4::splat(2), i32x4::from_array([3, 2, 3, 6])); assert_eq!(b | i32x4::splat(2), i32x4::from_array([3, 2, 3, -2]));
assert_eq!(b ^ i32x4::splat(2), i32x4::from_array([3, 0, 1, -2]));
// FIXME use Mask::from_array once simd_cast is implemented.
assert_eq!(a.lanes_eq(i32x4::splat(5)*b), Mask::from_int(i32x4::from_array([0, -1, 0, 0])));
assert_eq!(a.lanes_ne(i32x4::splat(5)*b), Mask::from_int(i32x4::from_array([-1, 0, -1, -1])));
assert_eq!(a.lanes_le(i32x4::splat(5)*b), Mask::from_int(i32x4::from_array([0, -1, -1, 0])));
assert_eq!(a.lanes_lt(i32x4::splat(5)*b), Mask::from_int(i32x4::from_array([0, 0, -1, 0])));
assert_eq!(a.lanes_ge(i32x4::splat(5)*b), Mask::from_int(i32x4::from_array([-1, -1, 0, -1])));
assert_eq!(a.lanes_gt(i32x4::splat(5)*b), Mask::from_int(i32x4::from_array([-1, 0, 0, -1])));
} }
fn simd_intrinsics() { fn simd_intrinsics() {
extern "platform-intrinsic" { extern "platform-intrinsic" {
fn simd_eq<T, U>(x: T, y: T) -> U; fn simd_eq<T, U>(x: T, y: T) -> U;
fn simd_reduce_any<T>(x: T) -> bool; fn simd_reduce_any<T>(x: T) -> bool;
fn simd_reduce_all<T>(x: T) -> bool;
fn simd_select<M, T>(m: M, yes: T, no: T) -> T; fn simd_select<M, T>(m: M, yes: T, no: T) -> T;
} }
unsafe { unsafe {
@ -60,6 +86,10 @@ fn simd_intrinsics() {
assert!(!simd_reduce_any(i32x4::splat(0))); assert!(!simd_reduce_any(i32x4::splat(0)));
assert!(simd_reduce_any(i32x4::splat(-1))); assert!(simd_reduce_any(i32x4::splat(-1)));
assert!(simd_reduce_any(i32x2::from_array([0, -1])));
assert!(!simd_reduce_all(i32x4::splat(0)));
assert!(simd_reduce_all(i32x4::splat(-1)));
assert!(!simd_reduce_all(i32x2::from_array([0, -1])));
assert_eq!(simd_select(i8x4::from_array([0, -1, -1, 0]), a, b), i32x4::from_array([1, 10, 10, 4])); assert_eq!(simd_select(i8x4::from_array([0, -1, -1, 0]), a, b), i32x4::from_array([1, 10, 10, 4]));
assert_eq!(simd_select(i8x4::from_array([0, -1, -1, 0]), b, a), i32x4::from_array([10, 2, 10, 10])); assert_eq!(simd_select(i8x4::from_array([0, -1, -1, 0]), b, a), i32x4::from_array([10, 2, 10, 10]));