Skip to content

Commit

Permalink
train_config: allow for CLI overrides of certain params
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Oct 24, 2023
1 parent 043223c commit d597f7c
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 12 deletions.
5 changes: 2 additions & 3 deletions torchdrive/models/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from torchdrive.amp import autocast
from torchdrive.models.mlp import MLP
from torchdrive.models.regnet import ConvPEBlock
from torchdrive.models.transformer import StockTransformerDecoder, transformer_init
from torchdrive.positional_encoding import apply_sin_cos_enc2d

Expand Down Expand Up @@ -48,7 +47,7 @@ def __init__(

self.bev_encoder = nn.Conv2d(bev_dim, dim, 1)

self.bev_project = compile_fn(ConvPEBlock(bev_dim, bev_dim, bev_shape, depth=1))
# self.bev_project = compile_fn(ConvPEBlock(bev_dim, bev_dim, bev_shape, depth=1))

self.pos_encoder = compile_fn(
nn.Sequential(
Expand Down Expand Up @@ -107,7 +106,7 @@ def forward(
static = self.static_encoder(static_feats).permute(0, 2, 1)

# bev features
bev = self.bev_project(bev)
# bev = self.bev_project(bev)
bev = apply_sin_cos_enc2d(bev)
bev = self.bev_encoder(bev).flatten(-2, -1).permute(0, 2, 1)

Expand Down
15 changes: 15 additions & 0 deletions torchdrive/test_train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from parameterized import parameterized

from torchdrive.datasets.dataset import Datasets
from torchdrive.train_config import create_parser, TrainConfig


CONFIG_DIR = os.path.join(os.path.dirname(__file__), "..", "configs")
Expand All @@ -35,3 +36,17 @@ def test_configs(self, module_name: str) -> None:

config.dataset = Datasets.DUMMY
dataset = config.create_dataset()

def test_parser(self) -> None:
parser = create_parser()
args = parser.parse_args(
[
"--output=foo",
"--config=simplebev3d",
"--config.lr=1234",
"--config.ae=true",
]
)
self.assertIsInstance(args.config, TrainConfig)
self.assertEqual(args.config.lr, 1234)
self.assertEqual(args.config.ae, True)
57 changes: 53 additions & 4 deletions torchdrive/train_config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import argparse
from dataclasses import dataclass
from typing import Callable, List, Tuple
import importlib
from dataclasses import dataclass, fields
from typing import Callable, List, Optional, Tuple

import torch
from dataclasses_json import dataclass_json
from torch import nn

from torchdrive.datasets.dataset import Dataset, Datasets
from torchdrive.tasks.bev import BEVTask, BEVTaskVan


@dataclass_json
@dataclass
class TrainConfig:
# backbone settings
Expand Down Expand Up @@ -67,7 +70,7 @@ def create_dataset(self, smoke: bool = False) -> Dataset:
dataset = NuscenesDataset(
data_dir=self.dataset_path,
version="v1.0-mini" if smoke else "v1.0-trainval",
lidar=True,
lidar=False,
num_frames=self.num_frames,
)
elif self.dataset == Datasets.DUMMY:
Expand Down Expand Up @@ -252,6 +255,41 @@ def cam_encoder() -> RegNetEncoder:
return model


class _ConfigAction(argparse.Action):
def __init__(self, dest: str, *args: object, **kwargs: object) -> None:
super().__init__(*args, dest=dest, **kwargs)
self.dest = dest

def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
value: str,
option_string: Optional[str] = None,
) -> None:
config_module = importlib.import_module(f"configs.{value}")
config = config_module.CONFIG

setattr(namespace, self.dest, config)


class _ConfigFieldAction(argparse.Action):
def __init__(self, dest: str, *args: object, **kwargs: object) -> None:
super().__init__(*args, dest=dest, **kwargs)
self.dest = dest

def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
value: str,
option_string: Optional[str] = None,
) -> None:
target, _, field = self.dest.partition(".")
config = getattr(namespace, target)
setattr(config, field, value)


def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="train")

Expand All @@ -268,12 +306,23 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--compile", default=False, action="store_true", help="use torch.compile"
)
parser.add_argument("--config", required=True, help="the config file name to use")
parser.add_argument(
"--smoke",
default=False,
action="store_true",
help="run with a smaller smoke test config",
)

parser.add_argument(
"--config",
required=True,
help="the config file name to use",
action=_ConfigAction,
)

for field in fields(TrainConfig):
parser.add_argument(
f"--config.{field.name}", type=field.type, action=_ConfigFieldAction
)

return parser
9 changes: 4 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
import os.path
import sys
from collections import defaultdict
from typing import Callable, cast, Dict, Iterator, List, Optional, Set, Union

Expand Down Expand Up @@ -31,10 +32,7 @@
parser = create_parser()
args: argparse.Namespace = parser.parse_args()

import importlib

config_module = importlib.import_module("configs." + args.config)
config = config_module.CONFIG
config = args.config

os.makedirs(args.output, exist_ok=True)

Expand Down Expand Up @@ -65,7 +63,8 @@
max_queue=500,
flush_secs=60,
)
writer.add_text("args", json.dumps(vars(args), indent=4))
writer.add_text("argv", json.dumps(sys.argv, indent=4))
writer.add_text("train_config", config.to_json(indent=4))

import git

Expand Down

0 comments on commit d597f7c

Please sign in to comment.