-
Notifications
You must be signed in to change notification settings - Fork 8
/
train.py
154 lines (110 loc) · 5.01 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import argparse
import datetime
import os
import time
from typing import NamedTuple
from colorama import Fore, init
from keras.callbacks import CSVLogger, ModelCheckpoint, TensorBoard
import env
from common.decode import create_decoder
from common.files import is_dir, make_dir_if_not_exists
from core.callbacks.error_rates import ErrorRates
from core.generators.dataset_generator import DatasetGenerator
from core.model.lipnet import LipNet
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
init(autoreset=True)
ROOT_PATH = os.path.dirname(os.path.realpath(__file__))
OUTPUT_DIR = os.path.realpath(os.path.join(ROOT_PATH, 'data', 'res'))
LOG_DIR = os.path.realpath(os.path.join(ROOT_PATH, 'data', 'res_logs'))
DICTIONARY_PATH = os.path.realpath(os.path.join(ROOT_PATH, 'data', 'dictionaries', 'grid.txt'))
class TrainingConfig(NamedTuple):
dataset_path: str
aligns_path: str
epochs: int = 1
frame_count: int = env.FRAME_COUNT
image_width: int = env.IMAGE_WIDTH
image_height: int = env.IMAGE_HEIGHT
image_channels: int = env.IMAGE_CHANNELS
max_string: int = env.MAX_STRING
batch_size: int = env.BATCH_SIZE
val_split: float = env.VAL_SPLIT
use_cache: bool = True
def main():
"""
Entry point of the script for training a model.
i.e: python train.py -d data/dataset -a data/aligns -e 150
"""
print(r'''
__ __ ______ __ __ ______ ______
/\ \ /\ \ /\ == \ /\ "-.\ \ /\ ___\ /\__ _\
\ \ \____ \ \ \ \ \ _-/ \ \ \-. \ \ \ __\ \/_/\ \/
\ \_____\ \ \_\ \ \_\ \ \_\\"\_\ \ \_____\ \ \_\
\/_____/ \/_/ \/_/ \/_/ \/_/ \/_____/ \/_/
implemented by Omar Salinas
''')
ap = argparse.ArgumentParser()
ap.add_argument('-d', '--dataset-path', required=True, help='Path to the dataset root directory')
ap.add_argument('-a', '--aligns-path', required=True, help='Path to the directory containing all align files')
ap.add_argument('-e', '--epochs', required=False, help='(Optional) Number of epochs to run', type=int, default=1)
ap.add_argument('-ic', '--ignore-cache', required=False, help='(Optional) Force the generator to ignore the cache file', action='store_true', default=False)
args = vars(ap.parse_args())
dataset_path = os.path.realpath(args['dataset_path'])
aligns_path = os.path.realpath(args['aligns_path'])
epochs = args['epochs']
ignore_cache = args['ignore_cache']
if not is_dir(dataset_path):
print(Fore.RED + '\nERROR: The dataset path is not a directory')
return
if not is_dir(aligns_path):
print(Fore.RED + '\nERROR: The aligns path is not a directory')
return
if not isinstance(epochs, int) or epochs <= 0:
print(Fore.RED + '\nERROR: The number of epochs must be a valid integer greater than zero')
return
name = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M')
config = TrainingConfig(dataset_path, aligns_path, epochs=epochs, use_cache=not ignore_cache)
train(name, config)
def train(run_name: str, config: TrainingConfig):
print("\nTRAINING: {}\n".format(run_name))
print('For dataset at: {}'.format(config.dataset_path))
print('With aligns at: {}'.format(config.aligns_path))
make_dir_if_not_exists(OUTPUT_DIR)
make_dir_if_not_exists(LOG_DIR)
lipnet = LipNet(config.frame_count, config.image_channels, config.image_height, config.image_width, config.max_string).compile_model()
datagen = DatasetGenerator(config.dataset_path, config.aligns_path, config.batch_size, config.max_string, config.val_split, config.use_cache)
callbacks = create_callbacks(run_name, lipnet, datagen)
print('\nStarting training...\n')
start_time = time.time()
lipnet.model.fit_generator(
generator =datagen.train_generator,
validation_data=datagen.val_generator,
epochs =config.epochs,
verbose =1,
shuffle =True,
max_queue_size =5,
workers =2,
callbacks =callbacks,
use_multiprocessing=True
)
elapsed_time = time.time() - start_time
print('\nTraining completed in: {}'.format(datetime.timedelta(seconds=elapsed_time)))
def create_callbacks(run_name: str, lipnet: LipNet, datagen: DatasetGenerator) -> list:
run_log_dir = os.path.join(LOG_DIR, run_name)
make_dir_if_not_exists(run_log_dir)
# Tensorboard
tensorboard = TensorBoard(log_dir=run_log_dir)
# Training logger
csv_log = os.path.join(run_log_dir, 'training.csv')
csv_logger = CSVLogger(csv_log, separator=',', append=True)
# Model checkpoint saver
checkpoint_dir = os.path.join(OUTPUT_DIR, run_name)
make_dir_if_not_exists(checkpoint_dir)
checkpoint_template = os.path.join(checkpoint_dir, "lipnet_{epoch:03d}_{val_loss:.2f}.hdf5")
checkpoint = ModelCheckpoint(checkpoint_template, monitor='val_loss', save_weights_only=True, mode='auto', period=1, verbose=1)
# WER/CER Error rate calculator
error_rate_log = os.path.join(run_log_dir, 'error_rates.csv')
decoder = create_decoder(DICTIONARY_PATH, False)
error_rates = ErrorRates(error_rate_log, lipnet, datagen.val_generator, decoder)
return [checkpoint, csv_logger, error_rates, tensorboard]
if __name__ == '__main__':
main()