1use serde::{Deserialize, Serialize};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28pub enum ModelType {
29 SkipGram,
31 Cbow,
33}
34
35impl std::fmt::Display for ModelType {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 match self {
38 ModelType::SkipGram => write!(f, "Skip-gram"),
39 ModelType::Cbow => write!(f, "CBOW"),
40 }
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct Config {
49 pub embedding_dim: usize,
51 pub window_size: usize,
53 pub negative_samples: usize,
55 pub epochs: usize,
57 pub learning_rate: f32,
59 pub min_learning_rate: f32,
61 pub min_count: usize,
63 pub subsample_threshold: f64,
65 pub model: ModelType,
67 pub num_threads: usize,
69 pub seed: u64,
71}
72
73impl Default for Config {
74 fn default() -> Self {
76 Self {
77 embedding_dim: 100,
78 window_size: 5,
79 negative_samples: 5,
80 epochs: 5,
81 learning_rate: 0.025,
82 min_learning_rate: 0.0001,
83 min_count: 1,
84 subsample_threshold: 1e-3,
85 model: ModelType::SkipGram,
86 num_threads: 0,
87 seed: 42,
88 }
89 }
90}
91
92impl Config {
93 pub fn validate(&self) -> Result<(), String> {
101 if self.embedding_dim == 0 {
102 return Err(format!(
103 "embedding_dim must be > 0, got {}",
104 self.embedding_dim
105 ));
106 }
107 if self.window_size == 0 {
108 return Err(format!("window_size must be > 0, got {}", self.window_size));
109 }
110 if self.negative_samples == 0 {
111 return Err(format!(
112 "negative_samples must be > 0, got {}",
113 self.negative_samples
114 ));
115 }
116 if self.epochs == 0 {
117 return Err("epochs must be > 0".to_string());
118 }
119 if self.learning_rate <= 0.0 {
120 return Err(format!(
121 "learning_rate must be > 0, got {}",
122 self.learning_rate
123 ));
124 }
125 Ok(())
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn default_config_is_valid() {
135 assert!(Config::default().validate().is_ok());
136 }
137
138 #[test]
139 fn zero_dim_is_invalid() {
140 let cfg = Config {
141 embedding_dim: 0,
142 ..Config::default()
143 };
144 assert!(cfg.validate().is_err());
145 }
146
147 #[test]
148 fn zero_window_is_invalid() {
149 let cfg = Config {
150 window_size: 0,
151 ..Config::default()
152 };
153 assert!(cfg.validate().is_err());
154 }
155
156 #[test]
157 fn model_type_display() {
158 assert_eq!(ModelType::SkipGram.to_string(), "Skip-gram");
159 assert_eq!(ModelType::Cbow.to_string(), "CBOW");
160 }
161}