Skip to content

Commit

Permalink
ExportedDataset: wip
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jul 20, 2024
1 parent 6db10ef commit b5d157f
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 10 deletions.
25 changes: 25 additions & 0 deletions configs/diff_traj_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from torchdrive.train_config import Datasets, DiffTrajTrainConfig


CONFIG = DiffTrajTrainConfig(
# backbone settings
cameras=[
"main",
],
num_frames=1,
num_encode_frames=1,
cam_shape=(480, 640),
# optimizer settings
epochs=200,
lr=1e-4,
grad_clip=1.0,
step_size=1000,
# dataset
dataset=Datasets.RICE,
dataset_path="/mnt/ext/openape/snapshots/out-2024/index.txt",
autolabel_path=None,
mask_path="/mnt/ext/openape/masks/",
num_workers=16,
batch_size=64,
autolabel=False,
)
137 changes: 137 additions & 0 deletions export_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import argparse
import os
from multiprocessing.pool import ThreadPool
from typing import Dict
import dataclasses
import zstd

from tqdm import tqdm

# set device before loading CUDA/PyTorch
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(LOCAL_RANK))

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.utils import save_image
from torchvision.transforms import v2

from torchdrive.data import Batch, TransferCollator
from torchdrive.datasets.autolabeler import AutoLabeler, LabelType, save_tensors
from torchdrive.datasets.dataset import Dataset
from torchdrive.train_config import create_parser, TrainConfig
from torchworld.transforms.img import normalize_img

# pyre-fixme[5]: Global expression must be annotated.
parser = create_parser()
parser.add_argument("--num_workers", type=int, required=True)
args: argparse.Namespace = parser.parse_args()


config: TrainConfig = args.config

# overrides
config.num_frames = 1

if "RANK" in os.environ:
WORLD_SIZE: int = int(os.environ["WORLD_SIZE"])
RANK: int = int(os.environ["RANK"])
else:
WORLD_SIZE = 1
RANK = 0

if torch.cuda.is_available():
assert torch.cuda.device_count() <= 1
device_id = 0
device = torch.device(device_id)
else:
device = torch.device("cpu")

torch.set_float32_matmul_precision("high")

dataset, _ = config.create_dataset(smoke=args.smoke)

def transform_img(t: torch.Tensor):
t = normalize_img(t)
t.clamp_(min=0.0, max=1.0)

return [
v2.functional.to_pil_image(frame)
for frame in t
]

dataset.transform = transform_img

if isinstance(dataset, AutoLabeler):
dataset = dataset.dataset

sampler: DistributedSampler[Dataset] = DistributedSampler(
dataset,
num_replicas=WORLD_SIZE,
rank=RANK,
shuffle=False,
drop_last=False,
# seed=1,
)
dataloader = DataLoader[Batch](
dataset,
batch_size=None,
num_workers=args.num_workers,
pin_memory=False,
sampler=sampler,
)

assert os.path.exists(args.output), "output dir must exist"

pool = ThreadPool(args.num_workers or 4)

# pyre-fixme[5]: Global expression must be annotated.
handles = []

def run(f, *args):
handles.append(pool.apply_async(f, args))

output_path = os.path.join(args.output, config.dataset)
index_path = os.path.join(output_path, "index.txt")

os.makedirs(output_path, exist_ok=True)

#with open(index_path, "wta") as index_file:
for batch in tqdm(dataloader, "export"):
if batch is None:
continue

token = batch.token[0][0]
assert len(token) > 5
token_path = os.path.join(output_path, f"{token}.pt")

#index_file.write(token+"\n")
#index_file.flush()
if os.path.exists(token_path):
continue

for cam, frames in batch.color.items():
for i, frame in enumerate(frames):
frame_token = batch.token[0][i]
frame_path = os.path.join(output_path, f"{frame_token}_{cam}.jpg")
if not os.path.exists(frame_path):
run(lambda path, frame: frame.save(path), frame_path, frame)

# clear color data
batch = dataclasses.replace(batch, color = None)
run(lambda path, batch: torch.save(dataclasses.asdict(batch), path), token_path, batch)

while len(handles) > args.num_workers * 2:
handles.pop(0).get()

for handle in handles:
handle.get()
pool.terminate()
pool.join()
# print(i, len(buf), type(buf), len(compressed), pred.dtype)
# break
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ pandas
pythreejs
lintrunner
lintrunner-adapters
torchtune
diffusers
32 changes: 30 additions & 2 deletions torchdrive/data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import random
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass, fields
from dataclasses import dataclass, fields, asdict
import io
from typing import (
Callable,
Dict,
Expand All @@ -15,6 +16,7 @@
Union,
)

import zstd
import torch
from torch.utils.data import DataLoader, default_collate

Expand All @@ -24,7 +26,6 @@

from torchdrive.render.raymarcher import CustomPerspectiveCameras


