implement SIMD sqrt and fma
This commit is contained in:
parent
a9a0d0e5e7
commit
4fd5dca27c
@ -329,7 +329,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
|
||||
| "simd_ceil"
|
||||
| "simd_floor"
|
||||
| "simd_round"
|
||||
| "simd_trunc" => {
|
||||
| "simd_trunc"
|
||||
| "simd_fsqrt" => {
|
||||
let &[ref op] = check_arg_count(args)?;
|
||||
let (op, op_len) = this.operand_to_simd(op)?;
|
||||
let (dest, dest_len) = this.place_to_simd(dest)?;
|
||||
@ -342,6 +343,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
|
||||
Floor,
|
||||
Round,
|
||||
Trunc,
|
||||
Sqrt,
|
||||
}
|
||||
#[derive(Copy, Clone)]
|
||||
enum Op {
|
||||
@ -356,6 +358,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
|
||||
"simd_floor" => Op::HostOp(HostFloatOp::Floor),
|
||||
"simd_round" => Op::HostOp(HostFloatOp::Round),
|
||||
"simd_trunc" => Op::HostOp(HostFloatOp::Trunc),
|
||||
"simd_fsqrt" => Op::HostOp(HostFloatOp::Sqrt),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
@ -388,6 +391,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
|
||||
HostFloatOp::Floor => f.floor(),
|
||||
HostFloatOp::Round => f.round(),
|
||||
HostFloatOp::Trunc => f.trunc(),
|
||||
HostFloatOp::Sqrt => f.sqrt(),
|
||||
};
|
||||
Scalar::from_u32(res.to_bits())
|
||||
}
|
||||
@ -398,6 +402,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
|
||||
HostFloatOp::Floor => f.floor(),
|
||||
HostFloatOp::Round => f.round(),
|
||||
HostFloatOp::Trunc => f.trunc(),
|
||||
HostFloatOp::Sqrt => f.sqrt(),
|
||||
};
|
||||
Scalar::from_u64(res.to_bits())
|
||||
}
|
||||
@ -508,6 +513,36 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
|
||||
this.write_scalar(val, &dest.into())?;
|
||||
}
|
||||
}
|
||||
"simd_fma" => {
|
||||
let &[ref a, ref b, ref c] = check_arg_count(args)?;
|
||||
let (a, a_len) = this.operand_to_simd(a)?;
|
||||
let (b, b_len) = this.operand_to_simd(b)?;
|
||||
let (c, c_len) = this.operand_to_simd(c)?;
|
||||
let (dest, dest_len) = this.place_to_simd(dest)?;
|
||||
|
||||
assert_eq!(dest_len, a_len);
|
||||
assert_eq!(dest_len, b_len);
|
||||
assert_eq!(dest_len, c_len);
|
||||
|
||||
for i in 0..dest_len {
|
||||
let a = this.read_immediate(&this.mplace_index(&a, i)?.into())?.to_scalar()?;
|
||||
let b = this.read_immediate(&this.mplace_index(&b, i)?.into())?.to_scalar()?;
|
||||
let c = this.read_immediate(&this.mplace_index(&c, i)?.into())?.to_scalar()?;
|
||||
let dest = this.mplace_index(&dest, i)?;
|
||||
|
||||
// Works for f32 and f64.
|
||||
let ty::Float(float_ty) = dest.layout.ty.kind() else {
|
||||
bug!("{} operand is not a float", intrinsic_name)
|
||||
};
|
||||
let val = match float_ty {
|
||||
FloatTy::F32 =>
|
||||
Scalar::from_f32(a.to_f32()?.mul_add(b.to_f32()?, c.to_f32()?).value),
|
||||
FloatTy::F64 =>
|
||||
Scalar::from_f64(a.to_f64()?.mul_add(b.to_f64()?, c.to_f64()?).value),
|
||||
};
|
||||
this.write_scalar(val, &dest.into())?;
|
||||
}
|
||||
}
|
||||
#[rustfmt::skip]
|
||||
| "simd_reduce_and"
|
||||
| "simd_reduce_or"
|
||||
|
@ -15,6 +15,11 @@ fn simd_ops_f32() {
|
||||
assert_eq!(a.max(b * f32x4::splat(4.0)), f32x4::from_array([10.0, 10.0, 12.0, 10.0]));
|
||||
assert_eq!(a.min(b * f32x4::splat(4.0)), f32x4::from_array([4.0, 8.0, 10.0, -16.0]));
|
||||
|
||||
assert_eq!(a.mul_add(b, a), (a*b)+a);
|
||||
assert_eq!(b.mul_add(b, a), (b*b)+a);
|
||||
assert_eq!((a*a).sqrt(), a);
|
||||
assert_eq!((b*b).sqrt(), b.abs());
|
||||
|
||||
assert_eq!(a.lanes_eq(f32x4::splat(5.0) * b), Mask::from_array([false, true, false, false]));
|
||||
assert_eq!(a.lanes_ne(f32x4::splat(5.0) * b), Mask::from_array([true, false, true, true]));
|
||||
assert_eq!(a.lanes_le(f32x4::splat(5.0) * b), Mask::from_array([false, true, true, false]));
|
||||
@ -59,6 +64,11 @@ fn simd_ops_f64() {
|
||||
assert_eq!(a.max(b * f64x4::splat(4.0)), f64x4::from_array([10.0, 10.0, 12.0, 10.0]));
|
||||
assert_eq!(a.min(b * f64x4::splat(4.0)), f64x4::from_array([4.0, 8.0, 10.0, -16.0]));
|
||||
|
||||
assert_eq!(a.mul_add(b, a), (a*b)+a);
|
||||
assert_eq!(b.mul_add(b, a), (b*b)+a);
|
||||
assert_eq!((a*a).sqrt(), a);
|
||||
assert_eq!((b*b).sqrt(), b.abs());
|
||||
|
||||
assert_eq!(a.lanes_eq(f64x4::splat(5.0) * b), Mask::from_array([false, true, false, false]));
|
||||
assert_eq!(a.lanes_ne(f64x4::splat(5.0) * b), Mask::from_array([true, false, true, true]));
|
||||
assert_eq!(a.lanes_le(f64x4::splat(5.0) * b), Mask::from_array([false, true, true, false]));
|
||||
|
Loading…
x
Reference in New Issue
Block a user