Skip to content

Commit

Permalink
voxel: added multi_pose support
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Oct 13, 2023
1 parent 6cc5911 commit de09cbf
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 44 deletions.
43 changes: 43 additions & 0 deletions configs/simplebev3d_multi_pose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from torchdrive.train_config import Datasets, TrainConfig


CONFIG = TrainConfig(
# backbone settings
cameras=[
"CAM_FRONT",
"CAM_FRONT_LEFT",
"CAM_FRONT_RIGHT",
"CAM_BACK",
"CAM_BACK_LEFT",
"CAM_BACK_RIGHT",
],
dim=256,
cam_dim=96,
hr_dim=384,
backbone="simple_bev3d",
cam_encoder="simple_regnet",
num_frames=9,
num_encode_frames=3,
start_offsets=(0, 4),
cam_shape=(480, 640),
num_upsamples=1,
grid_shape=(256, 256, 16),
# optimizer settings
epochs=20,
lr=1e-4,
grad_clip=1.0,
step_size=1000,
# dataset
dataset=Datasets.NUSCENES,
dataset_path="/mnt/ext3/nuscenes",
autolabel_path="/mnt/ext3/autolabel",
mask_path="n/a", # only used for rice dataset
num_workers=6,
batch_size=2,
# tasks
det=False,
ae=False,
voxel=True,
voxelsem=True,
path=False,
)
5 changes: 5 additions & 0 deletions torchdrive/models/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ class BDD100KSemSeg:
18,
)

# dynamic object indexes when filtered by NON_SKY
DYNAMIC_NON_SKY = tuple(
idx-1 for idx in DYNAMIC
)

NON_SKY = (
0,
1,
Expand Down
94 changes: 72 additions & 22 deletions torchdrive/tasks/test_voxel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,33 @@

VOXEL_LOSSES = [
"tvl1",
"lossproj-voxel/right/o1",
"lossproj-voxel/left/o1",
"lossproj-voxel/right/o0",
"lossproj-voxel/left/o0",
"lossproj-voxel/right/o-1",
"lossproj-voxel/left/o-1",
"lossproj-cam/right/o1",
"lossproj-cam/left/o1",
"lossproj-cam/right/o0",
"lossproj-cam/left/o0",
"lossproj-cam/right/o-1",
"lossproj-cam/left/o-1",
"losssmooth-voxel-disp/left",
"losssmooth-voxel-disp/right",
"losssmooth-cam-disp/left",
"losssmooth-cam-disp/right",
"visible-probs/left",
"visible-probs/right",
"depth-probs/left",
"depth-probs/right",
"lossproj-voxel/right1/o1",
"lossproj-voxel/left1/o1",
"lossproj-voxel/right1/o0",
"lossproj-voxel/left1/o0",
"lossproj-voxel/right1/o-1",
"lossproj-voxel/left1/o-1",
"lossproj-cam/right1/o1",
"lossproj-cam/left1/o1",
"lossproj-cam/right1/o0",
"lossproj-cam/left1/o0",
"lossproj-cam/right1/o-1",
"lossproj-cam/left1/o-1",
"losssmooth-voxel-disp/left1",
"losssmooth-voxel-disp/right1",
"losssmooth-cam-disp/left1",
"losssmooth-cam-disp/right1",
"visible-probs/left1",
"visible-probs/right1",
"depth-probs/left1",
"depth-probs/right1",
# "losssmooth-cam-vel/left",
# "losssmooth-cam-vel/right",
]

SEMANTIC_LOSSES = [
"semantic-voxel/left",
"semantic-voxel/right",
"semantic-voxel/left1",
"semantic-voxel/right1",
"semantic-cam/left",
"semantic-cam/right",
# "losssmooth-voxel-vel/left",
Expand Down Expand Up @@ -69,6 +69,7 @@ def test_voxel_task(self) -> None:
device=device,
render_batch_size=5,
n_pts_per_ray=10,
start_offsets=(0,),
offsets=(-1, 0, 1),
).to(device)
batch = dummy_batch().to(device)
Expand Down Expand Up @@ -127,6 +128,55 @@ def test_semantic_voxel_task(self) -> None:
)
self._assert_loss_shapes(losses)

def test_multi_pose_semantic_voxel_task(self) -> None:
device = torch.device("cpu")
cameras = ["left", "right"]
m = VoxelTask(
cameras=cameras,
cam_shape=(320, 240),
cam_feats_shape=(320 // 16, 240 // 16),
dim=4,
cam_dim=4,
hr_dim=5,
height=12,
device=device,
semantic=["left"],
start_offsets=(0, 1),
offsets=(-1, 0),
)
batch = dummy_batch()
ctx = Context(
log_img=True,
log_text=True,
global_step=0,
writer=MagicMock(),
start_frame=1,
scaler=None,
name="det",
output="",
weights=batch.weight,
cam_feats={cam: torch.rand(2, 4, 320 // 16, 240 // 16) for cam in cameras},
)
bev = torch.rand(2, 1, 5, 4, 4)
losses = m(ctx, batch, bev)
ctx.backward(losses)

target_keys = set()
for key in SEMANTIC_LOSSES + VOXEL_LOSSES:
if "o1" in key:
continue
target_keys.add(key)
if "cam" not in key:
if "left1" in key:
target_keys.add(key.replace("left1", "left2"))
elif "right1" in key:
target_keys.add(key.replace("right1", "right2"))

self.assertCountEqual(
losses.keys(), target_keys,
)
self._assert_loss_shapes(losses)

def test_stereoscopic_voxel_task(self) -> None:
device = torch.device("cpu")
cameras = ["left", "right"]
Expand Down
Loading

0 comments on commit de09cbf

Please sign in to comment.