graph_spanning/
kruskal.rs

1use crate::DisjointSet;
2use graph_core::{Edge, Graph, NodeId};
3
4/// The result of a minimum spanning tree computation.
5#[derive(Debug, Clone)]
6pub struct SpanningTree {
7    /// The edges that form the MST, in the order they were added.
8    pub edges: Vec<Edge<f64>>,
9    /// Sum of all edge weights in the MST.
10    pub total_weight: f64,
11}
12
13/// Computes a **Minimum Spanning Tree** using Kruskal's algorithm.
14///
15/// Kruskal's sorts all edges by weight then greedily adds the cheapest edge
16/// that connects two previously disconnected components, using a
17/// [`DisjointSet`] to detect cycles in O(α(n)) per edge.
18///
19/// Works on both directed and undirected graphs. For directed graphs, edges
20/// are treated as undirected (the MST is of the underlying undirected graph).
21///
22/// # Returns
23///
24/// `Some(SpanningTree)` if the graph is connected (a spanning tree exists),
25/// or `None` if the graph is disconnected or empty.
26///
27/// # Complexity
28///
29/// O(E log E) dominated by sorting.
30///
31/// # Examples
32///
33/// ```
34/// use graph_core::{AdjacencyList, Graph};
35/// use graph_spanning::kruskal;
36///
37/// //   1       3
38/// // A --- B ----- C
39/// //  \         /
40/// //   ----2----
41/// let mut g: AdjacencyList<&str> = AdjacencyList::undirected();
42/// let a = g.add_node("A");
43/// let b = g.add_node("B");
44/// let c = g.add_node("C");
45/// g.add_edge(a, b, 1.0).unwrap();
46/// g.add_edge(b, c, 3.0).unwrap();
47/// g.add_edge(a, c, 2.0).unwrap();
48///
49/// let mst = kruskal(&g).unwrap();
50/// assert_eq!(mst.edges.len(), 2);    // V-1 edges
51/// assert_eq!(mst.total_weight, 3.0); // cheapest: A-B(1) + A-C(2)
52/// ```
53pub fn kruskal<G>(graph: &G) -> Option<SpanningTree>
54where
55    G: Graph<Weight = f64>,
56{
57    let n = graph.node_count();
58    if n == 0 {
59        return None;
60    }
61
62    // Collect and sort all edges by weight.
63    let mut edges = graph.all_edges();
64    edges.sort_by(|a, b| {
65        a.weight
66            .partial_cmp(&b.weight)
67            .unwrap_or(std::cmp::Ordering::Equal)
68    });
69
70    // Map NodeId → contiguous index for DisjointSet.
71    let node_index = node_index_map(graph);
72
73    let mut ds = DisjointSet::new(n);
74    let mut mst_edges: Vec<Edge<f64>> = Vec::with_capacity(n - 1);
75    let mut total_weight = 0.0f64;
76
77    for edge in edges {
78        let u = node_index[&edge.source];
79        let v = node_index[&edge.target];
80
81        // Skip self-loops and edges within the same component.
82        if u == v || !ds.union(u, v) {
83            continue;
84        }
85
86        total_weight += edge.weight;
87        mst_edges.push(edge);
88
89        // A spanning tree has exactly V-1 edges.
90        if mst_edges.len() == n - 1 {
91            break;
92        }
93    }
94
95    // If we didn't collect V-1 edges the graph is disconnected.
96    if mst_edges.len() < n - 1 {
97        return None;
98    }
99
100    Some(SpanningTree {
101        edges: mst_edges,
102        total_weight,
103    })
104}
105
106/// Builds a map from [`NodeId`] to a contiguous `0..n` index for use with
107/// [`DisjointSet`].
108pub(crate) fn node_index_map<G: Graph>(graph: &G) -> std::collections::HashMap<NodeId, usize> {
109    graph.nodes().enumerate().map(|(i, id)| (id, i)).collect()
110}