Skip to content

Commit

Permalink
autolabel: added det support
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Oct 17, 2023
1 parent bcac59c commit b8fd677
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 78 deletions.
180 changes: 119 additions & 61 deletions autolabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,30 @@
import importlib
import os
from multiprocessing.pool import ThreadPool
from typing import Dict

from tqdm import tqdm

# set device before loading CUDA/PyTorch
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(LOCAL_RANK))

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from torchdrive.data import Batch, TransferCollator
from torchdrive.datasets.autolabeler import LabelType, save_tensors
from torchdrive.datasets.dataset import Dataset
from torchdrive.models.semantic import BDD100KSemSeg
from torchdrive.train_config import create_parser

parser = create_parser()
parser.add_argument("--num_workers", type=int, required=True)
parser.add_argument("--batch_size", type=int, required=True)
parser.add_argument("--smoke", action="store_true")
args: argparse.Namespace = parser.parse_args()


Expand Down Expand Up @@ -70,80 +71,137 @@
)
collator = TransferCollator(dataloader, batch_size=args.batch_size, device=device)

# model_config = "upernet_convnext-t_fp16_512x1024_80k_sem_seg_bdd100k.py" # 1.02s/it bs12
# model_config = "upernet_convnext-s_fp16_512x1024_80k_sem_seg_bdd100k.py" # 1.20s/it bs12
model_config = (
"upernet_convnext-b_fp16_512x1024_80k_sem_seg_bdd100k.py" # 1.39s/it bs12
)
model = BDD100KSemSeg(
device=device,
compile_fn=torch.compile if args.compile else lambda x: x,
mmlab=True,
half=True,
config=model_config,
)
compile_fn = torch.compile if args.compile else lambda x: x

"""
print("Quantizing...")

model_fp32 = model.orig_model
print(model_fp32)
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')
#model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32)
class LabelSemSeg(nn.Module):
TYPE = LabelType.SEM_SEG

print("Feeding data...")
def __init__(self) -> None:
super().__init__()

for batch in tqdm(collator):
cam_data = {}
for cam, frames in batch.color.items():
frames = frames.squeeze(1)
frames = model.normalize(frames)
frames = model.transform(frames)
model_fp32_prepared(frames)
from torchdrive.models.semantic import BDD100KSemSeg

# model_config = "upernet_convnext-t_fp16_512x1024_80k_sem_seg_bdd100k.py" # 1.02s/it bs12
# model_config = "upernet_convnext-s_fp16_512x1024_80k_sem_seg_bdd100k.py" # 1.20s/it bs12
model_config = (
"upernet_convnext-b_fp16_512x1024_80k_sem_seg_bdd100k.py" # 1.39s/it bs12
)
self.model = BDD100KSemSeg(
device=device,
compile_fn=compile_fn,
mmlab=True,
half=True,
config=model_config,
)

def forward(self, img: torch.Tensor) -> torch.Tensor:
pred = self.model(img)
pred = F.interpolate(pred, scale_factor=1 / 2, mode="bilinear")
# uses 1 byte to represent probabilities between 0-1 (0:0.0, 255:1.0)
pred = (pred.sigmoid() * 255).byte()
return pred


class LabelDet(nn.Module):
TYPE = LabelType.DET

def __init__(self) -> None:
super().__init__()

from torchdrive.models.det import BDD100KDet

model_config = "cascade_rcnn_convnext-s_fpn_fp16_3x_det_bdd100k.py"
self.model = BDD100KDet(
config=model_config,
device=device,
half=True,
compile_fn=compile_fn,
)

def forward(self, img: torch.Tensor) -> torch.Tensor:
pred = self.model(img)
return pred


TASKS = [
LabelDet(),
# LabelSemSeg(),
]


def flatten(v: object) -> Dict[str, torch.Tensor]:
out = {}

def _flatten(prefix: str, v: object):
if isinstance(v, torch.Tensor):
out[prefix] = v
elif isinstance(v, np.ndarray):
out[prefix] = torch.from_numpy(v)
elif isinstance(v, dict):
for k, v2 in v.items():
_flatten(os.path.join(prefix, k), v2)
elif isinstance(v, list):
for i, v2 in enumerate(v):
_flatten(os.path.join(prefix, str(i)), v2)
else:
raise TypeError(f"unknown type of {type(v)} - {v}")

break
_flatten("", v)
return out


def get_task_path(task: nn.Module) -> str:
return os.path.join(args.output, config.dataset, task.TYPE)

print("Converting...")
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
"""

