Skip to content

Commit

Permalink
autolabeler: added Autolabeler and tests for it
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Oct 12, 2023
1 parent f7ec683 commit 4ce2c2e
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 16 deletions.
17 changes: 4 additions & 13 deletions autolabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,16 @@
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(LOCAL_RANK))

import safetensors.torch

import torch
import torch.distributed as dist
import torch.nn.functional as F
import zstd
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from torchdrive.data import Batch, TransferCollator
from torchdrive.datasets.autolabeler import LabelType, save_tensors
from torchdrive.datasets.dataset import Dataset
from torchdrive.models.semantic import BDD100KSemSeg

from torchdrive.train_config import create_parser

parser = create_parser()
Expand Down Expand Up @@ -111,20 +109,13 @@
"""

assert os.path.exists(args.output), "output dir must exist"
sem_seg_path = os.path.join(args.output, config.dataset, "sem_seg")
sem_seg_path = os.path.join(args.output, config.dataset, LabelType.SEM_SEG)
print(f"writing to {sem_seg_path}")
os.makedirs(sem_seg_path, exist_ok=True)

pool = ThreadPool(args.batch_size)


def save(path, frame_data):
buf = safetensors.torch.save(frame_data)
compressed = zstd.ZSTD_compress(buf, 3, 1) # level, threads
with open(path, "wb") as f:
f.write(compressed)


handles = []

for batch in tqdm(collator):
Expand All @@ -146,7 +137,7 @@ def save(path, frame_data):
path = os.path.join(sem_seg_path, f"{token}.safetensors.zstd")
handles.append(
pool.apply_async(
save,
save_tensors,
(
path,
frame_data,
Expand Down
6 changes: 5 additions & 1 deletion torchdrive/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
@dataclass(frozen=True)
class Batch:
# per frame unique token, must be unique across the entire dataset
token: List[List[object]]
token: List[List[str]]
# example weight [BS]
weight: torch.Tensor
# per frame distance traveled in meters [BS, num_frames]
Expand Down Expand Up @@ -52,6 +52,10 @@ class Batch:
# Lidar data [BS, 4, n], channel format is [x, y, z, intensity]
lidar: Optional[torch.Tensor] = None

# AutoLabeler fields
# semantic segmentation for each camera and the frames
sem_seg: Optional[Dict[str, torch.Tensor]] = None

global_batch_size: int = 1

def batch_size(self) -> int:
Expand Down
49 changes: 47 additions & 2 deletions torchdrive/datasets/autolabeler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,39 @@
import dataclasses
import os.path
from enum import Enum
from typing import Dict

import safetensors.torch
import torch

import zstd

from torchdrive.data import Batch
from torchdrive.datasets.dataset import Dataset

ZSTD_COMPRESS_LEVEL = 3
ZSTD_THREADS = 1


class LabelType(str, Enum):
SEM_SEG = "sem_seg"


def save_tensors(path: str, data: Dict[str, torch.Tensor]) -> None:
buf = safetensors.torch.save(data)
buf = zstd.compress(buf, ZSTD_COMPRESS_LEVEL, ZSTD_THREADS)
with open(path, "wb") as f:
f.write(buf)


def load_tensors(path: str) -> object:
with open(path, "rb") as f:
buf = f.read()
buf = zstd.uncompress(buf)
return safetensors.torch.load(buf)


class Autolabeler:
class AutoLabeler:
"""
Autolabeler takes in a dataset and a cache location and automatically loads
the autolabeled data based off of the batch tokens.
Expand All @@ -17,4 +48,18 @@ def __len__(self) -> int:

def __getitem__(self, idx: int) -> Batch:
batch = self.dataset[idx]
tokens = batch.token
tokens = batch.token[0]
out = {cam: [] for cam in self.dataset.cameras}
for token in tokens:
path = os.path.join(
self.path,
self.dataset.NAME,
LabelType.SEM_SEG,
f"{token}.safetensors.zstd",
)
data = load_tensors(path)
for cam, frame in data.items():
out[cam].append(frame)
return dataclasses.replace(
batch, sem_seg={cam: torch.stack(frames) for cam, frames in out.items()}
)
1 change: 1 addition & 0 deletions torchdrive/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class Datasets(str, Enum):

RICE = "rice"
NUSCENES = "nuscenes"
DUMMY = "dummy"


class Dataset(TorchDataset, ABC):
Expand Down
15 changes: 15 additions & 0 deletions torchdrive/datasets/dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from torchdrive.data import Batch, dummy_item
from torchdrive.datasets.dataset import Dataset, Datasets


class DummyDataset(Dataset):
NAME = Datasets.DUMMY
cameras = ["left", "right"]

def __len__(self) -> int:
return 10

def __getitem__(self, idx: int) -> Batch:
if idx > len(self):
raise IndexError("invalid index")
return dummy_item()
55 changes: 55 additions & 0 deletions torchdrive/datasets/test_autolabeler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os.path
import tempfile
import unittest

import torch

from torchdrive.datasets.autolabeler import (
AutoLabeler,
LabelType,
load_tensors,
save_tensors,
)

from torchdrive.datasets.dummy import DummyDataset


class TestAutolabeler(unittest.TestCase):
def test_save_load(self) -> None:
with tempfile.TemporaryDirectory() as path:
want = {"foo": torch.rand(1)}
path = os.path.join(path, "foo.safetensors.zstd")
save_tensors(path, want)
out = load_tensors(path)

self.assertEqual(out, want)

def test_autolabeler(self) -> None:
dataset = DummyDataset()
with tempfile.TemporaryDirectory() as path:
labeler = AutoLabeler(dataset, path)
self.assertEqual(len(labeler), len(dataset))

sem_seg_path = os.path.join(
path,
dataset.NAME,
LabelType.SEM_SEG,
)
os.makedirs(sem_seg_path)

batch = dataset[0]
for token in batch.token[0]:
tensor_path = os.path.join(
sem_seg_path,
f"{token}.safetensors.zstd",
)
save_tensors(
tensor_path,
{
"left": torch.tensor(1),
"right": torch.tensor(1),
},
)

out = labeler[0]
self.assertIsNotNone(out.sem_seg)
12 changes: 12 additions & 0 deletions torchdrive/datasets/test_dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import unittest

from torchdrive.data import Batch
from torchdrive.datasets.dummy import DummyDataset


class TestDummyDataset(unittest.TestCase):
def test_sanity(self) -> None:
dataset = DummyDataset()
self.assertEqual(len(dataset), 10)
for item in dataset:
self.assertIsInstance(item, Batch)

0 comments on commit 4ce2c2e

Please sign in to comment.