-
Notifications
You must be signed in to change notification settings - Fork 25
/
train.py
109 lines (80 loc) · 4.36 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import importlib
import os
import torch
import torch.nn as nn
import torchvision
import numpy as np
from dataset.dataset_splitter import DatasetSplitter
from dataset.transforms import TransformsGenerator
from dataset.video_dataset import VideoDataset
from evaluation.action_sampler import OneHotActionSampler, GroundTruthActionSampler
from evaluation.evaluator import Evaluator
from training.trainer import Trainer
from utils.configuration import Configuration
from utils.logger import Logger
torch.backends.cudnn.benchmark = True
if __name__ == "__main__":
# Loads configuration file
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
arguments = parser.parse_args()
config_path = arguments.config
configuration = Configuration(config_path)
configuration.check_config()
configuration.create_directory_structure()
config = configuration.get_config()
logger = Logger(config)
search_name = config["model"]["architecture"]
model = getattr(importlib.import_module(search_name), 'model')(config)
model.cuda()
datasets = {}
dataset_splits = DatasetSplitter.generate_splits(config)
transformations = TransformsGenerator.get_final_transforms(config)
for key in dataset_splits:
path, batching_config, split = dataset_splits[key]
transform = transformations[key]
datasets[key] = VideoDataset(path, batching_config, transform, split)
# Creates trainer and evaluator
trainer = getattr(importlib.import_module(config["training"]["trainer"]), 'trainer')(config, model, datasets["train"], logger)
# Evaluators will be assigned their specific action samplers to implement the evaluation strategy
evaluator_inferred_actions = getattr(importlib.import_module(config["evaluation"]["evaluator"]), 'evaluator')(config, datasets["validation"], logger, action_sampler=None, logger_prefix="validation_inferred_actions")
evaluator_inferred_actions_onehot = getattr(importlib.import_module(config["evaluation"]["evaluator"]), 'evaluator')(config, datasets["validation"], logger, action_sampler=OneHotActionSampler(), logger_prefix="validation_inferred_actions_onehot")
evaluator_ground_truth_actions = getattr(importlib.import_module(config["evaluation"]["evaluator"]), 'evaluator')(config, datasets["validation"], logger, action_sampler=None, logger_prefix="validation_gt_actions")
# Resume training
try:
trainer.load_checkpoint(model)
except Exception as e:
logger.print(e)
logger.print("- Warning: training without loading saved checkpoint")
model = nn.DataParallel(model)
model.cuda()
logger.get_wandb().watch(model, log='all')
last_save_step = 0
last_eval_step = 0
# Makes the model parallel and train
while trainer.global_step < config["training"]["max_steps"]:
model.train()
trainer.train_epoch(model)
# Saves the model
trainer.save_checkpoint(model)
if trainer.global_step > last_save_step + config["training"]["save_freq"]:
trainer.save_checkpoint(model, f"checkpoint_{trainer.global_step}")
last_save_step = trainer.global_step
model.eval()
# Evaluates the model
if trainer.global_step > last_eval_step + config["evaluation"]["eval_freq"]:
# Evaluates with actions predicted from the model
evaluator_inferred_actions.evaluate(model, trainer.global_step)
# Evaluates with actions predicted from the model in one hot version
# Disabled to improve evaluation time
#evaluator_inferred_actions_onehot.evaluate(model, trainer.global_step)
if config["data"]["ground_truth_available"]:
# Evaluates with ground truth actions translated to the model action space
# Uses the mapping between inferred and ground truth actions to configure the
# ground truth action space -> model action space translation function
action_mapping = evaluator_inferred_actions.get_best_action_mappings()
ground_truth_action_sampler = GroundTruthActionSampler(action_mapping)
evaluator_ground_truth_actions.set_action_sampler(ground_truth_action_sampler)
evaluator_ground_truth_actions.evaluate(model, trainer.global_step)
last_eval_step = trainer.global_step