Skip to main content

word2vec/
config.rs

1//! Hyperparameter configuration for Word2Vec training.
2//!
3//! # Example
4//!
5//! ```rust
6//! use word2vec::{Config, ModelType};
7//!
8//! let cfg = Config {
9//!     embedding_dim: 128,
10//!     window_size: 5,
11//!     negative_samples: 10,
12//!     epochs: 10,
13//!     learning_rate: 0.025,
14//!     min_learning_rate: 0.0001,
15//!     min_count: 5,
16//!     subsample_threshold: 1e-3,
17//!     model: ModelType::SkipGram,
18//!     num_threads: 4,
19//!     seed: 42,
20//! };
21//!
22//! assert_eq!(cfg.embedding_dim, 128);
23//! ```
24
25use serde::{Deserialize, Serialize};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28pub enum ModelType {
29    /// Predict context words from the center word. Better for rare words.
30    SkipGram,
31    /// Predict center word from context words. Faster; better for frequent words.
32    Cbow,
33}
34
35impl std::fmt::Display for ModelType {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            ModelType::SkipGram => write!(f, "Skip-gram"),
39            ModelType::Cbow => write!(f, "CBOW"),
40        }
41    }
42}
43
44/// Full training configuration.
45///
46/// All fields have sensible defaults via [`Config::default()`].
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct Config {
49    /// Dimensionality of word embeddings (typical: 50–300).
50    pub embedding_dim: usize,
51    /// Half-width of the context window (words on each side of target).
52    pub window_size: usize,
53    /// Number of negative samples per positive pair (typical: 5–20).
54    pub negative_samples: usize,
55    /// Full passes over the corpus.
56    pub epochs: usize,
57    /// Initial learning rate (decays linearly to `min_learning_rate`).
58    pub learning_rate: f32,
59    /// Floor for decayed learning rate.
60    pub min_learning_rate: f32,
61    /// Discard words appearing fewer than this many times.
62    pub min_count: usize,
63    /// Frequent-word subsampling threshold (Mikolov et al. suggest 1e-3 – 1e-5).
64    pub subsample_threshold: f64,
65    /// Architecture choice.
66    pub model: ModelType,
67    /// Rayon thread count (0 = use all logical cores).
68    pub num_threads: usize,
69    /// RNG seed for reproducibility.
70    pub seed: u64,
71}
72
73impl Default for Config {
74    /// Sensible defaults matching the original word2vec paper.
75    fn default() -> Self {
76        Self {
77            embedding_dim: 100,
78            window_size: 5,
79            negative_samples: 5,
80            epochs: 5,
81            learning_rate: 0.025,
82            min_learning_rate: 0.0001,
83            min_count: 1,
84            subsample_threshold: 1e-3,
85            model: ModelType::SkipGram,
86            num_threads: 0,
87            seed: 42,
88        }
89    }
90}
91
92impl Config {
93    /// Validate configuration, returning an error message if invalid.
94    ///
95    /// ```rust
96    /// use word2vec::Config;
97    /// let cfg = Config { embedding_dim: 0, ..Config::default() };
98    /// assert!(cfg.validate().is_err());
99    /// ```
100    pub fn validate(&self) -> Result<(), String> {
101        if self.embedding_dim == 0 {
102            return Err(format!(
103                "embedding_dim must be > 0, got {}",
104                self.embedding_dim
105            ));
106        }
107        if self.window_size == 0 {
108            return Err(format!("window_size must be > 0, got {}", self.window_size));
109        }
110        if self.negative_samples == 0 {
111            return Err(format!(
112                "negative_samples must be > 0, got {}",
113                self.negative_samples
114            ));
115        }
116        if self.epochs == 0 {
117            return Err("epochs must be > 0".to_string());
118        }
119        if self.learning_rate <= 0.0 {
120            return Err(format!(
121                "learning_rate must be > 0, got {}",
122                self.learning_rate
123            ));
124        }
125        Ok(())
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn default_config_is_valid() {
135        assert!(Config::default().validate().is_ok());
136    }
137
138    #[test]
139    fn zero_dim_is_invalid() {
140        let cfg = Config {
141            embedding_dim: 0,
142            ..Config::default()
143        };
144        assert!(cfg.validate().is_err());
145    }
146
147    #[test]
148    fn zero_window_is_invalid() {
149        let cfg = Config {
150            window_size: 0,
151            ..Config::default()
152        };
153        assert!(cfg.validate().is_err());
154    }
155
156    #[test]
157    fn model_type_display() {
158        assert_eq!(ModelType::SkipGram.to_string(), "Skip-gram");
159        assert_eq!(ModelType::Cbow.to_string(), "CBOW");
160    }
161}