Skip to content

Commit

Permalink
Set mean reversion target to D0(4) (#124)
Browse files Browse the repository at this point in the history
* Fix/set mean reversion target to D0(4)

Supersedes and closes #123

* Update fsrs_optimizer.py

* fix X not defined

* Update fsrs_simulator.py

* decouple init_d_with_short_term

* bump version

---------

Co-authored-by: Jarrett Ye <[email protected]>
  • Loading branch information
user1823 and L-M-Sherlock authored Jul 27, 2024
1 parent 86eeb69 commit 66d4695
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "5.0.2"
version = "5.0.3"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
8 changes: 6 additions & 2 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,13 @@ def stability_short_term(self, state: Tensor, rating: Tensor) -> Tensor:
new_s = state[:, 0] * torch.exp(self.w[17] * (rating - 3 + self.w[18]))
return new_s

def init_d(self, rating: Tensor) -> Tensor:
new_d = self.w[4] - torch.exp(self.w[5] * (rating - 1)) + 1
return new_d

def next_d(self, state: Tensor, rating: Tensor) -> Tensor:
new_d = state[:, 1] - self.w[6] * (rating - 3)
new_d = self.mean_reversion(self.w[4], new_d)
new_d = self.mean_reversion(self.init_d(4), new_d)
return new_d

def step(self, X: Tensor, state: Tensor) -> Tensor:
Expand All @@ -113,7 +117,7 @@ def step(self, X: Tensor, state: Tensor) -> Tensor:
# first learn, init memory states
new_s = torch.ones_like(state[:, 0])
new_s[index[0]] = self.w[index[1]]
new_d = self.w[4] - torch.exp(self.w[5] * (X[:, 1] - 1)) + 1
new_d = self.init_d(X[:, 1])
new_d = new_d.clamp(1, 10)
else:
r = power_forgetting_curve(X[:, 0], state[:, 0])
Expand Down
10 changes: 6 additions & 4 deletions src/fsrs_optimizer/fsrs_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,16 @@ def stability_short_term(s, init_rating=None):
return new_s

def init_d(rating):
new_d = w[4] - np.exp(w[5] * (rating - 1)) + 1
return w[4] - np.exp(w[5] * (rating - 1)) + 1

def init_d_with_short_term(rating):
rating_offset = np.choose(rating - 1, first_rating_offset)
new_d -= w[6] * rating_offset
new_d = init_d(rating) - w[6] * rating_offset
return np.clip(new_d, 1, 10)

def next_d(d, rating):
new_d = d - w[6] * (rating - 3)
new_d = mean_reversion(w[4], new_d)
new_d = mean_reversion(init_d(4), new_d)
return np.clip(new_d, 1, 10)

def mean_reversion(init, current):
Expand Down Expand Up @@ -202,7 +204,7 @@ def mean_reversion(init, current):
card_table[col["stability"]][true_learn],
init_rating=card_table[col["rating"]][true_learn].astype(int),
)
card_table[col["difficulty"]][true_learn] = init_d(
card_table[col["difficulty"]][true_learn] = init_d_with_short_term(
card_table[col["rating"]][true_learn].astype(int)
)

Expand Down

0 comments on commit 66d4695

Please sign in to comment.