@dataclass(frozen=True)
class Batch:
# per frame unique token, must be unique across the entire dataset
Expand Down Expand Up @@ -207,6 +208,33 @@ def lidar_to_world(self) -> torch.Tensor:
"""
return self.car_to_world(0).matmul(self.lidar_T)

def save(self, path: str, compress_level: int = 3, threads: int = 1) -> None:
"""
Saves the batch to the specified path.
"""
data = asdict(self)

buffer = io.BytesIO()
torch.save(data, buffer)
buffer.seek(0)
buf = buffer.read()
buf = zstd.compress(buf, compress_level, threads)

with open(path, "wb") as f:
f.write(buf)

@classmethod
def load(cls, path: str) -> None:
with open(path, "rb") as f:
buf = f.read()

buf = zstd.uncompress(buf)
buffer = io.BytesIO(buf)
data = torch.load(buffer, weights_only=True)

return cls(**data)



def _rand_det_target() -> torch.Tensor:
t = torch.rand(2, 5)
Expand Down
7 changes: 7 additions & 0 deletions torchdrive/datasets/exported.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import zstd

from torchdrive.datasets.dataset import Dataset


class ExportedDataset(Dataset):
pass
27 changes: 21 additions & 6 deletions torchdrive/datasets/rice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from av.filter import Graph
from PIL import Image
from torch import Tensor
from tqdm import tqdm

from torchdrive.data import Batch
from torchdrive.datasets.dataset import Dataset, Datasets
Expand All @@ -41,6 +42,8 @@ def compute_bin(v: float, bins: List[int]) -> int:


def bin_weights(bins: Dict[int, int]) -> Dict[int, float]:
assert len(bins) > 0, "got empty list of bins"

mean = sum(bins.values()) / len(bins)
return {k: mean / v for k, v in bins.items()}

Expand Down Expand Up @@ -120,6 +123,7 @@ class MultiCamDataset(Dataset):
"rightrepeater": ["backup", "rightpillar"],
"backup": ["leftrepeater", "rightrepeater"],
}
cameras: List[str] = list(CAMERA_OVERLAP)

def __init__(
self,
Expand All @@ -140,13 +144,17 @@ def __init__(
self.nframes_per_point = nframes_per_point
self.dtype = dtype

for cam in cameras:
assert cam in self.CAMERA_OVERLAP, f"unknown camera {cam}"
self.cameras = cameras

self.per_path_frame_count: Dict[str, int] = {}

root = os.path.dirname(index_file)
indexes = glob.glob(os.path.join(root, "*", "info_noradar.json"))

assert len(indexes) > 0

self.path_heading_bin: Dict[str, int] = {}
self.speed_bins: Dict[int, int] = defaultdict(lambda: 0)
self.heading_bins: Dict[int, int] = defaultdict(lambda: 0)
Expand All @@ -157,7 +165,7 @@ def __init__(
DROP_LAST_N += 30
MIN_DIST_M = 10

for path in indexes:
for path in tqdm(indexes, "indexing"):
path = os.path.dirname(path)
infos = self._get_raw_infos(path, 0, -1)
if infos is None or len(infos) == 0:
Expand All @@ -167,7 +175,8 @@ def __init__(
for camera in self.cameras:
_, _, offsets, _ = self._load_offsets(path, camera)
frame_counts.append(len(offsets))
except FileNotFoundError:
except FileNotFoundError as e:
print(e)
continue
frame_count = min(frame_counts)
infos = infos[:frame_count]
Expand Down Expand Up @@ -428,6 +437,9 @@ def __getitem__(self, idx: int) -> Optional[Batch]:
print(e)
except av.error.MemoryError as e:
print(e)
except Exception as e:
print(e)
raise

def _get_alignment(self, path: str) -> Dict[str, int]:
path = os.path.join(path, "alignment.json")
Expand Down Expand Up @@ -552,17 +564,19 @@ def load(cam: str, frames: List[int]) -> None:
Ks[label] = K
# out["inv_K", label] = K.pinverse()
Ts[label] = T
colors[label] = torch.stack(color).to(self.dtype)

frame_colors = torch.stack(color).to(self.dtype)
if self.transform is not None:
frame_colors = self.transform(frame_colors)
colors[label] = frame_colors
masks[label] = mask.to(self.dtype)

for camera in self.cameras:
load(camera, frames)

path_base = os.path.basename(path)
# pyre-fixme[10]: Name `frame` is used but not defined.
tokens = [f"{path_base}_{frame}" for i in frames]
tokens = [f"{path_base}_{i}" for i in frames]

# pyre-fixme[20]: Argument `lidar_T` expected.
return Batch(
weight=torch.tensor(self.heading_weights[self.path_heading_bin[path]]),
K=Ks,
Expand All @@ -575,4 +589,5 @@ def load(cam: str, frames: List[int]) -> None:
frame_T=frame_T,
frame_time=frame_time,
token=[tokens],
lidar_T=None,
)
13 changes: 13 additions & 0 deletions torchdrive/test_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest
from dataclasses import replace
import tempfile
import os.path

import torch
from torch.utils.data import DataLoader, Dataset
Expand Down Expand Up @@ -162,3 +164,14 @@ def test_grid_image(self) -> None:
self.assertEqual(mask.shape, (2, 1, 48, 64))
torch.testing.assert_allclose(img.time, batch.frame_time[:, 1])
self.assertFalse(img.camera.in_ndc())

def test_save_load(self) -> None:
batch = dummy_batch()

with tempfile.TemporaryDirectory("torchdrive-test_data") as path:
file_path = os.path.join(path, "file.pt.zstd")
batch.save(file_path)

out = Batch.load(file_path)

self.assertIsNotNone(out)
2 changes: 0 additions & 2 deletions torchdrive/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def create_dataset(self, smoke: bool = False) -> Tuple[Dataset, Optional[Dataset
cam_shape=self.cam_shape,
# 3 encode frames, 3 decode frames, overlap last frame
nframes_per_point=self.num_frames,
# pyre-fixme[16]: `TrainConfig` has no attribute `limit_size`.
limit_size=self.limit_size,
)
elif self.dataset == Datasets.NUSCENES:
from torchdrive.datasets.nuscenes_dataset import NuscenesDataset
Expand Down

0 comments on commit b5d157f

Please sign in to comment.