Skip to content

Commit

Permalink
autolabel,nuscenes: fixed autolabeling token
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Nov 15, 2023
1 parent 36413e9 commit 16dabcf
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 11 deletions.
9 changes: 5 additions & 4 deletions autolabel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import importlib
import os
from multiprocessing.pool import ThreadPool
from typing import Dict
Expand All @@ -19,7 +18,7 @@
from torch.utils.data.distributed import DistributedSampler

from torchdrive.data import Batch, TransferCollator
from torchdrive.datasets.autolabeler import LabelType, save_tensors
from torchdrive.datasets.autolabeler import AutoLabeler, LabelType, save_tensors
from torchdrive.datasets.dataset import Dataset
from torchdrive.train_config import create_parser, TrainConfig

Expand All @@ -31,8 +30,7 @@


# pyre-fixme[5]: Global expression must be annotated.
config_module = importlib.import_module("configs." + args.config)
config: TrainConfig = config_module.CONFIG
config: TrainConfig = args.config

# overrides
config.num_frames = 1
Expand All @@ -56,6 +54,9 @@

dataset: Dataset = config.create_dataset(smoke=args.smoke)

if isinstance(dataset, AutoLabeler):
dataset = dataset.dataset

sampler: DistributedSampler[Dataset] = DistributedSampler(
dataset,
num_replicas=WORLD_SIZE,
Expand Down
2 changes: 1 addition & 1 deletion torchdrive/datasets/nuscenes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def _getitem(self, sample_data: SampleData) -> Dict[str, object]:
"color": img,
"mask": None,
# pyre-fixme[27]: TypedDict `SampleData` has no key `sample_token`.
"token": sample_data["sample_token"],
"token": sample_data["token"],
}

def __getitem__(self, idx: int) -> Dict[str, object]:
Expand Down
26 changes: 26 additions & 0 deletions torchdrive/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from IPython.display import display
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import draw_bounding_boxes

from torchworld.transforms.img import normalize_img, render_color

Expand All @@ -20,3 +21,28 @@ def display_color(x: torch.Tensor) -> None:
Renders a [w, h] grid into colors and displays it.
"""
display(to_pil_image(render_color(x)))


def display_bboxes(img: torch.Tensor, bboxes: object, threshold: float = 0.5) -> None:
"""
Displays bounding boxes in mmcv format on the provided image.
"""
tboxes = []
labels = []
for i, box in enumerate(bboxes[0]):
if len(box) == 0:
continue
if not isinstance(box, torch.Tensor):
box = torch.from_numpy(box)
p = box[:, 4]
valid = p > threshold
box = box[valid]
tboxes.append(box[:, :4])
labels += [str(i)] * len(box)

tboxes = torch.cat(tboxes, dim=0)

img = normalize_img(img.float())
img = (img.clamp(min=0, max=1) * 255).byte()

display(to_pil_image(draw_bounding_boxes(image=img, boxes=tboxes, labels=labels)))
7 changes: 5 additions & 2 deletions torchworld/structures/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import math
import warnings
from typing import Any, Dict, List, Optional, Self, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import numpy as np
import torch
Expand All @@ -16,6 +16,9 @@

from torchworld.transforms.transform3d import Rotate, Transform3d, Translate

if TYPE_CHECKING:
from typing import Self


# Default values for rotation and translation matrices.
_R: torch.Tensor = torch.eye(3)[None] # (1, 3, 3)
Expand Down Expand Up @@ -366,7 +369,7 @@ def transform_points_screen(

# pyre-fixme[14]: `clone` overrides method defined in `TensorProperties`
# inconsistently.
def clone(self) -> Self:
def clone(self) -> "Self":
"""
Returns a copy of `self`.
"""
Expand Down
11 changes: 7 additions & 4 deletions torchworld/structures/grid.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Self, Tuple, TypeVar, Union
from typing import Optional, Tuple, TYPE_CHECKING, TypeVar, Union

import torch
from pytorch3d.structures.volumes import VolumeLocator

from torchworld.structures.cameras import CamerasBase
from torchworld.transforms.transform3d import Transform3d

if TYPE_CHECKING:
from typing import Self

T = TypeVar("T")


Expand All @@ -17,13 +20,13 @@ class BaseGrid(ABC):
time: torch.Tensor

@abstractmethod
def to(self, target: Union[torch.device, str]) -> Self:
def to(self, target: Union[torch.device, str]) -> "Self":
...

def cuda(self) -> Self:
def cuda(self) -> "Self":
return self.to(torch.device("cuda"))

def cpu(self) -> Self:
def cpu(self) -> "Self":
return self.to(torch.device("cpu"))

@property
Expand Down

0 comments on commit 16dabcf

Please sign in to comment.