-
Notifications
You must be signed in to change notification settings - Fork 0
/
5_2_train_semantic_baseline.py
125 lines (93 loc) · 3.81 KB
/
5_2_train_semantic_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
This script trains the end-to-end semantic baseline for comparing with the
proposed SCCS-R system.
"""
# ------------------------------------------------------------------------------
# imports
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from modules.e2e_baseline import EndToEndSemanticModel
# ------------------------------------------------------------------------------
def main(latent_dim: int, mod_type: str, channel: str):
TASKS = ['shapes', 'colors', 'isSpeedLimit']
N_DECODERS = len(TASKS)
N_CLASSES = [4, 4, 2]
N_IMAGES = 39209
IM_SIZE = 64
N_CHANNELS = 3
DATA_DIR = 'local/memmap_data/'
IMAGES_PATH = DATA_DIR + 'signs_dom_trn_data.npy'
LABELS_PATHS = [
DATA_DIR + f'signs_dom_trn_labels_{TASK}.npy' for TASK in TASKS
]
STOCH_BIN = True
CODE_RATE = 1.0/8.0
EPOCHS = 25
BATCH_SIZE = 64
EPOCHS: int = 200
STEPS_PER_EPOCH: int = 200 # TRAIN_SET_SIZE // BATCH_SIZE + 1
SCHEDULE: str = 'cosine'
SCH_INIT: float = 1e-4
SCH_WARMUP_EPOCHS: int = 40
SCH_WARMUP_STEPS: int = STEPS_PER_EPOCH * SCH_WARMUP_EPOCHS
SCH_TARGET: float = 1e-3
SCH_DECAY_STEPS: int = STEPS_PER_EPOCH * (EPOCHS - SCH_WARMUP_EPOCHS)
if not os.path.exists('local/models/end_to_end'):
os.makedirs('local/models/end_to_end')
# --- load in the data ---
images = np.memmap(IMAGES_PATH, dtype=np.float32, mode='r',
shape=(N_IMAGES, IM_SIZE, IM_SIZE, N_CHANNELS))
labels = []
for N_CLASS, LABELS_PATH in zip(N_CLASSES, LABELS_PATHS):
labels.append(
np.memmap(LABELS_PATH, dtype=np.float32, mode='r',
shape=(N_IMAGES, N_CLASS))
)
# --- instantiate the model ---
model = EndToEndSemanticModel(TASKS, latent_dim, N_CLASSES, N_DECODERS,
STOCH_BIN, mod_type, CODE_RATE, channel)
model(images[:10])
model.summary()
# --- compile and train the model ---
if SCHEDULE == 'cosine':
schedule = tf.keras.optimizers.schedules.CosineDecay(
initial_learning_rate=SCH_INIT,
decay_steps=SCH_DECAY_STEPS,
warmup_target=SCH_TARGET,
warmup_steps=SCH_WARMUP_STEPS
)
elif SCHEDULE is None:
schedule = SCH_INIT
else:
raise ValueError('Invalid learning rate schedule.')
optimizer = tf.keras.optimizers.Adam(learning_rate=schedule)
loss = [
tf.keras.losses.CategoricalCrossentropy(from_logits=False),
tf.keras.losses.CategoricalCrossentropy(from_logits=False),
tf.keras.losses.CategoricalCrossentropy(from_logits=False)
]
model.compile(optimizer=optimizer, loss=loss,
loss_weights=[1.0, 1.0, 1.0],
metrics=['accuracy', 'accuracy', 'accuracy'])
history = model.fit(images, labels, epochs=EPOCHS, batch_size=BATCH_SIZE)
# --- save the models ---
model.encoder.save(
f'local/models/end_to_end/len{latent_dim}_{mod_type}_{channel}_encoder.keras')
for i in range(N_DECODERS):
if TASKS[i] in model.decoders[i].name:
model.decoders[i].save(
f'local/models/end_to_end/{TASKS[i]}_len{latent_dim}_{mod_type}_{channel}_decoder.keras'
)
else:
raise ValueError('Model name mismatch!')
# ------------------------------------------------------------------------------
if __name__ == '__main__':
LATENT_DIMS = [10]
MOD_TYPES = ['BPSK']
CHANNELS = ['awgn']
for latent_dim in LATENT_DIMS:
for mod_type, channel in zip(MOD_TYPES, CHANNELS):
main(latent_dim, mod_type, channel)
# ------------------------------------------------------------------------------