Update bitmask API

This commit is contained in:
Caleb Zulawski 2021-04-19 04:31:43 +00:00
parent da42aa5403
commit eec42808aa
6 changed files with 196 additions and 160 deletions

View File

@ -76,6 +76,9 @@
pub(crate) fn simd_reduce_and<T, U>(x: T) -> U;
pub(crate) fn simd_reduce_or<T, U>(x: T) -> U;
pub(crate) fn simd_reduce_xor<T, U>(x: T) -> U;
// truncate integer vector to bitmask
pub(crate) fn simd_bitmask<T, U>(x: T) -> U;
}
#[cfg(feature = "std")]

View File

@ -1,14 +1,38 @@
/// Implemented for vectors that are supported by the implementation.
pub trait LanesAtMost32 {}
pub trait LanesAtMost32: sealed::Sealed {
#[doc(hidden)]
type BitMask: Into<u64>;
}
mod sealed {
pub trait Sealed {}
}
macro_rules! impl_for {
{ $name:ident } => {
impl LanesAtMost32 for $name<1> {}
impl LanesAtMost32 for $name<2> {}
impl LanesAtMost32 for $name<4> {}
impl LanesAtMost32 for $name<8> {}
impl LanesAtMost32 for $name<16> {}
impl LanesAtMost32 for $name<32> {}
impl<const LANES: usize> sealed::Sealed for $name<LANES>
where
$name<LANES>: LanesAtMost32,
{}
impl LanesAtMost32 for $name<1> {
type BitMask = u8;
}
impl LanesAtMost32 for $name<2> {
type BitMask = u8;
}
impl LanesAtMost32 for $name<4> {
type BitMask = u8;
}
impl LanesAtMost32 for $name<8> {
type BitMask = u8;
}
impl LanesAtMost32 for $name<16> {
type BitMask = u16;
}
impl LanesAtMost32 for $name<32> {
type BitMask = u32;
}
}
}

View File

