Make internal mask implementation safe
This commit is contained in:
parent
11c3eefa35
commit
20fa4b7623
@ -1,7 +1,7 @@
|
||||
#![allow(unused_imports)]
|
||||
use super::MaskElement;
|
||||
use crate::simd::intrinsics;
|
||||
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
|
||||
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// A mask where each lane is represented by a single bit.
|
||||
@ -116,13 +116,20 @@ where
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub unsafe fn to_bitmask_integer<U>(self) -> U {
|
||||
pub fn to_bitmask_integer<U>(self) -> U
|
||||
where
|
||||
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
|
||||
{
|
||||
// Safety: these are the same types
|
||||
unsafe { core::mem::transmute_copy(&self.0) }
|
||||
}
|
||||
|
||||
// Safety: U must be the integer with the exact number of bits required to hold the bitmask for
|
||||
#[inline]
|
||||
pub unsafe fn from_bitmask_integer<U>(bitmask: U) -> Self {
|
||||
pub fn from_bitmask_integer<U>(bitmask: U) -> Self
|
||||
where
|
||||
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
|
||||
{
|
||||
// Safety: these are the same types
|
||||
unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) }
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
use super::MaskElement;
|
||||
use crate::simd::intrinsics;
|
||||
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
|
||||
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
|
||||
|
||||
#[repr(transparent)]
|
||||
pub struct Mask<T, const LANES: usize>(Simd<T, LANES>)
|
||||
@ -66,6 +66,23 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// Used for bitmask bit order workaround
|
||||
pub(crate) trait ReverseBits {
|
||||
fn reverse_bits(self) -> Self;
|
||||
}
|
||||
|
||||
macro_rules! impl_reverse_bits {
|
||||
{ $($int:ty),* } => {
|
||||
$(
|
||||
impl ReverseBits for $int {
|
||||
fn reverse_bits(self) -> Self { <$int>::reverse_bits(self) }
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
impl_reverse_bits! { u8, u16, u32, u64 }
|
||||
|
||||
impl<T, const LANES: usize> Mask<T, LANES>
|
||||
where
|
||||
T: MaskElement,
|
||||
@ -110,16 +127,34 @@ where
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub unsafe fn to_bitmask_integer<U>(self) -> U {
|
||||
// Safety: caller must only return bitmask types
|
||||
unsafe { intrinsics::simd_bitmask(self.0) }
|
||||
pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U
|
||||
where
|
||||
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
|
||||
{
|
||||
// Safety: U is required to be the appropriate bitmask type
|
||||
let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) };
|
||||
|
||||
// LLVM assumes bit order should match endianness
|
||||
if cfg!(target_endian = "big") {
|
||||
bitmask.reverse_bits()
|
||||
} else {
|
||||
bitmask
|
||||
}
|
||||
}
|
||||
|
||||
// Safety: U must be the integer with the exact number of bits required to hold the bitmask for
|
||||
// this mask
|
||||
#[inline]
|
||||
pub unsafe fn from_bitmask_integer<U>(bitmask: U) -> Self {
|
||||
// Safety: caller must only pass bitmask types
|
||||
pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self
|
||||
where
|
||||
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
|
||||
{
|
||||
// LLVM assumes bit order should match endianness
|
||||
let bitmask = if cfg!(target_endian = "big") {
|
||||
bitmask.reverse_bits()
|
||||
} else {
|
||||
bitmask
|
||||
};
|
||||
|
||||
// Safety: U is required to be the appropriate bitmask type
|
||||
unsafe {
|
||||
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
|
||||
bitmask,
|
||||
|
@ -1,9 +1,26 @@
|
||||
use super::{mask_impl, Mask, MaskElement};
|
||||
use crate::simd::{LaneCount, SupportedLaneCount};
|
||||
|
||||
mod sealed {
|
||||
pub trait Sealed {}
|
||||
}
|
||||
pub use sealed::Sealed;
|
||||
|
||||
impl<T, const LANES: usize> Sealed for Mask<T, LANES>
|
||||
where
|
||||
T: MaskElement,
|
||||
LaneCount<LANES>: SupportedLaneCount,
|
||||
{
|
||||
}
|
||||
|
||||
/// Converts masks to and from integer bitmasks.
|
||||
///
|
||||
/// Each bit of the bitmask corresponds to a mask lane, starting with the LSB.
|
||||
pub trait ToBitMask {
|
||||
///
|
||||
/// # Safety
|
||||
/// This trait is `unsafe` and sealed, since the `BitMask` type must match the number of lanes in
|
||||
/// the mask.
|
||||
pub unsafe trait ToBitMask: Sealed {
|
||||
/// The integer bitmask type.
|
||||
type BitMask;
|
||||
|
||||
@ -14,32 +31,18 @@ pub trait ToBitMask {
|
||||
fn from_bitmask(bitmask: Self::BitMask) -> Self;
|
||||
}
|
||||
|
||||
/// Converts masks to and from byte array bitmasks.
|
||||
///
|
||||
/// Each bit of the bitmask corresponds to a mask lane, starting with the LSB of the first byte.
|
||||
pub trait ToBitMaskArray {
|
||||
/// The length of the bitmask array.
|
||||
const BYTES: usize;
|
||||
|
||||
/// Converts a mask to a bitmask.
|
||||
fn to_bitmask_array(self) -> [u8; Self::BYTES];
|
||||
|
||||
/// Converts a bitmask to a mask.
|
||||
fn from_bitmask_array(bitmask: [u8; Self::BYTES]) -> Self;
|
||||
}
|
||||
|
||||
macro_rules! impl_integer_intrinsic {
|
||||
{ $(unsafe impl ToBitMask<BitMask=$int:ty> for Mask<_, $lanes:literal>)* } => {
|
||||
$(
|
||||
impl<T: MaskElement> ToBitMask for Mask<T, $lanes> {
|
||||
unsafe impl<T: MaskElement> ToBitMask for Mask<T, $lanes> {
|
||||
type BitMask = $int;
|
||||
|
||||
fn to_bitmask(self) -> $int {
|
||||
unsafe { self.0.to_bitmask_integer() }
|
||||
self.0.to_bitmask_integer()
|
||||
}
|
||||
|
||||
fn from_bitmask(bitmask: $int) -> Self {
|
||||
unsafe { Self(mask_impl::Mask::from_bitmask_integer(bitmask)) }
|
||||
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
|
||||
}
|
||||
}
|
||||
)*
|
||||
|
Loading…
x
Reference in New Issue
Block a user