From b5a1a3a5c0d569ee3d8a87353fa2fc52ab634f88 Mon Sep 17 00:00:00 2001 From: Rutherther Date: Sun, 26 Oct 2025 21:42:37 +0100 Subject: [PATCH] feat: add crossover --- env/src/crossover.rs | 161 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 env/src/crossover.rs diff --git a/env/src/crossover.rs b/env/src/crossover.rs new file mode 100644 index 0000000000000000000000000000000000000000..717c88304cdea7fbaf1bad8863d96cb549d8da11 --- /dev/null +++ b/env/src/crossover.rs @@ -0,0 +1,161 @@ +use std::marker::PhantomData; + +use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OVector, Scalar, U1}; +use rand::{Rng, RngCore}; + +use crate::{binary_string::BinaryString, pairing::ParentPairing, replacement::{EvaluatedPopulation, Population}}; + +pub trait Crossover { + type Chromosome; + type Out; + + fn crossover( + &mut self, + parents: &EvaluatedPopulation, + pairs: impl Iterator + ) -> Population; +} + +pub struct BinaryOnePointCrossover { + _phantom1: PhantomData, + _phantom2: PhantomData, + rng: Box +} + +impl BinaryOnePointCrossover +where + DefaultAllocator: Allocator +{ + pub fn new() -> Self { + Self { + rng: Box::new(rand::rng()), + _phantom1: PhantomData, + _phantom2: PhantomData, + } + } + + fn find_cross_point(&mut self, chromosome: &BinaryString) -> usize { + let (min, max) = (0, chromosome.vec.len()); + self.rng.random_range(min..max) + } +} + +// TODO: make common functions for ovector that will be used from both BinaryOnePointCrossover and OVectorOnePointCrossover +// for not repeating the code. + +impl Crossover for BinaryOnePointCrossover +where + D: Dim, + DefaultAllocator: Allocator +{ + type Chromosome = BinaryString; + type Out = TOutput; + + fn crossover( + &mut self, + population: &EvaluatedPopulation, + pairs: impl Iterator + ) -> Population { + + let chromosome = &population.population[0].chromosome.vec; + let len = population.population[0].chromosome.vec.len(); + let indices = OVector::::from_iterator_generic(chromosome.shape_generic().0, U1, 0..len); + + let mut offsprings = Vec::new(); + for pair in pairs { + let ( + parent1, + parent2 + ) = (&population.population[pair.0], &population.population[pair.1]); + + let ( + chromosome1, + chromosome2 + ) = (&parent1.chromosome, &parent2.chromosome); + + let cross_point = self.find_cross_point(&population.population[0].chromosome); + + offsprings.push(BinaryString::from_ovector( + chromosome1.vec.zip_zip_map( + &chromosome2.vec, + &indices, + |first, second, i| if i <= cross_point { first } else { second } + ))); + } + + Population::from_vec(offsprings) + } +} + + +pub struct OVectorOnePointCrossover { + _phantom1: PhantomData, + _phantom2: PhantomData, + _phantom3: PhantomData, + rng: Box +} + +impl OVectorOnePointCrossover +where + T: Scalar, + DefaultAllocator: Allocator +{ + pub fn new() -> Self { + Self { + rng: Box::new(rand::rng()), + _phantom1: PhantomData, + _phantom2: PhantomData, + _phantom3: PhantomData, + } + } + + fn find_cross_point(&mut self, chromosome: &OVector) -> usize { + let (min, max) = (0, chromosome.len()); + self.rng.random_range(min..max) + } +} + +impl Crossover for OVectorOnePointCrossover +where + T: Scalar, + D: Dim, + DefaultAllocator: Allocator +{ + type Chromosome = OVector; + type Out = TOutput; + + fn crossover( + &mut self, + population: &EvaluatedPopulation, + pairs: impl Iterator + ) -> Population { + + let chromosome = &population.population[0].chromosome; + let len = population.population[0].chromosome.len(); + let indices = OVector::::from_iterator_generic(chromosome.shape_generic().0, U1, 0..len); + + let mut offsprings = Vec::new(); + for pair in pairs { + let ( + parent1, + parent2 + ) = (&population.population[pair.0], &population.population[pair.1]); + + let ( + chromosome1, + chromosome2 + ) = (&parent1.chromosome, &parent2.chromosome); + + let cross_point = self.find_cross_point(&population.population[0].chromosome); + + offsprings.push( + chromosome1.zip_zip_map( + &chromosome2, + &indices, + |first, second, i| if i <= cross_point { first } else { second } + )); + } + + Population::from_vec(offsprings) + } +}