Skip to content

Commit

Permalink
Feat/speed up dataloader (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Feb 25, 2024
1 parent 1750cd3 commit be89d5c
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 92 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "4.24.2"
version = "4.25.0"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
158 changes: 67 additions & 91 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import torch
from torch import nn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import StratifiedGroupKFold, TimeSeriesSplit
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scipy.optimize import minimize
Expand Down Expand Up @@ -178,80 +178,64 @@ def lineToTensor(line: str) -> Tensor:
return tensor


class RevlogDataset(Dataset):
def __init__(self, dataframe: pd.DataFrame):
class BatchDataset(Dataset):
def __init__(self, dataframe: pd.DataFrame, batch_size: int = 0):
if dataframe.empty:
raise ValueError("Training data is inadequate.")
padded = pad_sequence(
if batch_size > 0:
dataframe = dataframe.sort_values(by=["i"])
self.x_train = pad_sequence(
dataframe["tensor"].to_list(), batch_first=True, padding_value=0
)
self.x_train = padded.int()
self.t_train = torch.tensor(dataframe["delta_t"].values, dtype=torch.int)
self.y_train = torch.tensor(dataframe["y"].values, dtype=torch.float)
self.seq_len = torch.tensor(
dataframe["tensor"].map(len).values, dtype=torch.long
)
length = len(dataframe)
batch_num, remainder = divmod(length, max(1, batch_size))
self.batch_num = batch_num + 1 if remainder > 0 else batch_num
self.batches = [None] * self.batch_num
if batch_size > 0:
for i in range(self.batch_num):
start_index = i * batch_size
end_index = min((i + 1) * batch_size, length)
sequences = self.x_train[start_index:end_index]
seq_lens = self.seq_len[start_index:end_index]
max_len = max(seq_lens)
sequences_truncated = sequences[:, :max_len]
self.batches[i] = (
sequences_truncated.transpose(0, 1),
self.t_train[start_index:end_index],
self.y_train[start_index:end_index],
seq_lens,
)

def __getitem__(self, idx):
return (
self.x_train[idx],
self.t_train[idx],
self.y_train[idx],
self.seq_len[idx],
)
return self.batches[idx]

def __len__(self):
return len(self.y_train)
return self.batch_num


class RevlogSampler(Sampler[List[int]]):
def __init__(self, data_source: RevlogDataset, batch_size: int):
self.data_source = data_source
self.batch_size = batch_size
lengths = np.array(data_source.seq_len)
indices = np.argsort(lengths)
full_batches, remainder = divmod(indices.size, self.batch_size)
if full_batches > 0:
if remainder == 0:
self.batch_indices = np.split(indices, full_batches)
else:
self.batch_indices = np.split(indices[:-remainder], full_batches)
else:
self.batch_indices = []
if remainder > 0:
self.batch_indices.append(indices[-remainder:])
self.batch_nums = len(self.batch_indices)
# seed = int(torch.empty((), dtype=torch.int64).random_().item())
class BatchLoader:
def __init__(self, dataset: BatchDataset):
self.dataset = dataset
self.batch_nums = len(dataset.batches)
seed = 2023
self.generator = torch.Generator()
self.generator.manual_seed(seed)

def __iter__(self):
yield from (
self.batch_indices[idx]
self.dataset[idx]
for idx in torch.randperm(
self.batch_nums, generator=self.generator
).tolist()
)

def __len__(self):
return len(self.data_source)


def collate_fn(batch):
sequences, delta_ts, labels, seq_lens = zip(*batch)
sequences_packed = pack_padded_sequence(
torch.stack(sequences, dim=1),
lengths=torch.stack(seq_lens),
batch_first=False,
enforce_sorted=False,
)
sequences_padded, length = pad_packed_sequence(sequences_packed, batch_first=False)
sequences_padded = torch.as_tensor(sequences_padded)
seq_lens = torch.as_tensor(length)
delta_ts = torch.as_tensor(delta_ts)
labels = torch.as_tensor(labels)
return sequences_padded, delta_ts, labels, seq_lens
return self.batch_nums


class Trainer:
Expand All @@ -270,7 +254,7 @@ def __init__(
self.batch_size = batch_size
self.build_dataset(train_set, test_set)
self.n_epoch = n_epoch
self.batch_nums = self.next_train_data_loader.batch_sampler.batch_nums
self.batch_nums = self.next_train_data_loader.batch_nums
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=self.batch_nums * n_epoch
)
Expand All @@ -280,36 +264,24 @@ def __init__(

def build_dataset(self, train_set: pd.DataFrame, test_set: pd.DataFrame):
pre_train_set = train_set[train_set["i"] == 2]
self.pre_train_set = RevlogDataset(pre_train_set)
sampler = RevlogSampler(self.pre_train_set, batch_size=self.batch_size)
self.pre_train_data_loader = DataLoader(
self.pre_train_set, batch_sampler=sampler, collate_fn=collate_fn
)
self.pre_train_set = BatchDataset(pre_train_set, batch_size=self.batch_size)
self.pre_train_data_loader = BatchLoader(self.pre_train_set)

next_train_set = train_set[train_set["i"] > 2]
self.next_train_set = RevlogDataset(next_train_set)
sampler = RevlogSampler(self.next_train_set, batch_size=self.batch_size)
self.next_train_data_loader = DataLoader(
self.next_train_set, batch_sampler=sampler, collate_fn=collate_fn
)
self.next_train_set = BatchDataset(next_train_set, batch_size=self.batch_size)
self.next_train_data_loader = BatchLoader(self.next_train_set)

self.train_set = RevlogDataset(train_set)
sampler = RevlogSampler(self.train_set, batch_size=self.batch_size)
self.train_data_loader = DataLoader(
self.train_set, batch_sampler=sampler, collate_fn=collate_fn
)
self.train_set = BatchDataset(train_set, batch_size=self.batch_size)
self.train_data_loader = BatchLoader(self.train_set)

self.test_set = RevlogDataset(test_set)
sampler = RevlogSampler(self.test_set, batch_size=self.batch_size)
self.test_data_loader = DataLoader(
self.test_set, batch_sampler=sampler, collate_fn=collate_fn
)
self.test_set = BatchDataset(test_set, batch_size=self.batch_size)
self.test_data_loader = BatchLoader(self.test_set)
tqdm.write("dataset built")

def train(self, verbose: bool = True):
self.verbose = verbose
best_loss = np.inf
epoch_len = len(self.next_train_data_loader)
epoch_len = len(self.next_train_set.y_train)
if verbose:
pbar = tqdm(desc="train", colour="red", total=epoch_len * self.n_epoch)
print_len = max(self.batch_nums * self.n_epoch // 10, 1)
Expand Down Expand Up @@ -423,7 +395,7 @@ def predict(self, t_history: str, r_history: str):
return output_t[-1][0]

def batch_predict(self, dataset):
fast_dataset = RevlogDataset(dataset)
fast_dataset = BatchDataset(dataset)
with torch.no_grad():
outputs, _ = self.model(fast_dataset.x_train.transpose(0, 1))
stabilities, difficulties = outputs[
Expand Down Expand Up @@ -1120,11 +1092,13 @@ def preview(self, requestRetention: float, verbose=False):
"interval history: "
+ ",".join(
[
f"{ivl}d"
if ivl < 30
else f"{ivl / 30:.1f}m"
if ivl < 365
else f"{ivl / 365:.1f}y"
(
f"{ivl}d"
if ivl < 30
else (
f"{ivl / 30:.1f}m" if ivl < 365 else f"{ivl / 365:.1f}y"
)
)
for ivl in map(int, t_history.split(","))
]
)
Expand All @@ -1135,9 +1109,11 @@ def preview(self, requestRetention: float, verbose=False):
+ ",".join(
["0.0"]
+ [
f"{float(ivl) / float(pre_ivl):.2f}"
if pre_ivl != "0"
else "0.0"
(
f"{float(ivl) / float(pre_ivl):.2f}"
if pre_ivl != "0"
else "0.0"
)
for ivl, pre_ivl in zip(
t_history.split(",")[1:],
t_history.split(",")[:-1],
Expand Down Expand Up @@ -1166,11 +1142,11 @@ def preview_sequence(self, test_rating_sequence: str, requestRetention: float):
"interval history: "
+ ",".join(
[
f"{ivl}d"
if ivl < 30
else f"{ivl / 30:.1f}m"
if ivl < 365
else f"{ivl / 365:.1f}y"
(
f"{ivl}d"
if ivl < 30
else f"{ivl / 30:.1f}m" if ivl < 365 else f"{ivl / 365:.1f}y"
)
for ivl in map(int, t_history.split(","))
]
)
Expand Down Expand Up @@ -1217,9 +1193,9 @@ def predict_memory_states(self):
self.difficulty_distribution_padding = np.zeros(10)
for i in range(10):
if i + 1 in self.difficulty_distribution.index:
self.difficulty_distribution_padding[
i
] = self.difficulty_distribution.loc[i + 1]
self.difficulty_distribution_padding[i] = (
self.difficulty_distribution.loc[i + 1]
)
return self.difficulty_distribution

def find_optimal_retention(
Expand Down Expand Up @@ -1632,9 +1608,9 @@ def compare_with_sm2(self):
np.log(0.9) * self.dataset["delta_t"] / self.dataset["sm2_ivl"]
)
self.dataset["log_loss"] = self.dataset.apply(
lambda row: -np.log(row["sm2_p"])
if row["y"] == 1
else -np.log(1 - row["sm2_p"]),
lambda row: (
-np.log(row["sm2_p"]) if row["y"] == 1 else -np.log(1 - row["sm2_p"])
),
axis=1,
)
tqdm.write(f"Loss of SM-2: {self.dataset['log_loss'].mean():.4f}")
Expand Down

0 comments on commit be89d5c

Please sign in to comment.