From 16dabcf62774bae23ddac3f0c16eef322324c3e7 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Wed, 15 Nov 2023 00:34:28 -0800 Subject: [PATCH] autolabel,nuscenes: fixed autolabeling token --- autolabel.py | 9 +++++---- torchdrive/datasets/nuscenes_dataset.py | 2 +- torchdrive/notebook.py | 26 +++++++++++++++++++++++++ torchworld/structures/cameras.py | 7 +++++-- torchworld/structures/grid.py | 11 +++++++---- 5 files changed, 44 insertions(+), 11 deletions(-) diff --git a/autolabel.py b/autolabel.py index 51e8419..b5fac69 100644 --- a/autolabel.py +++ b/autolabel.py @@ -1,5 +1,4 @@ import argparse -import importlib import os from multiprocessing.pool import ThreadPool from typing import Dict @@ -19,7 +18,7 @@ from torch.utils.data.distributed import DistributedSampler from torchdrive.data import Batch, TransferCollator -from torchdrive.datasets.autolabeler import LabelType, save_tensors +from torchdrive.datasets.autolabeler import AutoLabeler, LabelType, save_tensors from torchdrive.datasets.dataset import Dataset from torchdrive.train_config import create_parser, TrainConfig @@ -31,8 +30,7 @@ # pyre-fixme[5]: Global expression must be annotated. -config_module = importlib.import_module("configs." + args.config) -config: TrainConfig = config_module.CONFIG +config: TrainConfig = args.config # overrides config.num_frames = 1 @@ -56,6 +54,9 @@ dataset: Dataset = config.create_dataset(smoke=args.smoke) +if isinstance(dataset, AutoLabeler): + dataset = dataset.dataset + sampler: DistributedSampler[Dataset] = DistributedSampler( dataset, num_replicas=WORLD_SIZE, diff --git a/torchdrive/datasets/nuscenes_dataset.py b/torchdrive/datasets/nuscenes_dataset.py index 363228a..7b3477f 100644 --- a/torchdrive/datasets/nuscenes_dataset.py +++ b/torchdrive/datasets/nuscenes_dataset.py @@ -361,7 +361,7 @@ def _getitem(self, sample_data: SampleData) -> Dict[str, object]: "color": img, "mask": None, # pyre-fixme[27]: TypedDict `SampleData` has no key `sample_token`. - "token": sample_data["sample_token"], + "token": sample_data["token"], } def __getitem__(self, idx: int) -> Dict[str, object]: diff --git a/torchdrive/notebook.py b/torchdrive/notebook.py index aecfb82..544ff9a 100644 --- a/torchdrive/notebook.py +++ b/torchdrive/notebook.py @@ -4,6 +4,7 @@ import torch from IPython.display import display from torchvision.transforms.functional import to_pil_image +from torchvision.utils import draw_bounding_boxes from torchworld.transforms.img import normalize_img, render_color @@ -20,3 +21,28 @@ def display_color(x: torch.Tensor) -> None: Renders a [w, h] grid into colors and displays it. """ display(to_pil_image(render_color(x))) + + +def display_bboxes(img: torch.Tensor, bboxes: object, threshold: float = 0.5) -> None: + """ + Displays bounding boxes in mmcv format on the provided image. + """ + tboxes = [] + labels = [] + for i, box in enumerate(bboxes[0]): + if len(box) == 0: + continue + if not isinstance(box, torch.Tensor): + box = torch.from_numpy(box) + p = box[:, 4] + valid = p > threshold + box = box[valid] + tboxes.append(box[:, :4]) + labels += [str(i)] * len(box) + + tboxes = torch.cat(tboxes, dim=0) + + img = normalize_img(img.float()) + img = (img.clamp(min=0, max=1) * 255).byte() + + display(to_pil_image(draw_bounding_boxes(image=img, boxes=tboxes, labels=labels))) diff --git a/torchworld/structures/cameras.py b/torchworld/structures/cameras.py index 862d66e..eab0932 100644 --- a/torchworld/structures/cameras.py +++ b/torchworld/structures/cameras.py @@ -6,7 +6,7 @@ import math import warnings -from typing import Any, Dict, List, Optional, Self, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union import numpy as np import torch @@ -16,6 +16,9 @@ from torchworld.transforms.transform3d import Rotate, Transform3d, Translate +if TYPE_CHECKING: + from typing import Self + # Default values for rotation and translation matrices. _R: torch.Tensor = torch.eye(3)[None] # (1, 3, 3) @@ -366,7 +369,7 @@ def transform_points_screen( # pyre-fixme[14]: `clone` overrides method defined in `TensorProperties` # inconsistently. - def clone(self) -> Self: + def clone(self) -> "Self": """ Returns a copy of `self`. """ diff --git a/torchworld/structures/grid.py b/torchworld/structures/grid.py index 9f8994e..01f6ba3 100644 --- a/torchworld/structures/grid.py +++ b/torchworld/structures/grid.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Self, Tuple, TypeVar, Union +from typing import Optional, Tuple, TYPE_CHECKING, TypeVar, Union import torch from pytorch3d.structures.volumes import VolumeLocator @@ -8,6 +8,9 @@ from torchworld.structures.cameras import CamerasBase from torchworld.transforms.transform3d import Transform3d +if TYPE_CHECKING: + from typing import Self + T = TypeVar("T") @@ -17,13 +20,13 @@ class BaseGrid(ABC): time: torch.Tensor @abstractmethod - def to(self, target: Union[torch.device, str]) -> Self: + def to(self, target: Union[torch.device, str]) -> "Self": ... - def cuda(self) -> Self: + def cuda(self) -> "Self": return self.to(torch.device("cuda")) - def cpu(self) -> Self: + def cpu(self) -> "Self": return self.to(torch.device("cpu")) @property