graph_spanning/
prim.rs

1use crate::kruskal::SpanningTree;
2use graph_core::{Edge, Graph, NodeId};
3use ordered_float::OrderedFloat;
4use std::cmp::Reverse;
5use std::collections::{BinaryHeap, HashMap, HashSet};
6
7/// Computes a **Minimum Spanning Tree** using Prim's algorithm.
8///
9/// Prim's grows the MST from an arbitrary starting node, always adding the
10/// minimum-weight edge that connects the current tree to a new node. It uses a
11/// binary min-heap (lazy deletion variant) similar to Dijkstra's algorithm.
12///
13/// Works on both directed and undirected graphs; for directed graphs the
14/// algorithm finds the MST of the underlying undirected graph by considering
15/// edges in both directions.
16///
17/// # Returns
18///
19/// `Some(SpanningTree)` if the graph is connected, or `None` if the graph is
20/// disconnected or has no nodes.
21///
22/// # Complexity
23///
24/// O((V + E) log V) with a binary heap.
25///
26/// # Examples
27///
28/// ```
29/// use graph_core::{AdjacencyList, Graph};
30/// use graph_spanning::prim;
31///
32/// let mut g: AdjacencyList<&str> = AdjacencyList::undirected();
33/// let a = g.add_node("A");
34/// let b = g.add_node("B");
35/// let c = g.add_node("C");
36/// g.add_edge(a, b, 1.0).unwrap();
37/// g.add_edge(b, c, 3.0).unwrap();
38/// g.add_edge(a, c, 2.0).unwrap();
39///
40/// let mst = prim(&g).unwrap();
41/// assert_eq!(mst.edges.len(), 2);
42/// assert_eq!(mst.total_weight, 3.0); // A-B(1) + A-C(2)
43/// ```
44pub fn prim<G>(graph: &G) -> Option<SpanningTree>
45where
46    G: Graph<Weight = f64>,
47{
48    let n = graph.node_count();
49    if n == 0 {
50        return None;
51    }
52
53    // Start from the first node in the graph.
54    let start = graph.nodes().next()?;
55
56    // key[node] = minimum edge weight connecting node to the current tree.
57    let mut key: HashMap<NodeId, f64> = graph.nodes().map(|n| (n, f64::INFINITY)).collect();
58    // parent_edge[node] = the edge (source, target, weight) that connects it.
59    let mut parent_edge: HashMap<NodeId, (NodeId, f64)> = HashMap::new();
60    let mut in_mst: HashSet<NodeId> = HashSet::new();
61
62    key.insert(start, 0.0);
63
64    // Min-heap: Reverse((key_value, node)).
65    let mut heap: BinaryHeap<Reverse<(OrderedFloat<f64>, NodeId)>> = BinaryHeap::new();
66    heap.push(Reverse((OrderedFloat(0.0), start)));
67
68    let mut mst_edges: Vec<Edge<f64>> = Vec::with_capacity(n - 1);
69    let mut total_weight = 0.0f64;
70
71    while let Some(Reverse((OrderedFloat(w), node))) = heap.pop() {
72        // Already committed this node to the MST.
73        if in_mst.contains(&node) {
74            continue;
75        }
76        // Lazy deletion: skip stale heap entries.
77        if w > *key.get(&node).unwrap_or(&f64::INFINITY) {
78            continue;
79        }
80
81        in_mst.insert(node);
82
83        // Record the edge that pulled this node into the MST (skip the root).
84        if let Some((parent, edge_w)) = parent_edge.get(&node) {
85            mst_edges.push(Edge::new(*parent, node, *edge_w));
86            total_weight += edge_w;
87        }
88
89        // Relax outgoing edges.
90        for (neighbour, &weight) in graph.neighbors(node) {
91            if !in_mst.contains(&neighbour)
92                && weight < *key.get(&neighbour).unwrap_or(&f64::INFINITY)
93            {
94                key.insert(neighbour, weight);
95                parent_edge.insert(neighbour, (node, weight));
96                heap.push(Reverse((OrderedFloat(weight), neighbour)));
97            }
98        }
99    }
100
101    // For undirected graphs all edges are visited; for directed we may miss
102    // some nodes. Check connectivity.
103    if in_mst.len() < n {
104        return None;
105    }
106
107    Some(SpanningTree {
108        edges: mst_edges,
109        total_weight,
110    })
111}