Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine-tune the sample size for CMRR.py #136

Merged
merged 16 commits into from
Sep 9, 2024
40 changes: 26 additions & 14 deletions src/fsrs_optimizer/fsrs_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,19 @@ def sample(
workload_only=False,
):
results = []
if learn_span < 100:
SAMPLE_SIZE = 16
elif learn_span < 365:
SAMPLE_SIZE = 8
else:
SAMPLE_SIZE = 4

def best_sample_size(days_to_simulate):
if days_to_simulate <= 30:
return 45
elif days_to_simulate >= 365:
return 4
else:
a1, a2, a3 = 8.20e-07, 2.41e-03, 1.30e-02
factor = a1 * np.power(days_to_simulate, 2) + a2 * days_to_simulate + a3
default_sample_size = 4
return int(default_sample_size/factor)

SAMPLE_SIZE = best_sample_size(learn_span)

for i in range(SAMPLE_SIZE):
_, _, _, memorized_cnt_per_day, cost_per_day = simulate(
Expand Down Expand Up @@ -422,6 +429,7 @@ def workload_graph(default_params, sampling_size=30):
default_params["deck_size"] / default_params["learn_span"]
)
default_params["review_limit_perday"] = math.inf
default_params["loss_aversion"] = 1
workload = [sample(r=r, workload_only=True, **default_params) for r in R]

# this is for testing
Expand Down Expand Up @@ -513,10 +521,14 @@ def workload_graph(default_params, sampling_size=30):
ax.xaxis.set_tick_params(labelsize=14)
ax.set_xlim(0.7, 0.99)

if max_w >= 4.5 * min_w:
lim = 4.5 * min_w
elif max_w >= 3.5 * min_w:
if max_w >= 3.5 * min_w:
lim = 3.5 * min_w
elif max_w >= 3 * min_w:
lim = 3 * min_w
elif max_w >= 2.5 * min_w:
lim = 2.5 * min_w
elif max_w >= 2 * min_w:
lim = 2 * min_w
else:
lim = 1.1 * max_w

Expand All @@ -527,13 +539,13 @@ def workload_graph(default_params, sampling_size=30):
ax.text(
0.701,
min_w,
"min. workload",
"minimum workload",
ha="left",
va="bottom",
color="black",
fontsize=12,
)
if max_w >= 1.8 * min_w:
if lim >= 1.8 * min_w:
ax.axhline(y=1.5 * min_w, color="black", alpha=0.75, ls="--")
ax.text(
0.701,
Expand All @@ -544,7 +556,7 @@ def workload_graph(default_params, sampling_size=30):
color="black",
fontsize=12,
)
if max_w >= 2.3 * min_w:
if lim >= 2.3 * min_w:
ax.axhline(y=2 * min_w, color="black", alpha=0.75, ls="--")
ax.text(
0.701,
Expand All @@ -555,7 +567,7 @@ def workload_graph(default_params, sampling_size=30):
color="black",
fontsize=12,
)
if max_w >= 2.8 * min_w:
if lim >= 2.8 * min_w:
ax.axhline(y=2.5 * min_w, color="black", alpha=0.75, ls="--")
ax.text(
0.701,
Expand All @@ -566,7 +578,7 @@ def workload_graph(default_params, sampling_size=30):
color="black",
fontsize=12,
)
if max_w >= 3.3 * min_w:
if lim >= 3.3 * min_w:
ax.axhline(y=3 * min_w, color="black", alpha=0.75, ls="--")
ax.text(
0.701,
Expand Down