Skip to content

Commit

Permalink
torchdrive: use bfloat16 grid_samples w/ pytorch nightly
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Nov 3, 2023
1 parent 976c15f commit dffa338
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 31 deletions.
2 changes: 1 addition & 1 deletion configs/simplebev3d_multi_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
autolabel_path="/mnt/ext3/autolabel",
mask_path="n/a", # only used for rice dataset
num_workers=4,
batch_size=2,
batch_size=4,
# tasks
det=True,
ae=False,
Expand Down
2 changes: 1 addition & 1 deletion torchdrive/models/simple_bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ def forward(
x = self.upsample(x)

x4_coarse = x4.view_as(x4)
x4 = x4.flatten(1, 2)
x4 = x4_coarse.flatten(1, 2)
x4 = self.coarse_project(x4)

return x, x4, {"coarse": x4_coarse, "skip": x4_skip}
4 changes: 2 additions & 2 deletions torchdrive/render/volume_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def forward(
# run the grid sampler on the volumes densities
rays_densities = torch.nn.functional.grid_sample(
volumes_densities,
rays_points_local_flat,
rays_points_local_flat.to(volumes_densities.dtype),
align_corners=True,
mode=self._sample_mode,
padding_mode=self._padding_mode,
Expand All @@ -159,7 +159,7 @@ def forward(
else:
rays_features = torch.nn.functional.grid_sample(
volumes_features,
rays_points_local_flat,
rays_points_local_flat.to(volumes_features.dtype),
align_corners=True,
mode=self._sample_mode,
padding_mode=self._padding_mode,
Expand Down
2 changes: 1 addition & 1 deletion torchdrive/tasks/bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def forward(
if len(backbone_out) >= 3:
x4_intermediates = backbone_out[2]

for tag, x in x4_intermediates:
for tag, x in x4_intermediates.items():
register_log_grad_norm(
t=x,
writer=writer,
Expand Down
50 changes: 25 additions & 25 deletions torchdrive/tasks/voxel.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
scale: int = 3,
z_offset: float = 0.5,
semantic: Optional[List[str]] = None,
render_batch_size: int = 5,
render_batch_size: int = 1000,
n_pts_per_ray: int = 216,
compile_fn: Callable[[nn.Module], nn.Module] = lambda x: x,
start_offsets: Tuple[int, ...] = (0,),
Expand Down Expand Up @@ -239,13 +239,11 @@ def forward(

with autocast():
embedding = self.decoder(bev) # .unflatten(1, (self.num_elem, self.height))
# convert back to float so sigmoid works
embedding = embedding.float()

# log grad norms on full embedding to make norms comparable
grid = ctx.log_grad_norm(embedding, "grad/norm/grid_embedding", "grid")[
:, :1
].sigmoid()
grid = ctx.log_grad_norm(embedding, "grad/norm/grid_embedding", "grid")[:, :1]
# convert back to float so sigmoid works
grid = grid.float().sigmoid()
feat_grid = ctx.log_grad_norm(
embedding, "grad/norm/grid_embedding", "feat_grid"
)[:, 1:]
Expand Down Expand Up @@ -457,15 +455,15 @@ def _losses_frame(
for cam in self.cameras:
primary_color = batch.color[cam][:, frame]
primary_color = F.interpolate(
primary_color.float(),
primary_color,
[h // 2, w // 2],
mode="bilinear",
align_corners=False,
)
primary_colors[cam] = primary_color
primary_mask = batch.mask[cam]
primary_mask = F.interpolate(
primary_mask.float(),
primary_mask,
[h // 2, w // 2],
mode="bilinear",
align_corners=False,
Expand All @@ -484,7 +482,7 @@ def _losses_frame(
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
semantic_target = batch.sem_seg[cam][:, frame]
semantic_target = F.interpolate(
semantic_target.float(),
semantic_target,
[h // 2, w // 2],
mode="bilinear",
align_corners=False,
Expand Down Expand Up @@ -539,9 +537,11 @@ def _losses_frame(
for cam in self.cameras:
dynamic_masks[cam] = torch.zeros(BS, 1, h // 2, w // 2, device=device)

dtype = torch.bfloat16 if grid.is_cuda else torch.float32

volumes = Volumes(
densities=grid.permute(0, 1, 4, 2, 3),
features=feat_grid.permute(0, 1, 4, 2, 3).float()
densities=grid.permute(0, 1, 4, 2, 3).to(dtype),
features=feat_grid.permute(0, 1, 4, 2, 3).to(dtype)
if feat_grid is not None
else None,
voxel_size=1 / self.scale,
Expand Down Expand Up @@ -638,7 +638,7 @@ def _losses_frame(
semantic_img, "grad/semantic_img", "semantic_vel"
)[:, self.classes_elem :]
semantic_vel = F.interpolate(
semantic_vel.float(),
semantic_vel,
[h // 2, w // 2],
mode="bilinear",
align_corners=False,
Expand All @@ -647,7 +647,7 @@ def _losses_frame(
semantic_img, "grad/semantic_img", "semantic_classes"
)[:, : self.classes_elem].sigmoid()
semantic_classes = F.interpolate(
semantic_classes.float(),
semantic_classes,
[h // 2, w // 2],
mode="bilinear",
align_corners=False,
Expand Down Expand Up @@ -768,14 +768,14 @@ def _losses_frame(
cam_disp, cam_vel, cam_sem = self.depth_decoder(ctx.cam_feats[cam])

cam_vel = F.interpolate(
cam_vel.float(),
cam_vel,
[h // 2, w // 2],
mode="bilinear",
align_corners=False,
)
if self.semantic:
cam_sem = F.interpolate(
cam_sem.float(),
cam_sem,
[h // 2, w // 2],
mode="bilinear",
align_corners=False,
Expand Down Expand Up @@ -876,13 +876,13 @@ def _sfm_loss(
multiple times.
"""
depth = F.interpolate(
depth.float().unsqueeze(1),
depth.unsqueeze(1),
[h // 2, w // 2],
mode="bilinear",
align_corners=False,
)
disp = F.interpolate(
disp.float().unsqueeze(1),
disp.unsqueeze(1),
[h // 2, w // 2],
mode="bilinear",
align_corners=False,
Expand Down Expand Up @@ -938,7 +938,7 @@ def _sfm_loss(

src_color = batch.color[cam][:, src_frame]
src_color = F.interpolate(
src_color.float(),
src_color,
[h // 2, w // 2],
mode="bilinear",
align_corners=False,
Expand Down Expand Up @@ -1023,10 +1023,10 @@ def _stereoscopic_loss(
"""
frame = ctx.start_frame
primary_mask = cam_masks[primary_cam]
primary_features = cam_features[primary_cam].float()
primary_features = cam_features[primary_cam]

primary_depth = F.interpolate(
primary_depth.float().unsqueeze(1),
primary_depth.unsqueeze(1),
[h // 2, w // 2],
mode="bilinear",
align_corners=False,
Expand All @@ -1036,7 +1036,7 @@ def _stereoscopic_loss(
target_cam_to_world = batch.cam_to_world(primary_cam, frame)
world_to_src_cam = batch.world_to_cam(src_cam, frame)
src_mask = cam_masks[src_cam]
src_features = cam_features[src_cam].float()
src_features = cam_features[src_cam]

proj_features, proj_mask = self._project(
batch=batch,
Expand Down Expand Up @@ -1104,8 +1104,8 @@ def _semantic_loss(
)

sem_loss = self.projection_loss(
semantic_classes.float(),
semantic_target.float(),
semantic_classes,
semantic_target,
scales=3,
mask=mask,
)
Expand Down Expand Up @@ -1185,14 +1185,14 @@ def _project(

color = F.grid_sample(
src_color,
pix_coords,
pix_coords.to(src_color.dtype),
mode="bilinear",
padding_mode="border",
align_corners=False,
)
mask = F.grid_sample(
src_mask,
pix_coords,
pix_coords.to(src_mask.dtype),
mode="nearest",
padding_mode="zeros",
align_corners=False,
Expand Down
2 changes: 1 addition & 1 deletion torchdrive/transforms/simple_bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def lift_cam_to_voxel(
pix_coords = pix_coords.permute(0, 2, 1).unsqueeze(1)
pix_coords = (pix_coords - 0.5) * 2
# features = features.permute(0, 1, 3, 2)
values = F.grid_sample(features.float(), pix_coords, align_corners=False)
values = F.grid_sample(features, pix_coords.to(features.dtype), align_corners=False)

# restore to grid shape
values = values.squeeze(2).unflatten(-1, grid_shape)
Expand Down

0 comments on commit dffa338

Please sign in to comment.