Skip to content

Commit

Permalink
torchworld/fpn,grid: make compilable
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Oct 29, 2023
1 parent 2ebd964 commit 39d7a0f
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 31 deletions.
17 changes: 8 additions & 9 deletions torchworld/models/fpn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import replace
from typing import Tuple

import torch
Expand Down Expand Up @@ -73,10 +72,10 @@ def forward(
x1 = self.up1_skip(x2, skip_x["1"])

return (
replace(grid, data=x1),
replace(grid, data=x2),
replace(grid, data=x3),
replace(grid, data=x4),
grid.replace(data=x1),
grid.replace(data=x2),
grid.replace(data=x3),
grid.replace(data=x4),
)


Expand Down Expand Up @@ -152,10 +151,10 @@ def forward(
x1 = self.up1_skip(x2, skip_x["1"])

return (
replace(grid, data=x1),
replace(grid, data=x2),
replace(grid, data=x3),
replace(grid, data=x4),
grid.replace(data=x1),
grid.replace(data=x2),
grid.replace(data=x3),
grid.replace(data=x4),
)


Expand Down
6 changes: 4 additions & 2 deletions torchworld/models/test_fpn.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import unittest

import torch
from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.transforms import Transform3d

from torchworld.models.fpn import Resnet18FPN3d, Resnet18FPNImage

from torchworld.structures.cameras import PerspectiveCameras
from torchworld.structures.grid import Grid3d, GridImage
from torchworld.transforms.transform3d import Transform3d


class TestFPN(unittest.TestCase):
Expand All @@ -17,6 +17,7 @@ def test_resnet18_fpn_3d(self) -> None:
time=torch.rand(2),
)
m = Resnet18FPN3d(in_channels=3)
m = torch.compile(m, fullgraph=True, backend="eager")
out = m(grid)
for grid in out:
self.assertIsInstance(grid, Grid3d)
Expand All @@ -28,6 +29,7 @@ def test_resnet18_fpn_image(self) -> None:
time=torch.rand(2),
)
m = Resnet18FPNImage(in_channels=3)
m = torch.compile(m, fullgraph=True, backend="eager")
out = m(grid)
self.assertEqual(len(out), 4)
for grid in out:
Expand Down
14 changes: 13 additions & 1 deletion torchworld/structures/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import Optional, Tuple, TypeVar, Union

import torch
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures.volumes import VolumeLocator

from torchworld.structures.cameras import CamerasBase
from torchworld.transforms.transform3d import Transform3d

T = TypeVar("T")
Expand Down Expand Up @@ -189,3 +189,15 @@ def to(self, target: Union[torch.device, str]) -> "GridImage":

def grid_shape(self) -> Tuple[int, int]:
return self.data.shape[2:4]

def replace(
self,
data: Optional[torch.Tensor] = None,
camera: Optional[CamerasBase] = None,
time: Optional[torch.Tensor] = None,
) -> "Grid3d":
return GridImage(
data=data if data is not None else self.data,
camera=camera if camera is not None else self.camera,
time=time if time is not None else self.time,
)
3 changes: 3 additions & 0 deletions torchworld/structures/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_grid_3d(self) -> None:

grid = grid.to("cpu")
grid = grid.cpu()
grid = grid.replace()

self.assertEqual(len(grid), 2)
self.assertEqual(grid.device, grid.data.device)
Expand All @@ -32,6 +33,8 @@ def test_grid_image(self) -> None:

grid = grid.to("cpu")
grid = grid.cpu()
grid = grid.replace()

self.assertEqual(grid.grid_shape(), (4, 5))

def test_grid_3d_from_volume(self) -> None:
Expand Down
21 changes: 2 additions & 19 deletions torchworld/transforms/test_simplebev.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,7 @@ def test_lift_image_to_3d(self) -> None:
time=torch.rand(2, device=device),
)

out, mask = lift_image_to_3d(src, dst)
compiled_lift = torch.compile(lift_image_to_3d, fullgraph=True, backend="eager")
out, mask = compiled_lift(src, dst)
self.assertEqual(out.data.shape, (2, 3, 1, 2, 3))
self.assertEqual(mask.data.shape, (2, 1, 1, 2, 3))

def test_lift_image_to_3d_compile(self) -> None:
device = torch.device("cpu")
dtype = torch.float
dst = Grid3d(
data=torch.rand(0, 3, 1, 2, 3, device=device, dtype=dtype),
local_to_world=Transform3d(device=device),
time=torch.rand(2, device=device),
)
src = GridImage(
data=torch.rand(2, 3, 1, 2, device=device, dtype=dtype),
camera=PerspectiveCameras(device=device),
time=torch.rand(2, device=device),
)

compiled_lift = torch.compile(lift_image_to_3d, fullgraph=True)
compiled_lift(src, dst)
compiled_lift(src, dst)

0 comments on commit 39d7a0f

Please sign in to comment.