70 lines
1.4 KiB
Rust
70 lines
1.4 KiB
Rust
|
// build-pass
|
||
|
|
||
|
#![allow(incomplete_features)]
|
||
|
#![feature(generic_const_exprs)]
|
||
|
|
||
|
use std::{marker::PhantomData, ops::Mul};
|
||
|
|
||
|
pub enum Nil {}
|
||
|
pub struct Cons<T, L> {
|
||
|
_phantom: PhantomData<(T, L)>,
|
||
|
}
|
||
|
|
||
|
pub trait Indices<const N: usize> {
|
||
|
const RANK: usize;
|
||
|
const NUM_ELEMS: usize;
|
||
|
}
|
||
|
|
||
|
impl<const N: usize> Indices<N> for Nil {
|
||
|
const RANK: usize = 0;
|
||
|
const NUM_ELEMS: usize = 1;
|
||
|
}
|
||
|
|
||
|
impl<T, I: Indices<N>, const N: usize> Indices<N> for Cons<T, I> {
|
||
|
const RANK: usize = I::RANK + 1;
|
||
|
const NUM_ELEMS: usize = I::NUM_ELEMS * N;
|
||
|
}
|
||
|
|
||
|
pub trait Concat<J> {
|
||
|
type Output;
|
||
|
}
|
||
|
|
||
|
impl<J> Concat<J> for Nil {
|
||
|
type Output = J;
|
||
|
}
|
||
|
|
||
|
impl<T, I, J> Concat<J> for Cons<T, I>
|
||
|
where
|
||
|
I: Concat<J>,
|
||
|
{
|
||
|
type Output = Cons<T, <I as Concat<J>>::Output>;
|
||
|
}
|
||
|
|
||
|
pub struct Tensor<I: Indices<N>, const N: usize>
|
||
|
where
|
||
|
[u8; I::NUM_ELEMS]: Sized,
|
||
|
{
|
||
|
pub data: [u8; I::NUM_ELEMS],
|
||
|
_phantom: PhantomData<I>,
|
||
|
}
|
||
|
|
||
|
impl<I: Indices<N>, J: Indices<N>, const N: usize> Mul<Tensor<J, N>> for Tensor<I, N>
|
||
|
where
|
||
|
I: Concat<J>,
|
||
|
<I as Concat<J>>::Output: Indices<N>,
|
||
|
[u8; I::NUM_ELEMS]: Sized,
|
||
|
[u8; J::NUM_ELEMS]: Sized,
|
||
|
[u8; <I as Concat<J>>::Output::NUM_ELEMS]: Sized,
|
||
|
{
|
||
|
type Output = Tensor<<I as Concat<J>>::Output, N>;
|
||
|
|
||
|
fn mul(self, _rhs: Tensor<J, N>) -> Self::Output {
|
||
|
Tensor {
|
||
|
data: [0u8; <I as Concat<J>>::Output::NUM_ELEMS],
|
||
|
_phantom: PhantomData,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn main() {}
|