From d597f7cd89cab46f92c5458d32411d62b35c9fd5 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Mon, 23 Oct 2023 22:36:03 -0700 Subject: [PATCH] train_config: allow for CLI overrides of certain params --- torchdrive/models/path.py | 5 ++- torchdrive/test_train_config.py | 15 +++++++++ torchdrive/train_config.py | 57 ++++++++++++++++++++++++++++++--- train.py | 9 +++--- 4 files changed, 74 insertions(+), 12 deletions(-) diff --git a/torchdrive/models/path.py b/torchdrive/models/path.py index 9255dd8..69a82d4 100644 --- a/torchdrive/models/path.py +++ b/torchdrive/models/path.py @@ -7,7 +7,6 @@ from torchdrive.amp import autocast from torchdrive.models.mlp import MLP -from torchdrive.models.regnet import ConvPEBlock from torchdrive.models.transformer import StockTransformerDecoder, transformer_init from torchdrive.positional_encoding import apply_sin_cos_enc2d @@ -48,7 +47,7 @@ def __init__( self.bev_encoder = nn.Conv2d(bev_dim, dim, 1) - self.bev_project = compile_fn(ConvPEBlock(bev_dim, bev_dim, bev_shape, depth=1)) + # self.bev_project = compile_fn(ConvPEBlock(bev_dim, bev_dim, bev_shape, depth=1)) self.pos_encoder = compile_fn( nn.Sequential( @@ -107,7 +106,7 @@ def forward( static = self.static_encoder(static_feats).permute(0, 2, 1) # bev features - bev = self.bev_project(bev) + # bev = self.bev_project(bev) bev = apply_sin_cos_enc2d(bev) bev = self.bev_encoder(bev).flatten(-2, -1).permute(0, 2, 1) diff --git a/torchdrive/test_train_config.py b/torchdrive/test_train_config.py index 2f3326b..85bd15c 100644 --- a/torchdrive/test_train_config.py +++ b/torchdrive/test_train_config.py @@ -9,6 +9,7 @@ from parameterized import parameterized from torchdrive.datasets.dataset import Datasets +from torchdrive.train_config import create_parser, TrainConfig CONFIG_DIR = os.path.join(os.path.dirname(__file__), "..", "configs") @@ -35,3 +36,17 @@ def test_configs(self, module_name: str) -> None: config.dataset = Datasets.DUMMY dataset = config.create_dataset() + + def test_parser(self) -> None: + parser = create_parser() + args = parser.parse_args( + [ + "--output=foo", + "--config=simplebev3d", + "--config.lr=1234", + "--config.ae=true", + ] + ) + self.assertIsInstance(args.config, TrainConfig) + self.assertEqual(args.config.lr, 1234) + self.assertEqual(args.config.ae, True) diff --git a/torchdrive/train_config.py b/torchdrive/train_config.py index c460606..839bfe7 100644 --- a/torchdrive/train_config.py +++ b/torchdrive/train_config.py @@ -1,14 +1,17 @@ import argparse -from dataclasses import dataclass -from typing import Callable, List, Tuple +import importlib +from dataclasses import dataclass, fields +from typing import Callable, List, Optional, Tuple import torch +from dataclasses_json import dataclass_json from torch import nn from torchdrive.datasets.dataset import Dataset, Datasets from torchdrive.tasks.bev import BEVTask, BEVTaskVan +@dataclass_json @dataclass class TrainConfig: # backbone settings @@ -67,7 +70,7 @@ def create_dataset(self, smoke: bool = False) -> Dataset: dataset = NuscenesDataset( data_dir=self.dataset_path, version="v1.0-mini" if smoke else "v1.0-trainval", - lidar=True, + lidar=False, num_frames=self.num_frames, ) elif self.dataset == Datasets.DUMMY: @@ -252,6 +255,41 @@ def cam_encoder() -> RegNetEncoder: return model +class _ConfigAction(argparse.Action): + def __init__(self, dest: str, *args: object, **kwargs: object) -> None: + super().__init__(*args, dest=dest, **kwargs) + self.dest = dest + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + value: str, + option_string: Optional[str] = None, + ) -> None: + config_module = importlib.import_module(f"configs.{value}") + config = config_module.CONFIG + + setattr(namespace, self.dest, config) + + +class _ConfigFieldAction(argparse.Action): + def __init__(self, dest: str, *args: object, **kwargs: object) -> None: + super().__init__(*args, dest=dest, **kwargs) + self.dest = dest + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + value: str, + option_string: Optional[str] = None, + ) -> None: + target, _, field = self.dest.partition(".") + config = getattr(namespace, target) + setattr(config, field, value) + + def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="train") @@ -268,7 +306,6 @@ def create_parser() -> argparse.ArgumentParser: parser.add_argument( "--compile", default=False, action="store_true", help="use torch.compile" ) - parser.add_argument("--config", required=True, help="the config file name to use") parser.add_argument( "--smoke", default=False, @@ -276,4 +313,16 @@ def create_parser() -> argparse.ArgumentParser: help="run with a smaller smoke test config", ) + parser.add_argument( + "--config", + required=True, + help="the config file name to use", + action=_ConfigAction, + ) + + for field in fields(TrainConfig): + parser.add_argument( + f"--config.{field.name}", type=field.type, action=_ConfigFieldAction + ) + return parser diff --git a/train.py b/train.py index 51780b9..120d714 100644 --- a/train.py +++ b/train.py @@ -3,6 +3,7 @@ import json import os import os.path +import sys from collections import defaultdict from typing import Callable, cast, Dict, Iterator, List, Optional, Set, Union @@ -31,10 +32,7 @@ parser = create_parser() args: argparse.Namespace = parser.parse_args() -import importlib - -config_module = importlib.import_module("configs." + args.config) -config = config_module.CONFIG +config = args.config os.makedirs(args.output, exist_ok=True) @@ -65,7 +63,8 @@ max_queue=500, flush_secs=60, ) - writer.add_text("args", json.dumps(vars(args), indent=4)) + writer.add_text("argv", json.dumps(sys.argv, indent=4)) + writer.add_text("train_config", config.to_json(indent=4)) import git