Skip to content

Commit

Permalink
torchworld/structures/grid: added device and len methods
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Oct 27, 2023
1 parent b4665c8 commit 5d3a7da
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
10 changes: 10 additions & 0 deletions torchworld/structures/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

@dataclass
class BaseGrid(ABC):
data: torch.Tensor
time: torch.Tensor

@abstractmethod
def to(self: T, target: Union[torch.device, str]) -> T:
...
Expand All @@ -22,6 +25,13 @@ def cuda(self) -> T:
def cpu(self) -> T:
return self.to(torch.device("cpu"))

@property
def device(self) -> torch.device:
return self.data.device

def __len__(self) -> int:
return len(self.data)


@dataclass
class Grid3d(BaseGrid):
Expand Down
3 changes: 3 additions & 0 deletions torchworld/structures/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def test_grid_3d(self) -> None:
grid = grid.to("cpu")
grid = grid.cpu()

self.assertEqual(len(grid), 2)
self.assertEqual(grid.device, grid.data.device)

def test_grid_image(self) -> None:
grid = GridImage(
data=torch.rand(2, 3, 4, 5),
Expand Down

0 comments on commit 5d3a7da

Please sign in to comment.