diff --git a/configs/diff_traj_export.py b/configs/diff_traj_export.py new file mode 100644 index 0000000..f7b4632 --- /dev/null +++ b/configs/diff_traj_export.py @@ -0,0 +1,25 @@ +from torchdrive.train_config import Datasets, DiffTrajTrainConfig + + +CONFIG = DiffTrajTrainConfig( + # backbone settings + cameras=[ + "main", + ], + num_frames=1, + num_encode_frames=1, + cam_shape=(480, 640), + # optimizer settings + epochs=200, + lr=1e-4, + grad_clip=1.0, + step_size=1000, + # dataset + dataset=Datasets.RICE, + dataset_path="/mnt/ext/openape/snapshots/out-2024/index.txt", + autolabel_path=None, + mask_path="/mnt/ext/openape/masks/", + num_workers=16, + batch_size=64, + autolabel=False, +) diff --git a/export_dataset.py b/export_dataset.py new file mode 100644 index 0000000..05e9411 --- /dev/null +++ b/export_dataset.py @@ -0,0 +1,137 @@ +import argparse +import os +from multiprocessing.pool import ThreadPool +from typing import Dict +import dataclasses +import zstd + +from tqdm import tqdm + +# set device before loading CUDA/PyTorch +LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0)) +os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(LOCAL_RANK)) + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.utils import save_image +from torchvision.transforms import v2 + +from torchdrive.data import Batch, TransferCollator +from torchdrive.datasets.autolabeler import AutoLabeler, LabelType, save_tensors +from torchdrive.datasets.dataset import Dataset +from torchdrive.train_config import create_parser, TrainConfig +from torchworld.transforms.img import normalize_img + +# pyre-fixme[5]: Global expression must be annotated. +parser = create_parser() +parser.add_argument("--num_workers", type=int, required=True) +args: argparse.Namespace = parser.parse_args() + + +config: TrainConfig = args.config + +# overrides +config.num_frames = 1 + +if "RANK" in os.environ: + WORLD_SIZE: int = int(os.environ["WORLD_SIZE"]) + RANK: int = int(os.environ["RANK"]) +else: + WORLD_SIZE = 1 + RANK = 0 + +if torch.cuda.is_available(): + assert torch.cuda.device_count() <= 1 + device_id = 0 + device = torch.device(device_id) +else: + device = torch.device("cpu") + +torch.set_float32_matmul_precision("high") + +dataset, _ = config.create_dataset(smoke=args.smoke) + +def transform_img(t: torch.Tensor): + t = normalize_img(t) + t.clamp_(min=0.0, max=1.0) + + return [ + v2.functional.to_pil_image(frame) + for frame in t + ] + +dataset.transform = transform_img + +if isinstance(dataset, AutoLabeler): + dataset = dataset.dataset + +sampler: DistributedSampler[Dataset] = DistributedSampler( + dataset, + num_replicas=WORLD_SIZE, + rank=RANK, + shuffle=False, + drop_last=False, + # seed=1, +) +dataloader = DataLoader[Batch]( + dataset, + batch_size=None, + num_workers=args.num_workers, + pin_memory=False, + sampler=sampler, +) + +assert os.path.exists(args.output), "output dir must exist" + +pool = ThreadPool(args.num_workers or 4) + +# pyre-fixme[5]: Global expression must be annotated. +handles = [] + +def run(f, *args): + handles.append(pool.apply_async(f, args)) + +output_path = os.path.join(args.output, config.dataset) +index_path = os.path.join(output_path, "index.txt") + +os.makedirs(output_path, exist_ok=True) + +#with open(index_path, "wta") as index_file: +for batch in tqdm(dataloader, "export"): + if batch is None: + continue + + token = batch.token[0][0] + assert len(token) > 5 + token_path = os.path.join(output_path, f"{token}.pt") + + #index_file.write(token+"\n") + #index_file.flush() + if os.path.exists(token_path): + continue + + for cam, frames in batch.color.items(): + for i, frame in enumerate(frames): + frame_token = batch.token[0][i] + frame_path = os.path.join(output_path, f"{frame_token}_{cam}.jpg") + if not os.path.exists(frame_path): + run(lambda path, frame: frame.save(path), frame_path, frame) + + # clear color data + batch = dataclasses.replace(batch, color = None) + run(lambda path, batch: torch.save(dataclasses.asdict(batch), path), token_path, batch) + + while len(handles) > args.num_workers * 2: + handles.pop(0).get() + +for handle in handles: + handle.get() +pool.terminate() +pool.join() +# print(i, len(buf), type(buf), len(compressed), pred.dtype) +# break diff --git a/requirements.txt b/requirements.txt index b4fcf0e..b8b8118 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,3 +33,5 @@ pandas pythreejs lintrunner lintrunner-adapters +torchtune +diffusers diff --git a/torchdrive/data.py b/torchdrive/data.py index 087415e..e8bc6b0 100644 --- a/torchdrive/data.py +++ b/torchdrive/data.py @@ -1,7 +1,8 @@ import random from concurrent.futures import Future, ThreadPoolExecutor from contextlib import contextmanager -from dataclasses import dataclass, fields +from dataclasses import dataclass, fields, asdict +import io from typing import ( Callable, Dict, @@ -15,6 +16,7 @@ Union, ) +import zstd import torch from torch.utils.data import DataLoader, default_collate @@ -24,7 +26,6 @@ from torchdrive.render.raymarcher import CustomPerspectiveCameras - @dataclass(frozen=True) class Batch: # per frame unique token, must be unique across the entire dataset @@ -207,6 +208,33 @@ def lidar_to_world(self) -> torch.Tensor: """ return self.car_to_world(0).matmul(self.lidar_T) + def save(self, path: str, compress_level: int = 3, threads: int = 1) -> None: + """ + Saves the batch to the specified path. + """ + data = asdict(self) + + buffer = io.BytesIO() + torch.save(data, buffer) + buffer.seek(0) + buf = buffer.read() + buf = zstd.compress(buf, compress_level, threads) + + with open(path, "wb") as f: + f.write(buf) + + @classmethod + def load(cls, path: str) -> None: + with open(path, "rb") as f: + buf = f.read() + + buf = zstd.uncompress(buf) + buffer = io.BytesIO(buf) + data = torch.load(buffer, weights_only=True) + + return cls(**data) + + def _rand_det_target() -> torch.Tensor: t = torch.rand(2, 5) diff --git a/torchdrive/datasets/exported.py b/torchdrive/datasets/exported.py new file mode 100644 index 0000000..dbe0337 --- /dev/null +++ b/torchdrive/datasets/exported.py @@ -0,0 +1,7 @@ +import zstd + +from torchdrive.datasets.dataset import Dataset + + +class ExportedDataset(Dataset): + pass diff --git a/torchdrive/datasets/rice.py b/torchdrive/datasets/rice.py index e312f16..fd988ea 100644 --- a/torchdrive/datasets/rice.py +++ b/torchdrive/datasets/rice.py @@ -18,6 +18,7 @@ from av.filter import Graph from PIL import Image from torch import Tensor +from tqdm import tqdm from torchdrive.data import Batch from torchdrive.datasets.dataset import Dataset, Datasets @@ -41,6 +42,8 @@ def compute_bin(v: float, bins: List[int]) -> int: def bin_weights(bins: Dict[int, int]) -> Dict[int, float]: + assert len(bins) > 0, "got empty list of bins" + mean = sum(bins.values()) / len(bins) return {k: mean / v for k, v in bins.items()} @@ -120,6 +123,7 @@ class MultiCamDataset(Dataset): "rightrepeater": ["backup", "rightpillar"], "backup": ["leftrepeater", "rightrepeater"], } + cameras: List[str] = list(CAMERA_OVERLAP) def __init__( self, @@ -140,6 +144,8 @@ def __init__( self.nframes_per_point = nframes_per_point self.dtype = dtype + for cam in cameras: + assert cam in self.CAMERA_OVERLAP, f"unknown camera {cam}" self.cameras = cameras self.per_path_frame_count: Dict[str, int] = {} @@ -147,6 +153,8 @@ def __init__( root = os.path.dirname(index_file) indexes = glob.glob(os.path.join(root, "*", "info_noradar.json")) + assert len(indexes) > 0 + self.path_heading_bin: Dict[str, int] = {} self.speed_bins: Dict[int, int] = defaultdict(lambda: 0) self.heading_bins: Dict[int, int] = defaultdict(lambda: 0) @@ -157,7 +165,7 @@ def __init__( DROP_LAST_N += 30 MIN_DIST_M = 10 - for path in indexes: + for path in tqdm(indexes, "indexing"): path = os.path.dirname(path) infos = self._get_raw_infos(path, 0, -1) if infos is None or len(infos) == 0: @@ -167,7 +175,8 @@ def __init__( for camera in self.cameras: _, _, offsets, _ = self._load_offsets(path, camera) frame_counts.append(len(offsets)) - except FileNotFoundError: + except FileNotFoundError as e: + print(e) continue frame_count = min(frame_counts) infos = infos[:frame_count] @@ -428,6 +437,9 @@ def __getitem__(self, idx: int) -> Optional[Batch]: print(e) except av.error.MemoryError as e: print(e) + except Exception as e: + print(e) + raise def _get_alignment(self, path: str) -> Dict[str, int]: path = os.path.join(path, "alignment.json") @@ -552,17 +564,19 @@ def load(cam: str, frames: List[int]) -> None: Ks[label] = K # out["inv_K", label] = K.pinverse() Ts[label] = T - colors[label] = torch.stack(color).to(self.dtype) + + frame_colors = torch.stack(color).to(self.dtype) + if self.transform is not None: + frame_colors = self.transform(frame_colors) + colors[label] = frame_colors masks[label] = mask.to(self.dtype) for camera in self.cameras: load(camera, frames) path_base = os.path.basename(path) - # pyre-fixme[10]: Name `frame` is used but not defined. - tokens = [f"{path_base}_{frame}" for i in frames] + tokens = [f"{path_base}_{i}" for i in frames] - # pyre-fixme[20]: Argument `lidar_T` expected. return Batch( weight=torch.tensor(self.heading_weights[self.path_heading_bin[path]]), K=Ks, @@ -575,4 +589,5 @@ def load(cam: str, frames: List[int]) -> None: frame_T=frame_T, frame_time=frame_time, token=[tokens], + lidar_T=None, ) diff --git a/torchdrive/test_data.py b/torchdrive/test_data.py index db354e4..1d963dd 100644 --- a/torchdrive/test_data.py +++ b/torchdrive/test_data.py @@ -1,5 +1,7 @@ import unittest from dataclasses import replace +import tempfile +import os.path import torch from torch.utils.data import DataLoader, Dataset @@ -162,3 +164,14 @@ def test_grid_image(self) -> None: self.assertEqual(mask.shape, (2, 1, 48, 64)) torch.testing.assert_allclose(img.time, batch.frame_time[:, 1]) self.assertFalse(img.camera.in_ndc()) + + def test_save_load(self) -> None: + batch = dummy_batch() + + with tempfile.TemporaryDirectory("torchdrive-test_data") as path: + file_path = os.path.join(path, "file.pt.zstd") + batch.save(file_path) + + out = Batch.load(file_path) + + self.assertIsNotNone(out) diff --git a/torchdrive/train_config.py b/torchdrive/train_config.py index ab0567a..1c000cb 100644 --- a/torchdrive/train_config.py +++ b/torchdrive/train_config.py @@ -44,8 +44,6 @@ def create_dataset(self, smoke: bool = False) -> Tuple[Dataset, Optional[Dataset cam_shape=self.cam_shape, # 3 encode frames, 3 decode frames, overlap last frame nframes_per_point=self.num_frames, - # pyre-fixme[16]: `TrainConfig` has no attribute `limit_size`. - limit_size=self.limit_size, ) elif self.dataset == Datasets.NUSCENES: from torchdrive.datasets.nuscenes_dataset import NuscenesDataset