@@ 1,4 1,6 @@
-use std::collections::VecDeque;
+use std::{cmp::Ordering, collections::VecDeque};
+
+use crate::union_find::UnionFind;
pub type Distance = usize;
@@ 269,10 271,25 @@ impl<T, TEdge: Edge> GenericGraph<T, TEdge>
self.node_neighbors[edge.to_node()].push(edge.from_node());
self.edges.push(edge);
-
idx
}
+ pub fn filter_edges(self, filter: impl Fn(usize, &TEdge) -> bool) -> Self {
+ let keep_edges = self.edges
+ .into_iter()
+ .enumerate()
+ .filter(|(i, e)| filter(*i, e))
+ .map(|(_, e)| e);
+
+ let mut new_graph = Self::new(self.nodes, self.adjacency_matrix.is_some());
+
+ for edge in keep_edges {
+ new_graph.add_generic_edge(edge);
+ }
+
+ new_graph
+ }
+
// NOTE: it's expected the edges will not reconnect, only the type will change.
// from_node() and to_node() should stay the same!
pub fn map_edges<TNewEdge: Edge>(self, map: impl Fn(TEdge) -> TNewEdge) -> GenericGraph<T, TNewEdge> {
@@ 526,3 543,89 @@ where
distances
}
+
+
+/// A generic representation of a minimum spanning
+/// tree for cases where the graph might not be
+/// fully connected, and thus it is possible the minimum
+/// spanning tree will not be a single component.
+#[derive(Clone, Debug, PartialEq)]
+pub struct MinimumSpanningTree {
+ // What component does node i belong to?
+ pub components: UnionFind,
+ // Edge indices in the tree
+ pub edges: Vec<usize>,
+}
+
+impl MinimumSpanningTree {
+ pub fn nodes_count(&self) -> usize {
+ self.components.len()
+ }
+
+ pub fn components_count(&self) -> usize {
+ self.nodes_count() - self.edges.len()
+ }
+}
+
+/// Use the kruskal algorithm for finding the minimum spanning tree.
+/// Take only edges filtered by selector as candidates for the spanning tree.
+/// Initial minimum spanning tree can be passed, that should usually be a result
+/// of prior run of this function with a different selector.
+pub fn minimal_spanning_tree_kruskal<'a, TNode, TWeight, TEdge, TGraph>(
+ graph: &TGraph,
+ initial: Option<MinimumSpanningTree>,
+ selector: impl Fn(&TEdge) -> bool
+) -> MinimumSpanningTree
+where
+ TWeight: PartialOrd,
+ TEdge: Edge + WeightedEdge<Cost = TWeight>,
+ TGraph: Graph<Node = TNode, Edge = TEdge>
+{
+ // let separate_new_edges = initial.is_some();
+ let nodes = graph.nodes().count();
+ let mut initial_edge_selected = vec![false; graph.edges().count()];
+ let mut current = if let Some(initial) = initial {
+ for edge in &initial.edges {
+ initial_edge_selected[*edge] = true;
+ }
+
+ initial
+ } else {
+ MinimumSpanningTree {
+ components: UnionFind::make_set(nodes),
+ edges: Vec::with_capacity(nodes - 1),
+ }
+ };
+
+ let mut remaining_edges = graph.edges()
+ .enumerate()
+ .filter(|(i, e)| !initial_edge_selected[*i] && selector(e))
+ .collect::<Vec<_>>();
+
+ remaining_edges.sort_by(
+ |a, b| a.1.cost().partial_cmp(&b.1.cost()).unwrap_or(Ordering::Less)
+ );
+
+ for (i, edge) in remaining_edges {
+ // 1. does the edge connect two components?
+ let (root_a, root_b) = {
+ if current.components_count() == 1 {
+ break;
+ }
+
+ let root_a = current.components.find(edge.from_node());
+ let root_b = current.components.find(edge.to_node());
+
+ if root_a == root_b {
+ continue;
+ }
+
+ (root_a, root_b)
+ };
+ // 2. if so, use it and connect them
+ current.components.union(root_a, root_b);
+ current.edges.push(i);
+ }
+
+ current
+}
@@ 0,0 1,45 @@
+#[derive(Clone, Debug, PartialEq)]
+pub struct UnionFindElement {
+ idx: usize,
+ parent: usize,
+}
+
+/// A structure for representing components with nodes.
+#[derive(Clone, Debug, PartialEq)]
+pub struct UnionFind {
+ elements: Vec<UnionFindElement>
+}
+
+impl UnionFind {
+ pub fn make_set(len: usize) -> UnionFind {
+ UnionFind {
+ elements: (0..len).map(|idx|
+ UnionFindElement {
+ idx,
+ parent: idx
+ }).collect()
+ }
+ }
+
+ pub fn len(&self) -> usize {
+ self.elements.len()
+ }
+
+ /// Find component of a node.
+ pub fn find(&self, x: usize) -> usize {
+ let mut current = &self.elements[x];
+ while current.parent != current.idx {
+ current = &self.elements[current.parent];
+ }
+
+ current.idx
+ }
+
+ /// Connect two components.
+ pub fn union(&mut self, x: usize, y: usize) {
+ let x = self.find(x);
+ let y = self.find(y);
+
+ self.elements[y].parent = x;
+ }
+}