assert os.path.exists(args.output), "output dir must exist"
sem_seg_path = os.path.join(args.output, config.dataset, LabelType.SEM_SEG)
print(f"writing to {sem_seg_path}")
os.makedirs(sem_seg_path, exist_ok=True)

pool = ThreadPool(args.batch_size)
for task in TASKS:
task_path = get_task_path(task)
print(f"{task.TYPE}: writing to {task_path}")
os.makedirs(task_path, exist_ok=True)

pool = ThreadPool(args.batch_size)

handles = []

for batch in tqdm(collator):
cam_data = {}
for cam, frames in batch.color.items():
squashed = frames.squeeze(1)
pred = model(squashed)
pred = F.interpolate(pred, scale_factor=1 / 2, mode="bilinear")
# pred = pred.argmax(dim=1).byte()
pred = (pred.sigmoid() * 255).byte()
cam_data[cam] = pred

for i in range(args.batch_size):
frame_data = {}
for cam, pred in cam_data.items():
frame_data[cam] = pred[i]

token = batch.token[i][0]
path = os.path.join(sem_seg_path, f"{token}.safetensors.zstd")
handles.append(
pool.apply_async(
save_tensors,
(
path,
frame_data,
),
for task in TASKS:
task_path = get_task_path(task)
token_paths = []
idxs = []
for i in range(batch.batch_size()):
token = batch.token[i][0]
assert len(token) > 5
token_path = os.path.join(task_path, f"{token}.safetensors.zstd")
token_paths.append(token_path)
if not os.path.exists(token_path):
idxs.append(i)

if len(idxs) == 0:
continue

cam_data = {}
for cam, frames in batch.color.items():
squashed = frames[idxs].squeeze(1)
cam_data[cam] = task(squashed)

for j, i in enumerate(idxs):
frame_data = {}
for cam, pred in cam_data.items():
frame_data[cam] = pred[j]

path = token_paths[i]
handles.append(
pool.apply_async(
save_tensors,
(
path,
flatten(frame_data),
),
)
)
)

while len(handles) > args.batch_size * 2:
handles.pop(0).get()
Expand Down
55 changes: 54 additions & 1 deletion torchdrive/datasets/autolabeler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import os.path
import sys
from enum import Enum
from typing import Dict, Optional

Expand All @@ -17,6 +18,7 @@

class LabelType(str, Enum):
SEM_SEG = "sem_seg"
DET = "det"


def save_tensors(path: str, data: Dict[str, torch.Tensor]) -> None:
Expand All @@ -26,11 +28,62 @@ def save_tensors(path: str, data: Dict[str, torch.Tensor]) -> None:
f.write(buf)


_SIZE = {
torch.int64: 8,
torch.float32: 4,
torch.int32: 4,
torch.bfloat16: 2,
torch.float16: 2,
torch.int16: 2,
torch.uint8: 1,
torch.int8: 1,
torch.bool: 1,
torch.float64: 8,
}

_TYPES = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
# "U64": torch.uint64,
"I32": torch.int32,
# "U32": torch.uint32,
"I16": torch.int16,
# "U16": torch.uint16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}


def _getdtype(dtype_str: str) -> torch.dtype:
return _TYPES[dtype_str]


def _view2torch(safeview) -> Dict[str, torch.Tensor]:
result = {}
for k, v in safeview:
dtype = _getdtype(v["dtype"])
if len(v["data"]) == 0:
arr = torch.zeros(*v["shape"], dtype=dtype)
assert arr.numel() == 0
else:
arr = torch.frombuffer(v["data"], dtype=dtype).reshape(v["shape"])
if sys.byteorder == "big":
arr = torch.from_numpy(arr.numpy().byteswap(inplace=False))
result[k] = arr

return result


def load_tensors(path: str) -> object:
with open(path, "rb") as f:
buf = f.read()
buf = zstd.uncompress(buf)
return safetensors.torch.load(buf)
raw = safetensors.deserialize(buf)
return _view2torch(raw)


class AutoLabeler(Dataset):
Expand Down
Loading

0 comments on commit b8fd677

Please sign in to comment.