-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
60 lines (42 loc) · 1.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import time
import torch
import config
import data_model
import utils
from utils import save_results
if not utils.continue_training(enabled=False):
print("Starting new training routine...")
config.CURRENT_PATH = utils.create_current_folder()
utils.setup_env()
# utils.save_visualization(data_model.train_visualize, "Train", enabled=True)
# utils.save_visualization(data_model.valid_visualize, "Valid", enabled=True)
# utils.save_visualization(data_model.test_visualize, "Test", enabled=True)
max_score = 0
start_time = time.time()
try:
while True:
torch.cuda.empty_cache()
print(f'\nEpoch: {config.EPOCH_COUNT}')
train_logs = data_model.train_epoch.run(data_model.train_loader)
valid_logs = data_model.valid_epoch.run(data_model.valid_loader)
if config.EPOCH_COUNT % 30 == 0:
test_log = data_model.evaluate_test_data()
test_log = 0
learning_score = valid_logs['iou_score']
utils.update_plot(train_logs, valid_logs, test_log, enabled=True)
if max_score < learning_score:
max_score = learning_score
utils.save_model()
if config.EPOCH_COUNT == 30:
config.optimizer.param_groups[0]['lr'] = 0.00001
print('Decrease decoder learning rate to 1e-5!')
if config.EPOCH_COUNT == 200:
config.optimizer.param_groups[0]['lr'] = 0.000001
print('Decrease decoder learning rate to 1e-6!')
config.EPOCH_COUNT += 1
except KeyboardInterrupt:
print("Learning interrupted by user.")
config.ELAPSED_TIME = time.time() - start_time
test_log = data_model.evaluate_test_data()
config.save_stats(test_log)
save_results(data_model.test_visualize, data_model.test_dataset, count=-1)