Workaround simd_bitmask limitations

This commit is contained in:
Caleb Zulawski 2023-11-17 10:15:12 -05:00
parent 4ca9f04db5
commit 082e3c8a5d
3 changed files with 90 additions and 25 deletions

View File

@ -207,40 +207,108 @@ where
} }
#[inline] #[inline]
pub(crate) fn to_bitmask_integer(self) -> u64 { unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U
let resized = self.to_int().extend::<64>(T::FALSE); where
LaneCount<M>: SupportedLaneCount,
{
let resized = self.to_int().resize::<M>(T::FALSE);
// SAFETY: `resized` is an integer vector with length 64 // Safety: `resized` is an integer vector with length M, which must match T
let bitmask: u64 = unsafe { intrinsics::simd_bitmask(resized) }; let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) };
// LLVM assumes bit order should match endianness // LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") { if cfg!(target_endian = "big") {
bitmask.reverse_bits() bitmask.reverse_bits(M)
} else { } else {
bitmask bitmask
} }
} }
#[inline] #[inline]
pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self { unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self
where
LaneCount<M>: SupportedLaneCount,
{
// LLVM assumes bit order should match endianness // LLVM assumes bit order should match endianness
let bitmask = if cfg!(target_endian = "big") { let bitmask = if cfg!(target_endian = "big") {
bitmask.reverse_bits() bitmask.reverse_bits(M)
} else { } else {
bitmask bitmask
}; };
// SAFETY: `mask` is the correct bitmask type for a u64 bitmask // SAFETY: `mask` is the correct bitmask type for a u64 bitmask
let mask: Simd<T, 64> = unsafe { let mask: Simd<T, M> = unsafe {
intrinsics::simd_select_bitmask( intrinsics::simd_select_bitmask(
bitmask, bitmask,
Simd::<T, 64>::splat(T::TRUE), Simd::<T, M>::splat(T::TRUE),
Simd::<T, 64>::splat(T::FALSE), Simd::<T, M>::splat(T::FALSE),
) )
}; };
// SAFETY: `mask` only contains `T::TRUE` or `T::FALSE` // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
unsafe { Self::from_int_unchecked(mask.extend::<N>(T::FALSE)) } unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) }
}
#[inline]
pub(crate) fn to_bitmask_integer(self) -> u64 {
// TODO modify simd_bitmask to zero-extend output, making this unnecessary
macro_rules! bitmask {
{ $($ty:ty: $($len:literal),*;)* } => {
match N {
$($(
// Safety: bitmask matches length
$len => unsafe { self.to_bitmask_impl::<$ty, $len>() as u64 },
)*)*
// Safety: bitmask matches length
_ => unsafe { self.to_bitmask_impl::<u64, 64>() },
}
}
}
#[cfg(all_lane_counts)]
bitmask! {
u8: 1, 2, 3, 4, 5, 6, 7, 8;
u16: 9, 10, 11, 12, 13, 14, 15, 16;
u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
}
#[cfg(not(all_lane_counts))]
bitmask! {
u8: 1, 2, 4, 8;
u16: 16;
u32: 32;
u64: 64;
}
}
#[inline]
pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
// TODO modify simd_bitmask_select to truncate input, making this unnecessary
macro_rules! bitmask {
{ $($ty:ty: $($len:literal),*;)* } => {
match N {
$($(
// Safety: bitmask matches length
$len => unsafe { Self::from_bitmask_impl::<$ty, $len>(bitmask as $ty) },
)*)*
// Safety: bitmask matches length
_ => unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) },
}
}
}
#[cfg(all_lane_counts)]
bitmask! {
u8: 1, 2, 3, 4, 5, 6, 7, 8;
u16: 9, 10, 11, 12, 13, 14, 15, 16;
u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
}
#[cfg(not(all_lane_counts))]
bitmask! {
u8: 1, 2, 4, 8;
u16: 16;
u32: 32;
u64: 64;
}
} }
#[inline] #[inline]

View File

@ -350,9 +350,9 @@ where
) )
} }
/// Extend a vector. /// Resize a vector.
/// ///
/// Extends the length of a vector, setting the new elements to `value`. /// If `M` > `N`, extends the length of a vector, setting the new elements to `value`.
/// If `M` < `N`, truncates the vector to the first `M` elements. /// If `M` < `N`, truncates the vector to the first `M` elements.
/// ///
/// ``` /// ```
@ -361,17 +361,17 @@ where
/// # #[cfg(not(feature = "as_crate"))] use core::simd; /// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::u32x4; /// # use simd::u32x4;
/// let x = u32x4::from_array([0, 1, 2, 3]); /// let x = u32x4::from_array([0, 1, 2, 3]);
/// assert_eq!(x.extend::<8>(9).to_array(), [0, 1, 2, 3, 9, 9, 9, 9]); /// assert_eq!(x.resize::<8>(9).to_array(), [0, 1, 2, 3, 9, 9, 9, 9]);
/// assert_eq!(x.extend::<2>(9).to_array(), [0, 1]); /// assert_eq!(x.resize::<2>(9).to_array(), [0, 1]);
/// ``` /// ```
#[inline] #[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"] #[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn extend<const M: usize>(self, value: T) -> Simd<T, M> pub fn resize<const M: usize>(self, value: T) -> Simd<T, M>
where where
LaneCount<M>: SupportedLaneCount, LaneCount<M>: SupportedLaneCount,
{ {
struct Extend<const N: usize>; struct Resize<const N: usize>;
impl<const N: usize, const M: usize> Swizzle<M> for Extend<N> { impl<const N: usize, const M: usize> Swizzle<M> for Resize<N> {
const INDEX: [usize; M] = const { const INDEX: [usize; M] = const {
let mut index = [0; M]; let mut index = [0; M];
let mut i = 0; let mut i = 0;
@ -382,6 +382,6 @@ where
index index
}; };
} }
Extend::<N>::concat_swizzle(self, Simd::splat(value)) Resize::<N>::concat_swizzle(self, Simd::splat(value))
} }
} }

View File

@ -13,7 +13,7 @@ macro_rules! test_mask_api {
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::*; use wasm_bindgen_test::*;
use core_simd::simd::{Mask, Simd}; use core_simd::simd::Mask;
#[test] #[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
@ -124,17 +124,14 @@ macro_rules! test_mask_api {
#[test] #[test]
fn roundtrip_bitmask_vector_conversion() { fn roundtrip_bitmask_vector_conversion() {
use core_simd::simd::ToBytes;
let values = [ let values = [
true, false, false, true, false, false, true, false, true, false, false, true, false, false, true, false,
true, true, false, false, false, false, false, true, true, true, false, false, false, false, false, true,
]; ];
let mask = Mask::<$type, 16>::from_array(values); let mask = Mask::<$type, 16>::from_array(values);
let bitmask = mask.to_bitmask_vector(); let bitmask = mask.to_bitmask_vector();
if core::mem::size_of::<$type>() == 1 { assert_eq!(bitmask.resize::<2>(0).to_ne_bytes()[..2], [0b01001001, 0b10000011]);
assert_eq!(bitmask, Simd::from_array([0b01001001 as _, 0b10000011 as _, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]));
} else {
assert_eq!(bitmask, Simd::from_array([0b1000001101001001 as _, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]));
}
assert_eq!(Mask::<$type, 16>::from_bitmask_vector(bitmask), mask); assert_eq!(Mask::<$type, 16>::from_bitmask_vector(bitmask), mask);
} }
} }