From 8eaf8145d149098a0ef70811ce8ea9c98447eb11 Mon Sep 17 00:00:00 2001 From: Rutherther Date: Sat, 1 Nov 2025 10:23:59 +0100 Subject: [PATCH] fix(tsp): properly implement binary string -> node permutation fitness --- codes/tsp_hw01/src/tsp.rs | 123 ++++++++++++++++++++++++-------------- 1 file changed, 77 insertions(+), 46 deletions(-) diff --git a/codes/tsp_hw01/src/tsp.rs b/codes/tsp_hw01/src/tsp.rs index d95714b4ffb9e22d4a439712ff29df113b3c6325..a4c6d39a31f54ad072ed979d58384e18bbd9f2c4 100644 --- a/codes/tsp_hw01/src/tsp.rs +++ b/codes/tsp_hw01/src/tsp.rs @@ -452,89 +452,120 @@ where } } -pub struct BinaryStringToNodePermutation { +pub struct TSPBinaryStringWrapper<'a, DIn: Dim, DOut: Dim> +where + DOut: Dim, + DefaultAllocator: Allocator +{ + instance: &'a TSPInstance, dim_in: DIn, dim_out: DOut, } -impl BinaryStringToNodePermutation { - pub fn new(dim_in: DIn, dim_out: DOut) -> Self { - Self { +impl<'a, DIn: Dim, DOut: Dim> TSPBinaryStringWrapper<'a, DIn, DOut> +where + DOut: Dim, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, +{ + pub fn new( + instance: &'a TSPInstance, + dim_in: DIn, + dim_out: DOut + ) -> Result { + let res = Self { + instance, dim_in, dim_out + }; + + if dim_out.value() * (dim_out.value() - 1) / 2 != dim_in.value() { + Err(DimensionMismatch::Mismatch) + } else { + Ok(res) } } -} -#[derive(Error, Debug)] -pub enum DimensionMismatch { - #[error("The input dimension should be equal to half matrix NxN where the output is N")] - Mismatch -} + pub fn to_permutation(&self, inp: &BinaryString) -> Result, DimensionMismatch> { + let nodes = self.dim_out.value(); -impl FitnessFunction for BinaryStringToNodePermutation -where - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, -{ - type In = BinaryString; - type Out = NodePermutation; - type Err = DimensionMismatch; + if inp.vec().shape_generic().0.value() != self.dim_in.value() { + return Err(DimensionMismatch::Mismatch); + } - fn fit(self: &Self, inp: &Self::In) -> Result { - let nodes = self.dim_out.value(); - let mut orderings = - OMatrix::::from_element_generic(self.dim_out, self.dim_out, Ordering::Equal); + // Count how many nodes each node comes after (precedence count) + let mut precedence_count = OVector::::zeros_generic(self.dim_out, U1); let mut in_index = 0usize; for i in 0..self.dim_out.value() { - for j in i+1..self.dim_out.value() { - orderings[(i, j)] = if inp.vec[in_index] == 1 { - Ordering::Greater - } else { - Ordering::Less - }; + for j in i+1..nodes { + if in_index >= inp.vec.len() { + return Err(DimensionMismatch::Mismatch); + } - orderings[(j, i)] = if inp.vec[in_index] == 1 { - Ordering::Less + if inp.vec[in_index] == 1 { + // i comes before j, so j has one more predecessor + precedence_count[j] += 1; } else { - Ordering::Greater - }; + // j comes before i, so i has one more predecessor + precedence_count[i] += 1; + } in_index += 1; } } - let mut result = - OVector::from_iterator_generic(self.dim_out, U1, 0..nodes); + if in_index != inp.vec.len() { + return Err(DimensionMismatch::Mismatch); + } - for i in 0..nodes { - for j in i+1..nodes { - let node1 = result[i]; - let node2 = result[j]; + let mut result = OVector::from_iterator_generic( + self.dim_out, + U1, + 0..nodes + ); - if orderings[(node1, node2)] == Ordering::Greater { - result.swap_rows(i, j); - } - } - } + result + .as_mut_slice() + .sort_by_key(|&node| precedence_count[node]); Ok(NodePermutation { permutation: result }) } } +#[derive(Error, Debug)] +pub enum DimensionMismatch { + #[error("The input dimension should be equal to half matrix NxN where the output is N")] + Mismatch +} + +impl<'a, DIn: Dim, DOut: Dim> FitnessFunction for TSPBinaryStringWrapper<'a, DIn, DOut> +where + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, +{ + type In = BinaryString; + type Out = f64; + type Err = DimensionMismatch; + + fn fit(self: &Self, inp: &Self::In) -> Result { + Ok(self.instance.fit(&self.to_permutation(inp)?).unwrap()) + } +} + #[cfg(test)] mod tests { use std::convert::Infallible; use eoa_lib::{binary_string::BinaryString, crossover::Crossover, fitness::FitnessFunction, initializer::Initializer, pairing::{AdjacentPairing, Pairing}, replacement::Population}; use nalgebra::{Const, SVector, U15, U6}; - use rand::{rngs::StdRng, RngCore, SeedableRng}; + use rand::{rngs::StdRng, seq::SliceRandom, RngCore, SeedableRng}; use crate::tsp::TSPInstance; - use super::{BinaryStringToNodePermutation, EdgeRecombinationCrossover, NodePermutation, ReverseSubsequencePerturbation, TSPRandomInitializer}; + use super::{TSPBinaryStringWrapper, EdgeRecombinationCrossover, NodePermutation, ReverseSubsequencePerturbation, TSPRandomInitializer}; use eoa_lib::perturbation::PerturbationOperator; struct MockRng;