Skip to content

Commit

Permalink
nuscenes: lidar loading fixes + metrics fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Sep 15, 2023
1 parent 6f0b2c3 commit a2dcff8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
13 changes: 9 additions & 4 deletions torchdrive/datasets/nuscenes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def __getitem__(self, idx: int) -> object:
pcl = LidarPointCloud.from_file(
os.path.join(self.dataroot, sample_data["filename"])
)
return pcl.points
return torch.from_numpy(pcl.points)


class NuscenesDataset(Dataset):
Expand All @@ -292,7 +292,9 @@ class NuscenesDataset(Dataset):
CamTypes.CAM_BACK_RIGHT: [CamTypes.CAM_BACK, CamTypes.CAM_FRONT_RIGHT],
}

def __init__(self, data_dir: str, version: str = "v1.0-trainval") -> None:
def __init__(
self, data_dir: str, version: str = "v1.0-trainval", lidar: bool = False
) -> None:
self.data_dir = data_dir
self.version = version
self.nusc = NuScenes(version=version, dataroot=data_dir, verbose=True)
Expand All @@ -304,7 +306,10 @@ def __init__(self, data_dir: str, version: str = "v1.0-trainval") -> None:
CamTypes.CAM_BACK_LEFT,
CamTypes.CAM_BACK_RIGHT,
]
self.sensor_types = self.cam_types + [SensorTypes.LIDAR_TOP]
self.sensor_types = self.cam_types
if lidar:
self.sensor_types = self.sensor_types + [SensorTypes.LIDAR_TOP]
self.lidar = lidar
self.cameras: List[str] = list(self.CAMERA_OVERLAP.keys())

# Organize all the sample_data into scenes by camera type
Expand Down Expand Up @@ -399,7 +404,7 @@ def _getitem(self, idx: int) -> Optional[Batch]:
colors[cam] = torch.stack(cam_colors, dim=0)
masks[cam] = sample_dict["mask"]

lidar = data[SensorTypes.LIDAR_TOP]
lidar = data[SensorTypes.LIDAR_TOP] if self.lidar else None

return Batch(
weight=weight.float(),
Expand Down
2 changes: 2 additions & 0 deletions torchdrive/tasks/voxel.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,12 @@ def _losses(
add_text=False,
)
ctx.add_figure("semantic/confusion_matrix", fig)
self.semantic_confusion_matrix.reset()
if ctx.log_text:
ctx.add_scalar(
"semantic/accuracy", self.semantic_accuracy.compute()
)
self.semantic_accuracy.reset()

del voxel_depth
del semantic_vel
Expand Down

0 comments on commit a2dcff8

Please sign in to comment.