Update bitmask API
This commit is contained in:
parent
da42aa5403
commit
eec42808aa
@ -76,6 +76,9 @@
|
||||
pub(crate) fn simd_reduce_and<T, U>(x: T) -> U;
|
||||
pub(crate) fn simd_reduce_or<T, U>(x: T) -> U;
|
||||
pub(crate) fn simd_reduce_xor<T, U>(x: T) -> U;
|
||||
|
||||
// truncate integer vector to bitmask
|
||||
pub(crate) fn simd_bitmask<T, U>(x: T) -> U;
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
|
@ -1,14 +1,38 @@
|
||||
/// Implemented for vectors that are supported by the implementation.
|
||||
pub trait LanesAtMost32 {}
|
||||
pub trait LanesAtMost32: sealed::Sealed {
|
||||
#[doc(hidden)]
|
||||
type BitMask: Into<u64>;
|
||||
}
|
||||
|
||||
mod sealed {
|
||||
pub trait Sealed {}
|
||||
}
|
||||
|
||||
macro_rules! impl_for {
|
||||
{ $name:ident } => {
|
||||
impl LanesAtMost32 for $name<1> {}
|
||||
impl LanesAtMost32 for $name<2> {}
|
||||
impl LanesAtMost32 for $name<4> {}
|
||||
impl LanesAtMost32 for $name<8> {}
|
||||
impl LanesAtMost32 for $name<16> {}
|
||||
impl LanesAtMost32 for $name<32> {}
|
||||
impl<const LANES: usize> sealed::Sealed for $name<LANES>
|
||||
where
|
||||
$name<LANES>: LanesAtMost32,
|
||||
{}
|
||||
|
||||
impl LanesAtMost32 for $name<1> {
|
||||
type BitMask = u8;
|
||||
}
|
||||
impl LanesAtMost32 for $name<2> {
|
||||
type BitMask = u8;
|
||||
}
|
||||
impl LanesAtMost32 for $name<4> {
|
||||
type BitMask = u8;
|
||||
}
|
||||
impl LanesAtMost32 for $name<8> {
|
||||
type BitMask = u8;
|
||||
}
|
||||
impl LanesAtMost32 for $name<16> {
|
||||
type BitMask = u16;
|
||||
}
|
||||
impl LanesAtMost32 for $name<32> {
|
||||
type BitMask = u32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,13 +1,9 @@
|
||||
use crate::LanesAtMost32;
|
||||
|
||||
/// A mask where each lane is represented by a single bit.
|
||||
#[derive(Copy, Clone, Debug, PartialOrd, PartialEq, Ord, Eq, Hash)]
|
||||
#[repr(transparent)]
|
||||
pub struct BitMask<const LANES: usize>(u64)
|
||||
pub struct BitMask<const LANES: usize>(u64);
|
||||
|
||||
impl<const LANES: usize> BitMask<LANES>
|
||||
where
|
||||
Self: LanesAtMost32,
|
||||
{
|
||||
#[inline]
|
||||
pub fn splat(value: bool) -> Self {
|
||||
@ -25,13 +21,50 @@ pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
|
||||
|
||||
#[inline]
|
||||
pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
|
||||
self.0 ^= ((value ^ self.test(lane)) as u64) << lane
|
||||
self.0 ^= ((value ^ self.test_unchecked(lane)) as u64) << lane
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn to_int<V, T>(self) -> V
|
||||
where
|
||||
V: Default + AsMut<[T; LANES]>,
|
||||
T: From<i8>,
|
||||
{
|
||||
// TODO this should be an intrinsic sign-extension
|
||||
let mut v = V::default();
|
||||
for i in 0..LANES {
|
||||
let lane = unsafe { self.test_unchecked(i) };
|
||||
v.as_mut()[i] = (-(lane as i8)).into();
|
||||
}
|
||||
v
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub unsafe fn from_int_unchecked<V>(value: V) -> Self
|
||||
where
|
||||
V: crate::LanesAtMost32,
|
||||
{
|
||||
let mask: V::BitMask = crate::intrinsics::simd_bitmask(value);
|
||||
Self(mask.into())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn to_bitmask(self) -> u64 {
|
||||
self.0
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn any(self) -> bool {
|
||||
self != Self::splat(false)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn all(self) -> bool {
|
||||
self == Self::splat(true)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitAnd for BitMask<LANES>
|
||||
where
|
||||
Self: LanesAtMost32,
|
||||
{
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
@ -41,8 +74,6 @@ fn bitand(self, rhs: Self) -> Self {
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitAnd<bool> for BitMask<LANES>
|
||||
where
|
||||
Self: LanesAtMost32,
|
||||
{
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
@ -52,8 +83,6 @@ fn bitand(self, rhs: bool) -> Self {
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitAnd<BitMask<LANES>> for bool
|
||||
where
|
||||
BitMask<LANES>: LanesAtMost32,
|
||||
{
|
||||
type Output = BitMask<LANES>;
|
||||
#[inline]
|
||||
@ -63,8 +92,6 @@ fn bitand(self, rhs: BitMask<LANES>) -> BitMask<LANES> {
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitOr for BitMask<LANES>
|
||||
where
|
||||
Self: LanesAtMost32,
|
||||
{
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
@ -73,31 +100,7 @@ fn bitor(self, rhs: Self) -> Self {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitOr<bool> for BitMask<LANES>
|
||||
where
|
||||
Self: LanesAtMost32,
|
||||
{
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn bitor(self, rhs: bool) -> Self {
|
||||
self | Self::splat(rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitOr<BitMask<LANES>> for bool
|
||||
where
|
||||
BitMask<LANES>: LanesAtMost32,
|
||||
{
|
||||
type Output = BitMask<LANES>;
|
||||
#[inline]
|
||||
fn bitor(self, rhs: BitMask<LANES>) -> BitMask<LANES> {
|
||||
BitMask::<LANES>::splat(self) | rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitXor for BitMask<LANES>
|
||||
where
|
||||
Self: LanesAtMost32,
|
||||
{
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
@ -106,42 +109,16 @@ fn bitxor(self, rhs: Self) -> Self::Output {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitXor<bool> for BitMask<LANES>
|
||||
where
|
||||
Self: LanesAtMost32,
|
||||
{
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn bitxor(self, rhs: bool) -> Self::Output {
|
||||
self ^ Self::splat(rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitXor<BitMask<LANES>> for bool
|
||||
where
|
||||
BitMask<LANES>: LanesAtMost32,
|
||||
{
|
||||
type Output = BitMask<LANES>;
|
||||
#[inline]
|
||||
fn bitxor(self, rhs: BitMask<LANES>) -> Self::Output {
|
||||
BitMask::<LANES>::splat(self) ^ rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::Not for BitMask<LANES>
|
||||
where
|
||||
Self: LanesAtMost32,
|
||||
{
|
||||
type Output = BitMask<LANES>;
|
||||
#[inline]
|
||||
fn not(self) -> Self::Output {
|
||||
Self(!self.0)
|
||||
Self(!self.0) & Self::splat(true)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitAndAssign for BitMask<LANES>
|
||||
where
|
||||
Self: LanesAtMost32,
|
||||
{
|
||||
#[inline]
|
||||
fn bitand_assign(&mut self, rhs: Self) {
|
||||
@ -149,19 +126,7 @@ fn bitand_assign(&mut self, rhs: Self) {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitAndAssign<bool> for BitMask<LANES>
|
||||
where
|
||||
Self: LanesAtMost32,
|
||||
{
|
||||
#[inline]
|
||||
fn bitand_assign(&mut self, rhs: bool) {
|
||||
*self &= Self::splat(rhs);
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitOrAssign for BitMask<LANES>
|
||||
where
|
||||
Self: LanesAtMost32,
|
||||
{
|
||||
#[inline]
|
||||
fn bitor_assign(&mut self, rhs: Self) {
|
||||
@ -169,19 +134,7 @@ fn bitor_assign(&mut self, rhs: Self) {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitOrAssign<bool> for BitMask<LANES>
|
||||
where
|
||||
Self: LanesAtMost32,
|
||||
{
|
||||
#[inline]
|
||||
fn bitor_assign(&mut self, rhs: bool) {
|
||||
*self |= Self::splat(rhs);
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitXorAssign for BitMask<LANES>
|
||||
where
|
||||
Self: LanesAtMost32,
|
||||
{
|
||||
#[inline]
|
||||
fn bitxor_assign(&mut self, rhs: Self) {
|
||||
@ -189,12 +142,9 @@ fn bitxor_assign(&mut self, rhs: Self) {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitXorAssign<bool> for BitMask<LANES>
|
||||
where
|
||||
Self: LanesAtMost32,
|
||||
{
|
||||
#[inline]
|
||||
fn bitxor_assign(&mut self, rhs: bool) {
|
||||
*self ^= Self::splat(rhs);
|
||||
}
|
||||
}
|
||||
pub type Mask8<const LANES: usize> = BitMask<LANES>;
|
||||
pub type Mask16<const LANES: usize> = BitMask<LANES>;
|
||||
pub type Mask32<const LANES: usize> = BitMask<LANES>;
|
||||
pub type Mask64<const LANES: usize> = BitMask<LANES>;
|
||||
pub type Mask128<const LANES: usize> = BitMask<LANES>;
|
||||
pub type MaskSize<const LANES: usize> = BitMask<LANES>;
|
||||
|
@ -46,14 +46,12 @@ pub fn splat(value: bool) -> Self {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn test(&self, lane: usize) -> bool {
|
||||
assert!(lane < LANES, "lane index out of range");
|
||||
pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
|
||||
self.0[lane] == -1
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn set(&mut self, lane: usize, value: bool) {
|
||||
assert!(lane < LANES, "lane index out of range");
|
||||
pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
|
||||
self.0[lane] = if value {
|
||||
-1
|
||||
} else {
|
||||
@ -70,6 +68,12 @@ pub fn to_int(self) -> crate::$type<LANES> {
|
||||
pub unsafe fn from_int_unchecked(value: crate::$type<LANES>) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn to_bitmask(self) -> u64 {
|
||||
let mask: <crate::$type<LANES> as crate::LanesAtMost32>::BitMask = unsafe { crate::intrinsics::simd_bitmask(self.0) };
|
||||
mask.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::convert::From<$name<LANES>> for crate::$type<LANES>
|
||||
@ -81,53 +85,6 @@ fn from(value: $name<LANES>) -> Self {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::fmt::Debug for $name<LANES>
|
||||
where
|
||||
crate::$type<LANES>: crate::LanesAtMost32,
|
||||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
f.debug_list()
|
||||
.entries((0..LANES).map(|lane| self.test(lane)))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::fmt::Binary for $name<LANES>
|
||||
where
|
||||
crate::$type<LANES>: crate::LanesAtMost32,
|
||||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
core::fmt::Binary::fmt(&self.0, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::fmt::Octal for $name<LANES>
|
||||
where
|
||||
crate::$type<LANES>: crate::LanesAtMost32,
|
||||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
core::fmt::Octal::fmt(&self.0, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::fmt::LowerHex for $name<LANES>
|
||||
where
|
||||
crate::$type<LANES>: crate::LanesAtMost32,
|
||||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
core::fmt::LowerHex::fmt(&self.0, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::fmt::UpperHex for $name<LANES>
|
||||
where
|
||||
crate::$type<LANES>: crate::LanesAtMost32,
|
||||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
core::fmt::UpperHex::fmt(&self.0, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LANES: usize> core::ops::BitAnd for $name<LANES>
|
||||
where
|
||||
crate::$type<LANES>: crate::LanesAtMost32,
|
||||
|
@ -8,6 +8,12 @@
|
||||
|
||||
use crate::{LanesAtMost32, SimdI16, SimdI32, SimdI64, SimdI8, SimdIsize};
|
||||
|
||||
/// Converts masks to bitmasks, with one bit set for each lane.
|
||||
pub trait ToBitMask {
|
||||
/// Converts this mask to a bitmask.
|
||||
fn to_bitmask(self) -> u64;
|
||||
}
|
||||
|
||||
macro_rules! define_opaque_mask {
|
||||
{
|
||||
$(#[$attr:meta])*
|
||||
@ -61,13 +67,53 @@ pub unsafe fn from_int_unchecked(value: $bits_ty<LANES>) -> Self {
|
||||
Self(<$inner_ty>::from_int_unchecked(value))
|
||||
}
|
||||
|
||||
/// Converts a vector of integers to a mask, where 0 represents `false` and -1
|
||||
/// represents `true`.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if any lane is not 0 or -1.
|
||||
#[inline]
|
||||
pub fn from_int(value: $bits_ty<LANES>) -> Self {
|
||||
assert!(
|
||||
(value.lanes_eq($bits_ty::splat(0)) | value.lanes_eq($bits_ty::splat(-1))).all(),
|
||||
"all values must be either 0 or -1",
|
||||
);
|
||||
unsafe { Self::from_int_unchecked(value) }
|
||||
}
|
||||
|
||||
/// Converts the mask to a vector of integers, where 0 represents `false` and -1
|
||||
/// represents `true`.
|
||||
#[inline]
|
||||
pub fn to_int(self) -> $bits_ty<LANES> {
|
||||
self.0.to_int()
|
||||
}
|
||||
|
||||
/// Tests the value of the specified lane.
|
||||
///
|
||||
/// # Safety
|
||||
/// `lane` must be less than `LANES`.
|
||||
#[inline]
|
||||
pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
|
||||
self.0.test_unchecked(lane)
|
||||
}
|
||||
|
||||
/// Tests the value of the specified lane.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if `lane` is greater than or equal to the number of lanes in the vector.
|
||||
#[inline]
|
||||
pub fn test(&self, lane: usize) -> bool {
|
||||
self.0.test(lane)
|
||||
assert!(lane < LANES, "lane index out of range");
|
||||
unsafe { self.test_unchecked(lane) }
|
||||
}
|
||||
|
||||
/// Sets the value of the specified lane.
|
||||
///
|
||||
/// # Safety
|
||||
/// `lane` must be less than `LANES`.
|
||||
#[inline]
|
||||
pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
|
||||
self.0.set_unchecked(lane, value);
|
||||
}
|
||||
|
||||
/// Sets the value of the specified lane.
|
||||
@ -76,7 +122,44 @@ pub fn test(&self, lane: usize) -> bool {
|
||||
/// Panics if `lane` is greater than or equal to the number of lanes in the vector.
|
||||
#[inline]
|
||||
pub fn set(&mut self, lane: usize, value: bool) {
|
||||
self.0.set(lane, value);
|
||||
assert!(lane < LANES, "lane index out of range");
|
||||
unsafe { self.set_unchecked(lane, value); }
|
||||
}
|
||||
}
|
||||
|
||||
impl ToBitMask for $name<1> {
|
||||
fn to_bitmask(self) -> u64 {
|
||||
self.0.to_bitmask()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToBitMask for $name<2> {
|
||||
fn to_bitmask(self) -> u64 {
|
||||
self.0.to_bitmask()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToBitMask for $name<4> {
|
||||
fn to_bitmask(self) -> u64 {
|
||||
self.0.to_bitmask()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToBitMask for $name<8> {
|
||||
fn to_bitmask(self) -> u64 {
|
||||
self.0.to_bitmask()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToBitMask for $name<16> {
|
||||
fn to_bitmask(self) -> u64 {
|
||||
self.0.to_bitmask()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToBitMask for $name<32> {
|
||||
fn to_bitmask(self) -> u64 {
|
||||
self.0.to_bitmask()
|
||||
}
|
||||
}
|
||||
|
||||
@ -147,10 +230,12 @@ fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
|
||||
|
||||
impl<const LANES: usize> core::fmt::Debug for $name<LANES>
|
||||
where
|
||||
$bits_ty<LANES>: LanesAtMost32,
|
||||
$bits_ty<LANES>: crate::LanesAtMost32,
|
||||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
core::fmt::Debug::fmt(&self.0, f)
|
||||
f.debug_list()
|
||||
.entries((0..LANES).map(|lane| self.test(lane)))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -56,6 +56,23 @@ fn all() {
|
||||
v.set(2, true);
|
||||
assert!(!v.all());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_int_conversion() {
|
||||
let values = [true, false, false, true, false, false, true, false];
|
||||
let mask = core_simd::$name::<8>::from_array(values);
|
||||
let int = mask.to_int();
|
||||
assert_eq!(int.to_array(), [-1, 0, 0, -1, 0, 0, -1, 0]);
|
||||
assert_eq!(core_simd::$name::<8>::from_int(int), mask);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_bitmask() {
|
||||
use core_simd::ToBitMask;
|
||||
let values = [true, false, false, true, false, false, true, false];
|
||||
let mask = core_simd::$name::<8>::from_array(values);
|
||||
assert_eq!(mask.to_bitmask(), 0b01001001);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user