From 7adc5812e51c320898762ce9611540bd5f79b997 Mon Sep 17 00:00:00 2001 From: Rutherther Date: Tue, 28 Oct 2025 10:21:57 +0100 Subject: [PATCH] refactor: pass rng as argument Instead of having the Rng stored inside the structs, pass it through the functions. This means it's no longer necessary to pass perturbations etc. as mutable. --- codes/eoa_lib/src/crossover.rs | 37 +++++++++---------- codes/eoa_lib/src/evolution.rs | 43 ++++++++++++---------- codes/eoa_lib/src/initializer/mod.rs | 24 +++++------- codes/eoa_lib/src/local_search/mod.rs | 30 ++++++++++++--- codes/eoa_lib/src/perturbation/mod.rs | 53 +++++++++++---------------- codes/eoa_lib/src/replacement.rs | 39 ++++++++++---------- codes/eoa_lib/src/selection.rs | 28 +++++++------- codes/tsp_hw01/src/tsp.rs | 28 ++++++-------- 8 files changed, 143 insertions(+), 139 deletions(-) diff --git a/codes/eoa_lib/src/crossover.rs b/codes/eoa_lib/src/crossover.rs index 717c88304cdea7fbaf1bad8863d96cb549d8da11..2a08c7bf1c2bdc696e7921acfb68e95af2be96bc 100644 --- a/codes/eoa_lib/src/crossover.rs +++ b/codes/eoa_lib/src/crossover.rs @@ -10,16 +10,16 @@ pub trait Crossover { type Out; fn crossover( - &mut self, + &self, parents: &EvaluatedPopulation, - pairs: impl Iterator + pairs: impl Iterator, + rng: &mut dyn RngCore ) -> Population; } pub struct BinaryOnePointCrossover { _phantom1: PhantomData, - _phantom2: PhantomData, - rng: Box + _phantom2: PhantomData } impl BinaryOnePointCrossover @@ -28,15 +28,14 @@ where { pub fn new() -> Self { Self { - rng: Box::new(rand::rng()), _phantom1: PhantomData, _phantom2: PhantomData, } } - fn find_cross_point(&mut self, chromosome: &BinaryString) -> usize { + fn find_cross_point(&self, chromosome: &BinaryString, rng: &mut dyn RngCore) -> usize { let (min, max) = (0, chromosome.vec.len()); - self.rng.random_range(min..max) + rng.random_range(min..max) } } @@ -52,9 +51,10 @@ where type Out = TOutput; fn crossover( - &mut self, + &self, population: &EvaluatedPopulation, - pairs: impl Iterator + pairs: impl Iterator, + rng: &mut dyn RngCore ) -> Population { let chromosome = &population.population[0].chromosome.vec; @@ -73,7 +73,7 @@ where chromosome2 ) = (&parent1.chromosome, &parent2.chromosome); - let cross_point = self.find_cross_point(&population.population[0].chromosome); + let cross_point = self.find_cross_point(&population.population[0].chromosome, rng); offsprings.push(BinaryString::from_ovector( chromosome1.vec.zip_zip_map( @@ -91,8 +91,7 @@ where pub struct OVectorOnePointCrossover { _phantom1: PhantomData, _phantom2: PhantomData, - _phantom3: PhantomData, - rng: Box + _phantom3: PhantomData } impl OVectorOnePointCrossover @@ -102,16 +101,15 @@ where { pub fn new() -> Self { Self { - rng: Box::new(rand::rng()), _phantom1: PhantomData, _phantom2: PhantomData, _phantom3: PhantomData, } } - fn find_cross_point(&mut self, chromosome: &OVector) -> usize { + fn find_cross_point(&self, chromosome: &OVector, rng: &mut dyn RngCore) -> usize { let (min, max) = (0, chromosome.len()); - self.rng.random_range(min..max) + rng.random_range(min..max) } } @@ -125,9 +123,10 @@ where type Out = TOutput; fn crossover( - &mut self, + &self, population: &EvaluatedPopulation, - pairs: impl Iterator + pairs: impl Iterator, + rng: &mut dyn RngCore ) -> Population { let chromosome = &population.population[0].chromosome; @@ -146,7 +145,7 @@ where chromosome2 ) = (&parent1.chromosome, &parent2.chromosome); - let cross_point = self.find_cross_point(&population.population[0].chromosome); + let cross_point = self.find_cross_point(&population.population[0].chromosome, rng); offsprings.push( chromosome1.zip_zip_map( @@ -158,4 +157,4 @@ where Population::from_vec(offsprings) } -} +} \ No newline at end of file diff --git a/codes/eoa_lib/src/evolution.rs b/codes/eoa_lib/src/evolution.rs index 7a54e14585b7e417a6d96ed9f01c09a28a31f3ac..303a3918ca985bc3ab99e21d3fac35b87eae64e2 100644 --- a/codes/eoa_lib/src/evolution.rs +++ b/codes/eoa_lib/src/evolution.rs @@ -1,4 +1,5 @@ use std::error::Error; +use rand::RngCore; use crate::{comparison::BetterThanOperator, crossover::Crossover, fitness::FitnessFunction, pairing::Pairing, perturbation::PerturbationOperator, replacement::{EvaluatedChromosome, EvaluatedPopulation, Population, Replacement}, selection::Selection}; @@ -31,14 +32,15 @@ pub fn evolution_algorithm( initial_population: Population, parents_count: usize, fitness: &impl FitnessFunction, - selection: &mut impl Selection, + selection: &impl Selection, pairing: &mut impl Pairing, - crossover: &mut impl Crossover, - perturbation: &mut impl PerturbationOperator, - replacement: &mut impl Replacement, + crossover: &impl Crossover, + perturbation: &impl PerturbationOperator, + replacement: &impl Replacement, better_than: &impl BetterThanOperator, // TODO: termination condition - iterations: usize + iterations: usize, + rng: &mut dyn RngCore, ) -> Result, Box> { let mut current_population = initial_population.evaluate(fitness)?; @@ -67,21 +69,21 @@ pub fn evolution_algorithm( } // Selection - let parents = selection.select(parents_count, ¤t_population, better_than); - let parent_pairings = pairing.pair(¤t_population, parents); + let parents = selection.select(parents_count, ¤t_population, better_than, rng).collect::>(); + let parent_pairings = pairing.pair(¤t_population, parents.into_iter()); // Crossover - let mut offsprings = crossover.crossover(¤t_population, parent_pairings); + let mut offsprings = crossover.crossover(¤t_population, parent_pairings, rng); // Mutation for offspring in offsprings.iter_mut() { - perturbation.perturb(offspring); + perturbation.perturb(offspring, rng); } let evaluated_offsprings = offsprings.evaluate(fitness)?; // Replace - current_population = replacement.replace(current_population, evaluated_offsprings, better_than); + current_population = replacement.replace(current_population, evaluated_offsprings, better_than, rng); } let best_candidate = last_best_candidate.evaluated_chromosome.clone(); @@ -99,7 +101,7 @@ pub fn evolution_algorithm( pub mod tests { use nalgebra::Const; - use crate::{binary_string::BinaryString, comparison::MinimizingOperator, crossover::BinaryOnePointCrossover, fitness::one_max::OneMax, initializer::{Initializer, RandomInitializer}, pairing::AdjacentPairing, perturbation::{BinaryStringBitPerturbation, MutationPerturbation}, replacement::{BestReplacement, Population, TournamentReplacement}, selection::{BestSelection, TournamentSelection}}; + use crate::{binary_string::BinaryString, comparison::MinimizingOperator, crossover::BinaryOnePointCrossover, fitness::one_max::OneMax, initializer::{Initializer, RandomInitializer}, pairing::AdjacentPairing, perturbation::{BinaryStringBitPerturbation, MutationPerturbation}, replacement::{BestReplacement, Population}, selection::TournamentSelection}; use super::evolution_algorithm; @@ -109,28 +111,31 @@ pub mod tests { let optimum = BinaryString::>::new(vec![0; D]); let one_max = OneMax::>::new(); - let mut initializer = RandomInitializer::, BinaryString::>>::new_binary(); + let initializer = RandomInitializer::, BinaryString::>>::new_binary(); let population_size = 512; - let population = Population::from_iterator( - initializer.initialize(Const::, population_size) + let mut rng_init = rand::rng(); + let population = Population::from_vec( + initializer.initialize(Const::, population_size, &mut rng_init) ); + let mut rng = rand::rng(); let result = evolution_algorithm( population, population_size / 4, &one_max, // TODO: tournament should somehow accept sorting? // TODO: deterministic and nondeterministic tournament ordering - &mut TournamentSelection::new(3, 0.8), + &TournamentSelection::new(3, 0.8), &mut AdjacentPairing::new(), - &mut BinaryOnePointCrossover::new(), - &mut MutationPerturbation::new( + &BinaryOnePointCrossover::new(), + &MutationPerturbation::new( Box::new(BinaryStringBitPerturbation::new(0.05)), 0.1), - &mut BestReplacement::new(), + &BestReplacement::new(), &MinimizingOperator, - 1000 + 1000, + &mut rng ).unwrap(); println!("{:?}", result.stats.best_candidates diff --git a/codes/eoa_lib/src/initializer/mod.rs b/codes/eoa_lib/src/initializer/mod.rs index 023c5ce574fb0c5ad4f423a641516f5932bf3068..b3df614f823f4c0b9b80221a0fce6397be74fd6e 100644 --- a/codes/eoa_lib/src/initializer/mod.rs +++ b/codes/eoa_lib/src/initializer/mod.rs @@ -4,10 +4,9 @@ use rand::RngCore; use crate::{binary_string::BinaryString, bounded::{Bounded, BoundedBinaryString}}; pub trait Initializer { - fn initialize_single(&mut self, size: D) -> T; - fn initialize(&mut self, size: D, count: usize) -> impl Iterator { - let size = size; - (0..count).map(move |_| self.initialize_single(size)) + fn initialize_single(&self, size: D, rng: &mut dyn RngCore) -> T; + fn initialize(&self, size: D, count: usize, rng: &mut dyn RngCore) -> Vec { + (0..count).map(|_| self.initialize_single(size, rng)).collect() } } @@ -26,9 +25,9 @@ where D: Dim, DefaultAllocator: Allocator { - fn initialize_single(&mut self, size: D) -> BinaryString { + fn initialize_single(&self, size: D, rng: &mut dyn RngCore) -> BinaryString { BinaryString::::from_ovector( - >>::initialize_single(self, size) + >>::initialize_single(self, size, rng) ) } } @@ -39,20 +38,18 @@ where D: Dim, DefaultAllocator: Allocator { - fn initialize_single(&mut self, size: D) -> OVector { + fn initialize_single(&self, size: D, _rng: &mut dyn RngCore) -> OVector { OVector::::from_element_generic(size, U1, Default::default()) } } pub struct RandomInitializer { - rng: Box, bounded: Box> } impl RandomInitializer { pub fn new(bounded: Box>) -> Self { Self { - rng: Box::new(rand::rng()), bounded } } @@ -65,7 +62,6 @@ where { pub fn new_binary() -> Self { Self { - rng: Box::new(rand::rng()), bounded: Box::new(BoundedBinaryString::unbounded()) } } @@ -76,8 +72,8 @@ where D: Dim, DefaultAllocator: Allocator { - fn initialize_single(&mut self, size: D) -> BinaryString { - self.bounded.next_random(size, &mut self.rng) + fn initialize_single(&self, size: D, rng: &mut dyn RngCore) -> BinaryString { + self.bounded.next_random(size, rng) } } @@ -87,7 +83,7 @@ where D: Dim, DefaultAllocator: Allocator { - fn initialize_single(&mut self, size: D) -> OVector { - self.bounded.next_random(size, &mut self.rng) + fn initialize_single(&self, size: D, rng: &mut dyn RngCore) -> OVector { + self.bounded.next_random(size, rng) } } diff --git a/codes/eoa_lib/src/local_search/mod.rs b/codes/eoa_lib/src/local_search/mod.rs index 49905b6fbf24c5281f54e0ca3e147c613f422919..b5ff6f52b672de15c78137525ab9ccdc71681dae 100644 --- a/codes/eoa_lib/src/local_search/mod.rs +++ b/codes/eoa_lib/src/local_search/mod.rs @@ -1,5 +1,6 @@ use std::error::Error; use std::fmt::Debug; +use rand::RngCore; use crate::binary_string::{BinaryString, BinaryStringConversionError}; use crate::evolutionary_strategy::{EvolutionaryStrategy, IdentityStrategy}; use crate::fitness::FitnessFunction; @@ -103,7 +104,8 @@ pub fn local_search_first_improving< terminating_condition: &mut TTerminatingCondition, perturbation_operator: &mut TPerturbationOperator, better_than_operator: &TBetterThanOperator, - initial: &TInput + initial: &TInput, + rng: &mut dyn RngCore ) -> Result, Box> where TResult: Clone, @@ -119,7 +121,8 @@ where perturbation_operator, better_than_operator, &mut IdentityStrategy, - initial + initial, + rng ) } @@ -130,7 +133,8 @@ pub fn local_search_first_improving_evolving< perturbation_operator: &mut TPerturbationOperator, better_than_operator: &TBetterThanOperator, evolutionary_strategy: &mut TEvolutionaryStrategy, - initial: &TInput + initial: &TInput, + rng: &mut dyn RngCore ) -> Result, Box> where TResult: Clone, @@ -152,7 +156,7 @@ where while !terminating_condition.should_terminate(&best_candidate, &stats, cycle) { let mut perturbed = best_candidate.pos.clone(); - perturbation_operator.perturb(&mut perturbed); + perturbation_operator.perturb(&mut perturbed, rng); let perturbed_fit = fit.fit(&perturbed)?; // Minimize @@ -303,6 +307,7 @@ pub mod tests { let sphere = Sphere::new(optimum_real); let sphere_wrapped = BinaryFitnessWrapper::new(sphere, min.clone(), max.clone()); + let mut rng = rand::rng(); let result = local_search_first_improving( &sphere_wrapped, &mut @@ -315,6 +320,7 @@ pub mod tests { &mut BinaryStringBitPerturbation::new(0.3), &MinimizingOperator::new(), &BinaryString::new(vec![1; 10]), + &mut rng, ).unwrap(); println!("{:?}", result); @@ -337,6 +343,7 @@ pub mod tests { let optimum = SVector::::repeat(4.0); let sphere = Sphere::new(optimum); + let mut rng = rand::rng(); let result = local_search_first_improving_evolving( &sphere, &mut @@ -350,6 +357,7 @@ pub mod tests { &MinimizingOperator::new(), &mut IdentityStrategy, &SVector::::repeat(-5.0), + &mut rng, ).unwrap(); println!("{:?}", result); @@ -370,6 +378,7 @@ pub mod tests { let one_max = OneMax::::new(); let optimum = BinaryString::::new(vec![0; 10]); + let mut rng = rand::rng(); let result = local_search_first_improving( &one_max, &mut @@ -382,6 +391,7 @@ pub mod tests { &mut BinaryStringBitPerturbation::new(0.3), &MinimizingOperator::new(), &BinaryString::::new(vec![1; 10]), + &mut rng, ).unwrap(); println!("{:?}", result); @@ -438,6 +448,7 @@ pub mod tests { let max = SVector::::from_element(15.0); let rosenbrock_wrapped = BinaryFitnessWrapper::new(rosenbrock, min, max); + let mut rng = rand::rng(); let result = local_search_first_improving( &rosenbrock_wrapped, &mut @@ -450,6 +461,7 @@ pub mod tests { &mut BinaryStringBitPerturbation::new(0.1), &MinimizingOperator::new(), &BinaryString::new(vec![0; 10]), + &mut rng, ).unwrap(); println!("{:?}", result); @@ -473,11 +485,12 @@ pub mod tests { let max = SVector::::from_vec(vec![10.0, 10.0]); let min = -SVector::::from_vec(vec![10.0, 10.0]); - let mut initializer = + let initializer = RandomInitializer::>::new(Box::new(BoundedOVector::::new(min, max))); let linear = Linear::new(7.0, SVector::::from_vec(vec![0.2, -0.5])); + let mut rng = rand::rng(); let result = local_search_first_improving( &linear, &mut @@ -493,7 +506,8 @@ pub mod tests { max, BoundedPerturbationStrategy::Retry(10)), &MinimizingOperator::new(), - &initializer.initialize_single(U2), + &initializer.initialize_single(U2, &mut rng), + &mut rng, ).unwrap(); println!("{:?}", result); @@ -519,6 +533,7 @@ pub mod tests { let linear = Linear::new(7.0, SVector::::from_vec(vec![0.2, -0.5])); + let mut rng = rand::rng(); let result = local_search_first_improving( &linear, &mut @@ -535,6 +550,7 @@ pub mod tests { BoundedPerturbationStrategy::Retry(10)), &MinimizingOperator::new(), &SVector::::zeros(), + &mut rng, ).unwrap(); println!("{:?}", result); @@ -560,6 +576,7 @@ pub mod tests { let linear = Linear::new(7.0, SVector::::from_vec(vec![0.2, -0.5])); + let mut rng = rand::rng(); let result = local_search_first_improving_evolving( &linear, &mut @@ -577,6 +594,7 @@ pub mod tests { &MinimizingOperator::new(), &mut OneToFiveStrategy, &SVector::::zeros(), + &mut rng, ).unwrap(); println!("{:?}", result); diff --git a/codes/eoa_lib/src/perturbation/mod.rs b/codes/eoa_lib/src/perturbation/mod.rs index f3a8d608eeecf35af59dbf1d8fe06fbe168189f4..492a3ec272ed93ab3fea929ae78b74f3ce74a36f 100644 --- a/codes/eoa_lib/src/perturbation/mod.rs +++ b/codes/eoa_lib/src/perturbation/mod.rs @@ -9,11 +9,10 @@ use crate::binary_string::BinaryString; pub trait PerturbationOperator { type Chromosome; - fn perturb(&mut self, chromosome: &mut Self::Chromosome); + fn perturb(&self, chromosome: &mut Self::Chromosome, rng: &mut dyn RngCore); } pub struct BinaryStringBitPerturbation { - rng: Box, p: f64, _phantom: PhantomData } @@ -21,7 +20,6 @@ pub struct BinaryStringBitPerturbation { impl BinaryStringBitPerturbation { pub fn new(p: f64) -> Self { Self { - rng: Box::new(rand::rng()), p, _phantom: PhantomData } @@ -35,14 +33,13 @@ where { type Chromosome = BinaryString; - fn perturb(&mut self, chromosome: &mut Self::Chromosome) { - chromosome.perturb(&mut self.rng, self.p); + fn perturb(&self, chromosome: &mut Self::Chromosome, rng: &mut dyn RngCore) { + chromosome.perturb(rng, self.p); } } pub struct RandomDistributionPerturbation> { distribution: TDistribution, - rng: Box, parameter: f64 } @@ -50,7 +47,6 @@ impl RandomDistributionPerturbation> { pub fn normal(std_dev: f64) -> Result { Ok(Self { distribution: Normal::new(0.0, std_dev)?, - rng: Box::new(rand::rng()), parameter: std_dev }) } @@ -70,7 +66,6 @@ impl RandomDistributionPerturbation> { pub fn uniform(range: f64) -> Result { Ok(Self { distribution: Uniform::new(-range/2.0, range/2.0)?, - rng: Box::new(rand::rng()), parameter: range, }) } @@ -89,21 +84,19 @@ impl RandomDistributionPerturbation> { impl, const LEN: usize> PerturbationOperator for RandomDistributionPerturbation { type Chromosome = SVector; - fn perturb(&mut self, chromosome: &mut Self::Chromosome) { - *chromosome += Self::Chromosome::zeros().map(|_| self.distribution.sample(&mut self.rng)); + fn perturb(&self, chromosome: &mut Self::Chromosome, rng: &mut dyn RngCore) { + *chromosome += Self::Chromosome::zeros().map(|_| self.distribution.sample(rng)); } } pub struct PatternPerturbation { - d: f64, - rng: Box + d: f64 } impl PatternPerturbation { pub fn new(d: f64) -> Self { Self { - d, - rng: Box::new(rand::rng()) + d } } } @@ -111,11 +104,11 @@ impl PatternPerturbation { impl PerturbationOperator for PatternPerturbation { type Chromosome = SVector::; - fn perturb(&mut self, chromosome: &mut Self::Chromosome) { + fn perturb(&self, chromosome: &mut Self::Chromosome, rng: &mut dyn RngCore) { // 1. Choose dimension - let idx = self.rng.random_range(0..LEN); + let idx = rng.random_range(0..LEN); // 2. Direction - let d = if self.rng.random_bool(0.5) { + let d = if rng.random_bool(0.5) { self.d } else { -self.d @@ -178,9 +171,9 @@ impl>> chromosome } - fn retry_perturb(&mut self, chromosome: &mut SVector, retries: Option) { + fn retry_perturb(&self, chromosome: &mut SVector, retries: Option, rng: &mut dyn RngCore) { let mut perturbed = chromosome.clone(); - self.perturbation.perturb(&mut perturbed); + self.perturbation.perturb(&mut perturbed, rng); if self.within_bounds(&perturbed) { *chromosome = perturbed; @@ -191,7 +184,7 @@ impl>> Some(0) | None => *chromosome = self.bound(perturbed), Some(retries) => { *chromosome = perturbed; - self.retry_perturb(chromosome, Some(retries - 1)); + self.retry_perturb(chromosome, Some(retries - 1), rng); } } } @@ -203,10 +196,10 @@ where { type Chromosome = SVector; - fn perturb(&mut self, chromosome: &mut Self::Chromosome) { + fn perturb(&self, chromosome: &mut Self::Chromosome, rng: &mut dyn RngCore) { match self.strategy { - BoundedPerturbationStrategy::Trim => self.retry_perturb(chromosome, None), - BoundedPerturbationStrategy::Retry(retries) => self.retry_perturb(chromosome, Some(retries)) + BoundedPerturbationStrategy::Trim => self.retry_perturb(chromosome, None, rng), + BoundedPerturbationStrategy::Retry(retries) => self.retry_perturb(chromosome, Some(retries), rng) } } } @@ -214,7 +207,6 @@ where /// Perform given perturbation only with given probability pub struct MutationPerturbation { perturbation: Box>, - rng: Box, probability: f64 } @@ -222,7 +214,6 @@ impl MutationPerturbation { pub fn new(perturbation: Box>, probability: f64) -> Self { Self { perturbation, - rng: Box::new(rand::rng()), probability } } @@ -231,9 +222,9 @@ impl MutationPerturbation { impl PerturbationOperator for MutationPerturbation { type Chromosome = T; - fn perturb(&mut self, chromosome: &mut Self::Chromosome) { - if self.rng.random_bool(self.probability) { - self.perturbation.perturb(chromosome); + fn perturb(&self, chromosome: &mut Self::Chromosome, rng: &mut dyn RngCore) { + if rng.random_bool(self.probability) { + self.perturbation.perturb(chromosome, rng); } } } @@ -253,9 +244,9 @@ impl CombinedPerturbation { impl PerturbationOperator for CombinedPerturbation { type Chromosome = T; - fn perturb(&mut self, chromosome: &mut Self::Chromosome) { - for perturbation in self.perturbations.iter_mut() { - perturbation.perturb(chromosome); + fn perturb(&self, chromosome: &mut Self::Chromosome, rng: &mut dyn RngCore) { + for perturbation in self.perturbations.iter() { + perturbation.perturb(chromosome, rng); } } } diff --git a/codes/eoa_lib/src/replacement.rs b/codes/eoa_lib/src/replacement.rs index bef817b54b1e6b8eb723052485d26324c92830ea..2061e0a12ae8e51926d62e3044dc8129ac534340 100644 --- a/codes/eoa_lib/src/replacement.rs +++ b/codes/eoa_lib/src/replacement.rs @@ -134,10 +134,11 @@ impl EvaluatedPopulation { pub trait Replacement { fn replace( - &mut self, + &self, parents_evaluations: EvaluatedPopulation, offsprings_evaluations: EvaluatedPopulation, - better_than: &dyn BetterThanOperator + better_than: &dyn BetterThanOperator, + rng: &mut dyn RngCore ) -> EvaluatedPopulation; } @@ -150,10 +151,11 @@ impl BestReplacement { impl Replacement for BestReplacement { fn replace( - &mut self, + &self, parents_evaluations: EvaluatedPopulation, offsprings_evaluations: EvaluatedPopulation, - better_than: &dyn BetterThanOperator + better_than: &dyn BetterThanOperator, + _rng: &mut dyn RngCore ) -> EvaluatedPopulation { let count = parents_evaluations.population.len(); let mut population = parents_evaluations; @@ -177,10 +179,11 @@ impl Replacement for B pub struct GenerationalReplacement; impl Replacement for GenerationalReplacement { fn replace( - &mut self, + &self, parents: EvaluatedPopulation, mut offsprings: EvaluatedPopulation, - _: &dyn BetterThanOperator + _: &dyn BetterThanOperator, + _rng: &mut dyn RngCore ) -> EvaluatedPopulation { let count = parents.population.len(); if count == offsprings.population.len() { @@ -198,24 +201,21 @@ impl Replacement for GenerationalReplacement { } } -pub struct RandomReplacement { - rng: Box -} +pub struct RandomReplacement; impl RandomReplacement { pub fn new() -> Self { - Self { - rng: Box::new(rand::rng()) - } + Self } } impl Replacement for RandomReplacement { fn replace( - &mut self, + &self, parents: EvaluatedPopulation, offsprings: EvaluatedPopulation, - _: &dyn BetterThanOperator + _: &dyn BetterThanOperator, + rng: &mut dyn RngCore ) -> EvaluatedPopulation { let count = parents.population.len(); @@ -223,7 +223,7 @@ impl Replacement for RandomReplacement { parents.deconstruct() .into_iter() .chain(offsprings.deconstruct().into_iter()) - .choose_multiple(&mut self.rng, count)) + .choose_multiple(rng, count)) } } @@ -246,19 +246,18 @@ impl TournamentReplacement { impl Replacement for TournamentReplacement { fn replace( - &mut self, + &self, parents: EvaluatedPopulation, offsprings: EvaluatedPopulation, - better_than: &dyn BetterThanOperator + better_than: &dyn BetterThanOperator, + rng: &mut dyn RngCore ) -> EvaluatedPopulation { let count = parents.population.len(); let mut population = parents; population.join(offsprings); - self.evaluation_pool.clear(); - // TODO: use a pool instead of allocating vector every run of this function - let selected = self.selection.select(count, &population, better_than) + let selected = self.selection.select(count, &population, better_than, rng) .collect::>(); let population = population.deconstruct(); diff --git a/codes/eoa_lib/src/selection.rs b/codes/eoa_lib/src/selection.rs index 79899b055cd58d6930c8b0c2c6bca52044068624..0c497f445f2e07f5bbc476e3ab80c3ea2fa68c70 100644 --- a/codes/eoa_lib/src/selection.rs +++ b/codes/eoa_lib/src/selection.rs @@ -4,10 +4,11 @@ use std::fmt::Debug; use crate::{comparison::BetterThanOperator, replacement::EvaluatedPopulation}; pub trait Selection { - fn select(&mut self, + fn select(&self, count: usize, evaluations: &EvaluatedPopulation, - better_than: &dyn BetterThanOperator + better_than: &dyn BetterThanOperator, + rng: &mut dyn RngCore ) -> impl Iterator; } @@ -19,10 +20,11 @@ impl BestSelection { } impl Selection for BestSelection { - fn select(&mut self, + fn select(&self, count: usize, evaluations: &EvaluatedPopulation, - better_than: &dyn BetterThanOperator + better_than: &dyn BetterThanOperator, + _rng: &mut dyn RngCore ) -> impl Iterator { let mut idxs = (0..evaluations.population.len()) .collect::>(); @@ -36,7 +38,6 @@ impl Selection for BestSelecti } pub struct TournamentSelection { - rng: Box, p: f64, k: usize } @@ -47,24 +48,24 @@ impl TournamentSelection { assert!(k > 0); Self { - rng: Box::new(rand::rng()), p, k } } fn tournament( - &mut self, + &self, idxs: &mut Vec, evaluations: &EvaluatedPopulation, - better_than: &dyn BetterThanOperator + better_than: &dyn BetterThanOperator, + rng: &mut dyn RngCore ) -> usize { idxs.sort_unstable_by(|&i, &j| better_than.ordering( &evaluations.population[i].evaluation, &evaluations.population[j].evaluation) ); - let mut p_selector = self.rng.random_range(0.0..=1.0f64); + let mut p_selector = rng.random_range(0.0..=1.0f64); let p = self.p; let k = self.k; @@ -90,18 +91,19 @@ impl TournamentSelection { impl Selection for TournamentSelection { fn select( - &mut self, + &self, count: usize, evaluations: &EvaluatedPopulation, - better_than: &dyn BetterThanOperator + better_than: &dyn BetterThanOperator, + rng: &mut dyn RngCore ) -> impl Iterator { // Let's reuse a single vector for the indices let mut k_selected_idxs = vec![0; self.k]; (0..count).map(move |_| { // Choose k. Do not care if already selected previously. - (0..evaluations.population.len()).choose_multiple_fill(&mut self.rng, &mut k_selected_idxs); + (0..evaluations.population.len()).choose_multiple_fill(rng, &mut k_selected_idxs); // Tournament between the k - let index = self.tournament(&mut k_selected_idxs, evaluations, better_than); + let index = self.tournament(&mut k_selected_idxs, evaluations, better_than, rng); index }) } diff --git a/codes/tsp_hw01/src/tsp.rs b/codes/tsp_hw01/src/tsp.rs index 5932f0b5e4779ff545c16496eca63a1d238510e8..b59a164db8bbc9a8f6f636e7d97f30365711f52a 100644 --- a/codes/tsp_hw01/src/tsp.rs +++ b/codes/tsp_hw01/src/tsp.rs @@ -139,8 +139,7 @@ where D: Dim, DefaultAllocator: Allocator, { - _phantom: PhantomData, - rng: Box + _phantom: PhantomData } impl Initializer> for TSPRandomInitializer @@ -149,18 +148,17 @@ where DefaultAllocator: Allocator, DefaultAllocator: Allocator, { - fn initialize_single(&mut self, size: D) -> NodePermutation { + fn initialize_single(&self, size: D, rng: &mut dyn RngCore) -> NodePermutation { let len = size.value(); let mut indices = OVector::::from_iterator_generic(size, U1, 0..len); - indices.as_mut_slice().shuffle(&mut self.rng); + indices.as_mut_slice().shuffle(rng); NodePermutation { permutation: indices } } } pub struct SwapPerturbation { - _phantom: PhantomData, - rng: Box, + _phantom: PhantomData } impl PerturbationOperator for SwapPerturbation @@ -171,20 +169,16 @@ where { type Chromosome = NodePermutation; - fn perturb(self: &mut Self, chromosome: &Self::Chromosome) -> Self::Chromosome { - let first = self.rng.random_range(0..=chromosome.permutation.len()); - let second = self.rng.random_range(0..=chromosome.permutation.len()); - - let mut new = chromosome.clone(); + fn perturb(&self, chromosome: &mut Self::Chromosome, rng: &mut dyn RngCore) { + let first = rng.random_range(0..=chromosome.permutation.len()); + let second = rng.random_range(0..=chromosome.permutation.len()); ( - new.permutation[first], - new.permutation[second] + chromosome.permutation[first], + chromosome.permutation[second] ) = ( - new.permutation[second], - new.permutation[first] + chromosome.permutation[second], + chromosome.permutation[first] ); - - new } }