~ruther/ctu-fee-eoa

b5a1a3a5c0d569ee3d8a87353fa2fc52ab634f88 — Rutherther a month ago d8580bd
feat: add crossover
1 files changed, 161 insertions(+), 0 deletions(-)

A env/src/crossover.rs
A env/src/crossover.rs => env/src/crossover.rs +161 -0
@@ 0,0 1,161 @@
use std::marker::PhantomData;

use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OVector, Scalar, U1};
use rand::{Rng, RngCore};

use crate::{binary_string::BinaryString, pairing::ParentPairing, replacement::{EvaluatedPopulation, Population}};

pub trait Crossover {
    type Chromosome;
    type Out;

    fn crossover(
        &mut self,
        parents: &EvaluatedPopulation<Self::Chromosome, Self::Out>,
        pairs: impl Iterator<Item = ParentPairing>
    ) -> Population<Self::Chromosome>;
}

pub struct BinaryOnePointCrossover<D: Dim, TOutput> {
    _phantom1: PhantomData<D>,
    _phantom2: PhantomData<TOutput>,
    rng: Box<dyn RngCore>
}

impl<D: Dim, TOutput> BinaryOnePointCrossover<D, TOutput>
where
    DefaultAllocator: Allocator<D>
{
    pub fn new() -> Self {
        Self {
            rng: Box::new(rand::rng()),
            _phantom1: PhantomData,
            _phantom2: PhantomData,
        }
    }

    fn find_cross_point(&mut self, chromosome: &BinaryString<D>) -> usize {
        let (min, max) = (0, chromosome.vec.len());
        self.rng.random_range(min..max)
    }
}

// TODO: make common functions for ovector that will be used from both BinaryOnePointCrossover and OVectorOnePointCrossover
// for not repeating the code.

impl<D, TOutput> Crossover for BinaryOnePointCrossover<D, TOutput>
where
    D: Dim,
    DefaultAllocator: Allocator<D>
{
    type Chromosome = BinaryString<D>;
    type Out = TOutput;

    fn crossover(
        &mut self,
        population: &EvaluatedPopulation<Self::Chromosome, Self::Out>,
        pairs: impl Iterator<Item = ParentPairing>
    ) -> Population<Self::Chromosome> {

        let chromosome = &population.population[0].chromosome.vec;
        let len = population.population[0].chromosome.vec.len();
        let indices = OVector::<usize, D>::from_iterator_generic(chromosome.shape_generic().0, U1, 0..len);

        let mut offsprings = Vec::new();
        for pair in pairs {
            let (
                parent1,
                parent2
            ) = (&population.population[pair.0], &population.population[pair.1]);

            let (
                chromosome1,
                chromosome2
            ) = (&parent1.chromosome, &parent2.chromosome);

            let cross_point = self.find_cross_point(&population.population[0].chromosome);

            offsprings.push(BinaryString::from_ovector(
                chromosome1.vec.zip_zip_map(
                    &chromosome2.vec,
                    &indices,
                    |first, second, i| if i <= cross_point { first } else { second }
                )));
        }

        Population::from_vec(offsprings)
    }
}


pub struct OVectorOnePointCrossover<D: Dim, T: Scalar, TOutput> {
    _phantom1: PhantomData<D>,
    _phantom2: PhantomData<TOutput>,
    _phantom3: PhantomData<T>,
    rng: Box<dyn RngCore>
}

impl<D: Dim, T, TOutput> OVectorOnePointCrossover<D, T, TOutput>
where
    T: Scalar,
    DefaultAllocator: Allocator<D>
{
    pub fn new() -> Self {
        Self {
            rng: Box::new(rand::rng()),
            _phantom1: PhantomData,
            _phantom2: PhantomData,
            _phantom3: PhantomData,
        }
    }

    fn find_cross_point(&mut self, chromosome: &OVector<T, D>) -> usize {
        let (min, max) = (0, chromosome.len());
        self.rng.random_range(min..max)
    }
}

impl<D, T, TOutput> Crossover for OVectorOnePointCrossover<D, T, TOutput>
where
    T: Scalar,
    D: Dim,
    DefaultAllocator: Allocator<D>
{
    type Chromosome = OVector<T, D>;
    type Out = TOutput;

    fn crossover(
        &mut self,
        population: &EvaluatedPopulation<Self::Chromosome, Self::Out>,
        pairs: impl Iterator<Item = ParentPairing>
    ) -> Population<Self::Chromosome> {

        let chromosome = &population.population[0].chromosome;
        let len = population.population[0].chromosome.len();
        let indices = OVector::<usize, D>::from_iterator_generic(chromosome.shape_generic().0, U1, 0..len);

        let mut offsprings = Vec::new();
        for pair in pairs {
            let (
                parent1,
                parent2
            ) = (&population.population[pair.0], &population.population[pair.1]);

            let (
                chromosome1,
                chromosome2
            ) = (&parent1.chromosome, &parent2.chromosome);

            let cross_point = self.find_cross_point(&population.population[0].chromosome);

            offsprings.push(
                chromosome1.zip_zip_map(
                    &chromosome2,
                    &indices,
                    |first, second, i| if i <= cross_point { first } else { second }
                ));
        }

        Population::from_vec(offsprings)
    }
}