Skip to main content

word2vec/
vocab.rs

1//! Vocabulary construction with frequency counting, subsampling, and
2//! the unigram noise distribution table for negative sampling.
3//!
4//! # Subsampling
5//!
6//! Frequent words are stochastically discarded using Mikolov's formula:
7//!
8//! `P(discard) = 1 - sqrt(t / f)`
9//!
10//! where `t` is [`Config::subsample_threshold`] and `f` is the word's
11//! relative corpus frequency.
12//!
13//! # Negative Sampling Table
14//!
15//! A flat array of `TABLE_SIZE` word indices drawn from `freq^0.75`,
16//! which downweights very frequent words as negatives.
17
18use rand::rngs::SmallRng;
19use rand::Rng;
20#[allow(unused_imports)]
21use rand::SeedableRng;
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24
25use crate::config::Config;
26use crate::error::{Result, Word2VecError};
27
28/// Size of the unigram noise table.
29const TABLE_SIZE: usize = 1_000_000;
30
31/// Maps tokens ↔ integer indices and stores frequency statistics.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Vocabulary {
34    /// word → index
35    pub word2idx: HashMap<String, usize>,
36    /// index → word
37    pub idx2word: Vec<String>,
38    /// Raw corpus frequency per index
39    pub counts: Vec<u64>,
40    /// Flat noise table for O(1) negative sampling
41    pub noise_table: Vec<u32>,
42    /// Total token count (after min_count filter)
43    pub total_tokens: u64,
44}
45
46impl Vocabulary {
47    /// Build vocabulary from a tokenised corpus.
48    ///
49    /// Steps:
50    /// 1. Count every token
51    /// 2. Drop tokens below `config.min_count`
52    /// 3. Sort by descending frequency (stable index order)
53    /// 4. Build unigram noise table
54    ///
55    /// ```rust
56    /// use word2vec::{Config, vocab::Vocabulary};
57    ///
58    /// let corpus = vec!["the cat sat on the mat".to_string()];
59    /// let tokens: Vec<Vec<String>> = corpus.iter()
60    ///     .map(|s| s.split_whitespace().map(str::to_string).collect())
61    ///     .collect();
62    ///
63    /// let vocab = Vocabulary::build(&tokens, &Config::default()).unwrap();
64    /// assert!(vocab.word2idx.contains_key("the"));
65    /// assert_eq!(vocab.count("the"), 2);
66    /// ```
67    pub fn build(sentences: &[Vec<String>], config: &Config) -> Result<Self> {
68        let mut raw_counts: HashMap<String, u64> = HashMap::new();
69        for sentence in sentences {
70            for token in sentence {
71                *raw_counts.entry(token.clone()).or_insert(0) += 1;
72            }
73        }
74
75        // Apply min_count filter
76        let mut filtered: Vec<(String, u64)> = raw_counts
77            .into_iter()
78            .filter(|(_, c)| *c >= config.min_count as u64)
79            .collect();
80
81        if filtered.is_empty() {
82            return Err(Word2VecError::EmptyVocabulary);
83        }
84        if filtered.len() < 2 {
85            return Err(Word2VecError::CorpusTooSmall(filtered.len()));
86        }
87
88        // Stable sort: descending frequency, then alphabetical for tie-breaking
89        filtered.sort_unstable_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
90
91        let total_tokens: u64 = filtered.iter().map(|(_, c)| c).sum();
92        let mut word2idx = HashMap::with_capacity(filtered.len());
93        let mut idx2word = Vec::with_capacity(filtered.len());
94        let mut counts = Vec::with_capacity(filtered.len());
95
96        for (idx, (word, count)) in filtered.into_iter().enumerate() {
97            word2idx.insert(word.clone(), idx);
98            idx2word.push(word);
99            counts.push(count);
100        }
101
102        let noise_table = Self::build_noise_table(&counts);
103
104        Ok(Self {
105            word2idx,
106            idx2word,
107            counts,
108            noise_table,
109            total_tokens,
110        })
111    }
112
113    /// Number of unique tokens in vocabulary.
114    #[inline]
115    pub fn len(&self) -> usize {
116        self.idx2word.len()
117    }
118
119    /// Returns `true` if the vocabulary contains no words.
120    #[inline]
121    pub fn is_empty(&self) -> bool {
122        self.idx2word.is_empty()
123    }
124
125    /// Frequency of a word (0 if not in vocab).
126    ///
127    /// ```rust
128    /// use word2vec::{Config, vocab::Vocabulary};
129    /// let corpus = vec!["a a b".to_string()];
130    /// let tokens: Vec<Vec<String>> = corpus.iter()
131    ///     .map(|s| s.split_whitespace().map(str::to_string).collect())
132    ///     .collect();
133    /// let vocab = Vocabulary::build(&tokens, &Config::default()).unwrap();
134    /// assert_eq!(vocab.count("a"), 2);
135    /// assert_eq!(vocab.count("z"), 0);
136    /// ```
137    pub fn count(&self, word: &str) -> u64 {
138        self.word2idx
139            .get(word)
140            .map(|&i| self.counts[i])
141            .unwrap_or(0)
142    }
143
144    /// Returns `true` if this word should be subsampled (discarded) given
145    /// a uniformly random `dice` in [0, 1).
146    ///
147    /// Uses Mikolov's formula: `P(keep) = min(1, sqrt(t/f) + t/f)`.
148    pub fn should_subsample(&self, idx: usize, threshold: f64, dice: f64) -> bool {
149        let freq = self.counts[idx] as f64 / self.total_tokens as f64;
150        let keep_prob = ((threshold / freq).sqrt() + threshold / freq).min(1.0);
151        dice >= keep_prob
152    }
153
154    /// Draw a negative sample index from the noise distribution.
155    ///
156    /// Uses the precomputed unigram table for O(1) lookup.
157    pub fn negative_sample(&self, rng: &mut SmallRng) -> usize {
158        let idx = rng.gen_range(0..TABLE_SIZE);
159        self.noise_table[idx] as usize
160    }
161
162    /// Build flat unigram noise table from `freq^0.75`.
163    fn build_noise_table(counts: &[u64]) -> Vec<u32> {
164        let powered: Vec<f64> = counts.iter().map(|&c| (c as f64).powf(0.75)).collect();
165        let total: f64 = powered.iter().sum();
166
167        let mut table = Vec::with_capacity(TABLE_SIZE);
168        let mut cumulative = 0.0_f64;
169        let mut word_idx = 0usize;
170
171        for i in 0..TABLE_SIZE {
172            let threshold = (i as f64 + 1.0) / TABLE_SIZE as f64;
173            while cumulative / total < threshold && word_idx < powered.len() - 1 {
174                cumulative += powered[word_idx];
175                word_idx += 1;
176            }
177            table.push(word_idx as u32);
178        }
179
180        table
181    }
182
183    /// Tokenise and subsample a sentence, returning word indices.
184    pub fn tokenise_and_subsample(
185        &self,
186        sentence: &[String],
187        threshold: f64,
188        rng: &mut SmallRng,
189    ) -> Vec<usize> {
190        sentence
191            .iter()
192            .filter_map(|w| self.word2idx.get(w).copied())
193            .filter(|&idx| !self.should_subsample(idx, threshold, rng.gen()))
194            .collect()
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    fn make_vocab(text: &str) -> Vocabulary {
203        let tokens = vec![text.split_whitespace().map(str::to_string).collect()];
204        Vocabulary::build(&tokens, &Config::default()).unwrap()
205    }
206
207    #[test]
208    fn vocab_word_counts() {
209        let vocab = make_vocab("a a a b b c");
210        assert_eq!(vocab.count("a"), 3);
211        assert_eq!(vocab.count("b"), 2);
212        assert_eq!(vocab.count("c"), 1);
213        assert_eq!(vocab.count("z"), 0);
214    }
215
216    #[test]
217    fn vocab_sorted_by_frequency() {
218        let vocab = make_vocab("a a a b b c");
219        assert_eq!(vocab.idx2word[0], "a");
220        assert_eq!(vocab.idx2word[1], "b");
221    }
222
223    #[test]
224    fn vocab_len() {
225        let vocab = make_vocab("hello world hello");
226        assert_eq!(vocab.len(), 2);
227    }
228
229    #[test]
230    fn min_count_filters() {
231        let tokens = vec!["a a a b b c c"
232            .split_whitespace()
233            .map(str::to_string)
234            .collect()];
235        let cfg = Config {
236            min_count: 2,
237            ..Config::default()
238        };
239        let _vocab = Vocabulary::build(&tokens, &cfg).unwrap();
240        let tokens2 = vec!["a a a b".split_whitespace().map(str::to_string).collect()];
241        let result = Vocabulary::build(&tokens2, &cfg);
242        assert!(result.is_err());
243        let tokens3 = vec!["a a b b".split_whitespace().map(str::to_string).collect()];
244        let vocab3 = Vocabulary::build(&tokens3, &cfg).unwrap();
245        assert!(!vocab3.word2idx.contains_key("x"));
246        assert!(vocab3.word2idx.contains_key("a"));
247        assert!(vocab3.word2idx.contains_key("b"));
248    }
249
250    #[test]
251    fn empty_corpus_errors() {
252        let tokens: Vec<Vec<String>> = vec![vec![]];
253        let result = Vocabulary::build(&tokens, &Config::default());
254        assert!(result.is_err());
255    }
256
257    #[test]
258    fn noise_table_has_correct_size() {
259        let vocab = make_vocab("a a b c d e");
260        assert_eq!(vocab.noise_table.len(), TABLE_SIZE);
261    }
262
263    #[test]
264    fn negative_sample_in_range() {
265        let vocab = make_vocab("the cat sat on the mat");
266        let mut rng = SmallRng::seed_from_u64(0);
267        for _ in 0..1000 {
268            let idx = vocab.negative_sample(&mut rng);
269            assert!(idx < vocab.len());
270        }
271    }
272}