Skip to main content

word2vec/
model.rs

1//! Neural network weights and forward/backward pass for Skip-gram and CBOW
2//! with Negative Sampling.
3//!
4//! # Weight Matrices
5//!
6//! - `input_weights` (`W_in`): shape `[vocab_size × embedding_dim]` — the
7//!   "input" or center-word embedding matrix.
8//! - `output_weights` (`W_out`): shape `[vocab_size × embedding_dim]` — the
9//!   context/output embedding matrix used in the dot-product scoring.
10//!
11//! # Negative Sampling Loss
12//!
13//! For a positive pair (center `c`, context `o`) and `k` negatives `n_i`:
14//!
15//! `L = log σ(v_o · v_c) + Σ log σ(-v_{n_i} · v_c)`
16//!
17//! Gradients are applied in-place via SGD.
18
19use rand::rngs::SmallRng;
20use serde::{Deserialize, Serialize};
21
22use crate::config::ModelType;
23
24/// Sigmoid activation.
25#[inline(always)]
26fn sigmoid(x: f32) -> f32 {
27    1.0 / (1.0 + (-x).exp())
28}
29
30/// Xavier uniform initialisation range for a given dimension.
31#[inline]
32fn xavier_range(dim: usize) -> f32 {
33    (6.0_f32 / dim as f32).sqrt()
34}
35
36/// Core weight matrices for Word2Vec.
37///
38/// Both matrices are row-major flat `Vec<f32>` for cache efficiency.
39/// Row `i` starts at byte offset `i * embedding_dim * 4`.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Model {
42    pub vocab_size: usize,
43    pub embedding_dim: usize,
44    pub input_weights: Vec<f32>,
45    pub output_weights: Vec<f32>,
46}
47
48impl Model {
49    /// Initialise with Xavier uniform weights.
50    ///
51    /// ```rust
52    /// use word2vec::model::Model;
53    /// let m = Model::new(100, 50, 42);
54    /// assert_eq!(m.input_weights.len(), 100 * 50);
55    /// ```
56    pub fn new(vocab_size: usize, embedding_dim: usize, seed: u64) -> Self {
57        use rand::Rng;
58        use rand::SeedableRng;
59        let mut rng = SmallRng::seed_from_u64(seed);
60        let r = xavier_range(embedding_dim);
61        let n = vocab_size * embedding_dim;
62        let input_weights: Vec<f32> = (0..n).map(|_| rng.gen_range(-r..r)).collect();
63        let output_weights = vec![0.0_f32; n]; // output weights start at zero
64
65        Self {
66            vocab_size,
67            embedding_dim,
68            input_weights,
69            output_weights,
70        }
71    }
72
73    /// Get the embedding vector for word at `idx` (slice into input_weights).
74    #[inline]
75    pub fn input_vec(&self, idx: usize) -> &[f32] {
76        let start = idx * self.embedding_dim;
77        &self.input_weights[start..start + self.embedding_dim]
78    }
79
80    /// Mutable access to input embedding.
81    #[inline]
82    pub fn input_vec_mut(&mut self, idx: usize) -> &mut [f32] {
83        let start = idx * self.embedding_dim;
84        &mut self.input_weights[start..start + self.embedding_dim]
85    }
86
87    /// Mutable access to output embedding.
88    #[inline]
89    pub fn output_vec_mut(&mut self, idx: usize) -> &mut [f32] {
90        let start = idx * self.embedding_dim;
91        &mut self.output_weights[start..start + self.embedding_dim]
92    }
93
94    /// Dot product between two embedding rows.
95    fn dot(a: &[f32], b: &[f32]) -> f32 {
96        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
97    }
98
99    /// Skip-gram update: given center word, update for one context word
100    /// and `n_neg` negative samples.
101    ///
102    /// Returns the binary cross-entropy loss contribution.
103    pub fn skipgram_update(
104        &mut self,
105        center: usize,
106        context: usize,
107        negatives: &[usize],
108        lr: f32,
109    ) -> f32 {
110        let dim = self.embedding_dim;
111        let mut grad_input = vec![0.0f32; dim];
112        let mut loss = 0.0f32;
113
114        // Positive pair
115        {
116            let score = Self::dot(self.input_vec(center), self.output_vec(context));
117            let sig = sigmoid(score);
118            let err = sig - 1.0;
119            loss -= sig.ln().max(-30.0);
120
121            let center_vec: Vec<f32> = self.input_vec(center).to_vec();
122            let out_vec: Vec<f32> = self.output_vec(context).to_vec();
123
124            for i in 0..dim {
125                grad_input[i] += err * out_vec[i];
126                self.output_weights[context * dim + i] -= lr * err * center_vec[i];
127            }
128        }
129
130        // Negative pairs
131        for &neg in negatives {
132            if neg == context {
133                continue;
134            }
135            let score = Self::dot(self.input_vec(center), self.output_vec(neg));
136            let sig = sigmoid(score);
137            loss -= (1.0 - sig).ln().max(-30.0);
138
139            let center_vec: Vec<f32> = self.input_vec(center).to_vec();
140            let neg_out: Vec<f32> = self.output_vec(neg).to_vec();
141
142            for i in 0..dim {
143                grad_input[i] += sig * neg_out[i];
144                self.output_weights[neg * dim + i] -= lr * sig * center_vec[i];
145            }
146        }
147
148        // Apply gradient to center word
149        for (i, &grad) in grad_input.iter().enumerate() {
150            self.input_weights[center * dim + i] -= lr * grad;
151        }
152
153        loss
154    }
155
156    /// CBOW update: average context embeddings, predict center word.
157    pub fn cbow_update(
158        &mut self,
159        center: usize,
160        context_words: &[usize],
161        negatives: &[usize],
162        lr: f32,
163    ) -> f32 {
164        if context_words.is_empty() {
165            return 0.0;
166        }
167        let dim = self.embedding_dim;
168        let scale = 1.0 / context_words.len() as f32;
169
170        let mut ctx_avg = vec![0.0f32; dim];
171        for &cidx in context_words {
172            let v = self.input_vec(cidx);
173            for i in 0..dim {
174                ctx_avg[i] += v[i] * scale;
175            }
176        }
177
178        let mut grad_ctx = vec![0.0f32; dim];
179        let mut loss = 0.0f32;
180
181        {
182            let score: f32 = ctx_avg
183                .iter()
184                .zip(self.output_vec(center))
185                .map(|(a, b)| a * b)
186                .sum();
187            let sig = sigmoid(score);
188            let err = sig - 1.0;
189            loss -= sig.ln().max(-30.0);
190            let out_center: Vec<f32> = self.output_vec(center).to_vec();
191            for i in 0..dim {
192                grad_ctx[i] += err * out_center[i];
193                self.output_weights[center * dim + i] -= lr * err * ctx_avg[i];
194            }
195        }
196
197        for &neg in negatives {
198            if neg == center {
199                continue;
200            }
201            let score: f32 = ctx_avg
202                .iter()
203                .zip(self.output_vec(neg))
204                .map(|(a, b)| a * b)
205                .sum();
206            let sig = sigmoid(score);
207            loss -= (1.0 - sig).ln().max(-30.0);
208            let out_neg: Vec<f32> = self.output_vec(neg).to_vec();
209            for i in 0..dim {
210                grad_ctx[i] += sig * out_neg[i];
211                self.output_weights[neg * dim + i] -= lr * sig * ctx_avg[i];
212            }
213        }
214
215        for &cidx in context_words {
216            for (i, &grad) in grad_ctx.iter().enumerate() {
217                self.input_weights[cidx * dim + i] -= lr * grad * scale;
218            }
219        }
220
221        loss
222    }
223
224    /// Convenience: output vector slice.
225    fn output_vec(&self, idx: usize) -> &[f32] {
226        let start = idx * self.embedding_dim;
227        &self.output_weights[start..start + self.embedding_dim]
228    }
229
230    /// Run a full Skip-gram or CBOW update, dispatching on model type.
231    pub fn update(
232        &mut self,
233        model_type: ModelType,
234        center: usize,
235        context_window: &[usize],
236        negatives: &[usize],
237        lr: f32,
238    ) -> f32 {
239        match model_type {
240            ModelType::SkipGram => {
241                let mut total_loss = 0.0;
242                for &ctx in context_window {
243                    total_loss += self.skipgram_update(center, ctx, negatives, lr);
244                }
245                total_loss
246            }
247            ModelType::Cbow => self.cbow_update(center, context_window, negatives, lr),
248        }
249    }
250}
251
252/// Generate training examples from a tokenised sentence.
253///
254/// Returns `(center_idx, context_indices)` pairs using a dynamic window
255/// (window size drawn uniformly from `[1, window_size]`).
256pub fn sentence_to_pairs(
257    tokens: &[usize],
258    window_size: usize,
259    rng: &mut SmallRng,
260) -> Vec<(usize, Vec<usize>)> {
261    use rand::Rng;
262    let mut pairs = Vec::new();
263
264    for (pos, &center) in tokens.iter().enumerate() {
265        // Dynamic window (Mikolov et al. 2013 trick)
266        let win: usize = rng.gen_range(1..=window_size);
267        let start = pos.saturating_sub(win);
268        let end = (pos + win + 1).min(tokens.len());
269
270        let context: Vec<usize> = (start..end)
271            .filter(|&i| i != pos)
272            .map(|i| tokens[i])
273            .collect();
274
275        if !context.is_empty() {
276            pairs.push((center, context));
277        }
278    }
279
280    pairs
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use rand::SeedableRng;
287
288    #[test]
289    fn model_weight_dimensions() {
290        let m = Model::new(50, 10, 0);
291        assert_eq!(m.input_weights.len(), 50 * 10);
292        assert_eq!(m.output_weights.len(), 50 * 10);
293    }
294
295    #[test]
296    fn skipgram_update_returns_finite_loss() {
297        let mut m = Model::new(10, 8, 0);
298        let loss = m.skipgram_update(0, 1, &[2, 3, 4], 0.025);
299        assert!(loss.is_finite(), "loss={loss}");
300        assert!(loss >= 0.0, "loss should be non-negative");
301    }
302
303    #[test]
304    fn cbow_update_returns_finite_loss() {
305        let mut m = Model::new(10, 8, 0);
306        let loss = m.cbow_update(0, &[1, 2, 3], &[4, 5], 0.025);
307        assert!(loss.is_finite());
308    }
309
310    #[test]
311    fn loss_decreases_after_repeated_updates() {
312        let mut m = Model::new(10, 16, 99);
313        let first = m.skipgram_update(0, 1, &[2, 3], 0.1);
314        let mut last = first;
315        for _ in 0..200 {
316            last = m.skipgram_update(0, 1, &[2, 3], 0.01);
317        }
318        assert!(
319            last < first,
320            "loss should decrease with repetition: {first} -> {last}"
321        );
322    }
323
324    #[test]
325    fn sentence_to_pairs_respects_window() {
326        let mut rng = SmallRng::seed_from_u64(0);
327        let tokens = vec![0, 1, 2, 3, 4];
328        let pairs = sentence_to_pairs(&tokens, 2, &mut rng);
329        assert!(!pairs.is_empty());
330        for (_, ctx) in &pairs {
331            assert!(!ctx.is_empty());
332        }
333    }
334
335    #[test]
336    fn sigmoid_bounds() {
337        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
338        assert!(sigmoid(100.0) > 0.9999);
339        assert!(sigmoid(-100.0) < 0.0001);
340    }
341}