~ruther/ctu-fee-eoa

8534d90f455d4f19cf476b3ec7f0e5805b0f7338 — Rutherther a month ago 944bef5
feat(tsp): add minimal spanning tree and filter_edges
2 files changed, 150 insertions(+), 2 deletions(-)

M codes/tsp_hw01/src/graph.rs
A codes/tsp_hw01/src/union_find.rs
M codes/tsp_hw01/src/graph.rs => codes/tsp_hw01/src/graph.rs +105 -2
@@ 1,4 1,6 @@
use std::collections::VecDeque;
use std::{cmp::Ordering, collections::VecDeque};

use crate::union_find::UnionFind;

pub type Distance = usize;



@@ 269,10 271,25 @@ impl<T, TEdge: Edge> GenericGraph<T, TEdge>
        self.node_neighbors[edge.to_node()].push(edge.from_node());
        self.edges.push(edge);


        idx
    }

    pub fn filter_edges(self, filter: impl Fn(usize, &TEdge) -> bool) -> Self {
        let keep_edges = self.edges
            .into_iter()
            .enumerate()
            .filter(|(i, e)| filter(*i, e))
            .map(|(_, e)| e);

        let mut new_graph = Self::new(self.nodes, self.adjacency_matrix.is_some());

        for edge in keep_edges {
            new_graph.add_generic_edge(edge);
        }

        new_graph
    }

    // NOTE: it's expected the edges will not reconnect, only the type will change.
    // from_node() and to_node() should stay the same!
    pub fn map_edges<TNewEdge: Edge>(self, map: impl Fn(TEdge) -> TNewEdge) -> GenericGraph<T, TNewEdge> {


@@ 526,3 543,89 @@ where

    distances
}


/// A generic representation of a minimum spanning
/// tree for cases where the graph might not be
/// fully connected, and thus it is possible the minimum
/// spanning tree will not be a single component.
#[derive(Clone, Debug, PartialEq)]
pub struct MinimumSpanningTree {
    // What component does node i belong to?
    pub components: UnionFind,
    // Edge indices in the tree
    pub edges: Vec<usize>,
}

impl MinimumSpanningTree {
    pub fn nodes_count(&self) -> usize {
        self.components.len()
    }

    pub fn components_count(&self) -> usize {
        self.nodes_count() - self.edges.len()
    }
}

/// Use the kruskal algorithm for finding the minimum spanning tree.
/// Take only edges filtered by selector as candidates for the spanning tree.
/// Initial minimum spanning tree can be passed, that should usually be a result
/// of prior run of this function with a different selector.
pub fn minimal_spanning_tree_kruskal<'a, TNode, TWeight, TEdge, TGraph>(
    graph: &TGraph,
    initial: Option<MinimumSpanningTree>,
    selector: impl Fn(&TEdge) -> bool
) -> MinimumSpanningTree
where
    TWeight: PartialOrd,
    TEdge: Edge + WeightedEdge<Cost = TWeight>,
    TGraph: Graph<Node = TNode, Edge = TEdge>
{
    // let separate_new_edges = initial.is_some();
    let nodes = graph.nodes().count();
    let mut initial_edge_selected = vec![false; graph.edges().count()];
    let mut current = if let Some(initial) = initial {
        for edge in &initial.edges {
            initial_edge_selected[*edge] = true;
        }

        initial
    } else {
        MinimumSpanningTree {
            components: UnionFind::make_set(nodes),
            edges: Vec::with_capacity(nodes - 1),
        }
    };

    let mut remaining_edges = graph.edges()
        .enumerate()
        .filter(|(i, e)| !initial_edge_selected[*i] && selector(e))
        .collect::<Vec<_>>();

    remaining_edges.sort_by(
        |a, b| a.1.cost().partial_cmp(&b.1.cost()).unwrap_or(Ordering::Less)
    );

    for (i, edge) in remaining_edges {
        // 1. does the edge connect two components?
        let (root_a, root_b) = {
            if current.components_count() == 1 {
                break;
            }

            let root_a = current.components.find(edge.from_node());
            let root_b = current.components.find(edge.to_node());

            if root_a == root_b {
                continue;
            }

            (root_a, root_b)
        };
        // 2. if so, use it and connect them
        current.components.union(root_a, root_b);
        current.edges.push(i);
    }

    current
}

A codes/tsp_hw01/src/union_find.rs => codes/tsp_hw01/src/union_find.rs +45 -0
@@ 0,0 1,45 @@
#[derive(Clone, Debug, PartialEq)]
pub struct UnionFindElement {
    idx: usize,
    parent: usize,
}

/// A structure for representing components with nodes.
#[derive(Clone, Debug, PartialEq)]
pub struct UnionFind {
    elements: Vec<UnionFindElement>
}

impl UnionFind {
    pub fn make_set(len: usize) -> UnionFind {
        UnionFind {
            elements: (0..len).map(|idx|
                                   UnionFindElement {
                                       idx,
                                       parent: idx
                                   }).collect()
        }
    }

    pub fn len(&self) -> usize {
        self.elements.len()
    }

    /// Find component of a node.
    pub fn find(&self, x: usize) -> usize {
        let mut current = &self.elements[x];
        while current.parent != current.idx {
            current = &self.elements[current.parent];
        }

        current.idx
    }

    /// Connect two components.
    pub fn union(&mut self, x: usize, y: usize) {
        let x = self.find(x);
        let y = self.find(y);

        self.elements[y].parent = x;
    }
}