@ -1,13 +1,9 @@
use crate::LanesAtMost32;
/// A mask where each lane is represented by a single bit.
#[derive(Copy, Clone, Debug, PartialOrd, PartialEq, Ord, Eq, Hash)]
#[repr(transparent)]
pub struct BitMask<const LANES: usize>(u64)
pub struct BitMask<const LANES: usize>(u64);
impl<const LANES: usize> BitMask<LANES>
where
Self: LanesAtMost32,
{
#[inline]
pub fn splat(value: bool) -> Self {
@ -25,13 +21,50 @@ pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
#[inline]
pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
self.0 ^= ((value ^ self.test(lane)) as u64) << lane
self.0 ^= ((value ^ self.test_unchecked(lane)) as u64) << lane
}
#[inline]
pub fn to_int<V, T>(self) -> V
where
V: Default + AsMut<[T; LANES]>,
T: From<i8>,
{
// TODO this should be an intrinsic sign-extension
let mut v = V::default();
for i in 0..LANES {
let lane = unsafe { self.test_unchecked(i) };
v.as_mut()[i] = (-(lane as i8)).into();
}
v
}
#[inline]
pub unsafe fn from_int_unchecked<V>(value: V) -> Self
where
V: crate::LanesAtMost32,
{
let mask: V::BitMask = crate::intrinsics::simd_bitmask(value);
Self(mask.into())
}
#[inline]
pub fn to_bitmask(self) -> u64 {
self.0
}
#[inline]
pub fn any(self) -> bool {
self != Self::splat(false)
}
#[inline]
pub fn all(self) -> bool {
self == Self::splat(true)
}
}
impl<const LANES: usize> core::ops::BitAnd for BitMask<LANES>
where
Self: LanesAtMost32,
{
type Output = Self;
#[inline]
@ -41,8 +74,6 @@ fn bitand(self, rhs: Self) -> Self {
}
impl<const LANES: usize> core::ops::BitAnd<bool> for BitMask<LANES>
where
Self: LanesAtMost32,
{
type Output = Self;
#[inline]
@ -52,8 +83,6 @@ fn bitand(self, rhs: bool) -> Self {
}
impl<const LANES: usize> core::ops::BitAnd<BitMask<LANES>> for bool
where
BitMask<LANES>: LanesAtMost32,
{
type Output = BitMask<LANES>;
#[inline]
@ -63,8 +92,6 @@ fn bitand(self, rhs: BitMask<LANES>) -> BitMask<LANES> {
}
impl<const LANES: usize> core::ops::BitOr for BitMask<LANES>
where
Self: LanesAtMost32,
{
type Output = Self;
#[inline]
@ -73,31 +100,7 @@ fn bitor(self, rhs: Self) -> Self {
}
}
impl<const LANES: usize> core::ops::BitOr<bool> for BitMask<LANES>
where
Self: LanesAtMost32,
{
type Output = Self;
#[inline]
fn bitor(self, rhs: bool) -> Self {
self | Self::splat(rhs)
}
}
impl<const LANES: usize> core::ops::BitOr<BitMask<LANES>> for bool
where
BitMask<LANES>: LanesAtMost32,
{
type Output = BitMask<LANES>;
#[inline]
fn bitor(self, rhs: BitMask<LANES>) -> BitMask<LANES> {
BitMask::<LANES>::splat(self) | rhs
}
}
impl<const LANES: usize> core::ops::BitXor for BitMask<LANES>
where
Self: LanesAtMost32,
{
type Output = Self;
#[inline]
@ -106,42 +109,16 @@ fn bitxor(self, rhs: Self) -> Self::Output {
}
}
impl<const LANES: usize> core::ops::BitXor<bool> for BitMask<LANES>
where
Self: LanesAtMost32,
{
type Output = Self;
#[inline]
fn bitxor(self, rhs: bool) -> Self::Output {
self ^ Self::splat(rhs)
}
}
impl<const LANES: usize> core::ops::BitXor<BitMask<LANES>> for bool
where
BitMask<LANES>: LanesAtMost32,
{
type Output = BitMask<LANES>;
#[inline]
fn bitxor(self, rhs: BitMask<LANES>) -> Self::Output {
BitMask::<LANES>::splat(self) ^ rhs
}
}
impl<const LANES: usize> core::ops::Not for BitMask<LANES>
where
Self: LanesAtMost32,
{
type Output = BitMask<LANES>;
#[inline]
fn not(self) -> Self::Output {
Self(!self.0)
Self(!self.0) & Self::splat(true)
}
}
impl<const LANES: usize> core::ops::BitAndAssign for BitMask<LANES>
where
Self: LanesAtMost32,
{
#[inline]
fn bitand_assign(&mut self, rhs: Self) {
@ -149,19 +126,7 @@ fn bitand_assign(&mut self, rhs: Self) {
}
}
impl<const LANES: usize> core::ops::BitAndAssign<bool> for BitMask<LANES>
where
Self: LanesAtMost32,
{
#[inline]
fn bitand_assign(&mut self, rhs: bool) {
*self &= Self::splat(rhs);
}
}
impl<const LANES: usize> core::ops::BitOrAssign for BitMask<LANES>
where
Self: LanesAtMost32,
{
#[inline]
fn bitor_assign(&mut self, rhs: Self) {
@ -169,19 +134,7 @@ fn bitor_assign(&mut self, rhs: Self) {
}
}
impl<const LANES: usize> core::ops::BitOrAssign<bool> for BitMask<LANES>
where
Self: LanesAtMost32,
{
#[inline]
fn bitor_assign(&mut self, rhs: bool) {
*self |= Self::splat(rhs);
}
}
impl<const LANES: usize> core::ops::BitXorAssign for BitMask<LANES>
where
Self: LanesAtMost32,
{
#[inline]
fn bitxor_assign(&mut self, rhs: Self) {
@ -189,12 +142,9 @@ fn bitxor_assign(&mut self, rhs: Self) {
}
}
impl<const LANES: usize> core::ops::BitXorAssign<bool> for BitMask<LANES>
where
Self: LanesAtMost32,
{
#[inline]
fn bitxor_assign(&mut self, rhs: bool) {
*self ^= Self::splat(rhs);
}
}
pub type Mask8<const LANES: usize> = BitMask<LANES>;
pub type Mask16<const LANES: usize> = BitMask<LANES>;
pub type Mask32<const LANES: usize> = BitMask<LANES>;
pub type Mask64<const LANES: usize> = BitMask<LANES>;
pub type Mask128<const LANES: usize> = BitMask<LANES>;
pub type MaskSize<const LANES: usize> = BitMask<LANES>;

View File

@ -46,14 +46,12 @@ pub fn splat(value: bool) -> Self {
}
#[inline]
pub fn test(&self, lane: usize) -> bool {
assert!(lane < LANES, "lane index out of range");
pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
self.0[lane] == -1
}
#[inline]
pub fn set(&mut self, lane: usize, value: bool) {
assert!(lane < LANES, "lane index out of range");
pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
self.0[lane] = if value {
-1
} else {
@ -70,6 +68,12 @@ pub fn to_int(self) -> crate::$type<LANES> {
pub unsafe fn from_int_unchecked(value: crate::$type<LANES>) -> Self {
Self(value)
}
#[inline]
pub fn to_bitmask(self) -> u64 {
let mask: <crate::$type<LANES> as crate::LanesAtMost32>::BitMask = unsafe { crate::intrinsics::simd_bitmask(self.0) };
mask.into()
}
}
impl<const LANES: usize> core::convert::From<$name<LANES>> for crate::$type<LANES>
@ -81,53 +85,6 @@ fn from(value: $name<LANES>) -> Self {
}
}
impl<const LANES: usize> core::fmt::Debug for $name<LANES>
where
crate::$type<LANES>: crate::LanesAtMost32,
{
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
f.debug_list()
.entries((0..LANES).map(|lane| self.test(lane)))
.finish()
}
}
impl<const LANES: usize> core::fmt::Binary for $name<LANES>
where
crate::$type<LANES>: crate::LanesAtMost32,
{
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
core::fmt::Binary::fmt(&self.0, f)
}
}
impl<const LANES: usize> core::fmt::Octal for $name<LANES>
where
crate::$type<LANES>: crate::LanesAtMost32,
{
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
core::fmt::Octal::fmt(&self.0, f)
}
}
impl<const LANES: usize> core::fmt::LowerHex for $name<LANES>
where
crate::$type<LANES>: crate::LanesAtMost32,
{
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
core::fmt::LowerHex::fmt(&self.0, f)
}
}
impl<const LANES: usize> core::fmt::UpperHex for $name<LANES>
where
crate::$type<LANES>: crate::LanesAtMost32,
{
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
core::fmt::UpperHex::fmt(&self.0, f)
}
}
impl<const LANES: usize> core::ops::BitAnd for $name<LANES>
where
crate::$type<LANES>: crate::LanesAtMost32,

View File

@ -8,6 +8,12 @@
use crate::{LanesAtMost32, SimdI16, SimdI32, SimdI64, SimdI8, SimdIsize};
/// Converts masks to bitmasks, with one bit set for each lane.
pub trait ToBitMask {
/// Converts this mask to a bitmask.
fn to_bitmask(self) -> u64;
}
macro_rules! define_opaque_mask {
{
$(#[$attr:meta])*
@ -61,13 +67,53 @@ pub unsafe fn from_int_unchecked(value: $bits_ty<LANES>) -> Self {
Self(<$inner_ty>::from_int_unchecked(value))
}
/// Converts a vector of integers to a mask, where 0 represents `false` and -1
/// represents `true`.
///
/// # Panics
/// Panics if any lane is not 0 or -1.
#[inline]
pub fn from_int(value: $bits_ty<LANES>) -> Self {
assert!(
(value.lanes_eq($bits_ty::splat(0)) | value.lanes_eq($bits_ty::splat(-1))).all(),
"all values must be either 0 or -1",
);
unsafe { Self::from_int_unchecked(value) }
}
/// Converts the mask to a vector of integers, where 0 represents `false` and -1
/// represents `true`.
#[inline]
pub fn to_int(self) -> $bits_ty<LANES> {
self.0.to_int()
}
/// Tests the value of the specified lane.
///
/// # Safety
/// `lane` must be less than `LANES`.
#[inline]
pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
self.0.test_unchecked(lane)
}
/// Tests the value of the specified lane.
///
/// # Panics
/// Panics if `lane` is greater than or equal to the number of lanes in the vector.
#[inline]
pub fn test(&self, lane: usize) -> bool {
self.0.test(lane)
assert!(lane < LANES, "lane index out of range");
unsafe { self.test_unchecked(lane) }
}
/// Sets the value of the specified lane.
///
/// # Safety
/// `lane` must be less than `LANES`.
#[inline]
pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
self.0.set_unchecked(lane, value);
}
/// Sets the value of the specified lane.
@ -76,7 +122,44 @@ pub fn test(&self, lane: usize) -> bool {
/// Panics if `lane` is greater than or equal to the number of lanes in the vector.
#[inline]
pub fn set(&mut self, lane: usize, value: bool) {
self.0.set(lane, value);
assert!(lane < LANES, "lane index out of range");
unsafe { self.set_unchecked(lane, value); }
}
}
impl ToBitMask for $name<1> {
fn to_bitmask(self) -> u64 {
self.0.to_bitmask()
}
}
impl ToBitMask for $name<2> {
fn to_bitmask(self) -> u64 {
self.0.to_bitmask()
}
}
impl ToBitMask for $name<4> {
fn to_bitmask(self) -> u64 {
self.0.to_bitmask()
}
}
impl ToBitMask for $name<8> {
fn to_bitmask(self) -> u64 {
self.0.to_bitmask()
}
}
impl ToBitMask for $name<16> {
fn to_bitmask(self) -> u64 {
self.0.to_bitmask()
}
}
impl ToBitMask for $name<32> {
fn to_bitmask(self) -> u64 {
self.0.to_bitmask()
}
}
@ -147,10 +230,12 @@ fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
impl<const LANES: usize> core::fmt::Debug for $name<LANES>
where
$bits_ty<LANES>: LanesAtMost32,
$bits_ty<LANES>: crate::LanesAtMost32,
{
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
core::fmt::Debug::fmt(&self.0, f)
f.debug_list()
.entries((0..LANES).map(|lane| self.test(lane)))
.finish()
}
}

View File

@ -56,6 +56,23 @@ fn all() {
v.set(2, true);
assert!(!v.all());
}
#[test]
fn roundtrip_int_conversion() {
let values = [true, false, false, true, false, false, true, false];
let mask = core_simd::$name::<8>::from_array(values);
let int = mask.to_int();
assert_eq!(int.to_array(), [-1, 0, 0, -1, 0, 0, -1, 0]);
assert_eq!(core_simd::$name::<8>::from_int(int), mask);
}
#[test]
fn to_bitmask() {
use core_simd::ToBitMask;
let values = [true, false, false, true, false, false, true, false];
let mask = core_simd::$name::<8>::from_array(values);
assert_eq!(mask.to_bitmask(), 0b01001001);
}
}
}
}