diff --git a/src/cosine_annealing.rs b/src/cosine_annealing.rs new file mode 100644 index 0000000..c130d4a --- /dev/null +++ b/src/cosine_annealing.rs @@ -0,0 +1,92 @@ +use burn::{lr_scheduler::LRScheduler, LearningRate}; +#[derive(Clone, Debug)] +pub struct CosineAnnealingLR { + t_max: f64, + eta_min: f64, + init_lr: LearningRate, + step_count: f64, + current_lr: LearningRate, +} + +impl CosineAnnealingLR { + pub fn init(t_max: f64, init_lr: LearningRate) -> CosineAnnealingLR { + CosineAnnealingLR { + t_max, + eta_min: 0.0, + init_lr, + step_count: 0.0, + current_lr: init_lr, + } + } +} + +impl LRScheduler for CosineAnnealingLR { + type Record = usize; + + fn step(&mut self) -> LearningRate { + self.step_count += 1.0; + use std::f64::consts::PI; + fn cosine_annealing_lr( + init_lr: LearningRate, + lr: LearningRate, + step_count: f64, + t_max: f64, + eta_min: f64, + ) -> LearningRate { + let cosine_arg = PI * step_count / t_max; + if (step_count - 1.0 - t_max) % (2.0 * t_max) == 0.0 { + (init_lr - eta_min) * (1.0 - f64::cos(PI / t_max)) / 2.0 + } else { + (1.0 + f64::cos(cosine_arg)) / (1.0 + f64::cos(PI * (step_count - 1.0) / t_max)) + * (lr - eta_min) + + eta_min + } + } + self.current_lr = cosine_annealing_lr( + self.init_lr, + self.current_lr, + self.step_count, + self.t_max, + self.eta_min, + ); + self.current_lr + } + + fn to_record(&self) -> Self::Record { + self.step_count as usize + } + + fn load_record(mut self, record: Self::Record) -> Self { + self.step_count = record as LearningRate; + self + } +} + +#[test] +fn test_lr_scheduler() { + let mut lr_scheduler = CosineAnnealingLR::init(100000.0, 1.0e-1); + let mut lrs = vec![]; + for i in 0..200000 { + if i % 20000 == 0 { + lrs.push(lr_scheduler.current_lr); + } + lr_scheduler.step(); + } + lrs.push(lr_scheduler.current_lr); + assert_eq!( + lrs, + vec![ + 0.1, + 0.09045084971874785, + 0.06545084971874875, + 0.034549150281253875, + 0.009549150281252989, + 0.0, + 0.009549150281252692, + 0.03454915028125239, + 0.06545084971874746, + 0.09045084971874952, + 0.10000000000000353 + ] + ) +} diff --git a/src/dataset.rs b/src/dataset.rs index c5d2f22..1ec77c9 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -139,6 +139,14 @@ impl FSRSDataset { Self::new() } + pub fn len(&self) -> usize { + self.dataset.len() + } + + pub fn is_empty(&self) -> bool { + self.dataset.is_empty() + } + fn new() -> Self { let dataset = InMemDataset::::new(anki_to_fsrs()); Self { dataset } diff --git a/src/lib.rs b/src/lib.rs index 9b4d539..f9319eb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod convertor; +mod cosine_annealing; pub mod dataset; pub mod model; pub mod training; diff --git a/src/training.rs b/src/training.rs index 8fd6d00..dd9b8bf 100644 --- a/src/training.rs +++ b/src/training.rs @@ -1,3 +1,4 @@ +use crate::cosine_annealing::CosineAnnealingLR; use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset}; use crate::model::{Model, ModelConfig}; use crate::weight_clipper::weight_clipper; @@ -127,6 +128,11 @@ pub fn train>( .num_workers(config.num_workers) .build(FSRSDataset::test()); + let lr_scheduler = CosineAnnealingLR::init( + (FSRSDataset::train().len() * config.num_epochs) as f64, + config.learning_rate, + ); + let learner = LearnerBuilder::new(artifact_dir) // .metric_train_plot(AccuracyMetric::new()) // .metric_valid_plot(AccuracyMetric::new()) @@ -138,7 +144,7 @@ pub fn train>( .build( config.model.init::(), config.optimizer.init(), - config.learning_rate, + lr_scheduler, ); let mut model_trained = learner.fit(dataloader_train, dataloader_test);