-
Notifications
You must be signed in to change notification settings - Fork 2
/
config_file.py
executable file
·68 lines (51 loc) · 2.81 KB
/
config_file.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright 2019 Gabriele Valvano
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
RUN_ID = 'SDTNet'
CUDA_VISIBLE_DEVICE = 0
data_path = './data/acdc_data'
def define_flags():
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('RUN_ID', RUN_ID, "")
# ____________________________________________________ #
# ========== ARCHITECTURE HYPER-PARAMETERS ========== #
# Learning rate:
tf.flags.DEFINE_float('lr', 1e-4, 'learning rate')
# batch size
tf.flags.DEFINE_integer('b_size', 7, "batch size")
tf.flags.DEFINE_integer('n_anatomical_masks', 8, "number of extracted anatomical masks")
tf.flags.DEFINE_integer('n_frame_composing_masks', 8, "number composing masks for next frame mask prediction")
tf.flags.DEFINE_integer('nz_latent', 8, "number latent variable for z code (encoder modality)")
tf.flags.DEFINE_integer('CUDA_VISIBLE_DEVICE', CUDA_VISIBLE_DEVICE, "visible gpu")
# ____________________________________________________ #
# =============== TRAINING STRATEGY ================== #
tf.flags.DEFINE_bool('augment', True, "Perform data augmentation")
tf.flags.DEFINE_bool('standardize', False, "Perform data standardization (z-score)") # data already pre-processed
# (others, such as learning rate decay params...)
# ____________________________________________________ #
# =============== INTERNAL VARIABLES ================= #
# internal variables:
tf.flags.DEFINE_integer('num_threads', 20, "number of threads for loading data")
tf.flags.DEFINE_integer('skip_step', 4000, "frequency of printing batch report")
tf.flags.DEFINE_integer('train_summaries_skip', 10, "number of skips before writing summaries for training steps "
"(used to reduce its verbosity; put 1 to avoid this)")
tf.flags.DEFINE_bool('tensorboard_verbose', True, "if True: save also layers weights every N epochs")
# ____________________________________________________ #
# ===================== DATA SET ====================== #
# ACDC data set:
tf.flags.DEFINE_string('acdc_data_path', data_path, """Path of data files.""")
# data specs:
tf.flags.DEFINE_list('input_size', [128, 128], "input size")
tf.flags.DEFINE_integer('n_classes', 4, "number of classes")
return FLAGS