Implement select generically

This commit is contained in:
Caleb Zulawski 2021-08-07 05:19:06 +00:00
parent de13b20b27
commit ea0280539c
2 changed files with 77 additions and 79 deletions

View File

@ -1,3 +1,5 @@
use crate::{LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount};
mod sealed {
pub trait Sealed {}
}
@ -9,79 +11,75 @@ pub trait Select<Mask>: Sealed {
fn select(mask: Mask, true_values: Self, false_values: Self) -> Self;
}
macro_rules! impl_select {
{
$mask:ident ($bits_ty:ident): $($type:ident),*
} => {
$(
impl<const LANES: usize> Sealed for crate::$type<LANES> where crate::LaneCount<LANES>: crate::SupportedLaneCount {}
impl<const LANES: usize> Select<crate::$mask<LANES>> for crate::$type<LANES>
where
crate::LaneCount<LANES>: crate::SupportedLaneCount,
{
#[doc(hidden)]
#[inline]
fn select(mask: crate::$mask<LANES>, true_values: Self, false_values: Self) -> Self {
unsafe { crate::intrinsics::simd_select(mask.to_int(), true_values, false_values) }
}
}
)*
impl<Element, const LANES: usize> Sealed for Simd<Element, LANES>
where
Element: SimdElement,
LaneCount<LANES>: SupportedLaneCount,
{
}
impl<const LANES: usize> Sealed for crate::$mask<LANES>
where
crate::LaneCount<LANES>: crate::SupportedLaneCount,
{}
impl<const LANES: usize> Select<Self> for crate::$mask<LANES>
where
crate::LaneCount<LANES>: crate::SupportedLaneCount,
{
#[doc(hidden)]
#[inline]
fn select(mask: Self, true_values: Self, false_values: Self) -> Self {
mask & true_values | !mask & false_values
}
}
impl<const LANES: usize> crate::$mask<LANES>
where
crate::LaneCount<LANES>: crate::SupportedLaneCount,
{
/// Choose lanes from two vectors.
///
/// For each lane in the mask, choose the corresponding lane from `true_values` if
/// that lane mask is true, and `false_values` if that lane mask is false.
///
/// ```
/// # #![feature(portable_simd)]
/// # use core_simd::{Mask32, SimdI32};
/// let a = SimdI32::from_array([0, 1, 2, 3]);
/// let b = SimdI32::from_array([4, 5, 6, 7]);
/// let mask = Mask32::from_array([true, false, false, true]);
/// let c = mask.select(a, b);
/// assert_eq!(c.to_array(), [0, 5, 6, 3]);
/// ```
///
/// `select` can also be used on masks:
/// ```
/// # #![feature(portable_simd)]
/// # use core_simd::Mask32;
/// let a = Mask32::from_array([true, true, false, false]);
/// let b = Mask32::from_array([false, false, true, true]);
/// let mask = Mask32::from_array([true, false, false, true]);
/// let c = mask.select(a, b);
/// assert_eq!(c.to_array(), [true, false, true, false]);
/// ```
#[inline]
pub fn select<S: Select<Self>>(self, true_values: S, false_values: S) -> S {
S::select(self, true_values, false_values)
}
}
impl<Element, const LANES: usize> Select<Mask<Element::Mask, LANES>> for Simd<Element, LANES>
where
Element: SimdElement,
LaneCount<LANES>: SupportedLaneCount,
{
#[inline]
fn select(mask: Mask<Element::Mask, LANES>, true_values: Self, false_values: Self) -> Self {
unsafe { crate::intrinsics::simd_select(mask.to_int(), true_values, false_values) }
}
}
impl_select! { Mask8 (SimdI8): SimdU8, SimdI8 }
impl_select! { Mask16 (SimdI16): SimdU16, SimdI16 }
impl_select! { Mask32 (SimdI32): SimdU32, SimdI32, SimdF32}
impl_select! { Mask64 (SimdI64): SimdU64, SimdI64, SimdF64}
impl_select! { MaskSize (SimdIsize): SimdUsize, SimdIsize }
impl<Element, const LANES: usize> Sealed for Mask<Element, LANES>
where
Element: MaskElement,
LaneCount<LANES>: SupportedLaneCount,
{
}
impl<Element, const LANES: usize> Select<Self> for Mask<Element, LANES>
where
Element: MaskElement,
LaneCount<LANES>: SupportedLaneCount,
{
#[doc(hidden)]
#[inline]
fn select(mask: Self, true_values: Self, false_values: Self) -> Self {
mask & true_values | !mask & false_values
}
}
impl<Element, const LANES: usize> Mask<Element, LANES>
where
Element: MaskElement,
LaneCount<LANES>: SupportedLaneCount,
{
/// Choose lanes from two vectors.
///
/// For each lane in the mask, choose the corresponding lane from `true_values` if
/// that lane mask is true, and `false_values` if that lane mask is false.
///
/// ```
/// # #![feature(portable_simd)]
/// # use core_simd::{Mask32, SimdI32};
/// let a = SimdI32::from_array([0, 1, 2, 3]);
/// let b = SimdI32::from_array([4, 5, 6, 7]);
/// let mask = Mask32::from_array([true, false, false, true]);
/// let c = mask.select(a, b);
/// assert_eq!(c.to_array(), [0, 5, 6, 3]);
/// ```
///
/// `select` can also be used on masks:
/// ```
/// # #![feature(portable_simd)]
/// # use core_simd::Mask32;
/// let a = Mask32::from_array([true, true, false, false]);
/// let b = Mask32::from_array([false, false, true, true]);
/// let mask = Mask32::from_array([true, false, false, true]);
/// let c = mask.select(a, b);
/// assert_eq!(c.to_array(), [true, false, true, false]);
/// ```
#[inline]
pub fn select<S: Select<Self>>(self, true_values: S, false_values: S) -> S {
S::select(self, true_values, false_values)
}
}

View File

@ -9,7 +9,7 @@ pub use uint::*;
// Vectors of pointers are not for public use at the current time.
pub(crate) mod ptr;
use crate::{LaneCount, SupportedLaneCount};
use crate::{LaneCount, MaskElement, SupportedLaneCount};
/// A SIMD vector of `LANES` elements of type `Element`.
#[repr(simd)]
@ -338,32 +338,32 @@ use sealed::Sealed;
/// Marker trait for types that may be used as SIMD vector elements.
pub unsafe trait SimdElement: Sealed + Copy {
/// The mask element type corresponding to this element type.
type Mask: SimdElement;
type Mask: MaskElement;
}
impl Sealed for u8 {}
unsafe impl SimdElement for u8 {
type Mask = u8;
type Mask = i8;
}
impl Sealed for u16 {}
unsafe impl SimdElement for u16 {
type Mask = u16;
type Mask = i16;
}
impl Sealed for u32 {}
unsafe impl SimdElement for u32 {
type Mask = u32;
type Mask = i32;
}
impl Sealed for u64 {}
unsafe impl SimdElement for u64 {
type Mask = u64;
type Mask = i64;
}
impl Sealed for usize {}
unsafe impl SimdElement for usize {
type Mask = usize;
type Mask = isize;
}
impl Sealed for i8 {}