graph_spanning/kruskal.rs
1use crate::DisjointSet;
2use graph_core::{Edge, Graph, NodeId};
3
4/// The result of a minimum spanning tree computation.
5#[derive(Debug, Clone)]
6pub struct SpanningTree {
7 /// The edges that form the MST, in the order they were added.
8 pub edges: Vec<Edge<f64>>,
9 /// Sum of all edge weights in the MST.
10 pub total_weight: f64,
11}
12
13/// Computes a **Minimum Spanning Tree** using Kruskal's algorithm.
14///
15/// Kruskal's sorts all edges by weight then greedily adds the cheapest edge
16/// that connects two previously disconnected components, using a
17/// [`DisjointSet`] to detect cycles in O(α(n)) per edge.
18///
19/// Works on both directed and undirected graphs. For directed graphs, edges
20/// are treated as undirected (the MST is of the underlying undirected graph).
21///
22/// # Returns
23///
24/// `Some(SpanningTree)` if the graph is connected (a spanning tree exists),
25/// or `None` if the graph is disconnected or empty.
26///
27/// # Complexity
28///
29/// O(E log E) dominated by sorting.
30///
31/// # Examples
32///
33/// ```
34/// use graph_core::{AdjacencyList, Graph};
35/// use graph_spanning::kruskal;
36///
37/// // 1 3
38/// // A --- B ----- C
39/// // \ /
40/// // ----2----
41/// let mut g: AdjacencyList<&str> = AdjacencyList::undirected();
42/// let a = g.add_node("A");
43/// let b = g.add_node("B");
44/// let c = g.add_node("C");
45/// g.add_edge(a, b, 1.0).unwrap();
46/// g.add_edge(b, c, 3.0).unwrap();
47/// g.add_edge(a, c, 2.0).unwrap();
48///
49/// let mst = kruskal(&g).unwrap();
50/// assert_eq!(mst.edges.len(), 2); // V-1 edges
51/// assert_eq!(mst.total_weight, 3.0); // cheapest: A-B(1) + A-C(2)
52/// ```
53pub fn kruskal<G>(graph: &G) -> Option<SpanningTree>
54where
55 G: Graph<Weight = f64>,
56{
57 let n = graph.node_count();
58 if n == 0 {
59 return None;
60 }
61
62 // Collect and sort all edges by weight.
63 let mut edges = graph.all_edges();
64 edges.sort_by(|a, b| {
65 a.weight
66 .partial_cmp(&b.weight)
67 .unwrap_or(std::cmp::Ordering::Equal)
68 });
69
70 // Map NodeId → contiguous index for DisjointSet.
71 let node_index = node_index_map(graph);
72
73 let mut ds = DisjointSet::new(n);
74 let mut mst_edges: Vec<Edge<f64>> = Vec::with_capacity(n - 1);
75 let mut total_weight = 0.0f64;
76
77 for edge in edges {
78 let u = node_index[&edge.source];
79 let v = node_index[&edge.target];
80
81 // Skip self-loops and edges within the same component.
82 if u == v || !ds.union(u, v) {
83 continue;
84 }
85
86 total_weight += edge.weight;
87 mst_edges.push(edge);
88
89 // A spanning tree has exactly V-1 edges.
90 if mst_edges.len() == n - 1 {
91 break;
92 }
93 }
94
95 // If we didn't collect V-1 edges the graph is disconnected.
96 if mst_edges.len() < n - 1 {
97 return None;
98 }
99
100 Some(SpanningTree {
101 edges: mst_edges,
102 total_weight,
103 })
104}
105
106/// Builds a map from [`NodeId`] to a contiguous `0..n` index for use with
107/// [`DisjointSet`].
108pub(crate) fn node_index_map<G: Graph>(graph: &G) -> std::collections::HashMap<NodeId, usize> {
109 graph.nodes().enumerate().map(|(i, id)| (id, i)).collect()
110}