-
Notifications
You must be signed in to change notification settings - Fork 77
/
train.py
92 lines (76 loc) · 2.64 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import os.path as osp
from time import time
import torch
import pprint
import random
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from configs.config import parse_args
from lib.core.trainer import Trainer
from lib.core.loss import WHAMLoss
from lib.utils.utils import prepare_output_dir
from lib.data.dataloader import setup_dloaders
from lib.utils.utils import create_logger, get_optimizer
from lib.models import build_network, build_body_model
def setup_seed(seed):
""" Setup seed for reproducibility """
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def main(cfg):
# Seed
if cfg.SEED_VALUE >= 0:
setup_seed(cfg.SEED_VALUE)
logger = create_logger(cfg.LOGDIR, phase='debug' if cfg.DEBUG else 'train')
logger.info(f'GPU name -> {torch.cuda.get_device_name()}')
logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}')
logger.info(pprint.pformat(cfg))
writer = SummaryWriter(log_dir=cfg.LOGDIR)
writer.add_text('config', pprint.pformat(cfg), 0)
# ========= Dataloaders ========= #
data_loaders = setup_dloaders(cfg, cfg.TRAIN.DATASET_EVAL, 'val')
logger.info(f'Dataset loaded')
# ========= Network and Optimizer ========= #
smpl_batch_size = cfg.TRAIN.BATCH_SIZE * cfg.DATASET.SEQLEN
smpl = build_body_model(cfg.DEVICE, smpl_batch_size)
network = build_network(cfg, smpl)
optimizer = get_optimizer(
cfg,
model=network,
optim_type=cfg.TRAIN.OPTIM,
momentum=cfg.TRAIN.MOMENTUM,
stage=cfg.TRAIN.STAGE)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=cfg.TRAIN.MILESTONES,
gamma=cfg.TRAIN.LR_DECAY_RATIO,
verbose=False,
)
# ========= Loss function ========= #
criterion = WHAMLoss(cfg, cfg.DEVICE)
# ========= Start Training ========= #
Trainer(
data_loaders=data_loaders,
network=network,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
criterion=criterion,
train_stage=cfg.TRAIN.STAGE,
start_epoch=cfg.TRAIN.START_EPOCH,
end_epoch=cfg.TRAIN.END_EPOCH,
checkpoint=cfg.TRAIN.CHECKPOINT,
device=cfg.DEVICE,
writer=writer,
debug=cfg.DEBUG,
resume=cfg.RESUME,
logdir=cfg.LOGDIR,
summary_iter=cfg.SUMMARY_ITER,
).fit()
if __name__ == '__main__':
cfg, cfg_file, _ = parse_args()
cfg = prepare_output_dir(cfg, cfg_file)
main(cfg)