-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
149 lines (128 loc) · 4.26 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
import args
import cocostuff
import data
import loss
import preprocessing as pre
import setup
import utils
from inc.config_snake.config import ConfigFile
from model import Model
def interface(config_file_path):
config_file = ConfigFile(config_file_path)
config = config_file.segmentation
assert config.training.validation_mode == "IID"
assert config.training.eval_mode == "hung"
train(config)
def train(config):
# SETUP
components = setup.setup(config)
image_info = components["image_info"]
heads_info = components["heads_info"]
output_files = components["output_files"]
state_folder = components["state_folder"]
image_folder = components["image_folder"]
label_folder = components["label_folder"]
net = components["net"]
# FORCE RESTART
if config.output.force_training_restart:
output_files.clear_output()
# OPTIMIZER
optimizer = torch.optim.Adam(net.parameters(), lr=config.optimizer.learning_rate)
# MODEL
state_folder = output_files.get_sub_root(output_files.STATE)
if Model.exists(state_folder):
# LOAD EXISTING
model = Model.load(state_folder, net=net, optimizer=optimizer)
else:
# CREATE NEW
# LOSS
iid_loss = loss.IIDLoss(
heads=heads_info.order,
output_files=output_files,
do_render=config.output.rendering.enabled,
**config.training.loss
)
# STATISTICS
epoch_stats = utils.EpochStatistics(
limit=config.training.num_epochs, output_files=output_files
)
model = Model(
state_folder=state_folder,
heads_info=heads_info,
net=net,
optimizer=optimizer,
loss_fn=iid_loss,
epoch_statistics=epoch_stats,
)
# PREPROCESSING
# transformation
transformation = pre.Transformation(**config.transformations)
# label mapping
LABEL_FILTERS = {"CocoFewLabels": cocostuff.CocoFewLabels}
if (
"label_filter" in config.dataset
and config.dataset.label_filter.name in LABEL_FILTERS
):
label_filter = LABEL_FILTERS[config.dataset.label_filter.name](
class_count=heads_info.class_count, **config.dataset.label_filter.parameters
)
label_mapper = pre.LabelMapper(mapping_function=label_filter.apply)
else:
print("unable to find label mapper, using identity mapping")
label_mapper = pre.LabelMapper()
# general preprocessing
preprocessing = pre.TransformPreprocessing(
transformation=transformation,
image_info=image_info,
label_mapper=label_mapper,
**config.preprocessor
)
# TRAIN DATALOADER
train_prep = pre.TrainImagePreprocessor(
image_info=image_info,
preprocessing=preprocessing,
output_files=output_files,
do_render=config.output.rendering.enabled,
render_limit=config.output.rendering.limit,
)
train_dataset = data.ImageFolderDataset(
image_folder=image_folder,
preprocessor=train_prep,
extensions=config.dataset.extensions,
label_folder=label_folder,
)
train_dataloader = data.TrainDataLoader(
dataset=train_dataset,
batch_size=config.training.batch_size,
shuffle=config.training.shuffle,
)
# TEST DATALOADER
test_prep = pre.TestImagePreprocessor(
image_info=image_info,
preprocessing=preprocessing,
output_files=output_files,
do_render=config.output.rendering.enabled,
render_limit=config.output.rendering.limit,
)
test_dataset = data.ImageFolderDataset(
image_folder=image_folder,
preprocessor=test_prep,
extensions=config.dataset.extensions,
label_folder=label_folder,
)
test_dataloader = data.TestDataLoader(
dataset=test_dataset, batch_size=config.training.batch_size
)
# DATALOADERS
# TODO link this to arch heads
dataloaders = {
"A": train_dataloader,
"B": train_dataloader,
"map_assign": test_dataloader,
"map_test": test_dataloader,
}
model.train(loaders=dataloaders)
if __name__ == "__main__":
arguments = args.Arguments()
interface(config_file_path=arguments.config_file_path)