Skip to content

Commit

Permalink
voxel: fixed grid render + refactors to handle multipose
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Oct 13, 2023
1 parent e80e212 commit 6cc5911
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 17 deletions.
2 changes: 1 addition & 1 deletion torchdrive/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def dummy_item() -> Batch:
mask={cam: torch.rand(1, 48, 64) for cam in cams},
lidar_T=torch.rand(4, 4),
lidar=torch.rand(4, random.randint(6, 10)),
sem_seg={cam: torch.rand(19, 24, 32) for cam in cams},
sem_seg={cam: torch.rand(N, 19, 24, 32) for cam in cams},
)


Expand Down
13 changes: 12 additions & 1 deletion torchdrive/tasks/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Dict, Mapping, Optional, Union
from typing import Dict, Mapping, Optional, Union, Set

import torch
from torch.cuda import amp
Expand Down Expand Up @@ -32,12 +32,15 @@ class Context:

name: str = "<unknown>"

_logged: Set[str] = field(default_factory=set)

def backward(self, losses: Dict[str, torch.Tensor]) -> None:
losses_backward(losses, scaler=self.scaler, weights=self.weights)

def add_scalars(self, name: str, scalars: Dict[str, torch.Tensor]) -> None:
if self.writer:
assert self.log_text
self._check_key(name)
self.writer.add_scalars(
f"{self.name}-{name}",
{k: _cpu_float(v) for k, v in scalars.items()},
Expand All @@ -49,20 +52,23 @@ def add_scalar(
) -> None:
if self.writer:
assert self.log_text
self._check_key(name)
self.writer.add_scalar(
f"{self.name}-{name}", _cpu_float(scalar), global_step=self.global_step
)

def add_image(self, name: str, img: torch.Tensor) -> None:
if self.writer:
assert self.log_img
self._check_key(name)
self.writer.add_image(
f"{self.name}-{name}", img, global_step=self.global_step
)

def add_figure(self, name: str, figure: object) -> None:
if self.writer:
assert self.log_img
self._check_key(name)
self.writer.add_figure(
f"{self.name}-{name}", figure, global_step=self.global_step
)
Expand All @@ -73,3 +79,8 @@ def log_grad_norm(self, tensor: torch.Tensor, key: str, tag: str) -> torch.Tenso
return log_grad_norm(
tensor, self.writer, f"{self.name}-{key}", tag, self.global_step
)

def _check_key(self, key: str) -> None:
if key in self._logged:
raise RuntimeError(f"already logged {key}")
self._logged.add(key)
2 changes: 1 addition & 1 deletion torchdrive/tasks/test_voxel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_voxel_task(self) -> None:
cam_dim=6,
height=12,
device=device,
render_batch_size=1,
render_batch_size=5,
n_pts_per_ray=10,
offsets=(-1, 0, 1),
).to(device)
Expand Down
59 changes: 45 additions & 14 deletions torchdrive/tasks/voxel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from typing import Callable, Dict, List, Mapping, Optional, Tuple

import numpy as np
from matplotlib import cm

import torch
import torch.nn.functional as F
import torchmetrics
from pytorch3d.renderer import ImplicitRenderer, NDCMultinomialRaysampler
from pytorch3d.structures import Volumes
from pytorch3d.structures.volumes import VolumeLocator
from torch import nn

