use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OVector, Scalar, U1};
use rand::RngCore;
use crate::{binary_string::BinaryString, bounded::{Bounded, BoundedBinaryString}};
pub trait Initializer<D: Dim, T> {
fn initialize_single(&self, size: D, rng: &mut dyn RngCore) -> T;
fn initialize(&self, size: D, count: usize, rng: &mut dyn RngCore) -> Vec<T> {
(0..count).map(|_| self.initialize_single(size, rng)).collect()
}
}
// Always initializes with zeros
pub struct ZeroInitializer;
impl ZeroInitializer {
pub fn new() -> Self {
Self {
}
}
}
impl<D> Initializer<D, BinaryString<D>> for ZeroInitializer
where
D: Dim,
DefaultAllocator: Allocator<D>
{
fn initialize_single(&self, size: D, rng: &mut dyn RngCore) -> BinaryString<D> {
BinaryString::<D>::from_ovector(
<Self as Initializer<D, OVector<i8, D>>>::initialize_single(self, size, rng)
)
}
}
impl<D, T> Initializer<D, OVector<T, D>> for ZeroInitializer
where
T: Scalar + Default,
D: Dim,
DefaultAllocator: Allocator<D>
{
fn initialize_single(&self, size: D, _rng: &mut dyn RngCore) -> OVector<T, D> {
OVector::<T, D>::from_element_generic(size, U1, Default::default())
}
}
pub struct RandomInitializer<D: Dim, T> {
bounded: Box<dyn Bounded<D, Item = T>>
}
impl<T, D: Dim> RandomInitializer<D, T> {
pub fn new(bounded: Box<dyn Bounded<D, Item = T>>) -> Self {
Self {
bounded
}
}
}
impl<D: Dim> RandomInitializer<D, BinaryString<D>>
where
D: Dim,
DefaultAllocator: Allocator<D>
{
pub fn new_binary() -> Self {
Self {
bounded: Box::new(BoundedBinaryString::unbounded())
}
}
}
impl<D> Initializer<D, BinaryString<D>> for RandomInitializer<D, BinaryString<D>>
where
D: Dim,
DefaultAllocator: Allocator<D>
{
fn initialize_single(&self, size: D, rng: &mut dyn RngCore) -> BinaryString<D> {
self.bounded.next_random(size, rng)
}
}
impl<D, T> Initializer<D, OVector<T, D>> for RandomInitializer<D, OVector<T, D>>
where
T: Scalar + Default,
D: Dim,
DefaultAllocator: Allocator<D>
{
fn initialize_single(&self, size: D, rng: &mut dyn RngCore) -> OVector<T, D> {
self.bounded.next_random(size, rng)
}
}