implement simd_reduce_min/max

This commit is contained in:
Ralf Jung 2022-03-06 14:31:45 -05:00
parent db06d4998f
commit 9851b743c1
2 changed files with 52 additions and 10 deletions

View File

@ -433,7 +433,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
| "simd_reduce_or"
| "simd_reduce_xor"
| "simd_reduce_any"
| "simd_reduce_all" => {
| "simd_reduce_all"
| "simd_reduce_max"
| "simd_reduce_min" => {
use mir::BinOp;
let &[ref op] = check_arg_count(args)?;
@ -445,19 +447,27 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
enum Op {
MirOp(BinOp),
MirOpBool(BinOp),
Max,
Min,
}
// The initial value is the neutral element.
let (which, init) = match intrinsic_name {
"simd_reduce_and" => (Op::MirOp(BinOp::BitAnd), ImmTy::from_int(-1, dest.layout)),
"simd_reduce_or" => (Op::MirOp(BinOp::BitOr), ImmTy::from_int(0, dest.layout)),
"simd_reduce_xor" => (Op::MirOp(BinOp::BitXor), ImmTy::from_int(0, dest.layout)),
"simd_reduce_any" => (Op::MirOpBool(BinOp::BitOr), imm_from_bool(false)),
"simd_reduce_all" => (Op::MirOpBool(BinOp::BitAnd), imm_from_bool(true)),
let which = match intrinsic_name {
"simd_reduce_and" => Op::MirOp(BinOp::BitAnd),
"simd_reduce_or" => Op::MirOp(BinOp::BitOr),
"simd_reduce_xor" => Op::MirOp(BinOp::BitXor),
"simd_reduce_any" => Op::MirOpBool(BinOp::BitOr),
"simd_reduce_all" => Op::MirOpBool(BinOp::BitAnd),
"simd_reduce_max" => Op::Max,
"simd_reduce_min" => Op::Min,
_ => unreachable!(),
};
let mut res = init;
for i in 0..op_len {
// Initialize with first lane, then proceed with the rest.
let mut res = this.read_immediate(&this.mplace_index(&op, 0)?.into())?;
if matches!(which, Op::MirOpBool(_)) {
// Convert to `bool` scalar.
res = imm_from_bool(simd_element_to_bool(res)?);
}
for i in 1..op_len {
let op = this.read_immediate(&this.mplace_index(&op, i)?.into())?;
res = match which {
Op::MirOp(mir_op) => {
@ -467,6 +477,26 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
let op = imm_from_bool(simd_element_to_bool(op)?);
this.binary_op(mir_op, &res, &op)?
}
Op::Max => {
// if `op > res`...
if this.binary_op(BinOp::Gt, &op, &res)?.to_scalar()?.to_bool()? {
// update accumulator
op
} else {
// no change
res
}
}
Op::Min => {
// if `op < res`...
if this.binary_op(BinOp::Lt, &op, &res)?.to_scalar()?.to_bool()? {
// update accumulator
op
} else {
// no change
res
}
}
};
}
this.write_immediate(*res, dest)?;

View File

@ -24,6 +24,10 @@ fn simd_ops_f32() {
assert_eq!(b.horizontal_sum(), 2.0);
assert_eq!(a.horizontal_product(), 100.0 * 100.0);
assert_eq!(b.horizontal_product(), -24.0);
assert_eq!(a.horizontal_max(), 10.0);
assert_eq!(b.horizontal_max(), 3.0);
assert_eq!(a.horizontal_min(), 10.0);
assert_eq!(b.horizontal_min(), -4.0);
}
fn simd_ops_f64() {
@ -49,6 +53,10 @@ fn simd_ops_f64() {
assert_eq!(b.horizontal_sum(), 2.0);
assert_eq!(a.horizontal_product(), 100.0 * 100.0);
assert_eq!(b.horizontal_product(), -24.0);
assert_eq!(a.horizontal_max(), 10.0);
assert_eq!(b.horizontal_max(), 3.0);
assert_eq!(a.horizontal_min(), 10.0);
assert_eq!(b.horizontal_min(), -4.0);
}
fn simd_ops_i32() {
@ -86,6 +94,10 @@ fn simd_ops_i32() {
assert_eq!(b.horizontal_sum(), 2);
assert_eq!(a.horizontal_product(), 100 * 100);
assert_eq!(b.horizontal_product(), -24);
assert_eq!(a.horizontal_max(), 10);
assert_eq!(b.horizontal_max(), 3);
assert_eq!(a.horizontal_min(), 10);
assert_eq!(b.horizontal_min(), -4);
}
fn simd_mask() {