Skip to content

Commit

Permalink
tasks/bev: compute camera features more efficiently
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Oct 8, 2023
1 parent 8c199e1 commit 035c4de
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
40 changes: 28 additions & 12 deletions torchdrive/tasks/bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,24 +118,39 @@ def forward(
drop_cameras = random.choices(self.cameras, k=self.num_drop_encode_cameras)
dropout_cameras = set(self.cameras) - set(drop_cameras)

with autocast():
with autocast(), torch.autograd.profiler.record_function("camera_feats"):
first_backprop_frame = max(
self.num_encode_frames - self.num_backprop_frames, 0
)

# individual frames
camera_feats = {cam: [] for cam in dropout_cameras}
for frame in range(0, first_backprop_frame):
for cam in dropout_cameras:
out = self.camera_encoders[cam](batch.color[cam][:, frame]).detach()
assert not out.requires_grad
camera_feats[cam].append(out)
for frame in range(first_backprop_frame, self.num_encode_frames):
pause = frame == (self.num_encode_frames - 1)
for cam in dropout_cameras:
feat = self.camera_encoders[cam](batch.color[cam][:, frame])

# these first set of frames don't use backprop so set them to eval
# mode and don't collect gradient
with torch.no_grad():
for cam in dropout_cameras:
# run frames in parallel
encoder = self.camera_encoders[cam]
encoder.eval()
inp = batch.color[cam][:, 0:first_backprop_frame]
feats = encoder(inp.flatten(0, 1)).unflatten(0, inp.shape[0:2])
assert not feats.requires_grad
for feat in feats:
camera_feats[cam].append(feat)

# frames we want backprop for
for cam in dropout_cameras:
# run frames in parallel
encoder = self.camera_encoders[cam]
encoder.train()
inp = batch.color[cam][:, first_backprop_frame : self.num_encode_frames]
feats = encoder(inp.flatten(0, 1)).unflatten(0, inp.shape[0:2])

for i, feat in enumerate(feats):
# pause the last cam encoder backprop for tasks with image
# space losses
if pause:
if i == (len(feats) - 1):
feat = autograd_pause(feat)
last_cam_feats[cam] = feat

Expand All @@ -149,7 +164,8 @@ def forward(
)
camera_feats[cam].append(feat)

hr_bev, bev = self.backbone(camera_feats, batch)
with torch.autograd.profiler.record_function("backbone"):
hr_bev, bev = self.backbone(camera_feats, batch)

last_cam_feats_resume = last_cam_feats

Expand Down
3 changes: 3 additions & 0 deletions torchdrive/tasks/test_bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def forward(
bev.mean().backward()
ctx.add_scalar("test", bev.shape[-1])

for cam_feat in ctx.cam_feats.values():
cam_feat.mean().backward()

# check that start position is at zero
car_to_world = batch.car_to_world(ctx.start_frame)
zero = (
Expand Down

0 comments on commit 035c4de

Please sign in to comment.