~ruther/ctu-fee-eoa

06f3259b7694559e224c43dac7a4fcf200762721 — Rutherther 2 months ago fcdc8b3
feat: add BoundedPerturbation

applies bounds on real numbers when performing the perturbation.
1 files changed, 76 insertions(+), 0 deletions(-)

M env/src/perturbation/mod.rs
M env/src/perturbation/mod.rs => env/src/perturbation/mod.rs +76 -0
@@ 54,6 54,82 @@ impl<TRng: Rng, TDistribution: Distribution<f64>, const LEN: usize> Perturbation
    }
}

pub enum BoundedPerturbationStrategy {
    /// Trims the value to get a value within bounds
    Trim,
    /// Retries calling the underlying perturbation until
    /// value within bounds is returned. If argument is given,
    /// this is the maximum number of retries to do and then
    /// fall back to trimming. Zero means retry indefinitely.
    Retry(usize)
}

pub struct BoundedPerturbation<const LEN: usize, T: PerturbationOperator<Chromosome = SVector<f64, LEN>>> {
    min: SVector<f64, LEN>,
    max: SVector<f64, LEN>,
    min_max: SVector<(f64, f64), LEN>,
    strategy: BoundedPerturbationStrategy,
    perturbation: T,
}

impl<const LEN: usize, T: PerturbationOperator<Chromosome = SVector<f64, LEN>>> BoundedPerturbation<LEN, T> {
    pub fn new(
        perturbation: T,
        min: SVector<f64, LEN>,
        max: SVector<f64, LEN>,
        strategy: BoundedPerturbationStrategy
    ) -> Self {
        let min_max = min.zip_map(&max, |min, max| (min, max));
        Self {
            min,
            max,
            min_max,
            strategy,
            perturbation
        }
    }

    fn within_bounds(&self, chromosome: &SVector<f64, LEN>) -> bool {
        chromosome.iter()
            .zip(self.min_max.iter())
            .all(|(&c, &(min, max))| c <= max && c >= min)
    }

    fn bound(&self, mut chromosome: SVector<f64, LEN>) -> SVector<f64, LEN> {
        chromosome
            .zip_apply(&self.min_max, |c, (min, max)| *c = c.clamp(min, max));

        chromosome
    }

    fn retry_perturb(self: &mut Self, chromosome: &SVector<f64, LEN>, retries: Option<usize>) -> SVector<f64, LEN> {
        let perturbed = self.perturbation.perturb(chromosome);

        if self.within_bounds(&perturbed) {
            return perturbed;
        }

        match retries {
            Some(0) | None => self.bound(perturbed),
            Some(retries) => self.retry_perturb(chromosome, Some(retries - 1))
        }
    }
}

impl<const LEN: usize, T> PerturbationOperator for BoundedPerturbation<LEN, T>
where
    T: PerturbationOperator<Chromosome = SVector<f64, LEN>>
{
    type Chromosome = SVector<f64, LEN>;

    fn perturb(self: &mut Self, chromosome: &Self::Chromosome) -> Self::Chromosome {
        match self.strategy {
            BoundedPerturbationStrategy::Trim => self.retry_perturb(chromosome, None),
            BoundedPerturbationStrategy::Retry(retries) => self.retry_perturb(chromosome, Some(retries))
        }
    }
}

#[cfg(test)]
pub mod tests {
    use crate::binary_string::BinaryString;