Skip to content

Commit

Permalink
add configs for rice+simplebev and made everything use voxel intermed…
Browse files Browse the repository at this point in the history
…iate representation
  • Loading branch information
d4l3k committed Oct 6, 2023
1 parent 5e455bd commit 1d4ecbd
Show file tree
Hide file tree
Showing 14 changed files with 181 additions and 67 deletions.
40 changes: 40 additions & 0 deletions configs/rice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from torchdrive.train_config import Datasets, TrainConfig


CONFIG = TrainConfig(
# backbone settings
cameras=[
"CAM_FRONT",
"CAM_FRONT_LEFT",
"CAM_FRONT_RIGHT",
"CAM_BACK",
"CAM_BACK_LEFT",
"CAM_BACK_RIGHT",
],
dim=256,
cam_dim=96,
hr_dim=384,
backbone="rice",
cam_encoder="regnet",
num_encode_frames=3,
cam_shape=(480, 640),
num_upsamples=4,
grid_shape=(256, 256, 16),
# optimizer settings
epochs=20,
lr=1e-4,
grad_clip=1.0,
step_size=1000,
# dataset
dataset=Datasets.NUSCENES,
dataset_path="/mnt/ext3/nuscenes",
mask_path="n/a", # only used for rice dataset
num_workers=6,
batch_size=2,
# tasks
det=False,
ae=False,
voxel=True,
voxelsem=True,
path=False,
)
40 changes: 40 additions & 0 deletions configs/simplebev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from torchdrive.train_config import Datasets, TrainConfig


