Skip to content

Commit

Permalink
Add clippy to CI checks (#12)
Browse files Browse the repository at this point in the history
* Remove unused fields from RevlogEntry

* Apply clippy suggestions

* Suppress a clippy warning

* Add clippy to CI checks
  • Loading branch information
dae authored Aug 21, 2023
1 parent 306e924 commit 7e9fb86
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 30 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/check.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#!/bin/bash

set -e

cargo fmt --check || (
echo
echo "Please run 'cargo fmt' to format the code."
exit 1
)

cargo clippy -- -Dwarnings
22 changes: 7 additions & 15 deletions src/convertor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@ use crate::dataset::{FSRSItem, Review};
struct RevlogEntry {
id: i64,
cid: i64,
usn: i64,
button_chosen: i32,
interval: i64,
last_interval: i64,
ease_factor: i64,
taken_millis: i64,
review_kind: i64,
delta_t: i32,
i: usize,
Expand All @@ -26,13 +22,9 @@ fn row_to_revlog_entry(row: &Row) -> Result<RevlogEntry> {
Ok(RevlogEntry {
id: row.get(0)?,
cid: row.get(1)?,
usn: row.get(2)?,
button_chosen: row.get(3)?,
interval: row.get(4)?,
last_interval: row.get(5)?,
ease_factor: row.get(6)?,
taken_millis: row.get(7).unwrap_or_default(),
review_kind: row.get(8).unwrap_or_default(),
button_chosen: row.get(2)?,
ease_factor: row.get(3)?,
review_kind: row.get(4).unwrap_or_default(),
delta_t: 0,
i: 0,
r_history: vec![],
Expand Down Expand Up @@ -66,7 +58,7 @@ fn read_collection() -> Vec<RevlogEntry> {
let current_timestamp = Utc::now().timestamp() * 1000;

let query = format!(
"SELECT *
"SELECT id, cid, ease, factor, type
FROM revlog
WHERE (type != 4 OR ivl <= 0)
AND id < {}
Expand Down Expand Up @@ -100,7 +92,7 @@ fn group_by_cid(revlogs: Vec<RevlogEntry>) -> Vec<Vec<RevlogEntry>> {
.push(revlog);
}

grouped.into_iter().map(|(_, v)| v).collect()
grouped.into_values().collect()
}

fn convert_to_date(timestamp: i64, next_day_starts_at: i64, timezone: Tz) -> chrono::NaiveDate {
Expand Down Expand Up @@ -232,8 +224,8 @@ pub fn anki_to_fsrs() -> Vec<FSRSItem> {
.collect();

let filtered_revlogs_per_card = remove_non_learning_first(extracted_revlogs_per_card);
let fsrs_items = convert_to_fsrs_items(filtered_revlogs_per_card);
fsrs_items

convert_to_fsrs_items(filtered_revlogs_per_card)
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ fn test_from_json() {
use burn_ndarray::NdArrayDevice;
let device = NdArrayDevice::Cpu;
type Backend = NdArrayBackend<f32>;
let batcher = FSRSBatcher::<Backend>::new(device.clone());
let batcher = FSRSBatcher::<Backend>::new(device);
let dataloader = DataLoaderBuilder::new(batcher)
.batch_size(1)
.shuffle(42)
Expand All @@ -167,7 +167,7 @@ fn test_from_anki() {
let device = NdArrayDevice::Cpu;
use burn_ndarray::NdArrayBackend;
type Backend = NdArrayBackend<f32>;
let batcher = FSRSBatcher::<Backend>::new(device.clone());
let batcher = FSRSBatcher::<Backend>::new(device);
use burn::data::dataloader::DataLoaderBuilder;
let dataloader = DataLoaderBuilder::new(batcher)
.batch_size(1)
Expand Down
21 changes: 10 additions & 11 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub struct Model<B: Backend> {
}

impl<B: Backend<FloatElem = f32>> Model<B> {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self {
w: Param::from(Tensor::from_floats([
Expand All @@ -27,8 +28,7 @@ impl<B: Backend<FloatElem = f32>> Model<B> {
}

pub fn power_forgetting_curve(&self, t: Tensor<B, 1>, s: Tensor<B, 1>) -> Tensor<B, 1> {
let retrievability = (t / (s * 9) + 1).powf(-1.0);
retrievability
(t / (s * 9) + 1).powf(-1.0)
}

fn stability_after_success(
Expand All @@ -41,17 +41,17 @@ impl<B: Backend<FloatElem = f32>> Model<B> {
let batch_size = rating.dims()[0];
let hard_penalty = Tensor::ones([batch_size])
.mask_where(rating.clone().equal_elem(2), self.w().slice([15..16]));
let easy_bonus = Tensor::ones([batch_size])
.mask_where(rating.clone().equal_elem(4), self.w().slice([16..17]));
let new_s = last_s.clone()
let easy_bonus =
Tensor::ones([batch_size]).mask_where(rating.equal_elem(4), self.w().slice([16..17]));

last_s.clone()
* (self.w().slice([8..9]).exp()
* (-new_d + 11)
* (-self.w().slice([9..10]) * last_s.log()).exp()
* (((-r + 1) * self.w().slice([10..11])).exp() - 1)
* hard_penalty
* easy_bonus
+ 1);
new_s
+ 1)
}

fn stability_after_failure(
Expand All @@ -60,11 +60,10 @@ impl<B: Backend<FloatElem = f32>> Model<B> {
new_d: Tensor<B, 1>,
r: Tensor<B, 1>,
) -> Tensor<B, 1> {
let new_s = self.w().slice([11..12])
self.w().slice([11..12])
* (-self.w().slice([12..13]) * new_d.log()).exp()
* ((self.w().slice([13..14]) * (last_s + 1).log()).exp() - 1)
* ((-r + 1) * self.w().slice([14..15])).exp();
new_s
* ((-r + 1) * self.w().slice([14..15])).exp()
}

fn step(
Expand All @@ -82,7 +81,7 @@ impl<B: Backend<FloatElem = f32>> Model<B> {
} else {
let r = self.power_forgetting_curve(delta_t, stability.clone());
// dbg!(&r);
let new_d = difficulty.clone() - self.w().slice([6..7]) * (rating.clone() - 3);
let new_d = difficulty - self.w().slice([6..7]) * (rating.clone() - 3);
let new_d = new_d.clamp(1.0, 10.0);
// dbg!(&new_d);
let s_recall = self.stability_after_success(
Expand Down
2 changes: 1 addition & 1 deletion src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,6 @@ fn test() {
train::<AutodiffBackend>(
artifact_dir,
TrainingConfig::new(ModelConfig::new(), AdamConfig::new()),
device.clone(),
device,
);
}
2 changes: 1 addition & 1 deletion src/weight_clipper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub fn weight_clipper<B: Backend<FloatElem = f32>>(weights: Tensor<B, 1>) -> Ten
let val: &mut Vec<f32> = &mut weights.to_data().value;

for (i, w) in val.iter_mut().skip(4).enumerate() {
*w = w.clamp(CLAMPS[i].0.into(), CLAMPS[i].1.into());
*w = w.clamp(CLAMPS[i].0, CLAMPS[i].1);
}

Tensor::from_data(Data::new(val.clone(), weights.shape()))
Expand Down

0 comments on commit 7e9fb86

Please sign in to comment.