use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OVector, U1};
use eoa_lib::{binary_string::BinaryString, fitness::FitnessFunction};
use thiserror::Error;
use crate::tsp::{NodePermutation, TSPInstance};
impl<'a, DIn: Dim, DOut: Dim> FitnessFunction for TSPBinaryStringWrapper<'a, DIn, DOut>
where
DefaultAllocator: Allocator<DIn>,
DefaultAllocator: Allocator<DOut>,
DefaultAllocator: Allocator<DOut, DOut>,
{
type In = BinaryString<DIn>;
type Out = f64;
type Err = DimensionMismatch;
fn fit(self: &Self, inp: &Self::In) -> Result<Self::Out, Self::Err> {
Ok(self.instance.fit(&self.to_permutation(inp)?).unwrap())
}
}
pub struct TSPBinaryStringWrapper<'a, DIn: Dim, DOut: Dim>
where
DOut: Dim,
DefaultAllocator: Allocator<DOut, DOut>
{
instance: &'a TSPInstance<DOut>,
dim_in: DIn,
dim_out: DOut,
}
impl<'a, DIn: Dim, DOut: Dim> TSPBinaryStringWrapper<'a, DIn, DOut>
where
DOut: Dim,
DefaultAllocator: Allocator<DOut, DOut>,
DefaultAllocator: Allocator<DIn>,
DefaultAllocator: Allocator<DOut>,
{
pub fn new(
instance: &'a TSPInstance<DOut>,
dim_in: DIn,
dim_out: DOut
) -> Result<Self, DimensionMismatch> {
let res = Self {
instance,
dim_in,
dim_out
};
if dim_out.value() * (dim_out.value() - 1) / 2 != dim_in.value() {
Err(DimensionMismatch::Mismatch)
} else {
Ok(res)
}
}
pub fn to_permutation(&self, inp: &BinaryString<DIn>) -> Result<NodePermutation<DOut>, DimensionMismatch> {
let nodes = self.dim_out.value();
if inp.vec().shape_generic().0.value() != self.dim_in.value() {
return Err(DimensionMismatch::Mismatch);
}
// Count how many nodes each node comes after (precedence count)
let mut precedence_count = OVector::<usize, DOut>::zeros_generic(self.dim_out, U1);
let mut in_index = 0usize;
for i in 0..self.dim_out.value() {
for j in i+1..nodes {
if in_index >= inp.vec.len() {
return Err(DimensionMismatch::Mismatch);
}
if inp.vec[in_index] == 1 {
// i comes before j, so j has one more predecessor
precedence_count[j] += 1;
} else {
// j comes before i, so i has one more predecessor
precedence_count[i] += 1;
}
in_index += 1;
}
}
if in_index != inp.vec.len() {
return Err(DimensionMismatch::Mismatch);
}
let mut result = OVector::from_iterator_generic(
self.dim_out,
U1,
0..nodes
);
result
.as_mut_slice()
.sort_by_key(|&node| precedence_count[node]);
Ok(NodePermutation { permutation: result })
}
}
#[derive(Error, Debug)]
pub enum DimensionMismatch {
#[error("The input dimension should be equal to half matrix NxN where the output is N")]
Mismatch
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::{Const, SVector, U15, U6};
use eoa_lib::binary_string::BinaryString;
#[test]
fn test_binary_string_representation() {
// x 0 1 2 3 4 5
// 0 0 0 0 0 0 0
// 1 1 0 0 0 0 0
// 2 1 1 0 0 0 0
// 3 1 1 1 0 0 0
// 4 1 1 1 1 0 0
// 5 1 1 1 1 1 0
// x 0 1 2 3 4 5
// 0 0 0 0 0 0
// 1 0 0 0 0
// 2 0 0 0
// 3 0 0
// 4 0
// 5
// 6 nodes
// length of binary string: 5 + 4 + 3 + 2 + 1 = 15
let tsp = TSPInstance::new_const(
vec![
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
]
);
let converter = TSPBinaryStringWrapper::new(
&tsp,
U15,
U6
).unwrap();
let binary_string_ordering = BinaryString::<U15>::new(vec![1; 15]);
let mut expected_permutation = vec![0, 1, 2, 3, 4, 5];
let mut permutation = converter.to_permutation(&binary_string_ordering)
.unwrap();
assert_eq!(
expected_permutation,
permutation.permutation.as_mut_slice().to_vec()
);
let binary_string_ordering = BinaryString::<U15>::new(vec![0; 15]);
expected_permutation.reverse();
let mut permutation = converter.to_permutation(&binary_string_ordering)
.unwrap();
assert_eq!(
expected_permutation,
permutation.permutation.as_mut_slice().to_vec()
)
}
#[test]
fn test_nontrivial_binary_string_representation() {
// x 0 1 2 3 4 5
// 0 0 1 0 0 0 0
// 1 0 0 0 0 0 0
// 2 1 1 0 0 0 1
// 3 1 1 1 0 0 0
// 4 1 1 1 1 0 0
// 5 1 1 0 1 1 0
// x 0 1 2 3 4 5
// 0 0 0 0 0 0
// 1 0 0 0 0
// 2 1 1 1
// 3 0 0
// 4 1
// 5
// 6 nodes
// length of binary string: 5 + 4 + 3 + 2 + 1 = 15
let tsp = TSPInstance::new_const(
vec![
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
(0.0, 0.0),
]
);
let converter = TSPBinaryStringWrapper::new(
&tsp,
U15,
U6
).unwrap();
let mut binary_string_ordering = BinaryString::<U15>::new(vec![0; 15]);
binary_string_ordering.vec[9] = 1;
binary_string_ordering.vec[10] = 1;
binary_string_ordering.vec[11] = 1;
binary_string_ordering.vec[14] = 1;
let expected_permutation = vec![2, 4, 5, 3, 1, 0];
let mut permutation = converter.to_permutation(&binary_string_ordering)
.unwrap();
assert_eq!(
expected_permutation,
permutation.permutation.as_mut_slice().to_vec()
);
}
}