Refactor ops.rs with a recursive macro

This approaches reducing macro nesting in a slightly different way.
Instead of just flattening details, make one macro apply another.
This allows specifying all details up-front in the first macro
invocation, making it easier to audit and refactor in the future.
This commit is contained in:
Jubilee Young 2021-12-21 18:28:57 -08:00
parent 5dcd397f47
commit bc326a2bbc

View File

@ -31,27 +31,10 @@ where
}
}
macro_rules! unsafe_base_op {
($(impl<const LANES: usize> $op:ident for Simd<$scalar:ty, LANES> {
fn $call:ident(self, rhs: Self) -> Self::Output {
unsafe{ $simd_call:ident }
}
})*) => {
$(impl<const LANES: usize> $op for Simd<$scalar, LANES>
where
$scalar: SimdElement,
LaneCount<LANES>: SupportedLaneCount,
{
type Output = Self;
#[inline]
#[must_use = "operator returns a new vector without mutating the inputs"]
fn $call(self, rhs: Self) -> Self::Output {
unsafe { $crate::intrinsics::$simd_call(self, rhs) }
}
}
)*
}
macro_rules! unsafe_base {
($lhs:ident, $rhs:ident, {$simd_call:ident}, $($_:tt)*) => {
unsafe { $crate::intrinsics::$simd_call($lhs, $rhs) }
};
}
/// SAFETY: This macro should not be used for anything except Shl or Shr, and passed the appropriate shift intrinsic.
@ -64,388 +47,191 @@ macro_rules! unsafe_base_op {
// FIXME: Consider implementing this in cg_llvm instead?
// cg_clif defaults to this, and scalar MIR shifts also default to wrapping
macro_rules! wrap_bitshift {
($(impl<const LANES: usize> $op:ident for Simd<$int:ty, LANES> {
fn $call:ident(self, rhs: Self) -> Self::Output {
unsafe { $simd_call:ident }
($lhs:ident, $rhs:ident, {$simd_call:ident}, $int:ident) => {
unsafe {
$crate::intrinsics::$simd_call($lhs, $rhs.bitand(Simd::splat(<$int>::BITS as $int - 1)))
}
})*) => {
$(impl<const LANES: usize> $op for Simd<$int, LANES>
where
$int: SimdElement,
LaneCount<LANES>: SupportedLaneCount,
{
type Output = Self;
#[inline]
#[must_use = "operator returns a new vector without mutating the inputs"]
fn $call(self, rhs: Self) -> Self::Output {
unsafe {
$crate::intrinsics::$simd_call(self, rhs.bitand(Simd::splat(<$int>::BITS as $int - 1)))
}
}
})*
};
}
macro_rules! bitops {
($(impl<const LANES: usize> BitOps for Simd<$int:ty, LANES> {
fn bitand(self, rhs: Self) -> Self::Output;
fn bitor(self, rhs: Self) -> Self::Output;
fn bitxor(self, rhs: Self) -> Self::Output;
fn shl(self, rhs: Self) -> Self::Output;
fn shr(self, rhs: Self) -> Self::Output;
})*) => {
$(
unsafe_base_op!{
impl<const LANES: usize> BitAnd for Simd<$int, LANES> {
fn bitand(self, rhs: Self) -> Self::Output {
unsafe { simd_and }
}
}
impl<const LANES: usize> BitOr for Simd<$int, LANES> {
fn bitor(self, rhs: Self) -> Self::Output {
unsafe { simd_or }
}
}
impl<const LANES: usize> BitXor for Simd<$int, LANES> {
fn bitxor(self, rhs: Self) -> Self::Output {
unsafe { simd_xor }
}
}
}
wrap_bitshift! {
impl<const LANES: usize> Shl for Simd<$int, LANES> {
fn shl(self, rhs: Self) -> Self::Output {
unsafe { simd_shl }
}
}
impl<const LANES: usize> Shr for Simd<$int, LANES> {
fn shr(self, rhs: Self) -> Self::Output {
// This automatically monomorphizes to lshr or ashr, depending,
// so it's fine to use it for both UInts and SInts.
unsafe { simd_shr }
}
}
}
)*
};
}
// Integers can always accept bitand, bitor, and bitxor.
// The only question is how to handle shifts >= <Int>::BITS?
// Our current solution uses wrapping logic.
bitops! {
impl<const LANES: usize> BitOps for Simd<i8, LANES> {
fn bitand(self, rhs: Self) -> Self::Output;
fn bitor(self, rhs: Self) -> Self::Output;
fn bitxor(self, rhs: Self) -> Self::Output;
fn shl(self, rhs: Self) -> Self::Output;
fn shr(self, rhs: Self) -> Self::Output;
}
impl<const LANES: usize> BitOps for Simd<i16, LANES> {
fn bitand(self, rhs: Self) -> Self::Output;
fn bitor(self, rhs: Self) -> Self::Output;
fn bitxor(self, rhs: Self) -> Self::Output;
fn shl(self, rhs: Self) -> Self::Output;
fn shr(self, rhs: Self) -> Self::Output;
}
impl<const LANES: usize> BitOps for Simd<i32, LANES> {
fn bitand(self, rhs: Self) -> Self::Output;
fn bitor(self, rhs: Self) -> Self::Output;
fn bitxor(self, rhs: Self) -> Self::Output;
fn shl(self, rhs: Self) -> Self::Output;
fn shr(self, rhs: Self) -> Self::Output;
}
impl<const LANES: usize> BitOps for Simd<i64, LANES> {
fn bitand(self, rhs: Self) -> Self::Output;
fn bitor(self, rhs: Self) -> Self::Output;
fn bitxor(self, rhs: Self) -> Self::Output;
fn shl(self, rhs: Self) -> Self::Output;
fn shr(self, rhs: Self) -> Self::Output;
}
impl<const LANES: usize> BitOps for Simd<isize, LANES> {
fn bitand(self, rhs: Self) -> Self::Output;
fn bitor(self, rhs: Self) -> Self::Output;
fn bitxor(self, rhs: Self) -> Self::Output;
fn shl(self, rhs: Self) -> Self::Output;
fn shr(self, rhs: Self) -> Self::Output;
}
impl<const LANES: usize> BitOps for Simd<u8, LANES> {
fn bitand(self, rhs: Self) -> Self::Output;
fn bitor(self, rhs: Self) -> Self::Output;
fn bitxor(self, rhs: Self) -> Self::Output;
fn shl(self, rhs: Self) -> Self::Output;
fn shr(self, rhs: Self) -> Self::Output;
}
impl<const LANES: usize> BitOps for Simd<u16, LANES> {
fn bitand(self, rhs: Self) -> Self::Output;
fn bitor(self, rhs: Self) -> Self::Output;
fn bitxor(self, rhs: Self) -> Self::Output;
fn shl(self, rhs: Self) -> Self::Output;
fn shr(self, rhs: Self) -> Self::Output;
}
impl<const LANES: usize> BitOps for Simd<u32, LANES> {
fn bitand(self, rhs: Self) -> Self::Output;
fn bitor(self, rhs: Self) -> Self::Output;
fn bitxor(self, rhs: Self) -> Self::Output;
fn shl(self, rhs: Self) -> Self::Output;
fn shr(self, rhs: Self) -> Self::Output;
}
impl<const LANES: usize> BitOps for Simd<u64, LANES> {
fn bitand(self, rhs: Self) -> Self::Output;
fn bitor(self, rhs: Self) -> Self::Output;
fn bitxor(self, rhs: Self) -> Self::Output;
fn shl(self, rhs: Self) -> Self::Output;
fn shr(self, rhs: Self) -> Self::Output;
}
impl<const LANES: usize> BitOps for Simd<usize, LANES> {
fn bitand(self, rhs: Self) -> Self::Output;
fn bitor(self, rhs: Self) -> Self::Output;
fn bitxor(self, rhs: Self) -> Self::Output;
fn shl(self, rhs: Self) -> Self::Output;
fn shr(self, rhs: Self) -> Self::Output;
}
}
macro_rules! float_arith {
($(impl<const LANES: usize> FloatArith for Simd<$float:ty, LANES> {
fn add(self, rhs: Self) -> Self::Output;
fn mul(self, rhs: Self) -> Self::Output;
fn sub(self, rhs: Self) -> Self::Output;
fn div(self, rhs: Self) -> Self::Output;
fn rem(self, rhs: Self) -> Self::Output;
})*) => {
$(
unsafe_base_op!{
impl<const LANES: usize> Add for Simd<$float, LANES> {
fn add(self, rhs: Self) -> Self::Output {
unsafe { simd_add }
}
}
impl<const LANES: usize> Mul for Simd<$float, LANES> {
fn mul(self, rhs: Self) -> Self::Output {
unsafe { simd_mul }
}
}
impl<const LANES: usize> Sub for Simd<$float, LANES> {
fn sub(self, rhs: Self) -> Self::Output {
unsafe { simd_sub }
}
}
impl<const LANES: usize> Div for Simd<$float, LANES> {
fn div(self, rhs: Self) -> Self::Output {
unsafe { simd_div }
}
}
impl<const LANES: usize> Rem for Simd<$float, LANES> {
fn rem(self, rhs: Self) -> Self::Output {
unsafe { simd_rem }
}
}
}
)*
};
}
// We don't need any special precautions here:
// Floats always accept arithmetic ops, but may become NaN.
float_arith! {
impl<const LANES: usize> FloatArith for Simd<f32, LANES> {
fn add(self, rhs: Self) -> Self::Output;
fn mul(self, rhs: Self) -> Self::Output;
fn sub(self, rhs: Self) -> Self::Output;
fn div(self, rhs: Self) -> Self::Output;
fn rem(self, rhs: Self) -> Self::Output;
}
impl<const LANES: usize> FloatArith for Simd<f64, LANES> {
fn add(self, rhs: Self) -> Self::Output;
fn mul(self, rhs: Self) -> Self::Output;
fn sub(self, rhs: Self) -> Self::Output;
fn div(self, rhs: Self) -> Self::Output;
fn rem(self, rhs: Self) -> Self::Output;
}
}
// Division by zero is poison, according to LLVM.
// So is dividing the MIN value of a signed integer by -1,
// since that would return MAX + 1.
// FIXME: Rust allows <SInt>::MIN / -1,
// so we should probably figure out how to make that safe.
macro_rules! int_divrem_guard {
($(impl<const LANES: usize> $op:ident for Simd<$sint:ty, LANES> {
const PANIC_ZERO: &'static str = $zero:literal;
const PANIC_OVERFLOW: &'static str = $overflow:literal;
fn $call:ident {
unsafe { $simd_call:ident }
}
})*) => {
$(impl<const LANES: usize> $op for Simd<$sint, LANES>
where
$sint: SimdElement,
LaneCount<LANES>: SupportedLaneCount,
( $lhs:ident,
$rhs:ident,
{ const PANIC_ZERO: &'static str = $zero:literal;
const PANIC_OVERFLOW: &'static str = $overflow:literal;
$simd_call:ident
},
$int:ident ) => {
if $rhs.lanes_eq(Simd::splat(0)).any() {
panic!($zero);
} else if <$int>::MIN != 0
&& $lhs.lanes_eq(Simd::splat(<$int>::MIN)) & $rhs.lanes_eq(Simd::splat(-1 as _))
!= Mask::splat(false)
{
type Output = Self;
#[inline]
#[must_use = "operator returns a new vector without mutating the inputs"]
fn $call(self, rhs: Self) -> Self::Output {
if rhs.lanes_eq(Simd::splat(0)).any() {
panic!("attempt to calculate the remainder with a divisor of zero");
} else if <$sint>::MIN != 0 && self.lanes_eq(Simd::splat(<$sint>::MIN)) & rhs.lanes_eq(Simd::splat(-1 as _))
!= Mask::splat(false)
{
panic!("attempt to calculate the remainder with overflow");
} else {
unsafe { $crate::intrinsics::$simd_call(self, rhs) }
}
}
})*
panic!($overflow);
} else {
unsafe { $crate::intrinsics::$simd_call($lhs, $rhs) }
}
};
}
macro_rules! int_arith {
($(impl<const LANES: usize> IntArith for Simd<$sint:ty, LANES> {
fn add(self, rhs: Self) -> Self::Output;
fn mul(self, rhs: Self) -> Self::Output;
fn sub(self, rhs: Self) -> Self::Output;
fn div(self, rhs: Self) -> Self::Output;
fn rem(self, rhs: Self) -> Self::Output;
})*) => {
$(
unsafe_base_op!{
impl<const LANES: usize> Add for Simd<$sint, LANES> {
fn add(self, rhs: Self) -> Self::Output {
unsafe { simd_add }
}
}
macro_rules! for_base_types {
( T = ($($scalar:ident),*);
type Lhs = Simd<T, N>;
type Rhs = Simd<T, N>;
type Output = $out:ty;
impl<const LANES: usize> Mul for Simd<$sint, LANES> {
fn mul(self, rhs: Self) -> Self::Output {
unsafe { simd_mul }
}
}
impl $op:ident::$call:ident {
$macro_impl:ident $inner:tt
}) => {
$(
impl<const N: usize> $op<Self> for Simd<$scalar, N>
where
$scalar: SimdElement,
LaneCount<N>: SupportedLaneCount,
{
type Output = $out;
impl<const LANES: usize> Sub for Simd<$sint, LANES> {
fn sub(self, rhs: Self) -> Self::Output {
unsafe { simd_sub }
}
}
#[inline]
#[must_use = "operator returns a new vector without mutating the inputs"]
fn $call(self, rhs: Self) -> Self::Output {
$macro_impl!(self, rhs, $inner, $scalar)
}
})*
}
}
// A "TokenTree muncher": takes a set of scalar types `T = {};`
// type parameters for the ops it implements, `Op::fn` names,
// and a macro that expands into an expr, substituting in an intrinsic.
// It passes that to for_base_types, which expands an impl for the types,
// using the expanded expr in the function, and recurses with itself.
//
// tl;dr impls a set of ops::{Traits} for a set of types
macro_rules! for_base_ops {
(
T = $types:tt;
type Lhs = Simd<T, N>;
type Rhs = Simd<T, N>;
type Output = $out:ident;
impl $op:ident::$call:ident
$inner:tt
$($rest:tt)*
) => {
for_base_types! {
T = $types;
type Lhs = Simd<T, N>;
type Rhs = Simd<T, N>;
type Output = $out;
impl $op::$call
$inner
}
int_divrem_guard!{
impl<const LANES: usize> Div for Simd<$sint, LANES> {
const PANIC_ZERO: &'static str = "attempt to divide by zero";
const PANIC_OVERFLOW: &'static str = "attempt to divide with overflow";
fn div {
unsafe { simd_div }
}
}
impl<const LANES: usize> Rem for Simd<$sint, LANES> {
const PANIC_ZERO: &'static str = "attempt to calculate the remainder with a divisor of zero";
const PANIC_OVERFLOW: &'static str = "attempt to calculate the remainder with overflow";
fn rem {
unsafe { simd_rem }
}
}
})*
for_base_ops! {
T = $types;
type Lhs = Simd<T, N>;
type Rhs = Simd<T, N>;
type Output = $out;
$($rest)*
}
};
($($done:tt)*) => {
// Done.
}
}
int_arith! {
impl<const LANES: usize> IntArith for Simd<i8, LANES> {
fn add(self, rhs: Self) -> Self::Output;
fn mul(self, rhs: Self) -> Self::Output;
fn sub(self, rhs: Self) -> Self::Output;
fn div(self, rhs: Self) -> Self::Output;
fn rem(self, rhs: Self) -> Self::Output;
// Integers can always accept add, mul, sub, bitand, bitor, and bitxor.
// For all of these operations, simd_* intrinsics apply wrapping logic.
for_base_ops! {
T = (i8, i16, i32, i64, isize, u8, u16, u32, u64, usize);
type Lhs = Simd<T, N>;
type Rhs = Simd<T, N>;
type Output = Self;
impl Add::add {
unsafe_base { simd_add }
}
impl<const LANES: usize> IntArith for Simd<i16, LANES> {
fn add(self, rhs: Self) -> Self::Output;
fn mul(self, rhs: Self) -> Self::Output;
fn sub(self, rhs: Self) -> Self::Output;
fn div(self, rhs: Self) -> Self::Output;
fn rem(self, rhs: Self) -> Self::Output;
impl Mul::mul {
unsafe_base { simd_mul }
}
impl<const LANES: usize> IntArith for Simd<i32, LANES> {
fn add(self, rhs: Self) -> Self::Output;
fn mul(self, rhs: Self) -> Self::Output;
fn sub(self, rhs: Self) -> Self::Output;
fn div(self, rhs: Self) -> Self::Output;
fn rem(self, rhs: Self) -> Self::Output;
impl Sub::sub {
unsafe_base { simd_sub }
}
impl<const LANES: usize> IntArith for Simd<i64, LANES> {
fn add(self, rhs: Self) -> Self::Output;
fn mul(self, rhs: Self) -> Self::Output;
fn sub(self, rhs: Self) -> Self::Output;
fn div(self, rhs: Self) -> Self::Output;
fn rem(self, rhs: Self) -> Self::Output;
impl BitAnd::bitand {
unsafe_base { simd_and }
}
impl<const LANES: usize> IntArith for Simd<isize, LANES> {
fn add(self, rhs: Self) -> Self::Output;
fn mul(self, rhs: Self) -> Self::Output;
fn sub(self, rhs: Self) -> Self::Output;
fn div(self, rhs: Self) -> Self::Output;
fn rem(self, rhs: Self) -> Self::Output;
impl BitOr::bitor {
unsafe_base { simd_or }
}
impl<const LANES: usize> IntArith for Simd<u8, LANES> {
fn add(self, rhs: Self) -> Self::Output;
fn mul(self, rhs: Self) -> Self::Output;
fn sub(self, rhs: Self) -> Self::Output;
fn div(self, rhs: Self) -> Self::Output;
fn rem(self, rhs: Self) -> Self::Output;
impl BitXor::bitxor {
unsafe_base { simd_xor }
}
impl<const LANES: usize> IntArith for Simd<u16, LANES> {
fn add(self, rhs: Self) -> Self::Output;
fn mul(self, rhs: Self) -> Self::Output;
fn sub(self, rhs: Self) -> Self::Output;
fn div(self, rhs: Self) -> Self::Output;
fn rem(self, rhs: Self) -> Self::Output;
impl Div::div {
int_divrem_guard {
const PANIC_ZERO: &'static str = "attempt to divide by zero";
const PANIC_OVERFLOW: &'static str = "attempt to divide with overflow";
simd_div
}
}
impl<const LANES: usize> IntArith for Simd<u32, LANES> {
fn add(self, rhs: Self) -> Self::Output;
fn mul(self, rhs: Self) -> Self::Output;
fn sub(self, rhs: Self) -> Self::Output;
fn div(self, rhs: Self) -> Self::Output;
fn rem(self, rhs: Self) -> Self::Output;
impl Rem::rem {
int_divrem_guard {
const PANIC_ZERO: &'static str = "attempt to calculate the remainder with a divisor of zero";
const PANIC_OVERFLOW: &'static str = "attempt to calculate the remainder with overflow";
simd_rem
}
}
impl<const LANES: usize> IntArith for Simd<u64, LANES> {
fn add(self, rhs: Self) -> Self::Output;
fn mul(self, rhs: Self) -> Self::Output;
fn sub(self, rhs: Self) -> Self::Output;
fn div(self, rhs: Self) -> Self::Output;
fn rem(self, rhs: Self) -> Self::Output;
// The only question is how to handle shifts >= <Int>::BITS?
// Our current solution uses wrapping logic.
impl Shl::shl {
wrap_bitshift { simd_shl }
}
impl<const LANES: usize> IntArith for Simd<usize, LANES> {
fn add(self, rhs: Self) -> Self::Output;
fn mul(self, rhs: Self) -> Self::Output;
fn sub(self, rhs: Self) -> Self::Output;
fn div(self, rhs: Self) -> Self::Output;
fn rem(self, rhs: Self) -> Self::Output;
impl Shr::shr {
wrap_bitshift {
// This automatically monomorphizes to lshr or ashr, depending,
// so it's fine to use it for both UInts and SInts.
simd_shr
}
}
}
// We don't need any special precautions here:
// Floats always accept arithmetic ops, but may become NaN.
for_base_ops! {
T = (f32, f64);
type Lhs = Simd<T, N>;
type Rhs = Simd<T, N>;
type Output = Self;
impl Add::add {
unsafe_base { simd_add }
}
impl Mul::mul {
unsafe_base { simd_mul }
}
impl Sub::sub {
unsafe_base { simd_sub }
}
impl Div::div {
unsafe_base { simd_div }
}
impl Rem::rem {
unsafe_base { simd_rem }
}
}