2021-12-08 16:25:52 -06:00
|
|
|
#![feature(generic_const_exprs)]
|
|
|
|
#![allow(incomplete_features)]
|
|
|
|
|
|
|
|
trait TensorDimension {
|
2022-03-22 04:38:46 -05:00
|
|
|
const DIM: usize;
|
|
|
|
//~^ ERROR cycle detected when resolving instance
|
|
|
|
// FIXME Given the current state of the compiler its expected that we cycle here,
|
|
|
|
// but the cycle is still wrong.
|
|
|
|
const ISSCALAR: bool = Self::DIM == 0;
|
|
|
|
fn is_scalar(&self) -> bool {
|
|
|
|
Self::ISSCALAR
|
|
|
|
}
|
2021-12-08 16:25:52 -06:00
|
|
|
}
|
|
|
|
|
2022-03-22 04:38:46 -05:00
|
|
|
trait TensorSize: TensorDimension {
|
|
|
|
fn size(&self) -> [usize; Self::DIM];
|
|
|
|
fn inbounds(&self, index: [usize; Self::DIM]) -> bool {
|
|
|
|
index.iter().zip(self.size().iter()).all(|(i, s)| i < s)
|
2021-12-08 16:25:52 -06:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
trait Broadcastable: TensorSize + Sized {
|
|
|
|
type Element;
|
2022-03-22 04:38:46 -05:00
|
|
|
fn bget(&self, index: [usize; Self::DIM]) -> Option<Self::Element>;
|
|
|
|
fn lazy_updim<const NEWDIM: usize>(
|
|
|
|
&self,
|
|
|
|
size: [usize; NEWDIM],
|
|
|
|
) -> LazyUpdim<Self, { Self::DIM }, NEWDIM> {
|
|
|
|
assert!(
|
|
|
|
NEWDIM >= Self::DIM,
|
|
|
|
"Updimmed tensor cannot have fewer indices than the initial one."
|
|
|
|
);
|
|
|
|
LazyUpdim { size, reference: &self }
|
2021-12-08 16:25:52 -06:00
|
|
|
}
|
2022-03-22 04:38:46 -05:00
|
|
|
fn bmap<T, F: Fn(Self::Element) -> T>(&self, foo: F) -> BMap<T, Self, F, { Self::DIM }> {
|
|
|
|
BMap { reference: self, closure: foo }
|
2021-12-08 16:25:52 -06:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-03-22 04:38:46 -05:00
|
|
|
struct LazyUpdim<'a, T: Broadcastable, const OLDDIM: usize, const DIM: usize> {
|
|
|
|
size: [usize; DIM],
|
|
|
|
reference: &'a T,
|
2021-12-08 16:25:52 -06:00
|
|
|
}
|
|
|
|
|
2022-03-22 04:38:46 -05:00
|
|
|
impl<'a, T: Broadcastable, const DIM: usize> TensorDimension for LazyUpdim<'a, T, { T::DIM }, DIM> {
|
|
|
|
const DIM: usize = DIM;
|
2021-12-08 16:25:52 -06:00
|
|
|
}
|
|
|
|
|
2022-03-22 04:38:46 -05:00
|
|
|
impl<'a, T: Broadcastable, const DIM: usize> TensorSize for LazyUpdim<'a, T, { T::DIM }, DIM> {
|
|
|
|
fn size(&self) -> [usize; DIM] {
|
|
|
|
self.size
|
|
|
|
}
|
2021-12-08 16:25:52 -06:00
|
|
|
}
|
|
|
|
|
2022-03-22 04:38:46 -05:00
|
|
|
impl<'a, T: Broadcastable, const DIM: usize> Broadcastable for LazyUpdim<'a, T, { T::DIM }, DIM> {
|
2021-12-08 16:25:52 -06:00
|
|
|
type Element = T::Element;
|
2022-03-22 04:38:46 -05:00
|
|
|
fn bget(&self, index: [usize; DIM]) -> Option<Self::Element> {
|
2021-12-08 16:25:52 -06:00
|
|
|
assert!(DIM >= T::DIM);
|
2022-03-22 04:38:46 -05:00
|
|
|
if !self.inbounds(index) {
|
|
|
|
return None;
|
|
|
|
}
|
2021-12-08 16:25:52 -06:00
|
|
|
let size = self.size();
|
2022-03-22 04:38:46 -05:00
|
|
|
let newindex: [usize; T::DIM] = Default::default();
|
2021-12-08 16:25:52 -06:00
|
|
|
self.reference.bget(newindex)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-03-22 04:38:46 -05:00
|
|
|
struct BMap<'a, R, T: Broadcastable, F: Fn(T::Element) -> R, const DIM: usize> {
|
|
|
|
reference: &'a T,
|
|
|
|
closure: F,
|
2021-12-08 16:25:52 -06:00
|
|
|
}
|
|
|
|
|
2022-03-22 04:38:46 -05:00
|
|
|
impl<'a, R, T: Broadcastable, F: Fn(T::Element) -> R, const DIM: usize> TensorDimension
|
|
|
|
for BMap<'a, R, T, F, DIM>
|
|
|
|
{
|
|
|
|
const DIM: usize = DIM;
|
2021-12-08 16:25:52 -06:00
|
|
|
}
|
2022-03-22 04:38:46 -05:00
|
|
|
impl<'a, R, T: Broadcastable, F: Fn(T::Element) -> R, const DIM: usize> TensorSize
|
|
|
|
for BMap<'a, R, T, F, DIM>
|
|
|
|
{
|
|
|
|
fn size(&self) -> [usize; DIM] {
|
|
|
|
self.reference.size()
|
|
|
|
}
|
2021-12-08 16:25:52 -06:00
|
|
|
}
|
|
|
|
|
2022-03-22 04:38:46 -05:00
|
|
|
impl<'a, R, T: Broadcastable, F: Fn(T::Element) -> R, const DIM: usize> Broadcastable
|
|
|
|
for BMap<'a, R, T, F, DIM>
|
|
|
|
{
|
2021-12-08 16:25:52 -06:00
|
|
|
type Element = R;
|
2022-03-22 04:38:46 -05:00
|
|
|
fn bget(&self, index: [usize; DIM]) -> Option<Self::Element> {
|
2021-12-08 16:25:52 -06:00
|
|
|
self.reference.bget(index).map(&self.closure)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<T> TensorDimension for Vec<T> {
|
2022-03-22 04:38:46 -05:00
|
|
|
const DIM: usize = 1;
|
2021-12-08 16:25:52 -06:00
|
|
|
}
|
|
|
|
impl<T> TensorSize for Vec<T> {
|
2022-03-22 04:38:46 -05:00
|
|
|
fn size(&self) -> [usize; 1] {
|
|
|
|
[self.len()]
|
|
|
|
}
|
2021-12-08 16:25:52 -06:00
|
|
|
}
|
|
|
|
impl<T: Clone> Broadcastable for Vec<T> {
|
|
|
|
type Element = T;
|
2022-03-22 04:38:46 -05:00
|
|
|
fn bget(&self, index: [usize; 1]) -> Option<T> {
|
2021-12-08 16:25:52 -06:00
|
|
|
self.get(index[0]).cloned()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn main() {
|
2022-03-22 04:38:46 -05:00
|
|
|
let v = vec![1, 2, 3];
|
|
|
|
let bv = v.lazy_updim([3, 4]);
|
|
|
|
let bbv = bv.bmap(|x| x * x);
|
2021-12-08 16:25:52 -06:00
|
|
|
|
2022-03-22 04:38:46 -05:00
|
|
|
println!("The size of v is {:?}", bbv.bget([0, 2]).expect("Out of bounds."));
|
2021-12-08 16:25:52 -06:00
|
|
|
}
|