diff --git a/crates/core_simd/src/masks.rs b/crates/core_simd/src/masks.rs index dcec336cfaf..e65548a3287 100644 --- a/crates/core_simd/src/masks.rs +++ b/crates/core_simd/src/masks.rs @@ -13,7 +13,10 @@ mod mask_impl; mod to_bitmask; -pub use to_bitmask::ToBitMask; +pub use to_bitmask::{ToBitMask, ToBitMaskArray}; + +#[cfg(feature = "generic_const_exprs")] +pub use to_bitmask::bitmask_len; use crate::simd::{intrinsics, LaneCount, Simd, SimdElement, SimdPartialEq, SupportedLaneCount}; use core::cmp::Ordering; diff --git a/crates/core_simd/src/masks/bitmask.rs b/crates/core_simd/src/masks/bitmask.rs index ec4dd357ee9..2e2c0a45c51 100644 --- a/crates/core_simd/src/masks/bitmask.rs +++ b/crates/core_simd/src/masks/bitmask.rs @@ -1,7 +1,7 @@ #![allow(unused_imports)] use super::MaskElement; use crate::simd::intrinsics; -use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask}; +use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask, ToBitMaskArray}; use core::marker::PhantomData; /// A mask where each lane is represented by a single bit. @@ -115,6 +115,24 @@ where unsafe { Self(intrinsics::simd_bitmask(value), PhantomData) } } + #[inline] + #[must_use = "method returns a new array and does not mutate the original value"] + pub fn to_bitmask_array(self) -> [u8; N] { + assert!(core::mem::size_of::() == N); + + // Safety: converting an integer to an array of bytes of the same size is safe + unsafe { core::mem::transmute_copy(&self.0) } + } + + #[inline] + #[must_use = "method returns a new mask and does not mutate the original value"] + pub fn from_bitmask_array(bitmask: [u8; N]) -> Self { + assert!(core::mem::size_of::() == N); + + // Safety: converting an array of bytes to an integer of the same size is safe + Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData) + } + #[inline] pub fn to_bitmask_integer(self) -> U where diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs index efa688b128f..b1c3b2b88ad 100644 --- a/crates/core_simd/src/masks/full_masks.rs +++ b/crates/core_simd/src/masks/full_masks.rs @@ -2,7 +2,7 @@ use super::MaskElement; use crate::simd::intrinsics; -use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask}; +use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask, ToBitMaskArray}; #[repr(transparent)] pub struct Mask(Simd) @@ -139,6 +139,72 @@ where unsafe { Mask(intrinsics::simd_cast(self.0)) } } + #[inline] + #[must_use = "method returns a new array and does not mutate the original value"] + pub fn to_bitmask_array(self) -> [u8; N] + where + super::Mask: ToBitMaskArray, + [(); as ToBitMaskArray>::BYTES]: Sized, + { + assert_eq!( as ToBitMaskArray>::BYTES, N); + + // Safety: N is the correct bitmask size + // + // The transmute below allows this function to be marked safe, since it will prevent + // monomorphization errors in the case of an incorrect size. + unsafe { + // Compute the bitmask + let bitmask: [u8; as ToBitMaskArray>::BYTES] = + intrinsics::simd_bitmask(self.0); + + // Transmute to the return type, previously asserted to be the same size + let mut bitmask: [u8; N] = core::mem::transmute_copy(&bitmask); + + // LLVM assumes bit order should match endianness + if cfg!(target_endian = "big") { + for x in bitmask.as_mut() { + *x = x.reverse_bits(); + } + }; + + bitmask + } + } + + #[inline] + #[must_use = "method returns a new mask and does not mutate the original value"] + pub fn from_bitmask_array(mut bitmask: [u8; N]) -> Self + where + super::Mask: ToBitMaskArray, + [(); as ToBitMaskArray>::BYTES]: Sized, + { + assert_eq!( as ToBitMaskArray>::BYTES, N); + + // Safety: N is the correct bitmask size + // + // The transmute below allows this function to be marked safe, since it will prevent + // monomorphization errors in the case of an incorrect size. + unsafe { + // LLVM assumes bit order should match endianness + if cfg!(target_endian = "big") { + for x in bitmask.as_mut() { + *x = x.reverse_bits(); + } + } + + // Transmute to the bitmask type, previously asserted to be the same size + let bitmask: [u8; as ToBitMaskArray>::BYTES] = + core::mem::transmute_copy(&bitmask); + + // Compute the regular mask + Self::from_int_unchecked(intrinsics::simd_select_bitmask( + bitmask, + Self::splat(true).to_int(), + Self::splat(false).to_int(), + )) + } + } + #[inline] pub(crate) fn to_bitmask_integer(self) -> U where diff --git a/crates/core_simd/src/masks/to_bitmask.rs b/crates/core_simd/src/masks/to_bitmask.rs index c263f6a4eec..ee229fc7a44 100644 --- a/crates/core_simd/src/masks/to_bitmask.rs +++ b/crates/core_simd/src/masks/to_bitmask.rs @@ -31,6 +31,24 @@ pub unsafe trait ToBitMask: Sealed { fn from_bitmask(bitmask: Self::BitMask) -> Self; } +/// Converts masks to and from byte array bitmasks. +/// +/// Each bit of the bitmask corresponds to a mask lane, starting with the LSB of the first byte. +/// +/// # Safety +/// This trait is `unsafe` and sealed, since the `BYTES` value must match the number of lanes in +/// the mask. +pub unsafe trait ToBitMaskArray: Sealed { + /// The length of the bitmask array. + const BYTES: usize; + + /// Converts a mask to a bitmask. + fn to_bitmask_array(self) -> [u8; Self::BYTES]; + + /// Converts a bitmask to a mask. + fn from_bitmask_array(bitmask: [u8; Self::BYTES]) -> Self; +} + macro_rules! impl_integer_intrinsic { { $(unsafe impl ToBitMask for Mask<_, $lanes:literal>)* } => { $( @@ -58,3 +76,23 @@ impl_integer_intrinsic! { unsafe impl ToBitMask for Mask<_, 32> unsafe impl ToBitMask for Mask<_, 64> } + +/// Returns the minimum numnber of bytes in a bitmask with `lanes` lanes. +pub const fn bitmask_len(lanes: usize) -> usize { + (lanes + 7) / 8 +} + +unsafe impl ToBitMaskArray for Mask +where + LaneCount: SupportedLaneCount, +{ + const BYTES: usize = bitmask_len(LANES); + + fn to_bitmask_array(self) -> [u8; Self::BYTES] { + self.0.to_bitmask_array() + } + + fn from_bitmask_array(bitmask: [u8; Self::BYTES]) -> Self { + Mask(mask_impl::Mask::from_bitmask_array(bitmask)) + } +} diff --git a/crates/core_simd/tests/masks.rs b/crates/core_simd/tests/masks.rs index 3a0493d4ee6..6150124b8ca 100644 --- a/crates/core_simd/tests/masks.rs +++ b/crates/core_simd/tests/masks.rs @@ -122,6 +122,19 @@ macro_rules! test_mask_api { cast_impl::(); cast_impl::(); } + + #[test] + fn roundtrip_bitmask_array_conversion() { + use core_simd::ToBitMaskArray; + let values = [ + true, false, false, true, false, false, true, false, + true, true, false, false, false, false, false, true, + ]; + let mask = core_simd::Mask::<$type, 16>::from_array(values); + let bitmask = mask.to_bitmask_array(); + assert_eq!(bitmask, [0b01001001, 0b10000011]); + assert_eq!(core_simd::Mask::<$type, 16>::from_bitmask_array(bitmask), mask); + } } } }