Skip to content

Commit

Permalink
fix: use to_torch instead of to_numpy
Browse files Browse the repository at this point in the history
This properly creates writeable numpy arrays.
  • Loading branch information
lars-reimann committed May 29, 2024
1 parent 82322a9 commit 578cb9d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/safeds/data/labeled/containers/_tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def _into_dataloader_with_classes(self, batch_size: int, num_of_classes: int) ->
if num_of_classes <= 2:
return DataLoader(
dataset=_create_dataset(
torch.Tensor(self.features._data_frame.to_numpy()).to(_get_device()),
torch.Tensor(self.target._series.to_numpy()).to(_get_device()).unsqueeze(dim=-1),
self.features._data_frame.to_torch().to(_get_device()),
self.target._series.to_torch().to(_get_device()).unsqueeze(dim=-1),
),
batch_size=batch_size,
shuffle=True,
Expand All @@ -206,9 +206,9 @@ def _into_dataloader_with_classes(self, batch_size: int, num_of_classes: int) ->
else:
return DataLoader(
dataset=_create_dataset(
torch.Tensor(self.features._data_frame.to_numpy()).to(_get_device()),
self.features._data_frame.to_torch().to(_get_device()),
torch.nn.functional.one_hot(
torch.LongTensor(self.target._series.to_numpy()).to(_get_device()),
self.target._series.to_torch().to(_get_device()),
num_classes=num_of_classes,
),
),
Expand Down

0 comments on commit 578cb9d

Please sign in to comment.