~ruther/ctu-fee-eoa

5f409074518fb0a5700af0bf1cd07a727bf8df17 — Rutherther a month ago 0420a0c
refactor(tsp): Do not allocate in binary string fitness function

Currently the fit function has allocated for each individual!
That slows it down terribly!

This reimplements it to implement the fit_population, reusing
the same node permutation and same precedence_count
2 files changed, 77 insertions(+), 23 deletions(-)

M codes/tsp_hw01/src/binary_string_representation.rs
M codes/tsp_hw01/src/main.rs
M codes/tsp_hw01/src/binary_string_representation.rs => codes/tsp_hw01/src/binary_string_representation.rs +62 -8
@@ 2,6 2,7 @@ use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OVector, U1};
use eoa_lib::{binary_string::BinaryString, fitness::FitnessFunction};
use thiserror::Error;
use crate::tsp::{NodePermutation, TSPInstance};
use eoa_lib::population::EvaluatedChromosome;

impl<'a, DIn: Dim, DOut: Dim> FitnessFunction for TSPBinaryStringWrapper<'a, DIn, DOut>
where


@@ 16,6 17,44 @@ where
    fn fit(self: &Self, inp: &Self::In) -> Result<Self::Out, Self::Err> {
        Ok(self.instance.fit(&self.to_permutation(inp)?).unwrap())
    }

    fn fit_population(&self, inp: Vec<Self::In>) -> Result<Vec<EvaluatedChromosome<Self::In, Self::Out>>, Self::Err> {
        let nodes = self.dim_out.value();

        // Count how many nodes each node comes after (precedence count)
        let mut precedence_count = OVector::<usize, DOut>::zeros_generic(self.dim_out, U1);

        let result = OVector::from_iterator_generic(
            self.dim_out,
            U1,
            0..nodes
        );

        let mut permutation = NodePermutation { permutation: result };

        inp
            .into_iter()
            .map(|chromosome| {
                // Reset
                precedence_count
                    .apply(|c| *c = 0);

                // NOTE no need to reset the permutation
                // as it's always sorted

                self.to_permutation_buff(
                    &chromosome,
                    &mut permutation,
                    &mut precedence_count
                )?;

                Ok(EvaluatedChromosome {
                    evaluation: self.instance.fit(&permutation).unwrap(),
                    chromosome,
                })
            })
            .collect::<Result<Vec<_>, _>>()
    }
}

pub struct TSPBinaryStringWrapper<'a, DIn: Dim, DOut: Dim>


