graph_shortest_path/astar.rs
1use graph_core::{Graph, GraphError, NodeId};
2use ordered_float::OrderedFloat;
3use std::cmp::Reverse;
4use std::collections::{BinaryHeap, HashMap};
5
6type AStarHeapEntry = Reverse<(OrderedFloat<f64>, OrderedFloat<f64>, NodeId)>;
7type AStarHeap = BinaryHeap<AStarHeapEntry>;
8
9/// Finds the shortest path from `start` to `goal` using the A\* search
10/// algorithm with a caller-supplied heuristic function.
11///
12/// A\* extends Dijkstra by adding a heuristic `h(node)` that estimates the
13/// remaining cost to `goal`. The priority of a node is `g(node) + h(node)`
14/// where `g` is the known shortest distance from `start`. When the heuristic
15/// is **admissible** (never overestimates the true cost), A\* is guaranteed to
16/// find the optimal path.
17///
18/// # Parameters
19///
20/// - `graph` — any graph with `f64` weights.
21/// - `start` — source node.
22/// - `goal` — target node.
23/// - `h` — heuristic closure: `h(node) -> f64`. Must satisfy `h(node) ≤
24/// true_distance(node, goal)`. Pass `|_| 0.0` to degrade to Dijkstra.
25///
26/// # Returns
27///
28/// `Some((path, total_cost))` where `path[0] == start` and
29/// `path.last() == &goal`, or `None` if no path exists.
30///
31/// # Errors
32///
33/// Returns [`GraphError::NodeNotFound`] if `start` or `goal` is not in the
34/// graph.
35///
36/// # Complexity
37///
38/// O(E log V) with a good heuristic; degrades to O((V+E) log V) with `h=0`.
39///
40/// # Examples
41///
42/// ```
43/// use graph_core::{AdjacencyList, Graph, NodeId};
44/// use graph_shortest_path::astar;
45///
46/// // Simple grid: four nodes in a line.
47/// // 0 --1-- 1 --1-- 2 --1-- 3
48/// let mut g: AdjacencyList<u32> = AdjacencyList::directed();
49/// let n: Vec<_> = (0u32..4).map(|i| g.add_node(i)).collect();
50/// for i in 0..3 {
51/// g.add_edge(n[i], n[i + 1], 1.0).unwrap();
52/// }
53///
54/// // Heuristic: remaining index distance (admissible for unit-weight grid).
55/// let goal_idx = 3usize;
56/// let (path, cost) = astar(&g, n[0], n[3], |id| {
57/// (goal_idx as f64) - (id.index() as f64)
58/// })
59/// .unwrap()
60/// .unwrap();
61///
62/// assert_eq!(cost, 3.0);
63/// assert_eq!(path, n);
64/// ```
65pub fn astar<G, H>(
66 graph: &G,
67 start: NodeId,
68 goal: NodeId,
69 h: H,
70) -> Result<Option<(Vec<NodeId>, f64)>, GraphError>
71where
72 G: Graph<Weight = f64>,
73 H: Fn(NodeId) -> f64,
74{
75 if !graph.contains_node(start) {
76 return Err(GraphError::NodeNotFound(start));
77 }
78 if !graph.contains_node(goal) {
79 return Err(GraphError::NodeNotFound(goal));
80 }
81
82 // g_score[node] = known shortest distance from start.
83 let mut g_score: HashMap<NodeId, f64> = HashMap::new();
84 g_score.insert(start, 0.0);
85
86 // Parent map for path reconstruction.
87 let mut parents: HashMap<NodeId, NodeId> = HashMap::new();
88
89 // Min-heap ordered by f = g + h.
90 // Entries: Reverse((f_score, g_score, node)) — g_score is the tiebreaker.
91 let mut open: AStarHeap = AStarHeap::new();
92
93 let start_h = h(start);
94 open.push(Reverse((OrderedFloat(start_h), OrderedFloat(0.0), start)));
95
96 while let Some(Reverse((_, OrderedFloat(g), node))) = open.pop() {
97 // Goal reached.
98 if node == goal {
99 let path = rebuild_path(&parents, start, goal);
100 return Ok(Some((path, g)));
101 }
102
103 // Lazy deletion: skip if we have already settled this node with a
104 // lower g-score.
105 if let Some(&best_g) = g_score.get(&node) {
106 if g > best_g {
107 continue;
108 }
109 }
110
111 for (neighbour, &weight) in graph.neighbors(node) {
112 let tentative_g = g + weight;
113 let current_best = g_score.get(&neighbour).copied().unwrap_or(f64::INFINITY);
114
115 if tentative_g < current_best {
116 g_score.insert(neighbour, tentative_g);
117 parents.insert(neighbour, node);
118
119 let f = tentative_g + h(neighbour);
120 open.push(Reverse((
121 OrderedFloat(f),
122 OrderedFloat(tentative_g),
123 neighbour,
124 )));
125 }
126 }
127 }
128
129 // Open list exhausted without reaching goal.
130 Ok(None)
131}
132
133fn rebuild_path(parents: &HashMap<NodeId, NodeId>, start: NodeId, end: NodeId) -> Vec<NodeId> {
134 if start == end {
135 return vec![start];
136 }
137 let mut path = vec![end];
138 let mut current = end;
139 while let Some(&prev) = parents.get(¤t) {
140 path.push(prev);
141 if prev == start {
142 break;
143 }
144 current = prev;
145 }
146 path.reverse();
147 path
148}