1use indicatif::{ProgressBar, ProgressStyle};
26use log::info;
27use rand::rngs::SmallRng;
28use rand::seq::SliceRandom;
29use rand::SeedableRng;
30use std::path::Path;
31
32use crate::config::Config;
33use crate::embeddings::Embeddings;
34use crate::error::{Result, Word2VecError};
35use crate::model::{sentence_to_pairs, Model};
36use crate::vocab::Vocabulary;
37
38#[derive(Debug, Clone)]
40pub struct EpochStats {
41 pub epoch: usize,
42 pub avg_loss: f64,
43 pub learning_rate: f32,
44 pub pairs_processed: u64,
45 pub elapsed_secs: f64,
46}
47
48pub struct Trainer {
50 config: Config,
51 pub history: Vec<EpochStats>,
53}
54
55impl Trainer {
56 pub fn new(config: Config) -> Self {
64 Self {
65 config,
66 history: Vec::new(),
67 }
68 }
69
70 fn tokenise(corpus: &[String]) -> Vec<Vec<String>> {
72 corpus
73 .iter()
74 .filter(|s| !s.trim().is_empty())
75 .map(|s| s.split_whitespace().map(str::to_string).collect())
76 .collect()
77 }
78
79 pub fn train(&mut self, corpus: &[String]) -> Result<Embeddings> {
110 self.config
111 .validate()
112 .map_err(|_e| Word2VecError::EmptyVocabulary)?;
113
114 let sentences = Self::tokenise(corpus);
115 let vocab = Vocabulary::build(&sentences, &self.config)?;
116
117 info!(
118 "Vocabulary: {} unique tokens ({} total), model: {}",
119 vocab.len(),
120 vocab.total_tokens,
121 self.config.model
122 );
123
124 let mut model = Model::new(vocab.len(), self.config.embedding_dim, self.config.seed);
125 let mut rng = SmallRng::seed_from_u64(self.config.seed);
126
127 let total_pairs_estimate = self.estimate_pairs(&sentences, &vocab);
128 let lr_start = self.config.learning_rate;
129 let lr_min = self.config.min_learning_rate;
130 let epochs = self.config.epochs;
131
132 let total_steps = total_pairs_estimate * epochs as u64;
134 let mut global_step: u64 = 0;
135
136 for epoch in 0..epochs {
137 let t_start = std::time::Instant::now();
138 let mut epoch_loss = 0.0_f64;
139 let mut epoch_pairs: u64 = 0;
140
141 let mut indices: Vec<usize> = (0..sentences.len()).collect();
143 indices.shuffle(&mut rng);
144
145 let pb = self.make_progress_bar(epoch, sentences.len());
146
147 for &sent_idx in &indices {
148 let tokens = vocab.tokenise_and_subsample(
149 &sentences[sent_idx],
150 self.config.subsample_threshold,
151 &mut rng,
152 );
153
154 if tokens.len() < 2 {
155 pb.inc(1);
156 continue;
157 }
158
159 let pairs = sentence_to_pairs(&tokens, self.config.window_size, &mut rng);
160
161 for (center, context_words) in pairs {
162 let progress = global_step as f32 / total_steps.max(1) as f32;
164 let lr = (lr_start - (lr_start - lr_min) * progress).max(lr_min);
165
166 let negatives: Vec<usize> = (0..self.config.negative_samples)
168 .map(|_| vocab.negative_sample(&mut rng))
169 .collect();
170
171 let loss =
172 model.update(self.config.model, center, &context_words, &negatives, lr);
173
174 epoch_loss += loss as f64;
175 epoch_pairs += 1;
176 global_step += 1;
177 }
178
179 pb.inc(1);
180 }
181
182 pb.finish_and_clear();
183
184 let lr_now =
185 (lr_start - (lr_start - lr_min) * (epoch + 1) as f32 / epochs as f32).max(lr_min);
186 let avg_loss = if epoch_pairs > 0 {
187 epoch_loss / epoch_pairs as f64
188 } else {
189 0.0
190 };
191 let elapsed = t_start.elapsed().as_secs_f64();
192
193 let stats = EpochStats {
194 epoch: epoch + 1,
195 avg_loss,
196 learning_rate: lr_now,
197 pairs_processed: epoch_pairs,
198 elapsed_secs: elapsed,
199 };
200
201 info!(
202 "Epoch {}/{} | loss: {:.4} | lr: {:.5} | pairs: {} | {:.1}s",
203 stats.epoch,
204 epochs,
205 stats.avg_loss,
206 stats.learning_rate,
207 stats.pairs_processed,
208 stats.elapsed_secs
209 );
210
211 self.history.push(stats);
212 }
213
214 Ok(Embeddings::new(model, vocab, self.config.clone()))
215 }
216
217 pub fn save_history(&self, path: impl AsRef<Path>) -> Result<()> {
226 let records: Vec<serde_json::Value> = self
227 .history
228 .iter()
229 .map(|s| {
230 serde_json::json!({
231 "epoch": s.epoch,
232 "avg_loss": s.avg_loss,
233 "learning_rate": s.learning_rate,
234 "pairs_processed": s.pairs_processed,
235 "elapsed_secs": s.elapsed_secs,
236 })
237 })
238 .collect();
239
240 let json = serde_json::to_string_pretty(&records)?;
241 std::fs::write(path, json)?;
242 Ok(())
243 }
244
245 fn estimate_pairs(&self, sentences: &[Vec<String>], vocab: &Vocabulary) -> u64 {
247 sentences
248 .iter()
249 .map(|s| {
250 let len = s.iter().filter(|w| vocab.word2idx.contains_key(*w)).count() as u64;
251 len.saturating_sub(1) * self.config.window_size as u64 * 2
252 })
253 .sum()
254 }
255
256 fn make_progress_bar(&self, epoch: usize, total: usize) -> ProgressBar {
257 let pb = ProgressBar::new(total as u64);
258 pb.set_style(
259 ProgressStyle::with_template(
260 "{prefix:.bold} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} {msg}",
261 )
262 .unwrap()
263 .progress_chars("=>-"),
264 );
265 pb.set_prefix(format!("Epoch {:>3}/{}", epoch + 1, self.config.epochs));
266 pb
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::config::ModelType;
274
275 fn tiny_corpus() -> Vec<String> {
276 vec![
277 "the quick brown fox".to_string(),
278 "the lazy dog sleeps".to_string(),
279 "fox and dog are animals".to_string(),
280 "quick animals run fast".to_string(),
281 ]
282 }
283
284 #[test]
285 fn training_runs_without_panic() {
286 let mut trainer = Trainer::new(Config {
287 epochs: 2,
288 embedding_dim: 10,
289 ..Config::default()
290 });
291 let result = trainer.train(&tiny_corpus());
292 assert!(result.is_ok(), "{:?}", result);
293 }
294
295 #[test]
296 fn history_records_all_epochs() {
297 let mut trainer = Trainer::new(Config {
298 epochs: 3,
299 embedding_dim: 8,
300 ..Config::default()
301 });
302 trainer.train(&tiny_corpus()).unwrap();
303 assert_eq!(trainer.history.len(), 3);
304 }
305
306 #[test]
307 fn loss_is_finite() {
308 let mut trainer = Trainer::new(Config {
309 epochs: 2,
310 embedding_dim: 8,
311 ..Config::default()
312 });
313 trainer.train(&tiny_corpus()).unwrap();
314 for s in &trainer.history {
315 assert!(
316 s.avg_loss.is_finite(),
317 "epoch {} loss={}",
318 s.epoch,
319 s.avg_loss
320 );
321 }
322 }
323
324 #[test]
325 fn cbow_training_runs() {
326 let mut trainer = Trainer::new(Config {
327 epochs: 2,
328 embedding_dim: 8,
329 model: ModelType::Cbow,
330 ..Config::default()
331 });
332 assert!(trainer.train(&tiny_corpus()).is_ok());
333 }
334
335 #[test]
336 fn empty_corpus_returns_error() {
337 let mut trainer = Trainer::new(Config::default());
338 let result = trainer.train(&[]);
339 assert!(result.is_err());
340 }
341}