Skip to main content

word2vec/
embeddings.rs

1//! Post-training embedding access: similarity, analogy, save/load.
2//!
3//! # Usage
4//!
5//! ```rust,no_run
6//! use word2vec::{Config, Trainer};
7//!
8//! let corpus = vec!["the cat sat on the mat".repeat(50)];
9//! let mut trainer = Trainer::new(Config { epochs: 3, ..Config::default() });
10//! let emb = trainer.train(&corpus).unwrap();
11//!
12//! let similar = emb.most_similar("cat", 3);
13//! // [("mat", 0.92), ("sat", 0.87), ("on", 0.81)] (values illustrative)
14//!
15//! // king - man + woman ≈ queen
16//! // let queen = emb.analogy("king", "man", "woman", 3);
17//! ```
18
19use serde::{Deserialize, Serialize};
20use std::path::Path;
21
22use crate::config::Config;
23use crate::error::{Result, Word2VecError};
24use crate::model::Model;
25use crate::vocab::Vocabulary;
26
27/// Trained embeddings with vocabulary — the primary inference interface.
28#[derive(Debug, Serialize, Deserialize)]
29pub struct Embeddings {
30    model: Model,
31    vocab: Vocabulary,
32    config: Config,
33}
34
35impl Embeddings {
36    /// Wrap a trained model and vocabulary.
37    pub(crate) fn new(model: Model, vocab: Vocabulary, config: Config) -> Self {
38        Self {
39            model,
40            vocab,
41            config,
42        }
43    }
44
45    /// Number of words in the vocabulary.
46    ///
47    /// ```rust,no_run
48    /// # use word2vec::{Config, Trainer};
49    /// # let corpus = vec!["hello world hello".to_string()];
50    /// # let mut t = Trainer::new(Config { epochs: 1, ..Config::default() });
51    /// # let emb = t.train(&corpus).unwrap();
52    /// assert!(emb.vocab_size() >= 2);
53    /// ```
54    pub fn vocab_size(&self) -> usize {
55        self.vocab.len()
56    }
57
58    /// Embedding dimension.
59    pub fn embedding_dim(&self) -> usize {
60        self.config.embedding_dim
61    }
62
63    /// Get the raw embedding vector for a word.
64    ///
65    /// Returns `None` if the word is not in the vocabulary.
66    ///
67    /// ```rust,no_run
68    /// # use word2vec::{Config, Trainer};
69    /// # let corpus = vec!["hello world".to_string()];
70    /// # let mut t = Trainer::new(Config { epochs: 1, ..Config::default() });
71    /// # let emb = t.train(&corpus).unwrap();
72    /// assert!(emb.get_vector("hello").is_some());
73    /// assert!(emb.get_vector("nonexistent").is_none());
74    /// ```
75    pub fn get_vector(&self, word: &str) -> Option<&[f32]> {
76        self.vocab
77            .word2idx
78            .get(word)
79            .map(|&i| self.model.input_vec(i))
80    }
81
82    /// Cosine similarity between two words.
83    ///
84    /// Returns a value in `[-1.0, 1.0]`, or an error if either word
85    /// is not in the vocabulary.
86    ///
87    /// ```rust,no_run
88    /// # use word2vec::{Config, Trainer};
89    /// # let corpus = vec!["cat mat cat mat bat".to_string()];
90    /// # let mut t = Trainer::new(Config { epochs: 1, ..Config::default() });
91    /// # let emb = t.train(&corpus).unwrap();
92    /// let sim = emb.similarity("cat", "mat").unwrap();
93    /// assert!(sim >= -1.0 && sim <= 1.0);
94    /// ```
95    pub fn similarity(&self, word_a: &str, word_b: &str) -> Result<f32> {
96        let va = self
97            .get_vector(word_a)
98            .ok_or_else(|| Word2VecError::UnknownWord(word_a.to_string()))?;
99        let vb = self
100            .get_vector(word_b)
101            .ok_or_else(|| Word2VecError::UnknownWord(word_b.to_string()))?;
102        Ok(cosine_similarity(va, vb))
103    }
104
105    /// Find the `top_k` most similar words to `query`.
106    ///
107    /// Returns `(word, cosine_similarity)` pairs sorted descending.
108    /// The query word itself is excluded from results.
109    ///
110    /// ```rust,no_run
111    /// # use word2vec::{Config, Trainer};
112    /// # let corpus = (0..100).map(|_| "cat mat bat rat sat".to_string()).collect::<Vec<_>>();
113    /// # let mut t = Trainer::new(Config { epochs: 3, ..Config::default() });
114    /// # let emb = t.train(&corpus).unwrap();
115    /// let nearest = emb.most_similar("cat", 3);
116    /// assert!(nearest.len() <= 3);
117    /// ```
118    pub fn most_similar(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
119        let query_vec = match self.get_vector(query) {
120            Some(v) => v,
121            None => return vec![],
122        };
123
124        let mut scores: Vec<(usize, f32)> = (0..self.vocab.len())
125            .filter(|&i| self.vocab.idx2word[i] != query)
126            .map(|i| (i, cosine_similarity(query_vec, self.model.input_vec(i))))
127            .collect();
128
129        scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
130        scores.truncate(top_k);
131
132        scores
133            .into_iter()
134            .map(|(i, sim)| (self.vocab.idx2word[i].clone(), sim))
135            .collect()
136    }
137
138    /// Solve an analogy: `pos_a - neg_a + pos_b ≈ result`.
139    ///
140    /// Classic example: `king - man + woman ≈ queen`.
141    ///
142    /// Excludes `pos_a`, `neg_a`, and `pos_b` from the candidate list.
143    ///
144    /// Returns `(word, score)` pairs sorted by cosine similarity to
145    /// the query vector.
146    pub fn analogy(
147        &self,
148        pos_a: &str,
149        neg_a: &str,
150        pos_b: &str,
151        top_k: usize,
152    ) -> Result<Vec<(String, f32)>> {
153        let va = self
154            .get_vector(pos_a)
155            .ok_or_else(|| Word2VecError::UnknownWord(pos_a.to_string()))?;
156        let vna = self
157            .get_vector(neg_a)
158            .ok_or_else(|| Word2VecError::UnknownWord(neg_a.to_string()))?;
159        let vb = self
160            .get_vector(pos_b)
161            .ok_or_else(|| Word2VecError::UnknownWord(pos_b.to_string()))?;
162
163        let dim = self.embedding_dim();
164        let query: Vec<f32> = (0..dim).map(|i| va[i] - vna[i] + vb[i]).collect();
165        let query_norm = normalize_vec(&query);
166
167        let exclude = [pos_a, neg_a, pos_b];
168        let mut scores: Vec<(usize, f32)> = (0..self.vocab.len())
169            .filter(|&i| !exclude.contains(&self.vocab.idx2word[i].as_str()))
170            .map(|i| (i, cosine_similarity(&query_norm, self.model.input_vec(i))))
171            .collect();
172
173        scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
174        scores.truncate(top_k);
175
176        Ok(scores
177            .into_iter()
178            .map(|(i, s)| (self.vocab.idx2word[i].clone(), s))
179            .collect())
180    }
181
182    /// Return all words in vocabulary sorted alphabetically.
183    pub fn words(&self) -> Vec<&str> {
184        let mut words: Vec<&str> = self.vocab.idx2word.iter().map(|s| s.as_str()).collect();
185        words.sort_unstable();
186        words
187    }
188
189    /// Save embeddings to JSON.
190    ///
191    /// The file contains both the model weights and vocabulary so it can
192    /// be loaded independently.
193    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
194        let json = serde_json::to_string(self)?;
195        std::fs::write(path, json)?;
196        Ok(())
197    }
198
199    /// Load embeddings from a JSON file produced by [`save`](Self::save).
200    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
201        let json = std::fs::read_to_string(path)?;
202        let emb: Self = serde_json::from_str(&json)?;
203        Ok(emb)
204    }
205
206    /// Export just the word vectors as a plain-text file (word2vec format).
207    ///
208    /// Format: first line is `<vocab_size> <dim>`, then one word per line
209    /// followed by space-separated floats.
210    pub fn save_text_format(&self, path: impl AsRef<Path>) -> Result<()> {
211        use std::io::Write;
212        let mut f = std::fs::File::create(path)?;
213        writeln!(f, "{} {}", self.vocab.len(), self.embedding_dim())?;
214        for (i, word) in self.vocab.idx2word.iter().enumerate() {
215            let vec = self.model.input_vec(i);
216            let vec_str: Vec<String> = vec.iter().map(|v| format!("{:.6}", v)).collect();
217            writeln!(f, "{} {}", word, vec_str.join(" "))?;
218        }
219        Ok(())
220    }
221
222    /// Get a reference to the vocabulary.
223    pub fn vocab(&self) -> &Vocabulary {
224        &self.vocab
225    }
226}
227
228/// Cosine similarity between two vectors (handles zero-norm gracefully).
229pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
230    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
231    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
232    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
233    if norm_a < 1e-8 || norm_b < 1e-8 {
234        return 0.0;
235    }
236    (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
237}
238
239/// Return a L2-normalised copy of a vector.
240pub fn normalize_vec(v: &[f32]) -> Vec<f32> {
241    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
242    if norm < 1e-8 {
243        return v.to_vec();
244    }
245    v.iter().map(|x| x / norm).collect()
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn cosine_identical_vectors() {
254        let v = vec![1.0, 2.0, 3.0];
255        assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-5);
256    }
257
258    #[test]
259    fn cosine_opposite_vectors() {
260        let a = vec![1.0, 0.0];
261        let b = vec![-1.0, 0.0];
262        assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-5);
263    }
264
265    #[test]
266    fn cosine_orthogonal_vectors() {
267        let a = vec![1.0, 0.0];
268        let b = vec![0.0, 1.0];
269        assert!(cosine_similarity(&a, &b).abs() < 1e-5);
270    }
271
272    #[test]
273    fn cosine_zero_vector() {
274        let a = vec![0.0, 0.0];
275        let b = vec![1.0, 0.0];
276        assert_eq!(cosine_similarity(&a, &b), 0.0);
277    }
278
279    #[test]
280    fn normalize_unit_vector() {
281        let v = vec![3.0, 4.0];
282        let n = normalize_vec(&v);
283        let norm: f32 = n.iter().map(|x| x * x).sum::<f32>().sqrt();
284        assert!((norm - 1.0).abs() < 1e-5);
285    }
286
287    fn make_embeddings() -> Embeddings {
288        use crate::{Config, Trainer};
289        let corpus: Vec<String> = (0..30)
290            .map(|i| format!("word{} word{} word{}", i % 5, (i + 1) % 5, (i + 2) % 5))
291            .collect();
292        let mut trainer = Trainer::new(Config {
293            epochs: 2,
294            embedding_dim: 10,
295            ..Config::default()
296        });
297        trainer.train(&corpus).unwrap()
298    }
299
300    #[test]
301    fn get_vector_known_word() {
302        let emb = make_embeddings();
303        assert!(emb.get_vector("word0").is_some());
304    }
305
306    #[test]
307    fn get_vector_unknown_word() {
308        let emb = make_embeddings();
309        assert!(emb.get_vector("unknown_xyz").is_none());
310    }
311
312    #[test]
313    fn most_similar_returns_sorted_results() {
314        let emb = make_embeddings();
315        let results = emb.most_similar("word0", 3);
316        for window in results.windows(2) {
317            assert!(window[0].1 >= window[1].1, "results not sorted");
318        }
319    }
320
321    #[test]
322    fn similarity_self_is_one() {
323        let emb = make_embeddings();
324        let sim = emb.similarity("word0", "word0").unwrap();
325        assert!((sim - 1.0).abs() < 1e-4, "self-similarity={sim}");
326    }
327
328    #[test]
329    fn save_load_roundtrip() {
330        let emb = make_embeddings();
331        let dir = tempfile::tempdir().unwrap();
332        let path = dir.path().join("emb.json");
333        emb.save(&path).unwrap();
334        let loaded = Embeddings::load(&path).unwrap();
335        assert_eq!(loaded.vocab_size(), emb.vocab_size());
336    }
337}