~ruther/ctu-fee-eoa

8b73ac0a6cd9d2d1ea618f9325fb59a74c5628ac — Rutherther 11 days ago e024536
feat: add bounded crossover similar to bounded perturbation
1 files changed, 112 insertions(+), 0 deletions(-)

M codes/eoa_lib/src/crossover.rs
M codes/eoa_lib/src/crossover.rs => codes/eoa_lib/src/crossover.rs +112 -0
@@ 455,3 455,115 @@ where
        Population::from_vec(offsprings)
    }
}

pub enum BoundedCrossoverStrategy {
    /// Trims offspring values to get values within bounds
    Trim,
    /// Retries calling the underlying crossover until offspring within bounds is returned.
    /// If argument is given, this is the maximum number of retries to do and then
    /// fall back to trimming. Zero means retry indefinitely.
    Retry(usize)
}

pub struct BoundedCrossover<D: Dim, const DParents: usize, T: Crossover<DParents, Chromosome = OVector<f64, D>>>
where
    DefaultAllocator: Allocator<D>
{
    min_max: OVector<(f64, f64), D>,
    strategy: BoundedCrossoverStrategy,
    crossover: T,
    _phantom: PhantomData<fn() -> [(); DParents]>
}

impl<D: Dim, const DParents: usize, T: Crossover<DParents, Chromosome = OVector<f64, D>>> BoundedCrossover<D, DParents, T>
where
    DefaultAllocator: Allocator<D>
{
    pub fn new(
        crossover: T,
        min: OVector<f64, D>,
        max: OVector<f64, D>,
        strategy: BoundedCrossoverStrategy
    ) -> Self {
        let min_max = min.zip_map(&max, |min, max| (min, max));
        Self {
            min_max,
            strategy,
            crossover,
            _phantom: PhantomData
        }
    }

    fn within_bounds(&self, chromosome: &OVector<f64, D>) -> bool {
        chromosome.iter()
            .zip(self.min_max.iter())
            .all(|(&c, &(min, max))| c <= max && c >= min)
    }

    fn bound(&self, mut chromosome: OVector<f64, D>) -> OVector<f64, D> {
        chromosome
            .zip_apply(&self.min_max, |c, (min, max)| *c = c.clamp(min, max));
        chromosome
    }

    fn bound_population(&self, mut population: Population<OVector<f64, D>>) -> Population<OVector<f64, D>> {
        for chromosome in population.iter_mut() {
            *chromosome = self.bound(chromosome.clone());
        }
        population
    }

    fn all_within_bounds(&self, population: &Population<OVector<f64, D>>) -> bool {
        population.iter().all(|chromosome| self.within_bounds(chromosome))
    }

    fn retry_crossover<TOut>(
        &self,
        population: &EvaluatedPopulation<OVector<f64, D>, TOut>,
        pairs: Vec<ParentPairing<DParents>>,
        retries: Option<usize>,
        rng: &mut dyn RngCore
    ) -> Population<OVector<f64, D>>
    where
        T: Crossover<DParents, Chromosome = OVector<f64, D>, Out = TOut>
    {
        let offspring = self.crossover.crossover(population, pairs.clone().into_iter(), rng);

        if self.all_within_bounds(&offspring) {
            return offspring;
        }

        match retries {
            Some(0) | None => self.bound_population(offspring),
            Some(retries) => {
                self.retry_crossover(population, pairs, Some(retries - 1), rng)
            }
        }
    }
}

impl<D: Dim, const DParents: usize, T, TOut> Crossover<DParents> for BoundedCrossover<D, DParents, T>
where
    T: Crossover<DParents, Chromosome = OVector<f64, D>, Out = TOut>,
    DefaultAllocator: Allocator<D>
{
    type Chromosome = OVector<f64, D>;
    type Out = TOut;

    fn crossover(
        &self,
        population: &EvaluatedPopulation<Self::Chromosome, Self::Out>,
        pairs: impl Iterator<Item = ParentPairing<DParents>>,
        rng: &mut dyn RngCore
    ) -> Population<Self::Chromosome> {
        match self.strategy {
            BoundedCrossoverStrategy::Trim => {
                self.bound_population(self.crossover.crossover(population, pairs, rng))
            },
            BoundedCrossoverStrategy::Retry(retries) => {
                let pairs_vec: Vec<_> = pairs.collect();
                self.retry_crossover(population, pairs_vec, Some(retries), rng)
            }
        }
    }
}