graph_spanning/prim.rs
1use crate::kruskal::SpanningTree;
2use graph_core::{Edge, Graph, NodeId};
3use ordered_float::OrderedFloat;
4use std::cmp::Reverse;
5use std::collections::{BinaryHeap, HashMap, HashSet};
6
7/// Computes a **Minimum Spanning Tree** using Prim's algorithm.
8///
9/// Prim's grows the MST from an arbitrary starting node, always adding the
10/// minimum-weight edge that connects the current tree to a new node. It uses a
11/// binary min-heap (lazy deletion variant) similar to Dijkstra's algorithm.
12///
13/// Works on both directed and undirected graphs; for directed graphs the
14/// algorithm finds the MST of the underlying undirected graph by considering
15/// edges in both directions.
16///
17/// # Returns
18///
19/// `Some(SpanningTree)` if the graph is connected, or `None` if the graph is
20/// disconnected or has no nodes.
21///
22/// # Complexity
23///
24/// O((V + E) log V) with a binary heap.
25///
26/// # Examples
27///
28/// ```
29/// use graph_core::{AdjacencyList, Graph};
30/// use graph_spanning::prim;
31///
32/// let mut g: AdjacencyList<&str> = AdjacencyList::undirected();
33/// let a = g.add_node("A");
34/// let b = g.add_node("B");
35/// let c = g.add_node("C");
36/// g.add_edge(a, b, 1.0).unwrap();
37/// g.add_edge(b, c, 3.0).unwrap();
38/// g.add_edge(a, c, 2.0).unwrap();
39///
40/// let mst = prim(&g).unwrap();
41/// assert_eq!(mst.edges.len(), 2);
42/// assert_eq!(mst.total_weight, 3.0); // A-B(1) + A-C(2)
43/// ```
44pub fn prim<G>(graph: &G) -> Option<SpanningTree>
45where
46 G: Graph<Weight = f64>,
47{
48 let n = graph.node_count();
49 if n == 0 {
50 return None;
51 }
52
53 // Start from the first node in the graph.
54 let start = graph.nodes().next()?;
55
56 // key[node] = minimum edge weight connecting node to the current tree.
57 let mut key: HashMap<NodeId, f64> = graph.nodes().map(|n| (n, f64::INFINITY)).collect();
58 // parent_edge[node] = the edge (source, target, weight) that connects it.
59 let mut parent_edge: HashMap<NodeId, (NodeId, f64)> = HashMap::new();
60 let mut in_mst: HashSet<NodeId> = HashSet::new();
61
62 key.insert(start, 0.0);
63
64 // Min-heap: Reverse((key_value, node)).
65 let mut heap: BinaryHeap<Reverse<(OrderedFloat<f64>, NodeId)>> = BinaryHeap::new();
66 heap.push(Reverse((OrderedFloat(0.0), start)));
67
68 let mut mst_edges: Vec<Edge<f64>> = Vec::with_capacity(n - 1);
69 let mut total_weight = 0.0f64;
70
71 while let Some(Reverse((OrderedFloat(w), node))) = heap.pop() {
72 // Already committed this node to the MST.
73 if in_mst.contains(&node) {
74 continue;
75 }
76 // Lazy deletion: skip stale heap entries.
77 if w > *key.get(&node).unwrap_or(&f64::INFINITY) {
78 continue;
79 }
80
81 in_mst.insert(node);
82
83 // Record the edge that pulled this node into the MST (skip the root).
84 if let Some((parent, edge_w)) = parent_edge.get(&node) {
85 mst_edges.push(Edge::new(*parent, node, *edge_w));
86 total_weight += edge_w;
87 }
88
89 // Relax outgoing edges.
90 for (neighbour, &weight) in graph.neighbors(node) {
91 if !in_mst.contains(&neighbour)
92 && weight < *key.get(&neighbour).unwrap_or(&f64::INFINITY)
93 {
94 key.insert(neighbour, weight);
95 parent_edge.insert(neighbour, (node, weight));
96 heap.push(Reverse((OrderedFloat(weight), neighbour)));
97 }
98 }
99 }
100
101 // For undirected graphs all edges are visited; for directed we may miss
102 // some nodes. Check connectivity.
103 if in_mst.len() < n {
104 return None;
105 }
106
107 Some(SpanningTree {
108 edges: mst_edges,
109 total_weight,
110 })
111}