1use rand::rngs::SmallRng;
20use serde::{Deserialize, Serialize};
21
22use crate::config::ModelType;
23
24#[inline(always)]
26fn sigmoid(x: f32) -> f32 {
27 1.0 / (1.0 + (-x).exp())
28}
29
30#[inline]
32fn xavier_range(dim: usize) -> f32 {
33 (6.0_f32 / dim as f32).sqrt()
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Model {
42 pub vocab_size: usize,
43 pub embedding_dim: usize,
44 pub input_weights: Vec<f32>,
45 pub output_weights: Vec<f32>,
46}
47
48impl Model {
49 pub fn new(vocab_size: usize, embedding_dim: usize, seed: u64) -> Self {
57 use rand::Rng;
58 use rand::SeedableRng;
59 let mut rng = SmallRng::seed_from_u64(seed);
60 let r = xavier_range(embedding_dim);
61 let n = vocab_size * embedding_dim;
62 let input_weights: Vec<f32> = (0..n).map(|_| rng.gen_range(-r..r)).collect();
63 let output_weights = vec![0.0_f32; n]; Self {
66 vocab_size,
67 embedding_dim,
68 input_weights,
69 output_weights,
70 }
71 }
72
73 #[inline]
75 pub fn input_vec(&self, idx: usize) -> &[f32] {
76 let start = idx * self.embedding_dim;
77 &self.input_weights[start..start + self.embedding_dim]
78 }
79
80 #[inline]
82 pub fn input_vec_mut(&mut self, idx: usize) -> &mut [f32] {
83 let start = idx * self.embedding_dim;
84 &mut self.input_weights[start..start + self.embedding_dim]
85 }
86
87 #[inline]
89 pub fn output_vec_mut(&mut self, idx: usize) -> &mut [f32] {
90 let start = idx * self.embedding_dim;
91 &mut self.output_weights[start..start + self.embedding_dim]
92 }
93
94 fn dot(a: &[f32], b: &[f32]) -> f32 {
96 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
97 }
98
99 pub fn skipgram_update(
104 &mut self,
105 center: usize,
106 context: usize,
107 negatives: &[usize],
108 lr: f32,
109 ) -> f32 {
110 let dim = self.embedding_dim;
111 let mut grad_input = vec![0.0f32; dim];
112 let mut loss = 0.0f32;
113
114 {
116 let score = Self::dot(self.input_vec(center), self.output_vec(context));
117 let sig = sigmoid(score);
118 let err = sig - 1.0;
119 loss -= sig.ln().max(-30.0);
120
121 let center_vec: Vec<f32> = self.input_vec(center).to_vec();
122 let out_vec: Vec<f32> = self.output_vec(context).to_vec();
123
124 for i in 0..dim {
125 grad_input[i] += err * out_vec[i];
126 self.output_weights[context * dim + i] -= lr * err * center_vec[i];
127 }
128 }
129
130 for &neg in negatives {
132 if neg == context {
133 continue;
134 }
135 let score = Self::dot(self.input_vec(center), self.output_vec(neg));
136 let sig = sigmoid(score);
137 loss -= (1.0 - sig).ln().max(-30.0);
138
139 let center_vec: Vec<f32> = self.input_vec(center).to_vec();
140 let neg_out: Vec<f32> = self.output_vec(neg).to_vec();
141
142 for i in 0..dim {
143 grad_input[i] += sig * neg_out[i];
144 self.output_weights[neg * dim + i] -= lr * sig * center_vec[i];
145 }
146 }
147
148 for (i, &grad) in grad_input.iter().enumerate() {
150 self.input_weights[center * dim + i] -= lr * grad;
151 }
152
153 loss
154 }
155
156 pub fn cbow_update(
158 &mut self,
159 center: usize,
160 context_words: &[usize],
161 negatives: &[usize],
162 lr: f32,
163 ) -> f32 {
164 if context_words.is_empty() {
165 return 0.0;
166 }
167 let dim = self.embedding_dim;
168 let scale = 1.0 / context_words.len() as f32;
169
170 let mut ctx_avg = vec![0.0f32; dim];
171 for &cidx in context_words {
172 let v = self.input_vec(cidx);
173 for i in 0..dim {
174 ctx_avg[i] += v[i] * scale;
175 }
176 }
177
178 let mut grad_ctx = vec![0.0f32; dim];
179 let mut loss = 0.0f32;
180
181 {
182 let score: f32 = ctx_avg
183 .iter()
184 .zip(self.output_vec(center))
185 .map(|(a, b)| a * b)
186 .sum();
187 let sig = sigmoid(score);
188 let err = sig - 1.0;
189 loss -= sig.ln().max(-30.0);
190 let out_center: Vec<f32> = self.output_vec(center).to_vec();
191 for i in 0..dim {
192 grad_ctx[i] += err * out_center[i];
193 self.output_weights[center * dim + i] -= lr * err * ctx_avg[i];
194 }
195 }
196
197 for &neg in negatives {
198 if neg == center {
199 continue;
200 }
201 let score: f32 = ctx_avg
202 .iter()
203 .zip(self.output_vec(neg))
204 .map(|(a, b)| a * b)
205 .sum();
206 let sig = sigmoid(score);
207 loss -= (1.0 - sig).ln().max(-30.0);
208 let out_neg: Vec<f32> = self.output_vec(neg).to_vec();
209 for i in 0..dim {
210 grad_ctx[i] += sig * out_neg[i];
211 self.output_weights[neg * dim + i] -= lr * sig * ctx_avg[i];
212 }
213 }
214
215 for &cidx in context_words {
216 for (i, &grad) in grad_ctx.iter().enumerate() {
217 self.input_weights[cidx * dim + i] -= lr * grad * scale;
218 }
219 }
220
221 loss
222 }
223
224 fn output_vec(&self, idx: usize) -> &[f32] {
226 let start = idx * self.embedding_dim;
227 &self.output_weights[start..start + self.embedding_dim]
228 }
229
230 pub fn update(
232 &mut self,
233 model_type: ModelType,
234 center: usize,
235 context_window: &[usize],
236 negatives: &[usize],
237 lr: f32,
238 ) -> f32 {
239 match model_type {
240 ModelType::SkipGram => {
241 let mut total_loss = 0.0;
242 for &ctx in context_window {
243 total_loss += self.skipgram_update(center, ctx, negatives, lr);
244 }
245 total_loss
246 }
247 ModelType::Cbow => self.cbow_update(center, context_window, negatives, lr),
248 }
249 }
250}
251
252pub fn sentence_to_pairs(
257 tokens: &[usize],
258 window_size: usize,
259 rng: &mut SmallRng,
260) -> Vec<(usize, Vec<usize>)> {
261 use rand::Rng;
262 let mut pairs = Vec::new();
263
264 for (pos, ¢er) in tokens.iter().enumerate() {
265 let win: usize = rng.gen_range(1..=window_size);
267 let start = pos.saturating_sub(win);
268 let end = (pos + win + 1).min(tokens.len());
269
270 let context: Vec<usize> = (start..end)
271 .filter(|&i| i != pos)
272 .map(|i| tokens[i])
273 .collect();
274
275 if !context.is_empty() {
276 pairs.push((center, context));
277 }
278 }
279
280 pairs
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use rand::SeedableRng;
287
288 #[test]
289 fn model_weight_dimensions() {
290 let m = Model::new(50, 10, 0);
291 assert_eq!(m.input_weights.len(), 50 * 10);
292 assert_eq!(m.output_weights.len(), 50 * 10);
293 }
294
295 #[test]
296 fn skipgram_update_returns_finite_loss() {
297 let mut m = Model::new(10, 8, 0);
298 let loss = m.skipgram_update(0, 1, &[2, 3, 4], 0.025);
299 assert!(loss.is_finite(), "loss={loss}");
300 assert!(loss >= 0.0, "loss should be non-negative");
301 }
302
303 #[test]
304 fn cbow_update_returns_finite_loss() {
305 let mut m = Model::new(10, 8, 0);
306 let loss = m.cbow_update(0, &[1, 2, 3], &[4, 5], 0.025);
307 assert!(loss.is_finite());
308 }
309
310 #[test]
311 fn loss_decreases_after_repeated_updates() {
312 let mut m = Model::new(10, 16, 99);
313 let first = m.skipgram_update(0, 1, &[2, 3], 0.1);
314 let mut last = first;
315 for _ in 0..200 {
316 last = m.skipgram_update(0, 1, &[2, 3], 0.01);
317 }
318 assert!(
319 last < first,
320 "loss should decrease with repetition: {first} -> {last}"
321 );
322 }
323
324 #[test]
325 fn sentence_to_pairs_respects_window() {
326 let mut rng = SmallRng::seed_from_u64(0);
327 let tokens = vec![0, 1, 2, 3, 4];
328 let pairs = sentence_to_pairs(&tokens, 2, &mut rng);
329 assert!(!pairs.is_empty());
330 for (_, ctx) in &pairs {
331 assert!(!ctx.is_empty());
332 }
333 }
334
335 #[test]
336 fn sigmoid_bounds() {
337 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
338 assert!(sigmoid(100.0) > 0.9999);
339 assert!(sigmoid(-100.0) < 0.0001);
340 }
341}