Skip to content

Commit

Permalink
autolabel loading progress, rice support
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Oct 12, 2023
1 parent adc55bb commit f7ec683
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 25 deletions.
107 changes: 95 additions & 12 deletions autolabel.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
import os
from tqdm import tqdm
import argparse
import importlib
import os
from multiprocessing.pool import ThreadPool

from tqdm import tqdm

# set device before loading CUDA/PyTorch
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(LOCAL_RANK))

import safetensors.torch

import torch
from torchdrive.train_config import create_parser
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parameter import Parameter
import torch.distributed as dist
import torch.nn.functional as F
import zstd
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from torchdrive.checkpoint import remap_state_dict
from torchdrive.data import Batch, transfer, TransferCollator
from torchdrive.data import Batch, TransferCollator
from torchdrive.datasets.dataset import Dataset
from torchdrive.models.semantic import BDD100KSemSeg

from torchdrive.train_config import create_parser

parser = create_parser()
parser.add_argument("--num_workers", type=int, required=True)
parser.add_argument("--batch_size", type=int, required=True)
Expand Down Expand Up @@ -55,7 +59,7 @@
dataset,
num_replicas=WORLD_SIZE,
rank=RANK,
shuffle=False,
shuffle=True,
drop_last=False,
seed=0,
)
Expand All @@ -68,15 +72,94 @@
)
collator = TransferCollator(dataloader, batch_size=args.batch_size, device=device)

# model_config = "upernet_convnext-t_fp16_512x1024_80k_sem_seg_bdd100k.py" # 1.02s/it bs12
# model_config = "upernet_convnext-s_fp16_512x1024_80k_sem_seg_bdd100k.py" # 1.20s/it bs12
model_config = (
"upernet_convnext-b_fp16_512x1024_80k_sem_seg_bdd100k.py" # 1.39s/it bs12
)
model = BDD100KSemSeg(
device=device,
compile_fn=torch.compile if args.compile else lambda x: x,
mmlab=True,
half=True,
config=model_config,
)

"""
print("Quantizing...")
model_fp32 = model.orig_model
print(model_fp32)
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')
#model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32)
print("Feeding data...")
for batch in tqdm(collator):
cam_data = {}
for cam, frames in batch.color.items():
squashed = frames.squeeze(1)
out = model(squashed)
print(squashed.shape, out.shape)
frames = frames.squeeze(1)
frames = model.normalize(frames)
frames = model.transform(frames)
model_fp32_prepared(frames)
break
print("Converting...")
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
"""

assert os.path.exists(args.output), "output dir must exist"
sem_seg_path = os.path.join(args.output, config.dataset, "sem_seg")
print(f"writing to {sem_seg_path}")
os.makedirs(sem_seg_path, exist_ok=True)

pool = ThreadPool(args.batch_size)


def save(path, frame_data):
buf = safetensors.torch.save(frame_data)
compressed = zstd.ZSTD_compress(buf, 3, 1) # level, threads
with open(path, "wb") as f:
f.write(compressed)


handles = []

for batch in tqdm(collator):
cam_data = {}
for cam, frames in batch.color.items():
squashed = frames.squeeze(1)
pred = model(squashed)
pred = F.interpolate(pred, scale_factor=1 / 2, mode="bilinear")
# pred = pred.argmax(dim=1).byte()
pred = (pred.sigmoid() * 255).byte()
cam_data[cam] = pred

for i in range(args.batch_size):
frame_data = {}
for cam, pred in cam_data.items():
frame_data[cam] = pred[i]

token = batch.token[i][0]
path = os.path.join(sem_seg_path, f"{token}.safetensors.zstd")
handles.append(
pool.apply_async(
save,
(
path,
frame_data,
),
)
)

while len(handles) > args.batch_size * 2:
handles.pop(0).get()

