Skip to main content

word2vec/
plot.rs

1//! Training visualisation: loss curves and 2D PCA word projection plots.
2//!
3//! Uses [`plotters`] to render PNG files.
4//!
5//! # Example
6//!
7//! ```rust,no_run
8//! use word2vec::plot::{plot_loss_curve, plot_word_vectors_pca};
9//! use word2vec::trainer::EpochStats;
10//!
11//! // After training:
12//! // plot_loss_curve(&trainer.history, "loss.png").unwrap();
13//! // plot_word_vectors_pca(&embeddings, 50, "pca.png").unwrap();
14//! ```
15
16use plotters::prelude::*;
17
18use crate::embeddings::Embeddings;
19use crate::error::{Result, Word2VecError};
20use crate::trainer::EpochStats;
21
22/// Render a loss-vs-epoch line chart to a PNG file.
23///
24/// # Arguments
25///
26/// * `history` — Slice of per-epoch statistics from [`Trainer::history`].
27/// * `output_path` — Destination PNG path (created or overwritten).
28pub fn plot_loss_curve(history: &[EpochStats], output_path: &str) -> Result<()> {
29    if history.is_empty() {
30        return Err(Word2VecError::Plot("history is empty".to_string()));
31    }
32
33    let root = BitMapBackend::new(output_path, (900, 500)).into_drawing_area();
34    root.fill(&WHITE)
35        .map_err(|e| Word2VecError::Plot(e.to_string()))?;
36
37    let losses: Vec<f64> = history.iter().map(|s| s.avg_loss).collect();
38    let max_loss = losses.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
39    let min_loss = losses.iter().cloned().fold(f64::INFINITY, f64::min);
40    let padding = (max_loss - min_loss).max(0.1) * 0.1;
41
42    let mut chart = ChartBuilder::on(&root)
43        .caption("Word2Vec Training Loss", ("sans-serif", 28).into_font())
44        .margin(30)
45        .x_label_area_size(50)
46        .y_label_area_size(70)
47        .build_cartesian_2d(
48            1usize..history.len(),
49            (min_loss - padding)..(max_loss + padding),
50        )
51        .map_err(|e| Word2VecError::Plot(e.to_string()))?;
52
53    chart
54        .configure_mesh()
55        .x_desc("Epoch")
56        .y_desc("Average Loss")
57        .axis_desc_style(("sans-serif", 16))
58        .draw()
59        .map_err(|e| Word2VecError::Plot(e.to_string()))?;
60
61    // Line
62    chart
63        .draw_series(LineSeries::new(
64            history.iter().enumerate().map(|(i, s)| (i + 1, s.avg_loss)),
65            &BLUE,
66        ))
67        .map_err(|e| Word2VecError::Plot(e.to_string()))?
68        .label("avg loss")
69        .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], BLUE));
70
71    // Dots at each epoch
72    chart
73        .draw_series(
74            history
75                .iter()
76                .enumerate()
77                .map(|(i, s)| Circle::new((i + 1, s.avg_loss), 4, BLUE.filled())),
78        )
79        .map_err(|e| Word2VecError::Plot(e.to_string()))?;
80
81    chart
82        .configure_series_labels()
83        .border_style(BLACK)
84        .draw()
85        .map_err(|e| Word2VecError::Plot(e.to_string()))?;
86
87    root.present()
88        .map_err(|e| Word2VecError::Plot(e.to_string()))?;
89    Ok(())
90}
91
92/// Project the top-`n_words` most frequent words into 2D using PCA
93/// (covariance-free: uses power iteration on two principal components)
94/// and render a scatter plot with word labels.
95pub fn plot_word_vectors_pca(emb: &Embeddings, n_words: usize, output_path: &str) -> Result<()> {
96    let n = n_words.min(emb.vocab_size());
97    if n < 2 {
98        return Err(Word2VecError::Plot(
99            "need at least 2 words to plot".to_string(),
100        ));
101    }
102
103    // Collect vectors (top-n by vocab order = most frequent due to sorted vocab)
104    let words: Vec<&str> = emb
105        .vocab()
106        .idx2word
107        .iter()
108        .take(n)
109        .map(|s| s.as_str())
110        .collect();
111    let vectors: Vec<&[f32]> = words.iter().filter_map(|w| emb.get_vector(w)).collect();
112
113    let dim = vectors[0].len();
114    let count = vectors.len();
115
116    // Center the data
117    let mean: Vec<f64> = (0..dim)
118        .map(|d| vectors.iter().map(|v| v[d] as f64).sum::<f64>() / count as f64)
119        .collect();
120
121    let centered: Vec<Vec<f64>> = vectors
122        .iter()
123        .map(|v| (0..dim).map(|d| v[d] as f64 - mean[d]).collect())
124        .collect();
125
126    // Power iteration for top-2 PCs
127    let pc1 = power_iteration(&centered, dim, 30, 0);
128    let pc2 = power_iteration_deflated(&centered, dim, 30, &pc1);
129
130    // Project
131    let projected: Vec<(f64, f64)> = centered
132        .iter()
133        .map(|v| (dot_f64(v, &pc1), dot_f64(v, &pc2)))
134        .collect();
135
136    let x_min = projected.iter().map(|p| p.0).fold(f64::INFINITY, f64::min);
137    let x_max = projected
138        .iter()
139        .map(|p| p.0)
140        .fold(f64::NEG_INFINITY, f64::max);
141    let y_min = projected.iter().map(|p| p.1).fold(f64::INFINITY, f64::min);
142    let y_max = projected
143        .iter()
144        .map(|p| p.1)
145        .fold(f64::NEG_INFINITY, f64::max);
146    let xpad = (x_max - x_min).max(0.1) * 0.15;
147    let ypad = (y_max - y_min).max(0.1) * 0.15;
148
149    let root = BitMapBackend::new(output_path, (1100, 700)).into_drawing_area();
150    root.fill(&WHITE)
151        .map_err(|e| Word2VecError::Plot(e.to_string()))?;
152
153    let mut chart = ChartBuilder::on(&root)
154        .caption(
155            "Word Vectors — PCA Projection",
156            ("sans-serif", 24).into_font(),
157        )
158        .margin(40)
159        .x_label_area_size(40)
160        .y_label_area_size(50)
161        .build_cartesian_2d(
162            (x_min - xpad)..(x_max + xpad),
163            (y_min - ypad)..(y_max + ypad),
164        )
165        .map_err(|e| Word2VecError::Plot(e.to_string()))?;
166
167    chart
168        .configure_mesh()
169        .x_desc("PC1")
170        .y_desc("PC2")
171        .draw()
172        .map_err(|e| Word2VecError::Plot(e.to_string()))?;
173
174    for (i, (&word, &(x, y))) in words.iter().zip(projected.iter()).enumerate() {
175        let color = Palette99::pick(i % 99);
176
177        chart
178            .draw_series(std::iter::once(Circle::new((x, y), 5, color.filled())))
179            .map_err(|e| Word2VecError::Plot(e.to_string()))?;
180
181        chart
182            .draw_series(std::iter::once(Text::new(
183                word.to_string(),
184                (x + xpad * 0.05, y),
185                ("sans-serif", 12).into_font(),
186            )))
187            .map_err(|e| Word2VecError::Plot(e.to_string()))?;
188    }
189
190    root.present()
191        .map_err(|e| Word2VecError::Plot(e.to_string()))?;
192    Ok(())
193}
194
195fn dot_f64(a: &[f64], b: &[f64]) -> f64 {
196    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
197}
198
199fn norm_f64(v: &[f64]) -> f64 {
200    dot_f64(v, v).sqrt()
201}
202
203fn normalize_f64(v: &mut [f64]) {
204    let n = norm_f64(v);
205    if n > 1e-10 {
206        for x in v.iter_mut() {
207            *x /= n;
208        }
209    }
210}
211
212/// Power iteration to find the dominant eigenvector of X^T X.
213fn power_iteration(data: &[Vec<f64>], dim: usize, iters: usize, seed: usize) -> Vec<f64> {
214    let mut v: Vec<f64> = (0..dim).map(|d| data[seed % data.len()][d]).collect();
215    normalize_f64(&mut v);
216
217    for _ in 0..iters {
218        let xv: Vec<f64> = data.iter().map(|row| dot_f64(row, &v)).collect();
219        let mut w = vec![0.0f64; dim];
220        for (row, &proj) in data.iter().zip(xv.iter()) {
221            for (wd, &rd) in w.iter_mut().zip(row.iter()) {
222                *wd += proj * rd;
223            }
224        }
225        normalize_f64(&mut w);
226        v = w;
227    }
228    v
229}
230
231/// Find second PC by deflating the first.
232fn power_iteration_deflated(data: &[Vec<f64>], dim: usize, iters: usize, pc1: &[f64]) -> Vec<f64> {
233    let deflated: Vec<Vec<f64>> = data
234        .iter()
235        .map(|row| {
236            let proj = dot_f64(row, pc1);
237            (0..dim).map(|d| row[d] - proj * pc1[d]).collect()
238        })
239        .collect();
240
241    power_iteration(&deflated, dim, iters, 1)
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    #[test]
249    fn plot_loss_curve_empty_errors() {
250        let result = plot_loss_curve(&[], "/tmp/test_empty.png");
251        assert!(result.is_err());
252    }
253
254    #[test]
255    fn plot_loss_curve_creates_file() {
256        let history = vec![
257            EpochStats {
258                epoch: 1,
259                avg_loss: 2.5,
260                learning_rate: 0.025,
261                pairs_processed: 100,
262                elapsed_secs: 0.5,
263            },
264            EpochStats {
265                epoch: 2,
266                avg_loss: 1.8,
267                learning_rate: 0.020,
268                pairs_processed: 100,
269                elapsed_secs: 0.5,
270            },
271            EpochStats {
272                epoch: 3,
273                avg_loss: 1.2,
274                learning_rate: 0.015,
275                pairs_processed: 100,
276                elapsed_secs: 0.5,
277            },
278        ];
279        let path = "/tmp/word2vec_test_loss.png";
280        plot_loss_curve(&history, path).unwrap();
281        assert!(std::path::Path::new(path).exists());
282    }
283
284    #[test]
285    fn pca_plot_creates_file() {
286        use crate::{Config, Trainer};
287        let corpus: Vec<String> = (0..50)
288            .map(|i| {
289                format!(
290                    "w{} w{} w{} w{}",
291                    i % 8,
292                    (i + 1) % 8,
293                    (i + 2) % 8,
294                    (i + 3) % 8
295                )
296            })
297            .collect();
298        let mut trainer = Trainer::new(Config {
299            epochs: 2,
300            embedding_dim: 20,
301            ..Config::default()
302        });
303        let emb = trainer.train(&corpus).unwrap();
304        let path = "/tmp/word2vec_test_pca.png";
305        plot_word_vectors_pca(&emb, 8, path).unwrap();
306        assert!(std::path::Path::new(path).exists());
307    }
308}