CONFIG = TrainConfig(
# backbone settings
cameras=[
"CAM_FRONT",
"CAM_FRONT_LEFT",
"CAM_FRONT_RIGHT",
"CAM_BACK",
"CAM_BACK_LEFT",
"CAM_BACK_RIGHT",
],
dim=256,
cam_dim=96,
hr_dim=384,
backbone="simple_bev",
cam_encoder="simple_regnet",
num_encode_frames=3,
cam_shape=(480, 640),
num_upsamples=1,
grid_shape=(256, 256, 16),
# optimizer settings
epochs=20,
lr=1e-4,
grad_clip=1.0,
step_size=1000,
# dataset
dataset=Datasets.NUSCENES,
dataset_path="/mnt/ext3/nuscenes",
mask_path="n/a", # only used for rice dataset
num_workers=6,
batch_size=2,
# tasks
det=False,
ae=False,
voxel=True,
voxelsem=True,
path=False,
)
3 changes: 2 additions & 1 deletion configs/simplebev3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
cam_encoder="simple_regnet",
num_encode_frames=3,
cam_shape=(480, 640),
bev_shape=(16, 16),
num_upsamples=1,
grid_shape=(256, 256, 16),
# optimizer settings
epochs=20,
lr=1e-4,
Expand Down
9 changes: 7 additions & 2 deletions torchdrive/models/bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(
dim: int,
hr_dim: int,
cam_dim: int,
bev_shape: Tuple[int, int],
grid_shape: Tuple[int, int, int], # [x, y, z]
input_shape: Tuple[int, int],
num_frames: int,
cameras: List[str],
Expand All @@ -229,6 +229,9 @@ def __init__(
super().__init__()

self.num_frames = num_frames
bev_shape = grid_shape[:2] # [x, y]
self.out_Z = grid_shape[2] * 2**num_upsamples
self.voxel_dim = max(hr_dim // self.out_Z, 1)

self.cam_transformers = nn.ModuleDict(
{
Expand All @@ -254,11 +257,11 @@ def __init__(
dim=dim,
output_dim=hr_dim,
)
self.project_voxel = nn.Conv2d(hr_dim, self.voxel_dim * self.out_Z, 1)

def forward(
self, camera_features: Mapping[str, List[torch.Tensor]], batch: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:

with autocast():
bev_grids = []

Expand All @@ -272,5 +275,7 @@ def forward(
bev = self.frame_merger(bev_grids)

hr_bev = self.upsample(bev)
hr_bev = self.project_voxel(hr_bev)
hr_bev = hr_bev.unflatten(1, (self.voxel_dim, self.out_Z))

return hr_bev, bev
43 changes: 30 additions & 13 deletions torchdrive/models/simple_bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import torch
import torchvision
from pytorch3d.transforms.transform3d import Transform3d
from pytorch3d.structures.volumes import VolumeLocator
from torch import nn
from torchvision import transforms
Expand Down Expand Up @@ -784,6 +783,10 @@ def __init__(
output_dim=hr_dim,
)
)

self.out_Z = grid_shape[2] * 2**num_upsamples
self.voxel_dim = max(hr_dim // self.out_Z, 1)
self.project_voxel = nn.Conv2d(hr_dim, self.voxel_dim * self.out_Z, 1)
# pyre-fixme[6]: invalid parameter type
self.lift_cam_to_voxel_mean: nn.Module = compile_fn(lift_cam_to_voxel_mean)

Expand All @@ -794,10 +797,16 @@ def forward(
S = len(camera_features) * self.num_frames
device = batch.device()

Z = self.grid_shape[2]
self.volume_locator = VolumeLocator(
batch_size=BS,
grid_sizes=self.grid_shape[::-1], # [z, y, x]
voxel_size=1 / self.scale,
volume_translation=(0, 0, -Z * 0.4 / self.scale), # -self.center,
device=device,
)
voxel_to_world = (
Transform3d(device=device)
.translate(*self.center)
.scale(1 / self.scale)
self.volume_locator.get_local_to_world_coords_transform()
.get_matrix()
.permute(0, 2, 1)
)
Expand Down Expand Up @@ -830,6 +839,8 @@ def forward(
# run through FPN
x, x4 = self.fpn(feat_mem)
x = self.upsample(x)
x = self.project_voxel(x)
x = x.unflatten(1, (self.voxel_dim, self.out_Z))
return x, x4


Expand Down Expand Up @@ -910,8 +921,7 @@ def __init__(
per_voxel_dim = max(hr_dim // (Z * 2), 1)
assert num_upsamples == 1, "only one upsample supported"
self.upsample: nn.Module = compile_fn(Upsample3DBlock(cam_dim, per_voxel_dim))
#self.final_project = nn.Conv2d(per_voxel_dim * Z * 2, hr_dim, 1)
#self.hr_project = nn.Conv2d(dim // HR_Z * HR_Z, dim, 1)
self.coarse_project = nn.Conv2d(dim // HR_Z * HR_Z, dim, 1)

# pyre-fixme[6]: invalid parameter type
self.lift_cam_to_voxel_mean: nn.Module = compile_fn(lift_cam_to_voxel_mean)
Expand All @@ -923,15 +933,19 @@ def forward(
S = len(camera_features) * self.num_frames
device = batch.device()

Z = self.grid_shape[2]
self.volume_locator = VolumeLocator(
batch_size=BS,
grid_sizes=self.grid_shape[::-1], # [z, y, x]
voxel_size=1/self.scale,
volume_translation=(0, 0, -8*0.4/self.scale), #-self.center,
grid_sizes=self.grid_shape[::-1], # [z, y, x]
voxel_size=1 / self.scale,
volume_translation=(0, 0, -Z * 0.4 / self.scale), # -self.center,
device=device,
)

voxel_to_world = self.volume_locator.get_local_to_world_coords_transform().get_matrix().permute(0, 2, 1)
voxel_to_world = (
self.volume_locator.get_local_to_world_coords_transform()
.get_matrix()
.permute(0, 2, 1)
)

features = []
Ks = []
Expand All @@ -953,12 +967,15 @@ def forward(
T = torch.stack(Ts, dim=1)
feat_mem = self.lift_cam_to_voxel_mean(feature, K, T, self.grid_shape)


with autocast():
# run through FPN
x = feat_mem
_x, x4 = self.fpn(x)
x, x4 = self.fpn(x)
assert x.shape == feat_mem.shape

x = self.upsample(x)

x4 = x4.flatten(1, 2)
x4 = self.coarse_project(x4)

return x, x4
15 changes: 12 additions & 3 deletions torchdrive/models/test_bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from torchdrive.data import dummy_batch

from torchdrive.models.bev import (
BEVMerger,
BEVUpsampler,
Expand Down Expand Up @@ -74,12 +76,17 @@ def test_bev_upsampler(self) -> None:
self.assertEqual(out.shape, (2, 3, 16, 16))

def test_rice_backbone(self) -> None:
batch = dummy_batch()
cameras = ["left", "right"]
num_frames = 2
latent_dim = 16
X = 8
Y = 16
Z = 24
m = RiceBackbone(
cam_dim=15,
dim=16,
bev_shape=(4, 4),
grid_shape=(X, Y, Z),
input_shape=(4, 6),
hr_dim=4,
num_upsamples=1,
Expand All @@ -89,5 +96,7 @@ def test_rice_backbone(self) -> None:
x, x4 = m(
{cam: [torch.rand(2, 15, 4, 6)] * num_frames for cam in cameras}, None
)
self.assertEqual(x.shape, (2, 4, 8, 8))
self.assertEqual(x4.shape, (2, 16, 4, 4))
# self.assertEqual(x.shape, (2, 4, 8, 8))
# self.assertEqual(x4.shape, (2, 16, 4, 4))
self.assertEqual(x.shape, (batch.batch_size(), 1, Z * 2, X * 2, Y * 2))
self.assertEqual(x4.shape, (batch.batch_size(), latent_dim, X, Y))
4 changes: 2 additions & 2 deletions torchdrive/models/test_simple_bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_segnet_backbone(self) -> None:
for feat in feats:
feat.requires_grad = True
x, x4 = m(camera_features, batch)
self.assertEqual(x.shape, (batch.batch_size(), hr_dim, X * 2, Y * 2))
self.assertEqual(x.shape, (batch.batch_size(), 1, Z * 2, X * 2, Y * 2))
self.assertEqual(x4.shape, (batch.batch_size(), latent_dim, X // 8, Y // 8))
(x.mean() + x4.mean()).backward()

Expand Down Expand Up @@ -169,7 +169,7 @@ def test_segnet_3d_backbone(self) -> None:
for feat in feats:
feat.requires_grad = True
x, x4 = m(camera_features, batch)
self.assertEqual(x.shape, (batch.batch_size(), hr_dim, X * 2, Y * 2))
self.assertEqual(x.shape, (batch.batch_size(), 1, Z * 2, X * 2, Y * 2))
self.assertEqual(x4.shape, (batch.batch_size(), latent_dim, X // 8, Y // 8))
(x.mean() + x4.mean()).backward()

Expand Down
5 changes: 1 addition & 4 deletions torchdrive/tasks/bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ def __init__(
cam_encoder: Callable[[], nn.Module],
tasks: Dict[str, BEVTask],
hr_tasks: Dict[str, BEVTask],
cam_shape: Tuple[int, int],
bev_shape: Tuple[int, int],
cameras: List[str],
dim: int,
hr_dim: int,
Expand Down Expand Up @@ -72,7 +70,6 @@ def __init__(
{cam: compile_fn(cam_encoder()) for cam in cameras}
)

self.cam_shape = cam_shape
self.tasks = nn.ModuleDict(tasks)

assert (len(tasks) + len(hr_tasks)) > 0, "no tasks specified"
Expand Down Expand Up @@ -170,7 +167,7 @@ def forward(

if log_img and writer:
writer.add_image(
"bev/bev", render_color(bev[0].sum(dim=(0, 1))), global_step=global_step
"bev/bev", render_color(bev[0].sum(dim=0)), global_step=global_step
)
writer.add_image(
"bev/hr_bev",
Expand Down
5 changes: 2 additions & 3 deletions torchdrive/tasks/test_bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,13 @@ class TestBEV(unittest.TestCase):
def test_bev_task_van(self) -> None:
cam_shape = (48, 64)
bev_shape = (4, 4)
grid_shape = (4, 4, 4)
cameras = ["left", "right"]
dim = 8
hr_dim = 1
m = BEVTaskVan(
tasks={"dummy": DummyBEVTask()},
hr_tasks={"hr_dummy": DummyBEVTask()},
cam_shape=cam_shape,
bev_shape=bev_shape,
cameras=cameras,
dim=dim,
hr_dim=hr_dim,
Expand All @@ -58,7 +57,7 @@ def test_bev_task_van(self) -> None:
dim=dim,
cam_dim=dim,
hr_dim=hr_dim,
bev_shape=bev_shape,
grid_shape=grid_shape,
input_shape=(48 // 16, 64 // 16),
num_frames=2,
cameras=cameras,
Expand Down
6 changes: 3 additions & 3 deletions torchdrive/tasks/test_voxel.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_voxel_task(self) -> None:
weights=batch.weight,
cam_feats={cam: torch.rand(2, 6, 320 // 16, 240 // 16) for cam in cameras},
)
bev = torch.rand(2, 5, 4, 4, device=device)
bev = torch.rand(2, 1, 5, 4, 4, device=device)
losses = m(ctx, batch, bev)
ctx.backward(losses)
self.assertCountEqual(losses.keys(), VOXEL_LOSSES)
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_semantic_voxel_task(self) -> None:
weights=batch.weight,
cam_feats={cam: torch.rand(2, 4, 320 // 16, 240 // 16) for cam in cameras},
)
bev = torch.rand(2, 5, 4, 4)
bev = torch.rand(2, 1, 5, 4, 4)
losses = m(ctx, batch, bev)
ctx.backward(losses)
self.assertCountEqual(
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_stereoscopic_voxel_task(self) -> None:
weights=batch.weight,
cam_feats={cam: torch.rand(2, 4, 320 // 16, 240 // 16) for cam in cameras},
)
bev = torch.rand(2, 5, 4, 4)
bev = torch.rand(2, 1, 5, 4, 4)
losses = m(ctx, batch, bev)
ctx.backward(losses)
self.assertCountEqual(
Expand Down
Loading

0 comments on commit 1d4ecbd

Please sign in to comment.