for handle in handles:
handle.get()
pool.terminate()
pool.join()
# print(i, len(buf), type(buf), len(compressed), pred.dtype)
# break
57 changes: 50 additions & 7 deletions notebooks/model-debug.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "08b1582e",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-06T03:40:11.683764Z",
"start_time": "2023-10-06T03:40:11.618526Z"
"end_time": "2023-10-11T05:34:22.297327Z",
"start_time": "2023-10-11T05:34:22.267654Z"
}
},
"outputs": [],
Expand All @@ -17,12 +17,12 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"id": "32dfa27f",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-06T03:40:18.378199Z",
"start_time": "2023-10-06T03:40:12.994199Z"
"end_time": "2023-10-11T05:34:37.103075Z",
"start_time": "2023-10-11T05:34:35.588210Z"
},
"scrolled": true
},
Expand All @@ -46,7 +46,7 @@
"31206 sample_data,\n",
"18538 sample_annotation,\n",
"4 map,\n",
"Done loading in 0.544 seconds.\n",
"Done loading in 0.476 seconds.\n",
"======\n",
"Reverse indexing ...\n",
"Done reverse indexing in 0.1 seconds.\n",
Expand Down Expand Up @@ -86,6 +86,7 @@
" data_dir=\"../../../ext3/nuscenes\",\n",
" version=\"v1.0-mini\",\n",
" lidar=True,\n",
" num_frames=5,\n",
" )\n",
"else:\n",
" from torchdrive.datasets.rice import MultiCamDataset\n",
Expand All @@ -112,6 +113,14 @@
"# display_img(example.color[cam][0].float())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0fd1c4d",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 3,
Expand Down Expand Up @@ -2139,6 +2148,40 @@
"render_color(img)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "e342e7cd",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-11T05:41:36.509127Z",
"start_time": "2023-10-11T05:41:36.425486Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUAAAADwCAIAAAD+Tyo8AAAHCElEQVR4nO3dv44bxwHH8b3khCAxFEABrkhnFzo45yKvkMLvoMaAWvkpLHV+A7sNYEAp8ggp/AoqJBtWYRcBVAgIghwcQ4iKFDRONI9H7vK4O/Ob+XygQvcPNyT3y5ld7vJOzu89GBiGYRhefHa39BDqdfHksvQQ2OJXpQdAAPVWS8Dv2EyJc2IJvYNF9eB5rW4C3q/njNVbOUtoCCZgbmT6rZ+A9+h5/Uz9BMx2pt8IAt7F9EvlBMwWpt8UAoZgAmaT6TeIgHexKVM5Ae9x8eRSxlRLwKNomDoJeCxTMRUS8DQapioCnkzD1EPA/IKnpywCnqzt8yvbvnXtEfA0tm+qImAIJuAJTL/URsAQTMAQTMAQTMAQTMAQTMBjOQRNhQQMwQQMwQQMwQQMwQQ8Sj9HsFxOmOW09AAoQKXNMAN3R70tEXBf1NsYAUMwAXfE9NseAUMwAffC9NskAUMwAXfB9NsqAUMwAe/Xz3mUxBFw+6yfGybgxqm3bQKGYAKGYAJumfVz8wTcLPX2wAX9wSSKgKsjS8azhK6LeplEwPuJimoJGILZBx7l4snli8/u7p2KnTXNwk7O7z0oPYY2HRaz5TqTWELPRYosQMAV0TxTCRiCCXguDmixAAFDMAFDMAFDMAFDMAHPxWtCLEDAM5rUsOA5gIDnJUtmJWAI5mokSnr8/vm7///wXcGRhBIwBax3u/5JDU8l4Hk5oXLD1nQ5mH3gGal3g3qPTsAQTMAQTMC1sN7mAAKGYAKGYAKGYAKGYAJmGLxCG0vA/FzvAg07U/LoBAzBBAzBBMw7xfeEiw8gjoAhmIAhmIBZ1PqB6O9/PL369+4brKKncEE/S1s1/PDsovRAWmAGpoy/vn5x05dMwuMJuCK9XVGo4dsTMOV98N7b0kNIJWAI1n7Aj98/tx6jVe0HvKJhmtRLwNRs/XXgFdctjdT+68AVbgrrf/GshiPPj3/47ucrCsvdV9//ePrBe28rfLAq137ABV08ubzq86Y/U7j6fA0ZF6feA5yc33tQegx0bXVK1o7XhNnBPjAEEzAEEzAEEzAEcxQ6yfPXj9Y//Ojsy1IjoRICzvD89aOPzr7cKPaqZyV3S8AV2Zhgx3919w+O5FkgkYCLOUp1R7Sa5EuPgmkEPIva4hxJw3EchT6+0HpXogffIQGzqUjD3uPuMAI+sjZmsDZuRQ8EzHYajrAn4IdnF9Y23dJw/czA7KLhygn4mJrc3Ju8Uc3oIuBvv7q/wG9peEOf9aa5lP829r8jx8Ozi9y7eD3dDz95ufV7/vLbP93041//9M3Wzzfc6g7zneNxdZwld0srZf+ZWLn36d6Jd0e669+wnnGf6VKtLpbQK8sspBv2/PUjz1+16Sjg625aIW/YO1F3RcNV6f1ihlXDY9bSI2vvgfcVqEezM/CkBbM4b8PSuqBmA2ZhMi5CwMNgL/d4NLwwAauXYL0HrN6jMwkvqfeAKW79cjeXvk3Ve8BTjz97yWQMk/Bien8d+GB3Tj9e//B/b/9RaiR1evXm2fqHf/zNn0uNpG0CHuvVm2eX/305DMOd07+XHkuejZ5/ybJ5i+v32NYnwWYD/vCTlxvnctx0NdIOT//9t6v/r+plbq/ePEucrnc+Q834K5oNeLyvf/rm6lj0erErul3epBiOUvsC+c2k5YA3JuFvv7q/dRJ+9ebZ0xseP/XWL7e9o2g54OFaw5MebPXexr/+88Uffv9p6VG0b/87coS65RPz1Hodhb5OwAsInoFnWjuZeAmyP+Ad74m1SmimY4ZL7tuIllCnIzq52P09cUcR5MqsltzA9s/An//z6QLjmIlW6+GxmEPwPvB1NpGqeDgWUGnAHnsYY0LA61Hd/d19jV3xGhKlnB7WoXqhBr1fD3wUG5cWMrhPlnIyDL8uPYZm9bO0lmspAp7XsRo+oJBlnj6kW1alR6G5cnAhh/3gpOzVW5x94HndOf34Nlt5zYXUPLZ+CHgJh23rRQqRZRYBV6rykCofXj8cxFrOmN3LSsLYO9RKxokZeDl7N/qUKlLG2QMzcAEb85seOJiAIZglNAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAQTMAT7PyUv2XYx0/ShAAAAAElFTkSuQmCC\n",
"text/plain": [
"<PIL.Image.Image image mode=RGB size=320x240>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import safetensors.torch\n",
"import zstd\n",
"\n",
"path = \"/mnt/ext3/autolabel/nuscenes/sem_seg/00889f8a9549450aa2f32cf310a3e305.safetensors.zstd\"\n",
"with open(path, \"rb\") as f:\n",
" decomp = zstd.ZSTD_uncompress(f.read())\n",
"out = safetensors.torch.load(decomp)\n",
"front = out['CAM_FRONT'].bfloat16()/255\n",
"display_color(front.argmax(dim=0))"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
15 changes: 15 additions & 0 deletions torchdrive/datasets/autolabeler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
from torchdrive.data import Batch
from torchdrive.datasets.dataset import Dataset


class Autolabeler:
"""
Autolabeler takes in a dataset and a cache location and automatically loads
the autolabeled data based off of the batch tokens.
"""

def __init__(self, dataset: Dataset, path: str) -> None:
self.dataset = dataset
self.path = path

def __len__(self) -> int:
return len(self.dataset)

def __getitem__(self, idx: int) -> Batch:
batch = self.dataset[idx]
tokens = batch.token
1 change: 1 addition & 0 deletions torchdrive/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class Dataset(TorchDataset, ABC):
Base class for datasets used by TorchDrive.
"""

NAME: Datasets
cameras: List[str]
CAMERA_OVERLAP: Dict[str, List[str]]

Expand Down
3 changes: 2 additions & 1 deletion torchdrive/datasets/nuscenes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.utils.data import ConcatDataset, DataLoader, Dataset as TorchDataset
from tqdm import tqdm

from torchdrive.datasets.dataset import Dataset
from torchdrive.datasets.dataset import Dataset, Datasets

Tensor = torch.Tensor

Expand Down Expand Up @@ -324,6 +324,7 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:


class NuscenesDataset(Dataset):
NAME = Datasets.NUSCENES
CAMERA_OVERLAP: Dict[str, List[str]] = {
CamTypes.CAM_FRONT: [CamTypes.CAM_FRONT_LEFT, CamTypes.CAM_FRONT_RIGHT],
CamTypes.CAM_FRONT_LEFT: [CamTypes.CAM_FRONT, CamTypes.CAM_BACK_LEFT],
Expand Down
10 changes: 6 additions & 4 deletions torchdrive/datasets/rice.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchvision import transforms

from torchdrive.data import Batch
from torchdrive.datasets.dataset import Dataset
from torchdrive.datasets.dataset import Dataset, Datasets
from torchdrive.transforms.mat import transformation_from_parameters

av.logging.set_level(logging.DEBUG) # pyre-fixme
Expand Down Expand Up @@ -107,6 +107,7 @@ def heading_diff(a: float, b: float) -> float:


class MultiCamDataset(Dataset):
NAME = Datasets.RICE
CAMERA_OVERLAP = {
"main": ["narrow", "fisheye"],
"narrow": ["main"],
Expand Down Expand Up @@ -482,9 +483,6 @@ def _cam_T(self, infos: Dict[str, torch.Tensor]) -> Tuple[Tensor, Tensor]:
return cam_T, frame_T

def _getitem(self, idx: int) -> Batch:
path: str
camera: str
idx: int
path, idx = self.frames[idx]

# metadata
Expand Down Expand Up @@ -557,6 +555,9 @@ def load(cam: str, frames: List[int]) -> None:
for camera in self.cameras:
load(camera, frames)

path_base = os.path.basename(path)
tokens = [f"{path_base}_{frame}" for i in frames]

return Batch(
weight=torch.tensor(self.heading_weights[self.path_heading_bin[path]]),
K=Ks,
Expand All @@ -568,4 +569,5 @@ def load(cam: str, frames: List[int]) -> None:
distances=dists,
frame_T=frame_T,
frame_time=frame_time,
token=[tokens],
)
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import os.path
from collections import defaultdict
from typing import Callable, cast, Dict, Iterator, List, Optional, Set, Tuple, Union
from typing import Callable, cast, Dict, Iterator, List, Optional, Set, Union

# set device before loading CUDA/PyTorch
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
Expand Down

0 comments on commit f7ec683

Please sign in to comment.