use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OVector, U1}; use eoa_lib::{binary_string::BinaryString, fitness::FitnessFunction}; use thiserror::Error; use crate::tsp::{NodePermutation, TSPInstance}; use eoa_lib::population::EvaluatedChromosome; impl<'a, DIn: Dim, DOut: Dim> FitnessFunction for TSPBinaryStringWrapper<'a, DIn, DOut> where DefaultAllocator: Allocator, DefaultAllocator: Allocator, DefaultAllocator: Allocator, { type In = BinaryString; type Out = f64; type Err = DimensionMismatch; fn fit(self: &Self, inp: &Self::In) -> Result { Ok(self.instance.fit(&self.to_permutation(inp)?).unwrap()) } fn fit_population(&self, inp: Vec) -> Result>, Self::Err> { let nodes = self.dim_out.value(); // Count how many nodes each node comes after (precedence count) let mut precedence_count = OVector::::zeros_generic(self.dim_out, U1); let result = OVector::from_iterator_generic( self.dim_out, U1, 0..nodes ); let mut permutation = NodePermutation { permutation: result }; inp .into_iter() .map(|chromosome| { // Reset precedence_count .apply(|c| *c = 0); // NOTE no need to reset the permutation // as it's always sorted self.to_permutation_buff( &chromosome, &mut permutation, &mut precedence_count )?; Ok(EvaluatedChromosome { evaluation: self.instance.fit(&permutation).unwrap(), chromosome, }) }) .collect::, _>>() } } pub struct TSPBinaryStringWrapper<'a, DIn: Dim, DOut: Dim> where DOut: Dim, DefaultAllocator: Allocator { instance: &'a TSPInstance, dim_in: DIn, dim_out: DOut, } impl<'a, DIn: Dim, DOut: Dim> TSPBinaryStringWrapper<'a, DIn, DOut> where DOut: Dim, DefaultAllocator: Allocator, DefaultAllocator: Allocator, DefaultAllocator: Allocator, { pub fn new( instance: &'a TSPInstance, dim_in: DIn, dim_out: DOut ) -> Result { let res = Self { instance, dim_in, dim_out }; if dim_out.value() * (dim_out.value() - 1) / 2 != dim_in.value() { Err(DimensionMismatch::Mismatch) } else { Ok(res) } } pub fn to_permutation_buff( &self, inp: &BinaryString, permutation: &mut NodePermutation, precedence_count: &mut OVector ) -> Result<(), DimensionMismatch> { let nodes = self.dim_out.value(); if inp.vec().shape_generic().0.value() != self.dim_in.value() { return Err(DimensionMismatch::Mismatch); } let mut in_index = 0usize; for i in 0..self.dim_out.value() { for j in i+1..nodes { if in_index >= inp.vec.len() { return Err(DimensionMismatch::Mismatch); } if inp.vec[in_index] == 1 { // i comes before j, so j has one more predecessor precedence_count[j] += 1; } else { // j comes before i, so i has one more predecessor precedence_count[i] += 1; } in_index += 1; } } if in_index != inp.vec.len() { return Err(DimensionMismatch::Mismatch); } permutation.permutation .as_mut_slice() .sort_unstable_by_key(|&node| precedence_count[node]); Ok(()) } pub fn to_permutation(&self, inp: &BinaryString) -> Result, DimensionMismatch> { let nodes = self.dim_out.value(); // Count how many nodes each node comes after (precedence count) let mut precedence_count = OVector::::zeros_generic(self.dim_out, U1); let mut result = OVector::from_iterator_generic( self.dim_out, U1, 0..nodes ); let mut permutation = NodePermutation { permutation: result }; self.to_permutation_buff(inp, &mut permutation, &mut precedence_count)?; Ok(permutation) } } #[derive(Error, Debug)] pub enum DimensionMismatch { #[error("The input dimension should be equal to half matrix NxN where the output is N")] Mismatch } #[cfg(test)] mod tests { use super::*; use nalgebra::{Const, SVector, U15, U6}; use eoa_lib::binary_string::BinaryString; #[test] fn test_binary_string_representation() { // x 0 1 2 3 4 5 // 0 0 0 0 0 0 0 // 1 1 0 0 0 0 0 // 2 1 1 0 0 0 0 // 3 1 1 1 0 0 0 // 4 1 1 1 1 0 0 // 5 1 1 1 1 1 0 // x 0 1 2 3 4 5 // 0 0 0 0 0 0 // 1 0 0 0 0 // 2 0 0 0 // 3 0 0 // 4 0 // 5 // 6 nodes // length of binary string: 5 + 4 + 3 + 2 + 1 = 15 let tsp = TSPInstance::new_const( vec![ (0.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0), ] ); let converter = TSPBinaryStringWrapper::new( &tsp, U15, U6 ).unwrap(); let binary_string_ordering = BinaryString::::new(vec![1; 15]); let mut expected_permutation = vec![0, 1, 2, 3, 4, 5]; let mut permutation = converter.to_permutation(&binary_string_ordering) .unwrap(); assert_eq!( expected_permutation, permutation.permutation.as_mut_slice().to_vec() ); let binary_string_ordering = BinaryString::::new(vec![0; 15]); expected_permutation.reverse(); let mut permutation = converter.to_permutation(&binary_string_ordering) .unwrap(); assert_eq!( expected_permutation, permutation.permutation.as_mut_slice().to_vec() ) } #[test] fn test_nontrivial_binary_string_representation() { // x 0 1 2 3 4 5 // 0 0 1 0 0 0 0 // 1 0 0 0 0 0 0 // 2 1 1 0 0 0 1 // 3 1 1 1 0 0 0 // 4 1 1 1 1 0 0 // 5 1 1 0 1 1 0 // x 0 1 2 3 4 5 // 0 0 0 0 0 0 // 1 0 0 0 0 // 2 1 1 1 // 3 0 0 // 4 1 // 5 // 6 nodes // length of binary string: 5 + 4 + 3 + 2 + 1 = 15 let tsp = TSPInstance::new_const( vec![ (0.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0), ] ); let converter = TSPBinaryStringWrapper::new( &tsp, U15, U6 ).unwrap(); let mut binary_string_ordering = BinaryString::::new(vec![0; 15]); binary_string_ordering.vec[9] = 1; binary_string_ordering.vec[10] = 1; binary_string_ordering.vec[11] = 1; binary_string_ordering.vec[14] = 1; let expected_permutation = vec![2, 4, 5, 3, 1, 0]; let mut permutation = converter.to_permutation(&binary_string_ordering) .unwrap(); assert_eq!( expected_permutation, permutation.permutation.as_mut_slice().to_vec() ); } }