From 11befbf6a156431429b2b66c42a90da63df6daf0 Mon Sep 17 00:00:00 2001 From: Rutherther Date: Sat, 1 Nov 2025 13:16:07 +0100 Subject: [PATCH] refactor(tsp): split to multiple files out of tsp.rs --- .../src/binary_string_representation.rs | 107 +++++ codes/tsp_hw01/src/crossovers.rs | 140 ++++++ codes/tsp_hw01/src/initializers.rs | 38 ++ codes/tsp_hw01/src/main.rs | 14 +- codes/tsp_hw01/src/perturbations.rs | 122 ++++++ codes/tsp_hw01/src/tsp.rs | 398 +----------------- 6 files changed, 422 insertions(+), 397 deletions(-) create mode 100644 codes/tsp_hw01/src/binary_string_representation.rs create mode 100644 codes/tsp_hw01/src/crossovers.rs create mode 100644 codes/tsp_hw01/src/initializers.rs create mode 100644 codes/tsp_hw01/src/perturbations.rs diff --git a/codes/tsp_hw01/src/binary_string_representation.rs b/codes/tsp_hw01/src/binary_string_representation.rs new file mode 100644 index 0000000000000000000000000000000000000000..d53c6a19a5355f657b5aeb15a61c2800cfba8500 --- /dev/null +++ b/codes/tsp_hw01/src/binary_string_representation.rs @@ -0,0 +1,107 @@ +use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OVector, U1}; +use eoa_lib::{binary_string::BinaryString, fitness::FitnessFunction}; +use thiserror::Error; +use crate::tsp::{NodePermutation, TSPInstance}; + +impl<'a, DIn: Dim, DOut: Dim> FitnessFunction for TSPBinaryStringWrapper<'a, DIn, DOut> +where + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, +{ + type In = BinaryString; + type Out = f64; + type Err = DimensionMismatch; + + fn fit(self: &Self, inp: &Self::In) -> Result { + Ok(self.instance.fit(&self.to_permutation(inp)?).unwrap()) + } +} + +pub struct TSPBinaryStringWrapper<'a, DIn: Dim, DOut: Dim> +where + DOut: Dim, + DefaultAllocator: Allocator +{ + instance: &'a TSPInstance, + dim_in: DIn, + dim_out: DOut, +} + +impl<'a, DIn: Dim, DOut: Dim> TSPBinaryStringWrapper<'a, DIn, DOut> +where + DOut: Dim, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, +{ + pub fn new( + instance: &'a TSPInstance, + dim_in: DIn, + dim_out: DOut + ) -> Result { + let res = Self { + instance, + dim_in, + dim_out + }; + + if dim_out.value() * (dim_out.value() - 1) / 2 != dim_in.value() { + Err(DimensionMismatch::Mismatch) + } else { + Ok(res) + } + } + + pub fn to_permutation(&self, inp: &BinaryString) -> Result, DimensionMismatch> { + let nodes = self.dim_out.value(); + + if inp.vec().shape_generic().0.value() != self.dim_in.value() { + return Err(DimensionMismatch::Mismatch); + } + + // Count how many nodes each node comes after (precedence count) + let mut precedence_count = OVector::::zeros_generic(self.dim_out, U1); + + let mut in_index = 0usize; + for i in 0..self.dim_out.value() { + for j in i+1..nodes { + if in_index >= inp.vec.len() { + return Err(DimensionMismatch::Mismatch); + } + + if inp.vec[in_index] == 1 { + // i comes before j, so j has one more predecessor + precedence_count[j] += 1; + } else { + // j comes before i, so i has one more predecessor + precedence_count[i] += 1; + } + + in_index += 1; + } + } + + if in_index != inp.vec.len() { + return Err(DimensionMismatch::Mismatch); + } + + let mut result = OVector::from_iterator_generic( + self.dim_out, + U1, + 0..nodes + ); + + result + .as_mut_slice() + .sort_by_key(|&node| precedence_count[node]); + + Ok(NodePermutation { permutation: result }) + } +} + +#[derive(Error, Debug)] +pub enum DimensionMismatch { + #[error("The input dimension should be equal to half matrix NxN where the output is N")] + Mismatch +} diff --git a/codes/tsp_hw01/src/crossovers.rs b/codes/tsp_hw01/src/crossovers.rs new file mode 100644 index 0000000000000000000000000000000000000000..27f1ae423935b4e081aed3a49261a3fca04c4c39 --- /dev/null +++ b/codes/tsp_hw01/src/crossovers.rs @@ -0,0 +1,140 @@ +use std::marker::PhantomData; +use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, Const, OMatrix, OVector}; +use rand::{prelude::IteratorRandom, Rng, RngCore}; +use eoa_lib::replacement::Population; +use itertools::Itertools; +use eoa_lib::crossover::Crossover; +use crate::tsp::NodePermutation; + +pub struct EdgeRecombinationCrossover { + _phantom: PhantomData +} + +impl EdgeRecombinationCrossover { + pub fn new() -> Self { + Self { _phantom: PhantomData } + } +} + +impl Crossover<2> for EdgeRecombinationCrossover +where + D: Dim, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, + DefaultAllocator: nalgebra::allocator::Allocator> +{ + type Chromosome = NodePermutation; + type Out = f64; + + fn crossover( + &self, + parents: &eoa_lib::replacement::EvaluatedPopulation, + pairs: impl Iterator>, + rng: &mut dyn RngCore + ) -> eoa_lib::replacement::Population { + let mut offsprings = vec![]; + + let permutation = &parents.population[0].chromosome.permutation; + let len = permutation.len(); + let mut adjacency_lists = OMatrix::from_element_generic( + permutation.shape_generic().0, + Const::<4>, + None); + let mut used_nodes = OVector::from_element_generic( + permutation.shape_generic().0, + Const::<1>, + false + ); + + let mut neighbors_count = OVector::from_element_generic( + permutation.shape_generic().0, + Const::<1>, + 2usize + ); + + for pair in pairs { + let parent1 = &parents.population[pair.x].chromosome; + let parent2 = &parents.population[pair.y].chromosome; + + used_nodes.apply(|n| *n = false); + + // 1. Populate adjacency lists + for (&c1, &n, &c2) in parent1.permutation.iter().circular_tuple_windows() { + adjacency_lists[(n, 0)] = Some(c1); + adjacency_lists[(n, 1)] = Some(c2); + neighbors_count[n] = 2; + } + + for (&c1, &n, &c2) in parent2.permutation.iter().circular_tuple_windows() { + // Not duplicit? + if adjacency_lists[(n, 0)].unwrap() != c1 && adjacency_lists[(n, 1)].unwrap() != c1 { + neighbors_count[n] += 1; + adjacency_lists[(n, 2)] = Some(c1); + } else { // Duplicit + adjacency_lists[(n, 2)] = None; + } + + // Not duplicit + if adjacency_lists[(n, 0)].unwrap() != c2 && adjacency_lists[(n, 1)].unwrap() != c2 { + neighbors_count[n] += 1; + adjacency_lists[(n, 3)] = Some(c2); + } else { // Duplicit + adjacency_lists[(n, 3)] = None; + } + } + + let chosen_parent = if rng.random_bool(0.5) { + &parent1 + } else { + &parent2 + }; + + let mut offspring = OVector::from_element_generic(permutation.shape_generic().0, Const::<1>, 0); + + let mut current_node = chosen_parent.permutation[0]; + + for i in 0..len-1 { + offspring[i] = current_node; + used_nodes[current_node] = true; + + for neighbor in adjacency_lists.row(current_node) { + if let Some(neighbor) = neighbor { + neighbors_count[*neighbor] -= 1; + } + } + + let min_neighbors = adjacency_lists.row(current_node) + .iter() + .flatten() + .filter(|&&neighbor| !used_nodes[neighbor]) + .map(|&neighbor| neighbors_count[neighbor]) + .min(); + + let neighbor = if let Some(min_neighbors) = min_neighbors { + adjacency_lists.row(current_node) + .iter() + .flatten() + .copied() + .filter(|&neighbor| !used_nodes[neighbor] && neighbors_count[neighbor] == min_neighbors) + .choose(rng) + } else { + None + }; + + current_node = if let Some(neighbor) = neighbor { + neighbor + } else { + (0..len).filter(|&node| !used_nodes[node]) + .choose(rng) + .unwrap() + }; + } + + offspring[len - 1] = current_node; + + offsprings.push(NodePermutation { permutation: offspring }); + } + + Population::from_vec(offsprings) + } +} diff --git a/codes/tsp_hw01/src/initializers.rs b/codes/tsp_hw01/src/initializers.rs new file mode 100644 index 0000000000000000000000000000000000000000..2875a2013a9312237ff62a3ef6729903a69ef8a3 --- /dev/null +++ b/codes/tsp_hw01/src/initializers.rs @@ -0,0 +1,38 @@ +use std::marker::PhantomData; +use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OVector, U1}; +use rand::{prelude::SliceRandom, RngCore}; +use eoa_lib::initializer::Initializer; +use crate::tsp::NodePermutation; + +pub struct TSPRandomInitializer +where + D: Dim, + DefaultAllocator: Allocator, +{ + _phantom: PhantomData +} + +impl TSPRandomInitializer +where + D: Dim, + DefaultAllocator: Allocator, +{ + pub fn new() -> Self { + Self { _phantom: PhantomData } + } +} + +impl Initializer> for TSPRandomInitializer +where + D: Dim, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, +{ + fn initialize_single(&self, size: D, rng: &mut dyn RngCore) -> NodePermutation { + let len = size.value(); + let mut indices = OVector::::from_iterator_generic(size, U1, 0..len); + indices.as_mut_slice().shuffle(rng); + + NodePermutation { permutation: indices } + } +} diff --git a/codes/tsp_hw01/src/main.rs b/codes/tsp_hw01/src/main.rs index 8892481cbc90dfec4a145c345b204817adb91cf4..1cf999fc41fe99913039874224bafd667d270a7c 100644 --- a/codes/tsp_hw01/src/main.rs +++ b/codes/tsp_hw01/src/main.rs @@ -1,10 +1,18 @@ pub mod tsp; +pub mod initializers; +pub mod crossovers; +pub mod binary_string_representation; +pub mod perturbations; pub mod graph; -use tsp::{EdgeRecombinationCrossover, MovePerturbation, NodePermutation, ReverseSubsequencePerturbation, SwapPerturbation, TSPBinaryStringWrapper, TSPInstance, TSPRandomInitializer}; +use tsp::{NodePermutation, TSPInstance}; +use initializers::TSPRandomInitializer; +use crossovers::EdgeRecombinationCrossover; +use perturbations::{MovePerturbation, ReverseSubsequencePerturbation, SwapPerturbation}; +use binary_string_representation::TSPBinaryStringWrapper; use nalgebra::{Dim, Dyn}; use eoa_lib::{ - binary_string::BinaryString, comparison::MinimizingOperator, crossover::BinaryNPointCrossover, evolution::{evolution_algorithm, EvolutionStats}, initializer::{Initializer, RandomInitializer}, local_search::{local_search_first_improving, LocalSearchStats}, pairing::AdjacentPairing, perturbation::{apply_to_perturbations, BinaryStringBitPerturbation, BinaryStringFlipNPerturbation, BinaryStringSingleBitPerturbation, CombinedPerturbation, MutationPerturbation}, replacement::{BestReplacement, TournamentReplacement}, selection::{BestSelection, TournamentSelection}, terminating::MaximumCyclesTerminatingCondition + binary_string::BinaryString, comparison::MinimizingOperator, crossover::BinaryNPointCrossover, evolution::{evolution_algorithm, EvolutionStats}, initializer::{Initializer, RandomInitializer}, local_search::{local_search_first_improving, LocalSearchStats}, pairing::AdjacentPairing, perturbation::{apply_to_perturbations, BinaryStringBitPerturbation, BinaryStringFlipNPerturbation, BinaryStringSingleBitPerturbation, CombinedPerturbation, MutationPerturbation}, replacement::{BestReplacement, TournamentReplacement}, selection::{BestSelection, RouletteWheelSelection}, terminating::MaximumCyclesTerminatingCondition }; use rand::rng; use std::env; @@ -224,7 +232,7 @@ fn run_evolution_algorithm(instance: &TSPInstance) -> Result { + _phantom: PhantomData +} + +impl SwapPerturbation { + pub fn new() -> Self { + Self { _phantom: PhantomData } + } +} + +impl PerturbationOperator for SwapPerturbation +where + D: Dim, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, +{ + type Chromosome = NodePermutation; + + fn perturb(&self, chromosome: &mut Self::Chromosome, rng: &mut dyn RngCore) { + let first = rng.random_range(0..chromosome.permutation.len()); + let second = rng.random_range(0..chromosome.permutation.len()); + chromosome.permutation.swap_rows(first, second); + } +} + +pub struct MovePerturbation { + _phantom: PhantomData +} + +impl MovePerturbation { + pub fn new() -> Self { + Self { _phantom: PhantomData } + } +} + +impl PerturbationOperator for MovePerturbation +where + D: Dim, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, +{ + type Chromosome = NodePermutation; + + fn perturb(&self, chromosome: &mut Self::Chromosome, rng: &mut dyn RngCore) { + let from = rng.random_range(0..chromosome.permutation.len()); + let to = rng.random_range(0..chromosome.permutation.len()); + + let element = chromosome.permutation[from]; + + if from < to { + for i in from..to { + chromosome.permutation[i] = chromosome.permutation[i + 1]; + } + } else { + for i in (to+1..=from).rev() { + chromosome.permutation[i] = chromosome.permutation[i - 1]; + } + } + + chromosome.permutation[to] = element; + } +} + +pub struct ReverseSubsequencePerturbation { + _phantom: PhantomData, + min_subsequence_len: usize, + max_subsequence_len: usize, +} + +impl ReverseSubsequencePerturbation { + pub fn new() -> Self { + Self { + _phantom: PhantomData, + max_subsequence_len: usize::MAX, + min_subsequence_len: 0, + } + } +} + +impl PerturbationOperator for ReverseSubsequencePerturbation +where + D: Dim, + DefaultAllocator: Allocator, + DefaultAllocator: Allocator, +{ + type Chromosome = NodePermutation; + + fn perturb(&self, chromosome: &mut Self::Chromosome, rng: &mut dyn RngCore) { + let len = chromosome.permutation.len(); + let index = rng.random_range(0..chromosome.permutation.len()); + let subsequence_len = rng.random_range( + self.min_subsequence_len..(chromosome.permutation.len().min(self.max_subsequence_len)) + ); + + // Reverse the subsequence between start and end (inclusive) + let mut left = index; + let mut right = (index + subsequence_len) % len; + + while left != right { + chromosome.permutation.swap_rows(left, right); + + left += 1; + left %= len; + + if left == right { + break; + } + + if right > 0 { + right -= 1; + } else { + right = len - 1; + } + } + } +} diff --git a/codes/tsp_hw01/src/tsp.rs b/codes/tsp_hw01/src/tsp.rs index d058081548676eab302aa6e70877826ceece7c38..a639f900c60774ab3c32615f5b187989f6dffb29 100644 --- a/codes/tsp_hw01/src/tsp.rs +++ b/codes/tsp_hw01/src/tsp.rs @@ -1,13 +1,9 @@ -use std::{cmp::Ordering, convert::Infallible, error::Error, marker::PhantomData}; +use std::convert::Infallible; -use eoa_lib::{binary_string::BinaryString, crossover::Crossover, fitness::FitnessFunction, initializer::Initializer, perturbation::PerturbationOperator, replacement::Population}; +use eoa_lib::fitness::FitnessFunction; use itertools::Itertools; -use nalgebra::{allocator::Allocator, distance, Const, DefaultAllocator, Dim, Dyn, OMatrix, OVector, Point, U1}; +use nalgebra::{allocator::Allocator, distance, Const, DefaultAllocator, Dim, Dyn, OMatrix, OVector, Point}; use plotters::prelude::*; -use rand::{seq::{IteratorRandom, SliceRandom}, Rng, RngCore}; -use thiserror::Error; - -use crate::graph::Edge; #[derive(PartialEq, Clone, Debug)] pub struct TSPCity { @@ -19,7 +15,7 @@ pub struct NodePermutation where DefaultAllocator: Allocator { - permutation: OVector + pub permutation: OVector } /// An instance of TSP, a fully connected graph @@ -207,392 +203,6 @@ where } } -pub struct TSPRandomInitializer -where - D: Dim, - DefaultAllocator: Allocator, -{ - _phantom: PhantomData -} - -impl TSPRandomInitializer -where - D: Dim, - DefaultAllocator: Allocator, -{ - pub fn new() -> Self { - Self { _phantom: PhantomData } - } -} - -impl Initializer> for TSPRandomInitializer -where - D: Dim, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, -{ - fn initialize_single(&self, size: D, rng: &mut dyn RngCore) -> NodePermutation { - let len = size.value(); - let mut indices = OVector::::from_iterator_generic(size, U1, 0..len); - indices.as_mut_slice().shuffle(rng); - - NodePermutation { permutation: indices } - } -} - -pub struct MovePerturbation { - _phantom: PhantomData -} - -impl MovePerturbation { - pub fn new() -> Self { - Self { _phantom: PhantomData } - } -} - -impl PerturbationOperator for MovePerturbation -where - D: Dim, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, -{ - type Chromosome = NodePermutation; - - fn perturb(&self, chromosome: &mut Self::Chromosome, rng: &mut dyn RngCore) { - let from = rng.random_range(0..chromosome.permutation.len()); - let to = rng.random_range(0..chromosome.permutation.len()); - - let element = chromosome.permutation[from]; - - if from < to { - for i in from..to { - chromosome.permutation[i] = chromosome.permutation[i + 1]; - } - } else { - for i in (to+1..=from).rev() { - chromosome.permutation[i] = chromosome.permutation[i - 1]; - } - } - - chromosome.permutation[to] = element; - } -} - -pub struct SwapPerturbation { - _phantom: PhantomData -} - -impl SwapPerturbation { - pub fn new() -> Self { - Self { _phantom: PhantomData } - } -} - -impl PerturbationOperator for SwapPerturbation -where - D: Dim, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, -{ - type Chromosome = NodePermutation; - - fn perturb(&self, chromosome: &mut Self::Chromosome, rng: &mut dyn RngCore) { - let first = rng.random_range(0..chromosome.permutation.len()); - let second = rng.random_range(0..chromosome.permutation.len()); - chromosome.permutation.swap_rows(first, second); - } -} - -pub struct ReverseSubsequencePerturbation { - _phantom: PhantomData, - min_subsequence_len: usize, - max_subsequence_len: usize, -} - -impl ReverseSubsequencePerturbation { - pub fn new() -> Self { - Self { - _phantom: PhantomData, - max_subsequence_len: usize::MAX, - min_subsequence_len: 0, - } - } -} - -impl PerturbationOperator for ReverseSubsequencePerturbation -where - D: Dim, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, -{ - type Chromosome = NodePermutation; - - fn perturb(&self, chromosome: &mut Self::Chromosome, rng: &mut dyn RngCore) { - let len = chromosome.permutation.len(); - let index = rng.random_range(0..chromosome.permutation.len()); - let subsequence_len = rng.random_range( - self.min_subsequence_len..(chromosome.permutation.len().min(self.max_subsequence_len)) - ); - - // Reverse the subsequence between start and end (inclusive) - let mut left = index; - let mut right = (index + subsequence_len) % len; - - while left != right { - chromosome.permutation.swap_rows(left, right); - - left += 1; - left %= len; - - if left == right { - break; - } - - if right > 0 { - right -= 1; - } else { - right = len - 1; - } - } - } -} - -pub struct EdgeRecombinationCrossover { - _phantom: PhantomData -} - -impl EdgeRecombinationCrossover { - pub fn new() -> Self { - Self { _phantom: PhantomData } - } -} - -impl Crossover<2> for EdgeRecombinationCrossover -where - D: Dim, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, - DefaultAllocator: nalgebra::allocator::Allocator> -{ - type Chromosome = NodePermutation; - type Out = f64; - - fn crossover( - &self, - parents: &eoa_lib::replacement::EvaluatedPopulation, - pairs: impl Iterator>, - rng: &mut dyn RngCore - ) -> eoa_lib::replacement::Population { - let mut offsprings = vec![]; - - let permutation = &parents.population[0].chromosome.permutation; - let len = permutation.len(); - let mut adjacency_lists = OMatrix::from_element_generic( - permutation.shape_generic().0, - Const::<4>, - None); - let mut used_nodes = OVector::from_element_generic( - permutation.shape_generic().0, - Const::<1>, - false - ); - - let mut neighbors_count = OVector::from_element_generic( - permutation.shape_generic().0, - Const::<1>, - 2usize - ); - - for pair in pairs { - let parent1 = &parents.population[pair.x].chromosome; - let parent2 = &parents.population[pair.y].chromosome; - - used_nodes.apply(|n| *n = false); - - // 1. Populate adjacency lists - for (&c1, &n, &c2) in parent1.permutation.iter().circular_tuple_windows() { - adjacency_lists[(n, 0)] = Some(c1); - adjacency_lists[(n, 1)] = Some(c2); - neighbors_count[n] = 2; - } - - for (&c1, &n, &c2) in parent2.permutation.iter().circular_tuple_windows() { - // Not duplicit? - if adjacency_lists[(n, 0)].unwrap() != c1 && adjacency_lists[(n, 1)].unwrap() != c1 { - neighbors_count[n] += 1; - adjacency_lists[(n, 2)] = Some(c1); - } else { // Duplicit - adjacency_lists[(n, 2)] = None; - } - - // Not duplicit - if adjacency_lists[(n, 0)].unwrap() != c2 && adjacency_lists[(n, 1)].unwrap() != c2 { - neighbors_count[n] += 1; - adjacency_lists[(n, 3)] = Some(c2); - } else { // Duplicit - adjacency_lists[(n, 3)] = None; - } - } - - let chosen_parent = if rng.random_bool(0.5) { - &parent1 - } else { - &parent2 - }; - - let mut offspring = OVector::from_element_generic(permutation.shape_generic().0, Const::<1>, 0); - - let mut current_node = chosen_parent.permutation[0]; - - for i in 0..len-1 { - offspring[i] = current_node; - used_nodes[current_node] = true; - - for neighbor in adjacency_lists.row(current_node) { - if let Some(neighbor) = neighbor { - neighbors_count[*neighbor] -= 1; - } - } - - let min_neighbors = adjacency_lists.row(current_node) - .iter() - .flatten() - .filter(|&&neighbor| !used_nodes[neighbor]) - .map(|&neighbor| neighbors_count[neighbor]) - .min(); - - let neighbor = if let Some(min_neighbors) = min_neighbors { - adjacency_lists.row(current_node) - .iter() - .flatten() - .copied() - .filter(|&neighbor| !used_nodes[neighbor] && neighbors_count[neighbor] == min_neighbors) - .choose(rng) - } else { - None - }; - - current_node = if let Some(neighbor) = neighbor { - neighbor - } else { - (0..len).filter(|&node| !used_nodes[node]) - .choose(rng) - .unwrap() - }; - } - - offspring[len - 1] = current_node; - - offsprings.push(NodePermutation { permutation: offspring }); - } - - Population::from_vec(offsprings) - } -} - -pub struct TSPBinaryStringWrapper<'a, DIn: Dim, DOut: Dim> -where - DOut: Dim, - DefaultAllocator: Allocator -{ - instance: &'a TSPInstance, - dim_in: DIn, - dim_out: DOut, -} - -impl<'a, DIn: Dim, DOut: Dim> TSPBinaryStringWrapper<'a, DIn, DOut> -where - DOut: Dim, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, -{ - pub fn new( - instance: &'a TSPInstance, - dim_in: DIn, - dim_out: DOut - ) -> Result { - let res = Self { - instance, - dim_in, - dim_out - }; - - if dim_out.value() * (dim_out.value() - 1) / 2 != dim_in.value() { - Err(DimensionMismatch::Mismatch) - } else { - Ok(res) - } - } - - pub fn to_permutation(&self, inp: &BinaryString) -> Result, DimensionMismatch> { - let nodes = self.dim_out.value(); - - if inp.vec().shape_generic().0.value() != self.dim_in.value() { - return Err(DimensionMismatch::Mismatch); - } - - // Count how many nodes each node comes after (precedence count) - let mut precedence_count = OVector::::zeros_generic(self.dim_out, U1); - - let mut in_index = 0usize; - for i in 0..self.dim_out.value() { - for j in i+1..nodes { - if in_index >= inp.vec.len() { - return Err(DimensionMismatch::Mismatch); - } - - if inp.vec[in_index] == 1 { - // i comes before j, so j has one more predecessor - precedence_count[j] += 1; - } else { - // j comes before i, so i has one more predecessor - precedence_count[i] += 1; - } - - in_index += 1; - } - } - - if in_index != inp.vec.len() { - return Err(DimensionMismatch::Mismatch); - } - - let mut result = OVector::from_iterator_generic( - self.dim_out, - U1, - 0..nodes - ); - - result - .as_mut_slice() - .sort_by_key(|&node| precedence_count[node]); - - Ok(NodePermutation { permutation: result }) - } -} - -#[derive(Error, Debug)] -pub enum DimensionMismatch { - #[error("The input dimension should be equal to half matrix NxN where the output is N")] - Mismatch -} - -impl<'a, DIn: Dim, DOut: Dim> FitnessFunction for TSPBinaryStringWrapper<'a, DIn, DOut> -where - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, - DefaultAllocator: Allocator, -{ - type In = BinaryString; - type Out = f64; - type Err = DimensionMismatch; - - fn fit(self: &Self, inp: &Self::In) -> Result { - Ok(self.instance.fit(&self.to_permutation(inp)?).unwrap()) - } -} - #[cfg(test)] mod tests { use std::convert::Infallible;