Skip to main content

word2vec/
trainer.rs

1//! Training loop with progress monitoring, learning rate decay, and
2//! optional checkpointing.
3//!
4//! # Training Flow
5//!
6//! ```text
7//! corpus ──► Vocabulary::build ──► Model::new
8//!                                       │
9//!              ┌────────────────────────┘
10//!              ▼
11//!         for each epoch:
12//!           shuffle sentences
13//!           for each sentence:
14//!             subsample tokens
15//!             for each (center, context) pair:
16//!               sample negatives
17//!               Model::update (SGD step)
18//!               update LR (linear decay)
19//!           record epoch loss
20//!              │
21//!              ▼
22//!         Embeddings ──► .most_similar() / .analogy() / .save()
23//! ```
24
25use indicatif::{ProgressBar, ProgressStyle};
26use log::info;
27use rand::rngs::SmallRng;
28use rand::seq::SliceRandom;
29use rand::SeedableRng;
30use std::path::Path;
31
32use crate::config::Config;
33use crate::embeddings::Embeddings;
34use crate::error::{Result, Word2VecError};
35use crate::model::{sentence_to_pairs, Model};
36use crate::vocab::Vocabulary;
37
38/// Recorded statistics for one epoch.
39#[derive(Debug, Clone)]
40pub struct EpochStats {
41    pub epoch: usize,
42    pub avg_loss: f64,
43    pub learning_rate: f32,
44    pub pairs_processed: u64,
45    pub elapsed_secs: f64,
46}
47
48/// Manages the full training pipeline.
49pub struct Trainer {
50    config: Config,
51    /// Loss history per epoch (populated during `train`).
52    pub history: Vec<EpochStats>,
53}
54
55impl Trainer {
56    /// Create a new trainer with the given configuration.
57    ///
58    /// ```rust
59    /// use word2vec::{Config, Trainer};
60    /// let trainer = Trainer::new(Config::default());
61    /// assert!(trainer.history.is_empty());
62    /// ```
63    pub fn new(config: Config) -> Self {
64        Self {
65            config,
66            history: Vec::new(),
67        }
68    }
69
70    /// Tokenise raw text into sentences (split on whitespace).
71    fn tokenise(corpus: &[String]) -> Vec<Vec<String>> {
72        corpus
73            .iter()
74            .filter(|s| !s.trim().is_empty())
75            .map(|s| s.split_whitespace().map(str::to_string).collect())
76            .collect()
77    }
78
79    /// Train on a corpus of sentences.
80    ///
81    /// Returns [`Embeddings`] which wraps the trained model and vocabulary
82    /// for inference queries.
83    ///
84    /// # Errors
85    ///
86    /// Returns [`Word2VecError`] if the config is invalid or the corpus
87    /// is too small to train.
88    ///
89    /// # Example
90    ///
91    /// ```rust,no_run
92    /// use word2vec::{Config, ModelType, Trainer};
93    ///
94    /// let corpus = vec![
95    ///     "paris is the capital of france".to_string(),
96    ///     "berlin is the capital of germany".to_string(),
97    ///     "tokyo is the capital of japan".to_string(),
98    /// ];
99    ///
100    /// let mut trainer = Trainer::new(Config {
101    ///     epochs: 3,
102    ///     embedding_dim: 50,
103    ///     ..Config::default()
104    /// });
105    ///
106    /// let embeddings = trainer.train(&corpus).unwrap();
107    /// println!("Vocab size: {}", embeddings.vocab_size());
108    /// ```
109    pub fn train(&mut self, corpus: &[String]) -> Result<Embeddings> {
110        self.config
111            .validate()
112            .map_err(|_e| Word2VecError::EmptyVocabulary)?;
113
114        let sentences = Self::tokenise(corpus);
115        let vocab = Vocabulary::build(&sentences, &self.config)?;
116
117        info!(
118            "Vocabulary: {} unique tokens ({} total), model: {}",
119            vocab.len(),
120            vocab.total_tokens,
121            self.config.model
122        );
123
124        let mut model = Model::new(vocab.len(), self.config.embedding_dim, self.config.seed);
125        let mut rng = SmallRng::seed_from_u64(self.config.seed);
126
127        let total_pairs_estimate = self.estimate_pairs(&sentences, &vocab);
128        let lr_start = self.config.learning_rate;
129        let lr_min = self.config.min_learning_rate;
130        let epochs = self.config.epochs;
131
132        // Grand total steps for LR decay
133        let total_steps = total_pairs_estimate * epochs as u64;
134        let mut global_step: u64 = 0;
135
136        for epoch in 0..epochs {
137            let t_start = std::time::Instant::now();
138            let mut epoch_loss = 0.0_f64;
139            let mut epoch_pairs: u64 = 0;
140
141            // Shuffle sentence order each epoch
142            let mut indices: Vec<usize> = (0..sentences.len()).collect();
143            indices.shuffle(&mut rng);
144
145            let pb = self.make_progress_bar(epoch, sentences.len());
146
147            for &sent_idx in &indices {
148                let tokens = vocab.tokenise_and_subsample(
149                    &sentences[sent_idx],
150                    self.config.subsample_threshold,
151                    &mut rng,
152                );
153
154                if tokens.len() < 2 {
155                    pb.inc(1);
156                    continue;
157                }
158
159                let pairs = sentence_to_pairs(&tokens, self.config.window_size, &mut rng);
160
161                for (center, context_words) in pairs {
162                    // Linear LR decay
163                    let progress = global_step as f32 / total_steps.max(1) as f32;
164                    let lr = (lr_start - (lr_start - lr_min) * progress).max(lr_min);
165
166                    // Sample negatives
167                    let negatives: Vec<usize> = (0..self.config.negative_samples)
168                        .map(|_| vocab.negative_sample(&mut rng))
169                        .collect();
170
171                    let loss =
172                        model.update(self.config.model, center, &context_words, &negatives, lr);
173
174                    epoch_loss += loss as f64;
175                    epoch_pairs += 1;
176                    global_step += 1;
177                }
178
179                pb.inc(1);
180            }
181
182            pb.finish_and_clear();
183
184            let lr_now =
185                (lr_start - (lr_start - lr_min) * (epoch + 1) as f32 / epochs as f32).max(lr_min);
186            let avg_loss = if epoch_pairs > 0 {
187                epoch_loss / epoch_pairs as f64
188            } else {
189                0.0
190            };
191            let elapsed = t_start.elapsed().as_secs_f64();
192
193            let stats = EpochStats {
194                epoch: epoch + 1,
195                avg_loss,
196                learning_rate: lr_now,
197                pairs_processed: epoch_pairs,
198                elapsed_secs: elapsed,
199            };
200
201            info!(
202                "Epoch {}/{} | loss: {:.4} | lr: {:.5} | pairs: {} | {:.1}s",
203                stats.epoch,
204                epochs,
205                stats.avg_loss,
206                stats.learning_rate,
207                stats.pairs_processed,
208                stats.elapsed_secs
209            );
210
211            self.history.push(stats);
212        }
213
214        Ok(Embeddings::new(model, vocab, self.config.clone()))
215    }
216
217    /// Save training history as JSON.
218    ///
219    /// ```rust,no_run
220    /// use word2vec::{Config, Trainer};
221    /// let mut trainer = Trainer::new(Config::default());
222    /// // trainer.train(...);
223    /// // trainer.save_history("history.json").unwrap();
224    /// ```
225    pub fn save_history(&self, path: impl AsRef<Path>) -> Result<()> {
226        let records: Vec<serde_json::Value> = self
227            .history
228            .iter()
229            .map(|s| {
230                serde_json::json!({
231                    "epoch": s.epoch,
232                    "avg_loss": s.avg_loss,
233                    "learning_rate": s.learning_rate,
234                    "pairs_processed": s.pairs_processed,
235                    "elapsed_secs": s.elapsed_secs,
236                })
237            })
238            .collect();
239
240        let json = serde_json::to_string_pretty(&records)?;
241        std::fs::write(path, json)?;
242        Ok(())
243    }
244
245    /// Rough pair count for LR scheduling (doesn't account for subsampling).
246    fn estimate_pairs(&self, sentences: &[Vec<String>], vocab: &Vocabulary) -> u64 {
247        sentences
248            .iter()
249            .map(|s| {
250                let len = s.iter().filter(|w| vocab.word2idx.contains_key(*w)).count() as u64;
251                len.saturating_sub(1) * self.config.window_size as u64 * 2
252            })
253            .sum()
254    }
255
256    fn make_progress_bar(&self, epoch: usize, total: usize) -> ProgressBar {
257        let pb = ProgressBar::new(total as u64);
258        pb.set_style(
259            ProgressStyle::with_template(
260                "{prefix:.bold} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} {msg}",
261            )
262            .unwrap()
263            .progress_chars("=>-"),
264        );
265        pb.set_prefix(format!("Epoch {:>3}/{}", epoch + 1, self.config.epochs));
266        pb
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use crate::config::ModelType;
274
275    fn tiny_corpus() -> Vec<String> {
276        vec![
277            "the quick brown fox".to_string(),
278            "the lazy dog sleeps".to_string(),
279            "fox and dog are animals".to_string(),
280            "quick animals run fast".to_string(),
281        ]
282    }
283
284    #[test]
285    fn training_runs_without_panic() {
286        let mut trainer = Trainer::new(Config {
287            epochs: 2,
288            embedding_dim: 10,
289            ..Config::default()
290        });
291        let result = trainer.train(&tiny_corpus());
292        assert!(result.is_ok(), "{:?}", result);
293    }
294
295    #[test]
296    fn history_records_all_epochs() {
297        let mut trainer = Trainer::new(Config {
298            epochs: 3,
299            embedding_dim: 8,
300            ..Config::default()
301        });
302        trainer.train(&tiny_corpus()).unwrap();
303        assert_eq!(trainer.history.len(), 3);
304    }
305
306    #[test]
307    fn loss_is_finite() {
308        let mut trainer = Trainer::new(Config {
309            epochs: 2,
310            embedding_dim: 8,
311            ..Config::default()
312        });
313        trainer.train(&tiny_corpus()).unwrap();
314        for s in &trainer.history {
315            assert!(
316                s.avg_loss.is_finite(),
317                "epoch {} loss={}",
318                s.epoch,
319                s.avg_loss
320            );
321        }
322    }
323
324    #[test]
325    fn cbow_training_runs() {
326        let mut trainer = Trainer::new(Config {
327            epochs: 2,
328            embedding_dim: 8,
329            model: ModelType::Cbow,
330            ..Config::default()
331        });
332        assert!(trainer.train(&tiny_corpus()).is_ok());
333    }
334
335    #[test]
336    fn empty_corpus_returns_error() {
337        let mut trainer = Trainer::new(Config::default());
338        let result = trainer.train(&[]);
339        assert!(result.is_err());
340    }
341}