~ruther/ctu-fee-eoa

ref: 49d82c7ba9f4df79e1133a4971b230f63a908b4a ctu-fee-eoa/env/src/selection.rs -rw-r--r-- 2.7 KiB
49d82c7b — Rutherther feat: add tournament selection a month ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// pub struct EvaluatedChromosome<TInput, TResult> {
//     chromosome: TInput,
//     evaluation: TResult,
// }

use rand::{Rng, RngCore};

use crate::comparison::BetterThanOperator;

pub trait Selection<T> {
    fn select(&mut self, count: usize, evaluations: &Vec<T>, better_than: &dyn BetterThanOperator<T>) -> impl Iterator<Item = usize>;
}

pub struct TournamentSelection {
    rng: Box<dyn RngCore>,
    p: f64,
    k: usize
}

impl TournamentSelection {
    pub fn new(k: usize, p: f64) -> Self {
        assert!(0.0 <= p && p <= 1.0);
        assert!(k > 0);

        Self {
            rng: Box::new(rand::rng()),
            p,
            k
        }
    }

    fn tournament<T: PartialOrd>(&mut self, idxs: &mut Vec<usize>, evaluations: &Vec<T>, better_than: &dyn BetterThanOperator<T>) -> usize {
        idxs.sort_by(|&i, &j| better_than.ordering(&evaluations[i], &evaluations[j]));

        let mut p_selector = self.rng.random_range(0.0..=1.0f64);
        let p = self.p;
        let k = self.k;

        let mut selected = idxs[k - 1];
        // let's say p = 0.7
        // the best has probability 0.7 of being selected
        // if the best is not selected, the second has 0.7 probability of being selected... (that's 0.7 * 0.3 without conditions)
        // and so on. The last element has the remaining probability.
        for i in 0..k-1 {
            if p_selector <= p {
                selected = i;
                break;
            }

            p_selector -= p;
            // 'Expand' the rest to '100%' again
            p_selector /= 1.0 - p;
        }

        selected
    }
}

impl<T: Ord> Selection<T> for TournamentSelection {
    fn select(&mut self, count: usize, evaluations: &Vec<T>, better_than: &dyn BetterThanOperator<T>) -> impl Iterator<Item = usize> {
        // 1. Rank
        // fn rank<T: Ord>(l: &Vec<T>) -> Vec<usize> {
        //     let mut indices = (0..l.len()).collect::<Vec<_>>();
        //     let mut ranks = vec![0; l.len()];

        //     // argsort...
        //     indices.sort_by_key(|&i| &l[i]);

        //     for (rank, idx) in indices.into_iter().enumerate() {
        //         ranks[idx] = rank;
        //     }

        //     ranks
        // }

        // let ranks = rank(evaluations);

        // 2. Let's choose k random 'count' times
        // let mut already_selected = vec![false; evaluations.len()];

        let mut k_selected_idxs = vec![0; self.k];
        (0..count).map(move |_| {
            for selected_idx in k_selected_idxs.iter_mut() {
                *selected_idx = self.rng.random_range(0..evaluations.len());
            }

            self.tournament(&mut k_selected_idxs, evaluations, better_than)
        })
    }
}