From 251fec13bbe6e7e9f307d2c80f69ba318c5d5b92 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Thu, 31 Oct 2024 18:14:14 +0800 Subject: [PATCH] Feat/support to set device for BatchDataset --- src/fsrs_optimizer/fsrs_optimizer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 3685b4f..4c8640e 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -219,6 +219,7 @@ def __init__( batch_size: int = 0, sort_by_length: bool = True, max_seq_len: int = math.inf, + device: str = "cpu", ): if dataframe.empty: raise ValueError("Training data is inadequate.") @@ -248,10 +249,10 @@ def __init__( max_seq_len = max(seq_lens) sequences_truncated = sequences[:, :max_seq_len] self.batches[i] = ( - sequences_truncated.transpose(0, 1), - self.t_train[start_index:end_index], - self.y_train[start_index:end_index], - seq_lens, + sequences_truncated.transpose(0, 1).to(device), + self.t_train[start_index:end_index].to(device), + self.y_train[start_index:end_index].to(device), + seq_lens.to(device), ) def __getitem__(self, idx):