1use rand::rngs::SmallRng;
19use rand::Rng;
20#[allow(unused_imports)]
21use rand::SeedableRng;
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24
25use crate::config::Config;
26use crate::error::{Result, Word2VecError};
27
28const TABLE_SIZE: usize = 1_000_000;
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Vocabulary {
34 pub word2idx: HashMap<String, usize>,
36 pub idx2word: Vec<String>,
38 pub counts: Vec<u64>,
40 pub noise_table: Vec<u32>,
42 pub total_tokens: u64,
44}
45
46impl Vocabulary {
47 pub fn build(sentences: &[Vec<String>], config: &Config) -> Result<Self> {
68 let mut raw_counts: HashMap<String, u64> = HashMap::new();
69 for sentence in sentences {
70 for token in sentence {
71 *raw_counts.entry(token.clone()).or_insert(0) += 1;
72 }
73 }
74
75 let mut filtered: Vec<(String, u64)> = raw_counts
77 .into_iter()
78 .filter(|(_, c)| *c >= config.min_count as u64)
79 .collect();
80
81 if filtered.is_empty() {
82 return Err(Word2VecError::EmptyVocabulary);
83 }
84 if filtered.len() < 2 {
85 return Err(Word2VecError::CorpusTooSmall(filtered.len()));
86 }
87
88 filtered.sort_unstable_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
90
91 let total_tokens: u64 = filtered.iter().map(|(_, c)| c).sum();
92 let mut word2idx = HashMap::with_capacity(filtered.len());
93 let mut idx2word = Vec::with_capacity(filtered.len());
94 let mut counts = Vec::with_capacity(filtered.len());
95
96 for (idx, (word, count)) in filtered.into_iter().enumerate() {
97 word2idx.insert(word.clone(), idx);
98 idx2word.push(word);
99 counts.push(count);
100 }
101
102 let noise_table = Self::build_noise_table(&counts);
103
104 Ok(Self {
105 word2idx,
106 idx2word,
107 counts,
108 noise_table,
109 total_tokens,
110 })
111 }
112
113 #[inline]
115 pub fn len(&self) -> usize {
116 self.idx2word.len()
117 }
118
119 #[inline]
121 pub fn is_empty(&self) -> bool {
122 self.idx2word.is_empty()
123 }
124
125 pub fn count(&self, word: &str) -> u64 {
138 self.word2idx
139 .get(word)
140 .map(|&i| self.counts[i])
141 .unwrap_or(0)
142 }
143
144 pub fn should_subsample(&self, idx: usize, threshold: f64, dice: f64) -> bool {
149 let freq = self.counts[idx] as f64 / self.total_tokens as f64;
150 let keep_prob = ((threshold / freq).sqrt() + threshold / freq).min(1.0);
151 dice >= keep_prob
152 }
153
154 pub fn negative_sample(&self, rng: &mut SmallRng) -> usize {
158 let idx = rng.gen_range(0..TABLE_SIZE);
159 self.noise_table[idx] as usize
160 }
161
162 fn build_noise_table(counts: &[u64]) -> Vec<u32> {
164 let powered: Vec<f64> = counts.iter().map(|&c| (c as f64).powf(0.75)).collect();
165 let total: f64 = powered.iter().sum();
166
167 let mut table = Vec::with_capacity(TABLE_SIZE);
168 let mut cumulative = 0.0_f64;
169 let mut word_idx = 0usize;
170
171 for i in 0..TABLE_SIZE {
172 let threshold = (i as f64 + 1.0) / TABLE_SIZE as f64;
173 while cumulative / total < threshold && word_idx < powered.len() - 1 {
174 cumulative += powered[word_idx];
175 word_idx += 1;
176 }
177 table.push(word_idx as u32);
178 }
179
180 table
181 }
182
183 pub fn tokenise_and_subsample(
185 &self,
186 sentence: &[String],
187 threshold: f64,
188 rng: &mut SmallRng,
189 ) -> Vec<usize> {
190 sentence
191 .iter()
192 .filter_map(|w| self.word2idx.get(w).copied())
193 .filter(|&idx| !self.should_subsample(idx, threshold, rng.gen()))
194 .collect()
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 fn make_vocab(text: &str) -> Vocabulary {
203 let tokens = vec![text.split_whitespace().map(str::to_string).collect()];
204 Vocabulary::build(&tokens, &Config::default()).unwrap()
205 }
206
207 #[test]
208 fn vocab_word_counts() {
209 let vocab = make_vocab("a a a b b c");
210 assert_eq!(vocab.count("a"), 3);
211 assert_eq!(vocab.count("b"), 2);
212 assert_eq!(vocab.count("c"), 1);
213 assert_eq!(vocab.count("z"), 0);
214 }
215
216 #[test]
217 fn vocab_sorted_by_frequency() {
218 let vocab = make_vocab("a a a b b c");
219 assert_eq!(vocab.idx2word[0], "a");
220 assert_eq!(vocab.idx2word[1], "b");
221 }
222
223 #[test]
224 fn vocab_len() {
225 let vocab = make_vocab("hello world hello");
226 assert_eq!(vocab.len(), 2);
227 }
228
229 #[test]
230 fn min_count_filters() {
231 let tokens = vec!["a a a b b c c"
232 .split_whitespace()
233 .map(str::to_string)
234 .collect()];
235 let cfg = Config {
236 min_count: 2,
237 ..Config::default()
238 };
239 let _vocab = Vocabulary::build(&tokens, &cfg).unwrap();
240 let tokens2 = vec!["a a a b".split_whitespace().map(str::to_string).collect()];
241 let result = Vocabulary::build(&tokens2, &cfg);
242 assert!(result.is_err());
243 let tokens3 = vec!["a a b b".split_whitespace().map(str::to_string).collect()];
244 let vocab3 = Vocabulary::build(&tokens3, &cfg).unwrap();
245 assert!(!vocab3.word2idx.contains_key("x"));
246 assert!(vocab3.word2idx.contains_key("a"));
247 assert!(vocab3.word2idx.contains_key("b"));
248 }
249
250 #[test]
251 fn empty_corpus_errors() {
252 let tokens: Vec<Vec<String>> = vec![vec![]];
253 let result = Vocabulary::build(&tokens, &Config::default());
254 assert!(result.is_err());
255 }
256
257 #[test]
258 fn noise_table_has_correct_size() {
259 let vocab = make_vocab("a a b c d e");
260 assert_eq!(vocab.noise_table.len(), TABLE_SIZE);
261 }
262
263 #[test]
264 fn negative_sample_in_range() {
265 let vocab = make_vocab("the cat sat on the mat");
266 let mut rng = SmallRng::seed_from_u64(0);
267 for _ in 0..1000 {
268 let idx = vocab.negative_sample(&mut rng);
269 assert!(idx < vocab.len());
270 }
271 }
272}