use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OVector, U1};
use eoa_lib::{binary_string::BinaryString, fitness::FitnessFunction};
use thiserror::Error;
use crate::tsp::{NodePermutation, TSPInstance};
impl<'a, DIn: Dim, DOut: Dim> FitnessFunction for TSPBinaryStringWrapper<'a, DIn, DOut>
where
DefaultAllocator: Allocator<DIn>,
DefaultAllocator: Allocator<DOut>,
DefaultAllocator: Allocator<DOut, DOut>,
{
type In = BinaryString<DIn>;
type Out = f64;
type Err = DimensionMismatch;
fn fit(self: &Self, inp: &Self::In) -> Result<Self::Out, Self::Err> {
Ok(self.instance.fit(&self.to_permutation(inp)?).unwrap())
}
}
pub struct TSPBinaryStringWrapper<'a, DIn: Dim, DOut: Dim>
where
DOut: Dim,
DefaultAllocator: Allocator<DOut, DOut>
{
instance: &'a TSPInstance<DOut>,
dim_in: DIn,
dim_out: DOut,
}
impl<'a, DIn: Dim, DOut: Dim> TSPBinaryStringWrapper<'a, DIn, DOut>
where
DOut: Dim,
DefaultAllocator: Allocator<DOut, DOut>,
DefaultAllocator: Allocator<DIn>,
DefaultAllocator: Allocator<DOut>,
{
pub fn new(
instance: &'a TSPInstance<DOut>,
dim_in: DIn,
dim_out: DOut
) -> Result<Self, DimensionMismatch> {
let res = Self {
instance,
dim_in,
dim_out
};
if dim_out.value() * (dim_out.value() - 1) / 2 != dim_in.value() {
Err(DimensionMismatch::Mismatch)
} else {
Ok(res)
}
}
pub fn to_permutation(&self, inp: &BinaryString<DIn>) -> Result<NodePermutation<DOut>, DimensionMismatch> {
let nodes = self.dim_out.value();
if inp.vec().shape_generic().0.value() != self.dim_in.value() {
return Err(DimensionMismatch::Mismatch);
}
// Count how many nodes each node comes after (precedence count)
let mut precedence_count = OVector::<usize, DOut>::zeros_generic(self.dim_out, U1);
let mut in_index = 0usize;
for i in 0..self.dim_out.value() {
for j in i+1..nodes {
if in_index >= inp.vec.len() {
return Err(DimensionMismatch::Mismatch);
}
if inp.vec[in_index] == 1 {
// i comes before j, so j has one more predecessor
precedence_count[j] += 1;
} else {
// j comes before i, so i has one more predecessor
precedence_count[i] += 1;
}
in_index += 1;
}
}
if in_index != inp.vec.len() {
return Err(DimensionMismatch::Mismatch);
}
let mut result = OVector::from_iterator_generic(
self.dim_out,
U1,
0..nodes
);
result
.as_mut_slice()
.sort_by_key(|&node| precedence_count[node]);
Ok(NodePermutation { permutation: result })
}
}
#[derive(Error, Debug)]
pub enum DimensionMismatch {
#[error("The input dimension should be equal to half matrix NxN where the output is N")]
Mismatch
}