1use plotters::prelude::*;
17
18use crate::embeddings::Embeddings;
19use crate::error::{Result, Word2VecError};
20use crate::trainer::EpochStats;
21
22pub fn plot_loss_curve(history: &[EpochStats], output_path: &str) -> Result<()> {
29 if history.is_empty() {
30 return Err(Word2VecError::Plot("history is empty".to_string()));
31 }
32
33 let root = BitMapBackend::new(output_path, (900, 500)).into_drawing_area();
34 root.fill(&WHITE)
35 .map_err(|e| Word2VecError::Plot(e.to_string()))?;
36
37 let losses: Vec<f64> = history.iter().map(|s| s.avg_loss).collect();
38 let max_loss = losses.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
39 let min_loss = losses.iter().cloned().fold(f64::INFINITY, f64::min);
40 let padding = (max_loss - min_loss).max(0.1) * 0.1;
41
42 let mut chart = ChartBuilder::on(&root)
43 .caption("Word2Vec Training Loss", ("sans-serif", 28).into_font())
44 .margin(30)
45 .x_label_area_size(50)
46 .y_label_area_size(70)
47 .build_cartesian_2d(
48 1usize..history.len(),
49 (min_loss - padding)..(max_loss + padding),
50 )
51 .map_err(|e| Word2VecError::Plot(e.to_string()))?;
52
53 chart
54 .configure_mesh()
55 .x_desc("Epoch")
56 .y_desc("Average Loss")
57 .axis_desc_style(("sans-serif", 16))
58 .draw()
59 .map_err(|e| Word2VecError::Plot(e.to_string()))?;
60
61 chart
63 .draw_series(LineSeries::new(
64 history.iter().enumerate().map(|(i, s)| (i + 1, s.avg_loss)),
65 &BLUE,
66 ))
67 .map_err(|e| Word2VecError::Plot(e.to_string()))?
68 .label("avg loss")
69 .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], BLUE));
70
71 chart
73 .draw_series(
74 history
75 .iter()
76 .enumerate()
77 .map(|(i, s)| Circle::new((i + 1, s.avg_loss), 4, BLUE.filled())),
78 )
79 .map_err(|e| Word2VecError::Plot(e.to_string()))?;
80
81 chart
82 .configure_series_labels()
83 .border_style(BLACK)
84 .draw()
85 .map_err(|e| Word2VecError::Plot(e.to_string()))?;
86
87 root.present()
88 .map_err(|e| Word2VecError::Plot(e.to_string()))?;
89 Ok(())
90}
91
92pub fn plot_word_vectors_pca(emb: &Embeddings, n_words: usize, output_path: &str) -> Result<()> {
96 let n = n_words.min(emb.vocab_size());
97 if n < 2 {
98 return Err(Word2VecError::Plot(
99 "need at least 2 words to plot".to_string(),
100 ));
101 }
102
103 let words: Vec<&str> = emb
105 .vocab()
106 .idx2word
107 .iter()
108 .take(n)
109 .map(|s| s.as_str())
110 .collect();
111 let vectors: Vec<&[f32]> = words.iter().filter_map(|w| emb.get_vector(w)).collect();
112
113 let dim = vectors[0].len();
114 let count = vectors.len();
115
116 let mean: Vec<f64> = (0..dim)
118 .map(|d| vectors.iter().map(|v| v[d] as f64).sum::<f64>() / count as f64)
119 .collect();
120
121 let centered: Vec<Vec<f64>> = vectors
122 .iter()
123 .map(|v| (0..dim).map(|d| v[d] as f64 - mean[d]).collect())
124 .collect();
125
126 let pc1 = power_iteration(¢ered, dim, 30, 0);
128 let pc2 = power_iteration_deflated(¢ered, dim, 30, &pc1);
129
130 let projected: Vec<(f64, f64)> = centered
132 .iter()
133 .map(|v| (dot_f64(v, &pc1), dot_f64(v, &pc2)))
134 .collect();
135
136 let x_min = projected.iter().map(|p| p.0).fold(f64::INFINITY, f64::min);
137 let x_max = projected
138 .iter()
139 .map(|p| p.0)
140 .fold(f64::NEG_INFINITY, f64::max);
141 let y_min = projected.iter().map(|p| p.1).fold(f64::INFINITY, f64::min);
142 let y_max = projected
143 .iter()
144 .map(|p| p.1)
145 .fold(f64::NEG_INFINITY, f64::max);
146 let xpad = (x_max - x_min).max(0.1) * 0.15;
147 let ypad = (y_max - y_min).max(0.1) * 0.15;
148
149 let root = BitMapBackend::new(output_path, (1100, 700)).into_drawing_area();
150 root.fill(&WHITE)
151 .map_err(|e| Word2VecError::Plot(e.to_string()))?;
152
153 let mut chart = ChartBuilder::on(&root)
154 .caption(
155 "Word Vectors — PCA Projection",
156 ("sans-serif", 24).into_font(),
157 )
158 .margin(40)
159 .x_label_area_size(40)
160 .y_label_area_size(50)
161 .build_cartesian_2d(
162 (x_min - xpad)..(x_max + xpad),
163 (y_min - ypad)..(y_max + ypad),
164 )
165 .map_err(|e| Word2VecError::Plot(e.to_string()))?;
166
167 chart
168 .configure_mesh()
169 .x_desc("PC1")
170 .y_desc("PC2")
171 .draw()
172 .map_err(|e| Word2VecError::Plot(e.to_string()))?;
173
174 for (i, (&word, &(x, y))) in words.iter().zip(projected.iter()).enumerate() {
175 let color = Palette99::pick(i % 99);
176
177 chart
178 .draw_series(std::iter::once(Circle::new((x, y), 5, color.filled())))
179 .map_err(|e| Word2VecError::Plot(e.to_string()))?;
180
181 chart
182 .draw_series(std::iter::once(Text::new(
183 word.to_string(),
184 (x + xpad * 0.05, y),
185 ("sans-serif", 12).into_font(),
186 )))
187 .map_err(|e| Word2VecError::Plot(e.to_string()))?;
188 }
189
190 root.present()
191 .map_err(|e| Word2VecError::Plot(e.to_string()))?;
192 Ok(())
193}
194
195fn dot_f64(a: &[f64], b: &[f64]) -> f64 {
196 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
197}
198
199fn norm_f64(v: &[f64]) -> f64 {
200 dot_f64(v, v).sqrt()
201}
202
203fn normalize_f64(v: &mut [f64]) {
204 let n = norm_f64(v);
205 if n > 1e-10 {
206 for x in v.iter_mut() {
207 *x /= n;
208 }
209 }
210}
211
212fn power_iteration(data: &[Vec<f64>], dim: usize, iters: usize, seed: usize) -> Vec<f64> {
214 let mut v: Vec<f64> = (0..dim).map(|d| data[seed % data.len()][d]).collect();
215 normalize_f64(&mut v);
216
217 for _ in 0..iters {
218 let xv: Vec<f64> = data.iter().map(|row| dot_f64(row, &v)).collect();
219 let mut w = vec![0.0f64; dim];
220 for (row, &proj) in data.iter().zip(xv.iter()) {
221 for (wd, &rd) in w.iter_mut().zip(row.iter()) {
222 *wd += proj * rd;
223 }
224 }
225 normalize_f64(&mut w);
226 v = w;
227 }
228 v
229}
230
231fn power_iteration_deflated(data: &[Vec<f64>], dim: usize, iters: usize, pc1: &[f64]) -> Vec<f64> {
233 let deflated: Vec<Vec<f64>> = data
234 .iter()
235 .map(|row| {
236 let proj = dot_f64(row, pc1);
237 (0..dim).map(|d| row[d] - proj * pc1[d]).collect()
238 })
239 .collect();
240
241 power_iteration(&deflated, dim, iters, 1)
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn plot_loss_curve_empty_errors() {
250 let result = plot_loss_curve(&[], "/tmp/test_empty.png");
251 assert!(result.is_err());
252 }
253
254 #[test]
255 fn plot_loss_curve_creates_file() {
256 let history = vec![
257 EpochStats {
258 epoch: 1,
259 avg_loss: 2.5,
260 learning_rate: 0.025,
261 pairs_processed: 100,
262 elapsed_secs: 0.5,
263 },
264 EpochStats {
265 epoch: 2,
266 avg_loss: 1.8,
267 learning_rate: 0.020,
268 pairs_processed: 100,
269 elapsed_secs: 0.5,
270 },
271 EpochStats {
272 epoch: 3,
273 avg_loss: 1.2,
274 learning_rate: 0.015,
275 pairs_processed: 100,
276 elapsed_secs: 0.5,
277 },
278 ];
279 let path = "/tmp/word2vec_test_loss.png";
280 plot_loss_curve(&history, path).unwrap();
281 assert!(std::path::Path::new(path).exists());
282 }
283
284 #[test]
285 fn pca_plot_creates_file() {
286 use crate::{Config, Trainer};
287 let corpus: Vec<String> = (0..50)
288 .map(|i| {
289 format!(
290 "w{} w{} w{} w{}",
291 i % 8,
292 (i + 1) % 8,
293 (i + 2) % 8,
294 (i + 3) % 8
295 )
296 })
297 .collect();
298 let mut trainer = Trainer::new(Config {
299 epochs: 2,
300 embedding_dim: 20,
301 ..Config::default()
302 });
303 let emb = trainer.train(&corpus).unwrap();
304 let path = "/tmp/word2vec_test_pca.png";
305 plot_word_vectors_pca(&emb, 8, path).unwrap();
306 assert!(std::path::Path::new(path).exists());
307 }
308}