Skip to content

Commit

Permalink
datasets: added more robust testing
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Oct 12, 2023
1 parent 4ce2c2e commit b51c623
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 5 deletions.
16 changes: 14 additions & 2 deletions torchdrive/datasets/autolabeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import zstd

from torchdrive.data import Batch
from torchdrive.datasets.dataset import Dataset
from torchdrive.datasets.dataset import Dataset, Datasets

ZSTD_COMPRESS_LEVEL = 3
ZSTD_THREADS = 1
Expand All @@ -33,7 +33,7 @@ def load_tensors(path: str) -> object:
return safetensors.torch.load(buf)


class AutoLabeler:
class AutoLabeler(Dataset):
"""
Autolabeler takes in a dataset and a cache location and automatically loads
the autolabeled data based off of the batch tokens.
Expand Down Expand Up @@ -63,3 +63,15 @@ def __getitem__(self, idx: int) -> Batch:
return dataclasses.replace(
batch, sem_seg={cam: torch.stack(frames) for cam, frames in out.items()}
)

@property
def NAME(self) -> Datasets:
return self.dataset.NAME

@property
def cameras(self) -> Datasets:
return self.dataset.cameras

@property
def CAMERA_OVERLAP(self) -> Datasets:
return self.dataset.CAMERA_OVERLAP
17 changes: 14 additions & 3 deletions torchdrive/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,20 @@ class Dataset(TorchDataset, ABC):
Base class for datasets used by TorchDrive.
"""

NAME: Datasets
cameras: List[str]
CAMERA_OVERLAP: Dict[str, List[str]]
@property
@abstractmethod
def NAME(self) -> Datasets:
...

@property
@abstractmethod
def cameras(self) -> List[str]:
...

@property
@abstractmethod
def CAMERA_OVERLAP(self) -> Dict[str, List[str]]:
...

@abstractmethod
def __getitem__(self, idx: int) -> Optional[Batch]:
Expand Down
1 change: 1 addition & 0 deletions torchdrive/datasets/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
class DummyDataset(Dataset):
NAME = Datasets.DUMMY
cameras = ["left", "right"]
CAMERA_OVERLAP = {}

def __len__(self) -> int:
return 10
Expand Down
9 changes: 9 additions & 0 deletions torchdrive/datasets/test_autolabeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def test_save_load(self) -> None:

self.assertEqual(out, want)

def test_props(self) -> None:
dataset = DummyDataset()
labeler = AutoLabeler(dataset, "/nonexistant")
self.assertEqual(labeler.NAME, dataset.NAME)
self.assertEqual(labeler.cameras, dataset.cameras)
self.assertEqual(labeler.CAMERA_OVERLAP, dataset.CAMERA_OVERLAP)

def test_autolabeler(self) -> None:
dataset = DummyDataset()
with tempfile.TemporaryDirectory() as path:
Expand Down Expand Up @@ -53,3 +60,5 @@ def test_autolabeler(self) -> None:

out = labeler[0]
self.assertIsNotNone(out.sem_seg)
self.assertEqual(out.sem_seg["left"].shape, (3,))
self.assertEqual(out.sem_seg["right"].shape, (3,))
1 change: 1 addition & 0 deletions torchdrive/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class TrainConfig:
mask_path: str
batch_size: int
num_workers: int
autolabel_path: str

# tasks
det: bool
Expand Down

0 comments on commit b51c623

Please sign in to comment.