use std::convert::Infallible;
use eoa_lib::fitness::FitnessFunction;
use itertools::Itertools;
use nalgebra::{allocator::Allocator, distance, Const, DefaultAllocator, Dim, Dyn, OMatrix, OVector, Point};
use plotters::prelude::*;
use crate::graph::{minimal_spanning_tree_kruskal, Edge, GenericGraph, Graph, WeightedEdge};
#[derive(PartialEq, Clone, Debug)]
pub struct TSPCity {
point: Point<f64, 2>
}
#[derive(Debug)]
pub struct TSPEdge {
from: usize,
to: usize,
distance: f64
}
impl Edge for TSPEdge {
fn from_node(&self) -> usize {
self.from
}
fn to_node(&self) -> usize {
self.to
}
}
impl WeightedEdge for TSPEdge {
type Cost = f64;
fn cost(&self) -> Self::Cost {
self.distance
}
}
#[derive(PartialEq, Clone, Debug)]
pub struct NodePermutation<D: Dim>
where
DefaultAllocator: Allocator<D>
{
pub permutation: OVector<usize, D>
}
/// An instance of TSP, a fully connected graph
/// with cities that connect to each other.
/// The D parameter represents the number of cities.
#[derive(PartialEq, Clone, Debug)]
pub struct TSPInstance<D>
where
D: Dim,
DefaultAllocator: Allocator<D, D>
{
pub cities: Vec<TSPCity>,
pub distances: OMatrix<f64, D, D>
}
impl TSPInstance<Dyn>
where
{
pub fn new_dyn(cities: Vec<(f64, f64)>) -> Self {
let dim = Dyn(cities.len());
let cities = OMatrix::<f64, Dyn, Const<2>>::from_fn_generic(dim, Const::<2>, |i, j| if j == 0 { cities[i].0 } else { cities[i].1 });
TSPInstance::new(cities)
}
}
impl<const D: usize> TSPInstance<Const<D>>
where
{
pub fn new_const(cities: Vec<(f64, f64)>) -> Self {
let cities = OMatrix::<f64, Const<D>, Const<2>>::from_fn(|i, j| if j == 0 { cities[i].0 } else { cities[i].1 });
TSPInstance::new(cities)
}
}
impl<D> TSPInstance<D>
where
D: Dim,
DefaultAllocator: Allocator<D, D>,
DefaultAllocator: Allocator<D>,
DefaultAllocator: Allocator<D, Const<2>>,
{
pub fn new(cities: OMatrix<f64, D, Const<2>>) -> Self {
let dim = cities.shape_generic().0;
let cities = cities.row_iter()
.map(|position|
TSPCity { point: Point::<f64, 2>::new(position[0], position[1]) }
)
.collect::<Vec<_>>();
let distances = OMatrix::from_fn_generic(
dim,
dim,
|i, j| distance(&cities[i].point, &cities[j].point)
);
Self {
cities,
distances
}
}
pub fn to_graph(self) -> GenericGraph<TSPCity, TSPEdge> {
let cities = self.cities.len();
let mut graph = GenericGraph::new(self.cities, false);
for i in 0..cities {
for j in i+1..cities {
graph.add_generic_edge(TSPEdge {
from: i,
to: j,
distance: self.distances[(i, j)]
});
}
}
graph
}
}
impl<D> TSPInstance<D>
where
D: Dim,
DefaultAllocator: Allocator<D, D>,
DefaultAllocator: Allocator<D>,
{
pub fn dimension(&self) -> D {
self.distances.shape_generic().0
}
pub fn verify_solution(solution: &NodePermutation<D>) -> bool {
let mut seen_vertices = OVector::from_element_generic(
solution.permutation.shape_generic().0,
solution.permutation.shape_generic().1,
false
);
for &vertex in solution.permutation.iter() {
// This vertex index is out of bounds
if vertex >= solution.permutation.len() {
return false;
}
// A node is repeating
if seen_vertices[vertex] {
return false;
}
seen_vertices[vertex] = true;
}
true
}
pub fn solution_cost(&self, solution: &NodePermutation<D>) -> f64 {
solution.permutation
.iter()
.circular_tuple_windows()
.map(|(&node1, &node2): (&usize, &usize)| self.distances(node1, node2))
.sum()
}
pub fn distances(&self, city_a: usize, city_b: usize) -> f64 {
self.distances[(city_a, city_b)]
}
fn plot_internal(&self, solution: Option<&NodePermutation<D>>, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
let root = BitMapBackend::new(filename, (800, 600)).into_drawing_area();
root.fill(&WHITE)?;
let x_coords: Vec<f64> = self.cities.iter().map(|city| city.point.x).collect();
let y_coords: Vec<f64> = self.cities.iter().map(|city| city.point.y).collect();
let x_min = x_coords.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let x_max = x_coords.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let y_min = y_coords.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let y_max = y_coords.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let x_padding = (x_max - x_min) * 0.1;
let y_padding = (y_max - y_min) * 0.1;
let x_range = (x_min - x_padding)..(x_max + x_padding);
let y_range = (y_min - y_padding)..(y_max + y_padding);
let title = if let Some(sol) = solution {
format!("TSP Solution (Cost: {:.2})", self.solution_cost(sol))
} else {
"TSP Instance".to_string()
};
let mut chart = ChartBuilder::on(&root)
.caption(&title, ("sans-serif", 40))
.margin(10)
.x_label_area_size(40)
.y_label_area_size(40)
.build_cartesian_2d(x_range, y_range)?;
chart.configure_mesh().draw()?;
if let Some(sol) = solution {
chart.draw_series(
sol.permutation.iter().circular_tuple_windows().map(|(&city1_idx, &city2_idx)| {
let city1 = &self.cities[city1_idx];
let city2 = &self.cities[city2_idx];
PathElement::new(vec![(city1.point.x, city1.point.y), (city2.point.x, city2.point.y)], BLUE)
})
)?;
}
chart.draw_series(
self.cities.iter().map(|city| {
Circle::new((city.point.x, city.point.y), 5, RED.filled())
})
)?;
root.present()?;
Ok(())
}
pub fn plot(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
self.plot_internal(None, filename)
}
pub fn draw_solution(&self, solution: &NodePermutation<D>, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
self.plot_internal(Some(solution), filename)
}
}
impl<D> FitnessFunction for TSPInstance<D>
where
D: Dim,
DefaultAllocator: Allocator<D, D>,
DefaultAllocator: Allocator<D>,
{
type In = NodePermutation<D>;
type Out = f64;
type Err = Infallible;
fn fit(self: &Self, inp: &Self::In) -> Result<Self::Out, Self::Err> {
assert_eq!(inp.permutation.len(), self.cities.len());
assert!(TSPInstance::verify_solution(inp));
Ok(self.solution_cost(inp))
}
}
#[cfg(test)]
mod tests {
use nalgebra::{Const, SVector};
use rand::seq::SliceRandom;
use super::{NodePermutation, TSPInstance};
#[test]
fn test_verify_solution() {
let mut rng = rand::rng();
let rng = &mut rng;
let mut chromosome = NodePermutation::<Const<6>> {
permutation: SVector::from_vec(vec![0, 1, 2, 3, 4, 5])
};
for _ in 0..100 {
chromosome.permutation.as_mut_slice().shuffle(rng);
assert!(TSPInstance::verify_solution(&chromosome));
}
// Out of bounds
chromosome.permutation[0] = 6;
assert!(!TSPInstance::verify_solution(&chromosome));
chromosome.permutation[0] = 7;
assert!(!TSPInstance::verify_solution(&chromosome));
chromosome.permutation[0] = 8;
assert!(!TSPInstance::verify_solution(&chromosome));
// Repeating
chromosome.permutation[0] = 5;
chromosome.permutation[1] = 5;
assert!(!TSPInstance::verify_solution(&chromosome));
let chromosome = NodePermutation::<Const<6>> {
permutation: SVector::from_vec(vec![0, 1, 2, 3, 1, 5])
};
assert!(!TSPInstance::verify_solution(&chromosome));
}
}