This repository has been archived by the owner on Feb 24, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_train.py
78 lines (72 loc) · 2.76 KB
/
data_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from skimage.io import imread
import time
import os
import configparser
from sklearn.model_selection import train_test_split
from keras.callbacks import ModelCheckpoint
from utils.processing import *
from utils.net import Unet
from utils.generator import m_gen
from utils.save_info import save_info
from utils.visualization import Training_visualization
import json
########################################
# Set up training configuration
########################################
config = configparser.RawConfigParser()
config.read('config.txt')
train_path = config.get('data path','train_path')
dir_path = config.get('data path','result_dir')
gpu_usage = config.get('model settings','gpu_usage')
seed = int(config.get('model settings','seed'))
epoch = int(config.get('train settings','epoch'))
steps_per_epoch = int(config.get('train settings','step_per_epoch'))
split_rate = float(config.get('train settings','split_rate'))
rfile = config.get('task settings','raw_file')
mfile = config.get('task settings','mask_file')
########################################
# Set up GPU usage mode
########################################
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_usage
########################################
# Read & Prepare training data
########################################
raw = prep_raw(imread(train_path+rfile))
mask = train_mask(imread(train_path+mfile))
rtrain,rval = train_test_split(raw,test_size=split_rate, random_state=seed)
mtrain,mval = train_test_split(mask,test_size=split_rate, random_state=seed)
########################################
# Training
########################################
net = Unet()
checkpoint= ModelCheckpoint(dir_path+'model.hdf5',monitor='val_loss',verbose=1,mode='min',save_best_only=True)
print('\n...Training...')
start_time = time.time()
history = net.fit_generator(m_gen(rtrain,mtrain),
validation_data=(rval,mval),
shuffle=True,
callbacks=[checkpoint],
epochs=epoch,
steps_per_epoch=steps_per_epoch,
verbose=1)
end_time = time.time()
sum_time = end_time-start_time
print('last %.2f seconds'%sum_time)
print('\nTraining Finished!')
########################################
# Draw the training curve
########################################
Training_visualization(history,epoch,dir_path)
########################################
# Save the training info
########################################
info = {}
#train
info['acc'] = history.history['binary_accuracy'][-1]
info['loss'] = history.history['loss'][-1]
info['val_acc'] = history.history['val_binary_accuracy'][-1]
info['val_loss'] = history.history['val_loss'][-1]
#model
info['duration'] = sum_time
info['epoch'] = epoch
save_info(info,dir_path)