~ruther/ctu-fee-eoa

ec2647b8696dad6844a448d29af485b862ec6841 — Rutherther a month ago 45b0d91
feat(lib): add roulette wheel selection
1 files changed, 49 insertions(+), 1 deletions(-)

M codes/eoa_lib/src/selection.rs
M codes/eoa_lib/src/selection.rs => codes/eoa_lib/src/selection.rs +49 -1
@@ 1,5 1,7 @@
use nalgebra::{Dyn, OVector, Scalar, U1};
use rand::{seq::IteratorRandom, Rng, RngCore};
use std::fmt::Debug;
use rand_distr::uniform::{SampleRange, SampleUniform};
use std::{cmp::Ordering, fmt::Debug, ops::{AddAssign, Sub}};

use crate::{comparison::BetterThanOperator, replacement::EvaluatedPopulation};



@@ 108,3 110,49 @@ impl<TChromosome, TResult: Copy + Debug> Selection<TChromosome, TResult> for Tou
        })
    }
}


pub struct RouletteWheelSelection;
impl RouletteWheelSelection {
    pub fn new() -> Self {
        Self
    }
}

impl<TChromosome,
     TResult: Scalar + Copy + Default + PartialOrd + SampleUniform + AddAssign + Sub<Output = TResult>>
    Selection<TChromosome, TResult> for RouletteWheelSelection
{
    fn select(
        &self,
        count: usize,
        evaluations: &EvaluatedPopulation<TChromosome, TResult>,
        _: &dyn BetterThanOperator<TResult>,
        rng: &mut dyn RngCore
    ) -> impl Iterator<Item = usize> {
        let zero: TResult = Default::default();
        let max = evaluations.iter()
            .map(|i| i.evaluation)
            .max_by(|a, b|
                    a.partial_cmp(b).unwrap_or(Ordering::Less))
            .unwrap();
        let summed = evaluations.iter().scan(
            zero,
            |acc, individual| {
                let subtracted: TResult = max - individual.evaluation;
                *acc += subtracted;
                Some(*acc)
            })
            .collect::<Vec<TResult>>();
        let max = summed.last().unwrap().clone();

        (0..count).map(move |_| {
            let rand = rng.random_range(zero..=max);

            (0..summed.len())
                .filter(|&i| summed[i] > rand)
                .next()
                .unwrap_or(summed.len() - 1)
        })
    }
}