use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OVector, U1};
use eoa_lib::{binary_string::BinaryString, fitness::FitnessFunction};
use thiserror::Error;
use crate::tsp::{NodePermutation, TSPInstance};
use eoa_lib::population::EvaluatedChromosome;
impl<'a, DIn: Dim, DOut: Dim> FitnessFunction for TSPBinaryStringWrapper<'a, DIn, DOut>
where
DefaultAllocator: Allocator<DIn>,
DefaultAllocator: Allocator<DOut>,
DefaultAllocator: Allocator<DOut, DOut>,
{
type In = BinaryString<DIn>;
type Out = f64;
type Err = DimensionMismatch;
fn fit(self: &Self, inp: &Self::In) -> Result<Self::Out, Self::Err> {
Ok(self.instance.fit(&self.to_permutation(inp)?).unwrap())
}
fn fit_population(&self, inp: Vec<Self::In>) -> Result<Vec<EvaluatedChromosome<Self::In, Self::Out>>, Self::Err> {
let nodes = self.dim_out.value();
// Count how many nodes each node comes after (precedence count)
let mut precedence_count = OVector::<usize, DOut>::zeros_generic(self.dim_out, U1);
let result = OVector::from_iterator_generic(
self.dim_out,
U1,
0..nodes
);
let mut permutation = NodePermutation { permutation: result };
inp
.into_iter()
.map(|chromosome| {
// Reset
precedence_count
.apply(|c| *c = 0);
// NOTE no need to reset the permutation
// as it's always sorted
self.to_permutation_buff(
&chromosome,
&mut permutation,
&mut precedence_count
)?;
Ok(EvaluatedChromosome {
evaluation: self.instance.fit(&permutation).unwrap(),
chromosome,
})
})
.collect::<Result<Vec<_>, _>>()
}
}
pub struct TSPBinaryStringWrapper<'a, DIn: Dim, DOut: Dim>
where
DOut: Dim,
DefaultAllocator: Allocator<DOut, DOut>
{
instance: &'a TSPInstance<DOut>,
dim_in: DIn,
dim_out: DOut,
}
impl<'a, DIn: Dim, DOut: Dim> TSPBinaryStringWrapper<'a, DIn, DOut>
where
DOut: Dim,
DefaultAllocator: Allocator<DOut, DOut>,
DefaultAllocator: Allocator<DIn>,
DefaultAllocator: Allocator<DOut>,
{
pub fn new(
instance: &'a TSPInstance<DOut>,
dim_in: DIn,
dim_out: DOut
) -> Result<Self, DimensionMismatch> {
let res = Self {
instance,
dim_in,
dim_out
};
if dim_out.value() * (dim_out.value() - 1) / 2 != dim_in.value() {
Err(DimensionMismatch::Mismatch)
} else {
Ok(res)
}
}
pub fn to_permutation_buff(
&self,
inp: &BinaryString<DIn>,
permutation: &mut NodePermutation<DOut>,
precedence_count: &mut OVector<usize, DOut>
) -> Result<(), DimensionMismatch> {
let nodes = self.dim_out.value();
if inp.vec().shape_generic().0.value() != self.dim_in.value() {
return Err(DimensionMismatch::Mismatch);
}
let mut in_index = 0usize;
for i in 0..self.dim_out.value() {
for j in i+1..nodes {
if in_index >= inp.vec.len() {
return Err(DimensionMismatch::Mismatch);
}
if inp.vec[in_index] == 1 {
// i comes before j, so j has one more predecessor
precedence_count[j] += 1;
} else {
// j comes before i, so i has one more predecessor
precedence_count[i] += 1;
}
in_index += 1;
}
}
if in_index != inp.vec.len() {
return Err(DimensionMismatch::Mismatch);
}
permutation.permutation
.as_mut_slice()
.sort_unstable_by_key(|&node| precedence_count[node]);
Ok(())
}
pub fn to_permutation(&self, inp: &BinaryString<DIn>) -> Result<NodePermutation<DOut>, DimensionMismatch> {
let nodes = self.dim_out.value();
// Count how many nodes each node comes after (precedence count)
let mut precedence_count = OVector::<usize, DOut>::zeros_generic(self.dim_out, U1);
let mut result = OVector::from_iterator_generic(
self.dim_out,
U1,
0..nodes
);
let mut permutation = NodePermutation { permutation: result };
self.to_permutation_buff(inp, &mut permutation, &mut precedence_count)?;
Ok(permutation)
}
}
#[derive(Error, Debug)]
pub enum DimensionMismatch {
#[error("The input dimension should be equal to half matrix NxN where the output is N")]
Mismatch
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::{Const, SVector, U15, U6};
use eoa_lib::binary_string::BinaryString;
#[test]
fn test_binary_string_representation() {
// x 0 1 2 3 4 5
// 0 0 0 0 0 0 0
// 1 1 0 0 0 0 0
// 2 1 1 0 0 0 0
// 3 1 1 1 0 0 0
// 4 1 1 1 1 0 0
// 5 1 1 1 1 1 0
// x 0 1 2 3 4 5
// 0 0 0 0 0 0
// 1 0 0 0 0
// 2 0 0 0
// 3 0 0
// 4 0
// 5
// 6 nodes
// length of binary string: 5 + 4 + 3 + 2 + 1 = 15
let tsp = TSPInstance::new_const(
vec![
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
]
);
let converter = TSPBinaryStringWrapper::new(
&tsp,
U15,
U6
).unwrap();
let binary_string_ordering = BinaryString::<U15>::new(vec![1; 15]);
let mut expected_permutation = vec![0, 1, 2, 3, 4, 5];
let mut permutation = converter.to_permutation(&binary_string_ordering)
.unwrap();
assert_eq!(
expected_permutation,
permutation.permutation.as_mut_slice().to_vec()
);
let binary_string_ordering = BinaryString::<U15>::new(vec![0; 15]);
expected_permutation.reverse();
let mut permutation = converter.to_permutation(&binary_string_ordering)
.unwrap();
assert_eq!(
expected_permutation,
permutation.permutation.as_mut_slice().to_vec()
)
}
#[test]
fn test_nontrivial_binary_string_representation() {
// x 0 1 2 3 4 5
// 0 0 1 0 0 0 0
// 1 0 0 0 0 0 0
// 2 1 1 0 0 0 1
// 3 1 1 1 0 0 0
// 4 1 1 1 1 0 0
// 5 1 1 0 1 1 0
// x 0 1 2 3 4 5
// 0 0 0 0 0 0
// 1 0 0 0 0
// 2 1 1 1
// 3 0 0
// 4 1
// 5
// 6 nodes
// length of binary string: 5 + 4 + 3 + 2 + 1 = 15
let tsp = TSPInstance::new_const(
vec![
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
]
);
let converter = TSPBinaryStringWrapper::new(
&tsp,
U15,
U6
).unwrap();
let mut binary_string_ordering = BinaryString::<U15>::new(vec![0; 15]);
binary_string_ordering.vec[9] = 1;
binary_string_ordering.vec[10] = 1;
binary_string_ordering.vec[11] = 1;
binary_string_ordering.vec[14] = 1;
let expected_permutation = vec![2, 4, 5, 3, 1, 0];
let mut permutation = converter.to_permutation(&binary_string_ordering)
.unwrap();
assert_eq!(
expected_permutation,
permutation.permutation.as_mut_slice().to_vec()
);
}
}