~ruther/ctu-fee-eoa

ref: 0ef7b7f5f8bdcdf492b1fc419279e7a7ea3fe666 ctu-fee-eoa/codes/tsp_hw01/src/binary_string_representation.rs -rw-r--r-- 3.1 KiB
0ef7b7f5 — Rutherther feat(tsp): add TSPInstance convertor to GenericGraph a month ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OVector, U1};
use eoa_lib::{binary_string::BinaryString, fitness::FitnessFunction};
use thiserror::Error;
use crate::tsp::{NodePermutation, TSPInstance};

impl<'a, DIn: Dim, DOut: Dim> FitnessFunction for TSPBinaryStringWrapper<'a, DIn, DOut>
where
    DefaultAllocator: Allocator<DIn>,
    DefaultAllocator: Allocator<DOut>,
    DefaultAllocator: Allocator<DOut, DOut>,
{
    type In = BinaryString<DIn>;
    type Out = f64;
    type Err = DimensionMismatch;

    fn fit(self: &Self, inp: &Self::In) -> Result<Self::Out, Self::Err> {
        Ok(self.instance.fit(&self.to_permutation(inp)?).unwrap())
    }
}

pub struct TSPBinaryStringWrapper<'a, DIn: Dim, DOut: Dim>
where
    DOut: Dim,
    DefaultAllocator: Allocator<DOut, DOut>
{
    instance: &'a TSPInstance<DOut>,
    dim_in: DIn,
    dim_out: DOut,
}

impl<'a, DIn: Dim, DOut: Dim> TSPBinaryStringWrapper<'a, DIn, DOut>
where
    DOut: Dim,
    DefaultAllocator: Allocator<DOut, DOut>,
    DefaultAllocator: Allocator<DIn>,
    DefaultAllocator: Allocator<DOut>,
{
    pub fn new(
        instance: &'a TSPInstance<DOut>,
        dim_in: DIn,
        dim_out: DOut
    ) -> Result<Self, DimensionMismatch> {
        let res = Self {
            instance,
            dim_in,
            dim_out
        };

        if dim_out.value() * (dim_out.value() - 1) / 2 != dim_in.value() {
            Err(DimensionMismatch::Mismatch)
        } else {
            Ok(res)
        }
    }

    pub fn to_permutation(&self, inp: &BinaryString<DIn>) -> Result<NodePermutation<DOut>, DimensionMismatch> {
        let nodes = self.dim_out.value();

        if inp.vec().shape_generic().0.value() != self.dim_in.value() {
            return Err(DimensionMismatch::Mismatch);
        }

        // Count how many nodes each node comes after (precedence count)
        let mut precedence_count = OVector::<usize, DOut>::zeros_generic(self.dim_out, U1);

        let mut in_index = 0usize;
        for i in 0..self.dim_out.value() {
            for j in i+1..nodes {
                if in_index >= inp.vec.len() {
                    return Err(DimensionMismatch::Mismatch);
                }

                if inp.vec[in_index] == 1 {
                    // i comes before j, so j has one more predecessor
                    precedence_count[j] += 1;
                } else {
                    // j comes before i, so i has one more predecessor
                    precedence_count[i] += 1;
                }

                in_index += 1;
            }
        }

        if in_index != inp.vec.len() {
            return Err(DimensionMismatch::Mismatch);
        }

        let mut result = OVector::from_iterator_generic(
            self.dim_out,
            U1,
            0..nodes
        );

        result
            .as_mut_slice()
            .sort_by_key(|&node| precedence_count[node]);

        Ok(NodePermutation { permutation: result })
    }
}

#[derive(Error, Debug)]
pub enum DimensionMismatch {
    #[error("The input dimension should be equal to half matrix NxN where the output is N")]
    Mismatch
}