diff --git a/.github/workflows/check.sh b/.github/workflows/check.sh index f157768..2dfe298 100755 --- a/.github/workflows/check.sh +++ b/.github/workflows/check.sh @@ -1,7 +1,11 @@ #!/bin/bash +set -e + cargo fmt --check || ( echo echo "Please run 'cargo fmt' to format the code." exit 1 ) + +cargo clippy -- -Dwarnings diff --git a/src/convertor.rs b/src/convertor.rs index 817c701..6d60d0c 100644 --- a/src/convertor.rs +++ b/src/convertor.rs @@ -9,12 +9,8 @@ use crate::dataset::{FSRSItem, Review}; struct RevlogEntry { id: i64, cid: i64, - usn: i64, button_chosen: i32, - interval: i64, - last_interval: i64, ease_factor: i64, - taken_millis: i64, review_kind: i64, delta_t: i32, i: usize, @@ -26,13 +22,9 @@ fn row_to_revlog_entry(row: &Row) -> Result { Ok(RevlogEntry { id: row.get(0)?, cid: row.get(1)?, - usn: row.get(2)?, - button_chosen: row.get(3)?, - interval: row.get(4)?, - last_interval: row.get(5)?, - ease_factor: row.get(6)?, - taken_millis: row.get(7).unwrap_or_default(), - review_kind: row.get(8).unwrap_or_default(), + button_chosen: row.get(2)?, + ease_factor: row.get(3)?, + review_kind: row.get(4).unwrap_or_default(), delta_t: 0, i: 0, r_history: vec![], @@ -66,7 +58,7 @@ fn read_collection() -> Vec { let current_timestamp = Utc::now().timestamp() * 1000; let query = format!( - "SELECT * + "SELECT id, cid, ease, factor, type FROM revlog WHERE (type != 4 OR ivl <= 0) AND id < {} @@ -100,7 +92,7 @@ fn group_by_cid(revlogs: Vec) -> Vec> { .push(revlog); } - grouped.into_iter().map(|(_, v)| v).collect() + grouped.into_values().collect() } fn convert_to_date(timestamp: i64, next_day_starts_at: i64, timezone: Tz) -> chrono::NaiveDate { @@ -232,8 +224,8 @@ pub fn anki_to_fsrs() -> Vec { .collect(); let filtered_revlogs_per_card = remove_non_learning_first(extracted_revlogs_per_card); - let fsrs_items = convert_to_fsrs_items(filtered_revlogs_per_card); - fsrs_items + + convert_to_fsrs_items(filtered_revlogs_per_card) } #[test] diff --git a/src/dataset.rs b/src/dataset.rs index c577b90..9fd1793 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -142,7 +142,7 @@ fn test_from_json() { use burn_ndarray::NdArrayDevice; let device = NdArrayDevice::Cpu; type Backend = NdArrayBackend; - let batcher = FSRSBatcher::::new(device.clone()); + let batcher = FSRSBatcher::::new(device); let dataloader = DataLoaderBuilder::new(batcher) .batch_size(1) .shuffle(42) @@ -167,7 +167,7 @@ fn test_from_anki() { let device = NdArrayDevice::Cpu; use burn_ndarray::NdArrayBackend; type Backend = NdArrayBackend; - let batcher = FSRSBatcher::::new(device.clone()); + let batcher = FSRSBatcher::::new(device); use burn::data::dataloader::DataLoaderBuilder; let dataloader = DataLoaderBuilder::new(batcher) .batch_size(1) diff --git a/src/model.rs b/src/model.rs index cf09454..0d9660d 100644 --- a/src/model.rs +++ b/src/model.rs @@ -10,6 +10,7 @@ pub struct Model { } impl> Model { + #[allow(clippy::new_without_default)] pub fn new() -> Self { Self { w: Param::from(Tensor::from_floats([ @@ -27,8 +28,7 @@ impl> Model { } pub fn power_forgetting_curve(&self, t: Tensor, s: Tensor) -> Tensor { - let retrievability = (t / (s * 9) + 1).powf(-1.0); - retrievability + (t / (s * 9) + 1).powf(-1.0) } fn stability_after_success( @@ -41,17 +41,17 @@ impl> Model { let batch_size = rating.dims()[0]; let hard_penalty = Tensor::ones([batch_size]) .mask_where(rating.clone().equal_elem(2), self.w().slice([15..16])); - let easy_bonus = Tensor::ones([batch_size]) - .mask_where(rating.clone().equal_elem(4), self.w().slice([16..17])); - let new_s = last_s.clone() + let easy_bonus = + Tensor::ones([batch_size]).mask_where(rating.equal_elem(4), self.w().slice([16..17])); + + last_s.clone() * (self.w().slice([8..9]).exp() * (-new_d + 11) * (-self.w().slice([9..10]) * last_s.log()).exp() * (((-r + 1) * self.w().slice([10..11])).exp() - 1) * hard_penalty * easy_bonus - + 1); - new_s + + 1) } fn stability_after_failure( @@ -60,11 +60,10 @@ impl> Model { new_d: Tensor, r: Tensor, ) -> Tensor { - let new_s = self.w().slice([11..12]) + self.w().slice([11..12]) * (-self.w().slice([12..13]) * new_d.log()).exp() * ((self.w().slice([13..14]) * (last_s + 1).log()).exp() - 1) - * ((-r + 1) * self.w().slice([14..15])).exp(); - new_s + * ((-r + 1) * self.w().slice([14..15])).exp() } fn step( @@ -82,7 +81,7 @@ impl> Model { } else { let r = self.power_forgetting_curve(delta_t, stability.clone()); // dbg!(&r); - let new_d = difficulty.clone() - self.w().slice([6..7]) * (rating.clone() - 3); + let new_d = difficulty - self.w().slice([6..7]) * (rating.clone() - 3); let new_d = new_d.clamp(1.0, 10.0); // dbg!(&new_d); let s_recall = self.stability_after_success( diff --git a/src/training.rs b/src/training.rs index 2fb2120..bea993f 100644 --- a/src/training.rs +++ b/src/training.rs @@ -152,6 +152,6 @@ fn test() { train::( artifact_dir, TrainingConfig::new(ModelConfig::new(), AdamConfig::new()), - device.clone(), + device, ); } diff --git a/src/weight_clipper.rs b/src/weight_clipper.rs index 1eac7e9..51bc62c 100644 --- a/src/weight_clipper.rs +++ b/src/weight_clipper.rs @@ -21,7 +21,7 @@ pub fn weight_clipper>(weights: Tensor) -> Ten let val: &mut Vec = &mut weights.to_data().value; for (i, w) in val.iter_mut().skip(4).enumerate() { - *w = w.clamp(CLAMPS[i].0.into(), CLAMPS[i].1.into()); + *w = w.clamp(CLAMPS[i].0, CLAMPS[i].1); } Tensor::from_data(Data::new(val.clone(), weights.shape()))