from torchdrive.amp import autocast
Expand Down Expand Up @@ -265,27 +267,42 @@ def forward(

gz = render_color(grid[0, 0].sum(dim=2))

vtw = voxel_to_world(
center=(-bev_shape[0] // 2, -bev_shape[1] // 2, 0),
scale=self.scale,
grid_shape = grid.shape[2:] # [x, y, z]
volume_locator = VolumeLocator(
batch_size=1,
grid_sizes=grid_shape[::-1], # [z, y, x]
voxel_size=1 / self.scale,
volume_translation=self.volume_translation,
device=device,
)
voxel_to_world = (
volume_locator.get_local_to_world_coords_transform()
.get_matrix()
.permute(0, 2, 1)
)

start_color = torch.tensor((0, 1, 0))
end_color = torch.tensor((0, 0, 1))
def get_color(percent: float) -> torch.Tensor:
return start_color*(1-percent) + percent*end_color

zero_coord = torch.tensor([[0, 0, 0, 1]], device=device, dtype=torch.float)
for frame in range(0, frames):
# create car to voxel transform
T = batch.world_to_car(frame)
T = T.matmul(vtw)
T = T.pinverse()

T = T.matmul(voxel_to_world)
T = T.inverse()
cam_coords = T.matmul(zero_coord.T).squeeze(-1)
cam_coords /= cam_coords[:, 3:].clamp(min=1e-8)
coord = cam_coords[0, :3]

# convert from -1 to 1 to the grid range
coord = (coord+1)/2 * torch.tensor(grid_shape, device=device)
x, y, z = coord.int()
_, d, w = gz.shape
if x >= d or y >= w or x < 0 or y < 0:
continue
gz[:, x, y] = torch.tensor((0, 1, 0))
gz[:, y, x] = get_color(frame/frames)

ctx.add_image(
"grid/z",
Expand Down Expand Up @@ -379,18 +396,31 @@ def _losses(
"""
compute losses for the provided minibatch.
"""
losses = self._losses_frame(ctx=ctx, batch=batch, grid=grid, feat_grid=feat_grid, frame=ctx.start_frame)
return losses

def _losses_frame(
self,
ctx: Context,
batch: Batch,
grid: torch.Tensor,
feat_grid: Optional[torch.Tensor],
frame: int,
) -> Dict[str, torch.Tensor]:
"""
compute losses for the provided minibatch starting at the specific frame
"""
BS = len(batch.distances)
frames = batch.distances.shape[1]
start_frame = ctx.start_frame
frame_time = batch.frame_time - batch.frame_time[:, ctx.start_frame].unsqueeze(
frame_time = batch.frame_time - batch.frame_time[:, start_frame].unsqueeze(
1
)
device = grid.device

losses = {}

h, w = self.cam_shape
frame = ctx.start_frame

primary_colors: Dict[str, torch.Tensor] = {}
primary_masks: Dict[str, torch.Tensor] = {}
Expand Down Expand Up @@ -422,8 +452,9 @@ def _losses(
if self.semantic:
with torch.autograd.profiler.record_function("segment"):
for cam in self.cameras:
semantic_target = batch.sem_seg[cam][:, frame]
semantic_target = F.interpolate(
batch.sem_seg[cam].float(),
semantic_target.float(),
[h // 2, w // 2],
mode="bilinear",
align_corners=False,
Expand Down Expand Up @@ -489,7 +520,7 @@ def _losses(
padding_mode="border",
)

if batch.lidar is not None:
if batch.lidar is not None and False:
with torch.no_grad():
ray_bundle, distances = self.lidar_raysampler(batch)
rays_densities, rays_features = volumetric_function(
Expand Down Expand Up @@ -832,12 +863,12 @@ def _sfm_loss(

if ctx.log_img:
ctx.add_image(
f"depth{label}/{cam}",
f"depth{label}/{frame}/{cam}",
render_color(-depth[0][0]),
)
out_disp = disp[0, 0] * primary_mask[0, 0]
ctx.add_image(
f"disp{label}/{cam}",
f"disp{label}/{frame}/{cam}",
render_color(out_disp),
)

Expand All @@ -853,7 +884,7 @@ def _sfm_loss(
assert (
time_max > 0 and time_max < 60
), f"frame_time is bad {offset} {time}"
ctx.add_scalar(f"frame_time_max/{offset}", time_max)
ctx.add_scalar(f"frame_time_max/{label}/{cam}/{offset}", time_max)

src_color = batch.color[cam][:, src_frame]
src_color = F.interpolate(
Expand Down

0 comments on commit 6cc5911

Please sign in to comment.