1use serde::{Deserialize, Serialize};
20use std::path::Path;
21
22use crate::config::Config;
23use crate::error::{Result, Word2VecError};
24use crate::model::Model;
25use crate::vocab::Vocabulary;
26
27#[derive(Debug, Serialize, Deserialize)]
29pub struct Embeddings {
30 model: Model,
31 vocab: Vocabulary,
32 config: Config,
33}
34
35impl Embeddings {
36 pub(crate) fn new(model: Model, vocab: Vocabulary, config: Config) -> Self {
38 Self {
39 model,
40 vocab,
41 config,
42 }
43 }
44
45 pub fn vocab_size(&self) -> usize {
55 self.vocab.len()
56 }
57
58 pub fn embedding_dim(&self) -> usize {
60 self.config.embedding_dim
61 }
62
63 pub fn get_vector(&self, word: &str) -> Option<&[f32]> {
76 self.vocab
77 .word2idx
78 .get(word)
79 .map(|&i| self.model.input_vec(i))
80 }
81
82 pub fn similarity(&self, word_a: &str, word_b: &str) -> Result<f32> {
96 let va = self
97 .get_vector(word_a)
98 .ok_or_else(|| Word2VecError::UnknownWord(word_a.to_string()))?;
99 let vb = self
100 .get_vector(word_b)
101 .ok_or_else(|| Word2VecError::UnknownWord(word_b.to_string()))?;
102 Ok(cosine_similarity(va, vb))
103 }
104
105 pub fn most_similar(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
119 let query_vec = match self.get_vector(query) {
120 Some(v) => v,
121 None => return vec![],
122 };
123
124 let mut scores: Vec<(usize, f32)> = (0..self.vocab.len())
125 .filter(|&i| self.vocab.idx2word[i] != query)
126 .map(|i| (i, cosine_similarity(query_vec, self.model.input_vec(i))))
127 .collect();
128
129 scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
130 scores.truncate(top_k);
131
132 scores
133 .into_iter()
134 .map(|(i, sim)| (self.vocab.idx2word[i].clone(), sim))
135 .collect()
136 }
137
138 pub fn analogy(
147 &self,
148 pos_a: &str,
149 neg_a: &str,
150 pos_b: &str,
151 top_k: usize,
152 ) -> Result<Vec<(String, f32)>> {
153 let va = self
154 .get_vector(pos_a)
155 .ok_or_else(|| Word2VecError::UnknownWord(pos_a.to_string()))?;
156 let vna = self
157 .get_vector(neg_a)
158 .ok_or_else(|| Word2VecError::UnknownWord(neg_a.to_string()))?;
159 let vb = self
160 .get_vector(pos_b)
161 .ok_or_else(|| Word2VecError::UnknownWord(pos_b.to_string()))?;
162
163 let dim = self.embedding_dim();
164 let query: Vec<f32> = (0..dim).map(|i| va[i] - vna[i] + vb[i]).collect();
165 let query_norm = normalize_vec(&query);
166
167 let exclude = [pos_a, neg_a, pos_b];
168 let mut scores: Vec<(usize, f32)> = (0..self.vocab.len())
169 .filter(|&i| !exclude.contains(&self.vocab.idx2word[i].as_str()))
170 .map(|i| (i, cosine_similarity(&query_norm, self.model.input_vec(i))))
171 .collect();
172
173 scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
174 scores.truncate(top_k);
175
176 Ok(scores
177 .into_iter()
178 .map(|(i, s)| (self.vocab.idx2word[i].clone(), s))
179 .collect())
180 }
181
182 pub fn words(&self) -> Vec<&str> {
184 let mut words: Vec<&str> = self.vocab.idx2word.iter().map(|s| s.as_str()).collect();
185 words.sort_unstable();
186 words
187 }
188
189 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
194 let json = serde_json::to_string(self)?;
195 std::fs::write(path, json)?;
196 Ok(())
197 }
198
199 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
201 let json = std::fs::read_to_string(path)?;
202 let emb: Self = serde_json::from_str(&json)?;
203 Ok(emb)
204 }
205
206 pub fn save_text_format(&self, path: impl AsRef<Path>) -> Result<()> {
211 use std::io::Write;
212 let mut f = std::fs::File::create(path)?;
213 writeln!(f, "{} {}", self.vocab.len(), self.embedding_dim())?;
214 for (i, word) in self.vocab.idx2word.iter().enumerate() {
215 let vec = self.model.input_vec(i);
216 let vec_str: Vec<String> = vec.iter().map(|v| format!("{:.6}", v)).collect();
217 writeln!(f, "{} {}", word, vec_str.join(" "))?;
218 }
219 Ok(())
220 }
221
222 pub fn vocab(&self) -> &Vocabulary {
224 &self.vocab
225 }
226}
227
228pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
230 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
231 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
232 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
233 if norm_a < 1e-8 || norm_b < 1e-8 {
234 return 0.0;
235 }
236 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
237}
238
239pub fn normalize_vec(v: &[f32]) -> Vec<f32> {
241 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
242 if norm < 1e-8 {
243 return v.to_vec();
244 }
245 v.iter().map(|x| x / norm).collect()
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn cosine_identical_vectors() {
254 let v = vec![1.0, 2.0, 3.0];
255 assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-5);
256 }
257
258 #[test]
259 fn cosine_opposite_vectors() {
260 let a = vec![1.0, 0.0];
261 let b = vec![-1.0, 0.0];
262 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-5);
263 }
264
265 #[test]
266 fn cosine_orthogonal_vectors() {
267 let a = vec![1.0, 0.0];
268 let b = vec![0.0, 1.0];
269 assert!(cosine_similarity(&a, &b).abs() < 1e-5);
270 }
271
272 #[test]
273 fn cosine_zero_vector() {
274 let a = vec![0.0, 0.0];
275 let b = vec![1.0, 0.0];
276 assert_eq!(cosine_similarity(&a, &b), 0.0);
277 }
278
279 #[test]
280 fn normalize_unit_vector() {
281 let v = vec![3.0, 4.0];
282 let n = normalize_vec(&v);
283 let norm: f32 = n.iter().map(|x| x * x).sum::<f32>().sqrt();
284 assert!((norm - 1.0).abs() < 1e-5);
285 }
286
287 fn make_embeddings() -> Embeddings {
288 use crate::{Config, Trainer};
289 let corpus: Vec<String> = (0..30)
290 .map(|i| format!("word{} word{} word{}", i % 5, (i + 1) % 5, (i + 2) % 5))
291 .collect();
292 let mut trainer = Trainer::new(Config {
293 epochs: 2,
294 embedding_dim: 10,
295 ..Config::default()
296 });
297 trainer.train(&corpus).unwrap()
298 }
299
300 #[test]
301 fn get_vector_known_word() {
302 let emb = make_embeddings();
303 assert!(emb.get_vector("word0").is_some());
304 }
305
306 #[test]
307 fn get_vector_unknown_word() {
308 let emb = make_embeddings();
309 assert!(emb.get_vector("unknown_xyz").is_none());
310 }
311
312 #[test]
313 fn most_similar_returns_sorted_results() {
314 let emb = make_embeddings();
315 let results = emb.most_similar("word0", 3);
316 for window in results.windows(2) {
317 assert!(window[0].1 >= window[1].1, "results not sorted");
318 }
319 }
320
321 #[test]
322 fn similarity_self_is_one() {
323 let emb = make_embeddings();
324 let sim = emb.similarity("word0", "word0").unwrap();
325 assert!((sim - 1.0).abs() < 1e-4, "self-similarity={sim}");
326 }
327
328 #[test]
329 fn save_load_roundtrip() {
330 let emb = make_embeddings();
331 let dir = tempfile::tempdir().unwrap();
332 let path = dir.path().join("emb.json");
333 emb.save(&path).unwrap();
334 let loaded = Embeddings::load(&path).unwrap();
335 assert_eq!(loaded.vocab_size(), emb.vocab_size());
336 }
337}