From 8534d90f455d4f19cf476b3ec7f0e5805b0f7338 Mon Sep 17 00:00:00 2001 From: Rutherther Date: Sat, 1 Nov 2025 19:40:55 +0100 Subject: [PATCH] feat(tsp): add minimal spanning tree and filter_edges --- codes/tsp_hw01/src/graph.rs | 107 ++++++++++++++++++++++++++++++- codes/tsp_hw01/src/union_find.rs | 45 +++++++++++++ 2 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 codes/tsp_hw01/src/union_find.rs diff --git a/codes/tsp_hw01/src/graph.rs b/codes/tsp_hw01/src/graph.rs index 8ca45f98a822f34270b69ca58e248bbdadbff694..f32248d8ebd70d39ae327e43909a225ebe885075 100644 --- a/codes/tsp_hw01/src/graph.rs +++ b/codes/tsp_hw01/src/graph.rs @@ -1,4 +1,6 @@ -use std::collections::VecDeque; +use std::{cmp::Ordering, collections::VecDeque}; + +use crate::union_find::UnionFind; pub type Distance = usize; @@ -269,10 +271,25 @@ impl GenericGraph self.node_neighbors[edge.to_node()].push(edge.from_node()); self.edges.push(edge); - idx } + pub fn filter_edges(self, filter: impl Fn(usize, &TEdge) -> bool) -> Self { + let keep_edges = self.edges + .into_iter() + .enumerate() + .filter(|(i, e)| filter(*i, e)) + .map(|(_, e)| e); + + let mut new_graph = Self::new(self.nodes, self.adjacency_matrix.is_some()); + + for edge in keep_edges { + new_graph.add_generic_edge(edge); + } + + new_graph + } + // NOTE: it's expected the edges will not reconnect, only the type will change. // from_node() and to_node() should stay the same! pub fn map_edges(self, map: impl Fn(TEdge) -> TNewEdge) -> GenericGraph { @@ -526,3 +543,89 @@ where distances } + + +/// A generic representation of a minimum spanning +/// tree for cases where the graph might not be +/// fully connected, and thus it is possible the minimum +/// spanning tree will not be a single component. +#[derive(Clone, Debug, PartialEq)] +pub struct MinimumSpanningTree { + // What component does node i belong to? + pub components: UnionFind, + // Edge indices in the tree + pub edges: Vec, +} + +impl MinimumSpanningTree { + pub fn nodes_count(&self) -> usize { + self.components.len() + } + + pub fn components_count(&self) -> usize { + self.nodes_count() - self.edges.len() + } +} + +/// Use the kruskal algorithm for finding the minimum spanning tree. +/// Take only edges filtered by selector as candidates for the spanning tree. +/// Initial minimum spanning tree can be passed, that should usually be a result +/// of prior run of this function with a different selector. +pub fn minimal_spanning_tree_kruskal<'a, TNode, TWeight, TEdge, TGraph>( + graph: &TGraph, + initial: Option, + selector: impl Fn(&TEdge) -> bool +) -> MinimumSpanningTree +where + TWeight: PartialOrd, + TEdge: Edge + WeightedEdge, + TGraph: Graph +{ + // let separate_new_edges = initial.is_some(); + let nodes = graph.nodes().count(); + let mut initial_edge_selected = vec![false; graph.edges().count()]; + let mut current = if let Some(initial) = initial { + for edge in &initial.edges { + initial_edge_selected[*edge] = true; + } + + initial + } else { + MinimumSpanningTree { + components: UnionFind::make_set(nodes), + edges: Vec::with_capacity(nodes - 1), + } + }; + + let mut remaining_edges = graph.edges() + .enumerate() + .filter(|(i, e)| !initial_edge_selected[*i] && selector(e)) + .collect::>(); + + remaining_edges.sort_by( + |a, b| a.1.cost().partial_cmp(&b.1.cost()).unwrap_or(Ordering::Less) + ); + + for (i, edge) in remaining_edges { + // 1. does the edge connect two components? + let (root_a, root_b) = { + if current.components_count() == 1 { + break; + } + + let root_a = current.components.find(edge.from_node()); + let root_b = current.components.find(edge.to_node()); + + if root_a == root_b { + continue; + } + + (root_a, root_b) + }; + // 2. if so, use it and connect them + current.components.union(root_a, root_b); + current.edges.push(i); + } + + current +} diff --git a/codes/tsp_hw01/src/union_find.rs b/codes/tsp_hw01/src/union_find.rs new file mode 100644 index 0000000000000000000000000000000000000000..6065a467029b902d03ef442fd539bdba2dced1c0 --- /dev/null +++ b/codes/tsp_hw01/src/union_find.rs @@ -0,0 +1,45 @@ +#[derive(Clone, Debug, PartialEq)] +pub struct UnionFindElement { + idx: usize, + parent: usize, +} + +/// A structure for representing components with nodes. +#[derive(Clone, Debug, PartialEq)] +pub struct UnionFind { + elements: Vec +} + +impl UnionFind { + pub fn make_set(len: usize) -> UnionFind { + UnionFind { + elements: (0..len).map(|idx| + UnionFindElement { + idx, + parent: idx + }).collect() + } + } + + pub fn len(&self) -> usize { + self.elements.len() + } + + /// Find component of a node. + pub fn find(&self, x: usize) -> usize { + let mut current = &self.elements[x]; + while current.parent != current.idx { + current = &self.elements[current.parent]; + } + + current.idx + } + + /// Connect two components. + pub fn union(&mut self, x: usize, y: usize) { + let x = self.find(x); + let y = self.find(y); + + self.elements[y].parent = x; + } +}