@@ 53,16 92,18 @@ where
        }
    }

    pub fn to_permutation(&self, inp: &BinaryString<DIn>) -> Result<NodePermutation<DOut>, DimensionMismatch> {
    pub fn to_permutation_buff(
        &self,
        inp: &BinaryString<DIn>,
        permutation: &mut NodePermutation<DOut>,
        precedence_count: &mut OVector<usize, DOut>
    ) -> Result<(), DimensionMismatch> {
        let nodes = self.dim_out.value();

        if inp.vec().shape_generic().0.value() != self.dim_in.value() {
            return Err(DimensionMismatch::Mismatch);
        }

        // Count how many nodes each node comes after (precedence count)
        let mut precedence_count = OVector::<usize, DOut>::zeros_generic(self.dim_out, U1);

        let mut in_index = 0usize;
        for i in 0..self.dim_out.value() {
            for j in i+1..nodes {


@@ 86,17 127,30 @@ where
            return Err(DimensionMismatch::Mismatch);
        }

        permutation.permutation
            .as_mut_slice()
            .sort_unstable_by_key(|&node| precedence_count[node]);

        Ok(())
    }

    pub fn to_permutation(&self, inp: &BinaryString<DIn>) -> Result<NodePermutation<DOut>, DimensionMismatch> {
        let nodes = self.dim_out.value();

        // Count how many nodes each node comes after (precedence count)
        let mut precedence_count = OVector::<usize, DOut>::zeros_generic(self.dim_out, U1);

        let mut result = OVector::from_iterator_generic(
            self.dim_out,
            U1,
            0..nodes
        );

        result
            .as_mut_slice()
            .sort_by_key(|&node| precedence_count[node]);
        let mut permutation = NodePermutation { permutation: result };

        self.to_permutation_buff(inp, &mut permutation, &mut precedence_count)?;

        Ok(NodePermutation { permutation: result })
        Ok(permutation)
    }
}


M codes/tsp_hw01/src/main.rs => codes/tsp_hw01/src/main.rs +15 -15
@@ 13,7 13,7 @@ use perturbations::{MovePerturbation, ReverseSubsequencePerturbation, SwapPertur
use binary_string_representation::TSPBinaryStringWrapper;
use nalgebra::{Dim, Dyn};
use eoa_lib::{
    binary_string::BinaryString, comparison::MinimizingOperator, crossover::BinaryNPointCrossover, evolution::{evolution_algorithm, EvolutionStats}, evolutionary_strategy::IdentityStrategy, initializer::{Initializer, RandomInitializer}, local_search::{local_search_first_improving, LocalSearchStats}, pairing::AdjacentPairing, perturbation::{apply_to_perturbations, BinaryStringBitPerturbation, BinaryStringFlipNPerturbation, BinaryStringSingleBitPerturbation, CombinedPerturbation, IdentityPerturbation, MutationPerturbation, OneOfPerturbation}, random_search::random_search, replacement::{BestReplacement, TournamentReplacement}, selection::{BestSelection, RouletteWheelSelection}, terminating::MaximumCyclesTerminatingCondition
    binary_string::BinaryString, comparison::MinimizingOperator, crossover::BinaryNPointCrossover, evolution::{evolution_algorithm, EvolutionStats}, evolutionary_strategy::IdentityStrategy, initializer::{Initializer, RandomInitializer}, local_search::{local_search_first_improving, LocalSearchStats}, pairing::AdjacentPairing, perturbation::{apply_to_perturbations, BinaryStringBitPerturbation, BinaryStringFlipNPerturbation, BinaryStringFlipPerturbation, BinaryStringSingleBitPerturbation, CombinedPerturbation, IdentityPerturbation, MutationPerturbation, OneOfPerturbation}, random_search::random_search, replacement::{BestReplacement, TournamentReplacement}, selection::{BestSelection, RouletteWheelSelection, TournamentSelection}, terminating::MaximumCyclesTerminatingCondition
};
use rand::rng;
use std::env;


@@ 615,10 615,10 @@ fn run_evolution_algorithm_binary(instance: &TSPInstance<Dyn>) -> Result<PlotDat
    let input_dimension = Dyn(output_dimension.value() * (output_dimension.value() - 1) / 2);

    // Create combined perturbation with two mutations wrapped in MutationPerturbation
    let bit_mutation = MutationPerturbation::new(Box::new(BinaryStringBitPerturbation::new(0.1)), 0.2);
    let bit_mutation = MutationPerturbation::new(Box::new(BinaryStringBitPerturbation::new(0.01)), 0.2);
    let single_bit_mutation = MutationPerturbation::new(Box::new(BinaryStringSingleBitPerturbation::new()), 0.4);
    let flip1_mutation = MutationPerturbation::new(Box::new(BinaryStringFlipNPerturbation::new(30)), 0.4);
    let flip2_mutation = MutationPerturbation::new(Box::new(BinaryStringFlipNPerturbation::new(20)), 0.4);
    let flip2_mutation = MutationPerturbation::new(Box::new(BinaryStringFlipPerturbation::new()), 0.4);
    let mut combined_perturbation = CombinedPerturbation::new(vec![
        Box::new(bit_mutation),
        Box::new(single_bit_mutation),


@@ 627,9 627,9 @@ fn run_evolution_algorithm_binary(instance: &TSPInstance<Dyn>) -> Result<PlotDat
    ]);

    // Set up other components
    let mut crossover = BinaryNPointCrossover::<10, _, _>::new();
    let mut crossover = BinaryNPointCrossover::<5, _, _>::new();
    let mut selection = BestSelection::new();
    let mut replacement = TournamentReplacement::new(5, 1.0);
    let mut replacement = TournamentReplacement::new(5, 0.9);
    let mut pairing = AdjacentPairing::new();
    let better_than_operator = MinimizingOperator::new();



@@ 658,15 658,15 @@ fn run_evolution_algorithm_binary(instance: &TSPInstance<Dyn>) -> Result<PlotDat
            let iters_till_end = EA_MAX_ITERATIONS - iteration + 1;
            let iters_since_better =
                iteration - stats.best_candidates.last().map(|c| c.iteration).unwrap_or(0);
            let mut found = false;
            apply_to_perturbations::<_, BinaryStringBitPerturbation<Dyn>>(
                perturbation,
                &mut |p| {
                    found = true;
                    p.p = (0.025 * (1.0 + (iters_since_better as f64 / iters_till_end as f64))).min(0.2);
                }
            );
            assert!(found);
            // let mut found = false;
            // apply_to_perturbations::<_, BinaryStringBitPerturbation<Dyn>>(
            //     perturbation,
            //     &mut |p| {
            //         found = true;
            //         p.p = (0.025 * (1.0 + (iters_since_better as f64 / iters_till_end as f64))).min(0.2);
            //     }
            // );
            // assert!(found);

            let mut found = 0;
            MutationPerturbation::apply_to_mutations(


@@ 674,7 674,7 @@ fn run_evolution_algorithm_binary(instance: &TSPInstance<Dyn>) -> Result<PlotDat
                &mut |p| {
                    // Do not touch multi bit mutation
                    if found > 0 {
                        p.probability = (0.5 * (1.0 + (iters_since_better as f64 / iters_till_end as f64))).min(1.0);
                        p.probability = (0.4 * (0.5 + (iters_since_better as f64 / EA_MAX_ITERATIONS as f64))).min(1.0);
                    }
                    found += 1;
                }