Skip to content

Commit

Permalink
Feat/not sqrt in weights for pretrain (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Apr 20, 2024
1 parent 6aa6fd0 commit 37ece5e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 43 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "4.27.7"
version = "4.28.0"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
77 changes: 35 additions & 42 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,23 @@
Relearning = 3

DEFAULT_WEIGHT = [
0.5701,
1.4436,
4.1386,
10.9355,
5.1443,
1.2006,
0.8627,
0.0362,
1.629,
0.1342,
1.0166,
2.1174,
0.0839,
0.3204,
1.4676,
0.219,
2.8237,
0.4872,
1.4003,
3.7145,
13.8206,
5.1618,
1.2298,
0.8975,
0.031,
1.6474,
0.1367,
1.0461,
2.1072,
0.0793,
0.3246,
1.587,
0.2272,
2.8755,
]

S_MIN = 0.01
Expand Down Expand Up @@ -443,6 +443,18 @@ def remove_non_continuous_rows(group):
return group.loc[: first_non_continuous_index - 1]


def fit_stability(delta_t, retention, size):
def loss(stability):
y_pred = power_forgetting_curve(delta_t, stability)
loss = sum(
-(retention * np.log(y_pred) + (1 - retention) * np.log(1 - y_pred)) * size
)
return loss

res = minimize(loss, x0=1, bounds=[(S_MIN, 36500)])
return res.x[0]


class Optimizer:
def __init__(self) -> None:
tqdm.pandas()
Expand Down Expand Up @@ -707,12 +719,9 @@ def cal_stability(group: pd.DataFrame) -> pd.DataFrame:
group["group_cnt"] = group_cnt
if group["i"].values[0] > 1:
group["stability"] = round(
curve_fit(
power_forgetting_curve,
group["delta_t"],
group["retention"],
sigma=1 / np.sqrt(group["total_cnt"]),
)[0][0],
fit_stability(
group["delta_t"], group["retention"], group["total_cnt"]
),
1,
)
else:
Expand Down Expand Up @@ -820,15 +829,14 @@ def pretrain(self, dataset=None, verbose=True):
group["y"]["count"] + 1
)
count = group["y"]["count"]
weight = np.sqrt(count)

init_s0 = r_s0_default[first_rating]

def loss(stability):
y_pred = power_forgetting_curve(delta_t, stability)
logloss = sum(
-(recall * np.log(y_pred) + (1 - recall) * np.log(1 - y_pred))
* weight
* count
)
l1 = np.abs(stability - init_s0) / 16
return logloss + l1
Expand All @@ -837,7 +845,7 @@ def loss(stability):
loss,
x0=init_s0,
bounds=((S_MIN, 100),),
options={"maxiter": int(sum(weight))},
options={"maxiter": int(sum(count))},
)
params = res.x
stability = params[0]
Expand Down Expand Up @@ -1494,22 +1502,7 @@ def cal_stability(tmp):
count = tmp["y_count"]
total_count = sum(count)

def loss(stability):
y_pred = power_forgetting_curve(delta_t, stability)
logloss = sum(
-(
recall * np.log(y_pred)
+ (1 - recall) * np.log(1 - y_pred)
)
* np.sqrt(count)
)
return logloss

res = minimize(loss, 1, bounds=((S_MIN, 3650),))
if res.success:
tmp["true_s"] = res.x[0]
else:
tmp["true_s"] = np.nan
tmp["true_s"] = fit_stability(delta_t, recall, count)
tmp["predicted_s"] = np.average(
tmp["stability_mean"], weights=count
)
Expand Down

0 comments on commit 37ece5e

Please sign in to comment.