Skip to content

Commit

Permalink
autolabel,nuscenes: align frames within sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Nov 16, 2023
1 parent 16dabcf commit 0a5dbf6
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 39 deletions.
7 changes: 3 additions & 4 deletions autolabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
args: argparse.Namespace = parser.parse_args()


# pyre-fixme[5]: Global expression must be annotated.
config: TrainConfig = args.config

# overrides
Expand Down Expand Up @@ -185,8 +184,8 @@ def get_task_path(task: nn.Module) -> str:
token = batch.token[i][0]
assert len(token) > 5
token_path = os.path.join(task_path, f"{token}.safetensors.zstd")
token_paths.append(token_path)
if not os.path.exists(token_path):
token_paths.append(token_path)
idxs.append(i)

if len(idxs) == 0:
Expand All @@ -198,10 +197,10 @@ def get_task_path(task: nn.Module) -> str:
with torch.no_grad():
cam_data[cam] = task(squashed)

for j, i in enumerate(idxs):
for i in range(len(idxs)):
frame_data = {}
for cam, pred in cam_data.items():
frame_data[cam] = pred[j]
frame_data[cam] = pred[i]

path = token_paths[i]
handles.append(
Expand Down
141 changes: 107 additions & 34 deletions torchdrive/datasets/nuscenes_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import bisect
import os
import time
from bisect import bisect_left
from bisect import bisect_left, bisect_right
from collections import defaultdict
from datetime import timedelta
from typing import Dict, List, Optional, Tuple, TypedDict
from typing import Dict, Iterable, List, Optional, Tuple, TypedDict, TypeVar

import orjson
import pandas as pd
Expand All @@ -14,7 +16,7 @@
from PIL import Image
from pytorch3d.transforms import quaternion_to_matrix
from strenum import StrEnum
from torch.utils.data import ConcatDataset, DataLoader, Dataset as TorchDataset
from torch.utils.data import DataLoader, Dataset as TorchDataset
from tqdm import tqdm

from torchdrive.datasets.dataset import Dataset, Datasets
Expand Down Expand Up @@ -186,6 +188,7 @@ def calculate_nearest_data_within_epsilon(
sorted_timestamps: List[int],
) -> Dict[int, Tuple[SampleData, Dict[str, int]]]:
nearest_data_within_epsilon = {}
unmatched = 0
for cam_front_timestamp in sorted_timestamps:
nearest_data = {}
nearest_data_idxs = {}
Expand All @@ -197,20 +200,27 @@ def calculate_nearest_data_within_epsilon(
continue

timestamp_range_start = bisect_left(sorted_timestamps, min_timestamp)
timestamp_range_end = bisect_left(sorted_timestamps, max_timestamp)
timestamp_range_end = bisect_right(sorted_timestamps, max_timestamp)

smallest_diff = 1000000000000000 # large number to start
for i in range(timestamp_range_start, timestamp_range_end):
timestamp = sorted_timestamps[i]
sample_data = timestamp_index[timestamp]
if cam in sample_data:
diff = abs(timestamp - cam_front_timestamp)
if cam in sample_data and diff < smallest_diff:
smallest_diff = diff
nearest_data[cam], nearest_data_idxs[cam] = sample_data[cam]
break

nearest_data_within_epsilon[cam_front_timestamp] = (
(nearest_data, nearest_data_idxs)
if len(nearest_data) == len(cam_samples) - 1
else (None, None)
)
if len(nearest_data) == (len(cam_samples) - 1):
nearest_data_within_epsilon[cam_front_timestamp] = (
nearest_data,
nearest_data_idxs,
)
else:
nearest_data_within_epsilon[cam_front_timestamp] = (None, None)
unmatched += 1

print(f"failed to match {unmatched}/{len(sorted_timestamps)} frames")

return nearest_data_within_epsilon

Expand Down Expand Up @@ -240,7 +250,66 @@ def get_nearest_data_within_epsilon(

cam_front_timestamp = as_num(cam_front_samples[idx]["timestamp"])
# pyre-fixme[6]: For 1st argument expected `int` but got `str`.
return self.nearest_data_within_epsilon[cam_front_timestamp]
out = self.nearest_data_within_epsilon[cam_front_timestamp]
return out


T_co = TypeVar("T_co", covariant=True)


class ConcatDataset(TorchDataset[T_co]):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets. Based on the
PyTorch ConcatDataset but supports multiple indices.
Args:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[TorchDataset[T_co]]
cumulative_sizes: List[int]

@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r

def __init__(self, datasets: Iterable[TorchDataset]) -> None:
super().__init__()
self.datasets = list(datasets)
assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]
self.cumulative_sizes = self.cumsum(self.datasets)

def __len__(self):
return self.cumulative_sizes[-1]

def __getitem__(self, idxs: List[int]):
idx = idxs[0]
if idx < 0:
if -idx > len(self):
raise ValueError(
"absolute value of index should not exceed dataset length"
)
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idxs = idxs
else:
sample_idxs = [idx - self.cumulative_sizes[dataset_idx - 1] for idx in idxs]
return self.datasets[dataset_idx][sample_idxs]

@property
def cummulative_sizes(self):
warnings.warn(
"cummulative_sizes attribute is renamed to " "cumulative_sizes",
DeprecationWarning,
stacklevel=2,
)
return self.cumulative_sizes


def get_ego_T(nusc: NuScenes, sample_data: SampleData) -> torch.Tensor:
Expand Down Expand Up @@ -301,7 +370,7 @@ def get_sensor_calibration_T(nusc: NuScenes, sample_data: SampleData) -> torch.T
return trans_T.matmul(rot_T)


class CameraDataset(TorchDataset):
class CameraDataset:
"""A "scene" is all the sample data from first (the one with no prev) to last (the one with no next) for a single camera."""

def __init__(
Expand All @@ -319,7 +388,7 @@ def __init__(
self.sensor = sensor

def __len__(self) -> int:
return len(self.samples) - self.num_frames - 1
return len(self.samples) # - self.num_frames - 1

def _getitem(self, sample_data: SampleData) -> Dict[str, object]:
cam_T = get_ego_T(self.nusc, sample_data)
Expand Down Expand Up @@ -364,9 +433,9 @@ def _getitem(self, sample_data: SampleData) -> Dict[str, object]:
"token": sample_data["token"],
}

def __getitem__(self, idx: int) -> Dict[str, object]:
def __getitem__(self, idxs: List[int]) -> Dict[str, object]:
frame_dicts = []
for i in range(idx, idx + self.num_frames):
for i in idxs:
sample = self.samples[i]
frame_dict = self._getitem(sample)
frame_dicts.append(frame_dict)
Expand Down Expand Up @@ -397,11 +466,9 @@ def __getitem__(self, idx: int) -> Dict[str, object]:
frame_time = torch.tensor(
[fd["frame_time"] for fd in frame_dicts], dtype=torch.int64
)
frame_time = frame_time - frame_time[0]
frame_time = frame_time.float() / 1e6

long_cam_Ts = [
get_ego_T(self.nusc, sample_data) for sample_data in self.samples[idx:]
get_ego_T(self.nusc, sample_data) for sample_data in self.samples[idxs[0] :]
]

mask_transform = transforms.Compose(
Expand Down Expand Up @@ -440,7 +507,7 @@ def __getitem__(self, idx: int) -> Dict[str, object]:
}


class LidarDataset(TorchDataset):
class LidarDataset:
def __init__(
self,
dataroot: str,
Expand All @@ -455,13 +522,12 @@ def __init__(
self.num_frames = num_frames

def __len__(self) -> int:
return len(self.samples) - self.num_frames - 1
return len(self.samples) # - self.num_frames - 1

def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
def __getitem__(self, idxs: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
# Account for FPS difference between cameras (15ps) and lidar
# (20fps)
idx += self.num_frames // 2 * 15 // 20
sample_data = self.samples[idx]
sample_data = self.samples[idxs[2]] # get 3rd lidar frame

calibrated_sensor_token = sample_data["calibrated_sensor_token"]
calibrated_sensor = self.nusc.get("calibrated_sensor", calibrated_sensor_token)
Expand Down Expand Up @@ -526,7 +592,7 @@ def __init__(

# Create a timestamp matcher to match the timestamps of each camera (addresses the issue of cameras being out of sync)
self.timestamp_matcher = TimestampMatcher(
self.cam_samples, epsilon=timedelta(milliseconds=51)
self.cam_samples, epsilon=timedelta(milliseconds=100)
)

def _cam2scenes(
Expand Down Expand Up @@ -582,21 +648,26 @@ def __len__(self) -> int:
def _getitem(self, idx: int) -> Optional[Batch]:
"""Returns one row of a Batch of data for the given index."""
# Do timestamp matching for this idx
sample_data, idxs = self.timestamp_matcher.get_nearest_data_within_epsilon(
idx
) # Returns { cam: sample_data, ... } for all cams except CAM_FRONT
if sample_data is None:
return None

front_idxs = list(range(idx, idx + self.num_frames))
cam_idxs = defaultdict(lambda: [])
for i in front_idxs:
_, idxs = self.timestamp_matcher.get_nearest_data_within_epsilon(i)
if idxs is None:
print(f"failed to find idx for {i}, {idx}")
return None
for cam, cam_idx in idxs.items():
cam_idxs[cam].append(cam_idx)

# Now get processed sample data using CameraDatasets from the cam_scenes
data = {}
for cam, adj_idx in idxs.items():
for cam, idxs in cam_idxs.items():
cam_scene = self.cam_scenes[cam]
if adj_idx < 0 or adj_idx >= len(cam_scene):
if min(idxs) < 0 or max(idxs) >= len(cam_scene):
# TODO: figure out why index is invalid
return None
data[cam] = cam_scene[adj_idx]
data[CamTypes.CAM_FRONT] = self.cam_scenes[CamTypes.CAM_FRONT][idx]
data[cam] = cam_scene[idxs]
data[CamTypes.CAM_FRONT] = self.cam_scenes[CamTypes.CAM_FRONT][front_idxs]

token: Tensor = data[CamTypes.CAM_FRONT]["token"]
weight: Tensor = data[CamTypes.CAM_FRONT]["weight"]
Expand All @@ -605,6 +676,8 @@ def _getitem(self, idx: int) -> Optional[Batch]:
long_cam_Ts: Tensor = data[CamTypes.CAM_FRONT]["long_cam_T"]
frame_Ts: Tensor = data[CamTypes.CAM_FRONT]["frame_T"]
frame_times: Tensor = data[CamTypes.CAM_FRONT]["frame_time"]
frame_times = frame_times - frame_times[0]
frame_times = frame_times.float() / 1e6
Ks: Dict[str, Tensor] = {}
Ts: Dict[str, Tensor] = {}
colors: Dict[str, Tensor] = {}
Expand Down
2 changes: 1 addition & 1 deletion torchdrive/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def display_bboxes(img: torch.Tensor, bboxes: object, threshold: float = 0.5) ->
"""
tboxes = []
labels = []
for i, box in enumerate(bboxes[0]):
for i, box in enumerate(bboxes):
if len(box) == 0:
continue
if not isinstance(box, torch.Tensor):
Expand Down

0 comments on commit 0a5dbf6

Please sign in to comment.