-
Notifications
You must be signed in to change notification settings - Fork 4
/
se-mlp.yaml
149 lines (122 loc) · 3.75 KB
/
se-mlp.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# ################################
# Model: Space-based GW detection and extraction
# ################################
#
# Basic parameters
# Seed needs to be set at top of yaml, before objects with parameters are made
#
seed: 1607
__set_seed: !apply:torch.manual_seed [!ref <seed>]
# cuda device num
cuda: 1
# Data params
data_folder:
data_hdf5:
noise_hdf5:
experiment_name: gw
#----------------------------------------
output_folder: !ref results/<experiment_name>/<seed>
train_log: !ref <output_folder>/train_log.txt
save_folder: !ref <output_folder>/save
# Experiment params
auto_mix_prec: False
test_only: False
num_spks: 1
progressbar: True
save_inf_data: False
save_attention_weights: False
# se loss * alpha + clsf loss * (1 - alpha)
alpha: 1
inf_data: !ref <save_folder>/inf_test/
# att_data: !ref <save_folder>/inf_test/
# Training parameters
N_epochs: 100
batch_size: 16
lr: 0.0005
clip_grad_norm: 5
loss_upper_lim: 999999 # this is the upper limit for an acceptable loss
# if True, the training sequences are cut to a specified length
limit_training_signal_len: False
# this is the length of sequences if we choose to limit
# the signal length of training sequences
training_signal_len: 4000
# loss thresholding -- this thresholds the training loss
threshold_byloss: True
threshold: -50
# Encoder parameters
N_encoder_out: 256
out_channels: 256
kernel_size: 16
kernel_stride: 8
# Specifying the network
Encoder: !new:speechbrain.nnet.linear.Linear
input_shape: [!ref <batch_size>, !ref <training_signal_len>, 1]
n_neurons: !ref <N_encoder_out>
SBtfintra: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock
num_layers: 2
d_model: !ref <out_channels>
nhead: 4
d_ffn: 256
dropout: 0
use_positional_encoding: True
norm_before: True
SBtfinter: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock
num_layers: 2
d_model: !ref <out_channels>
nhead: 4
d_ffn: 256
dropout: 0
use_positional_encoding: True
norm_before: True
MaskNet: !new:speechbrain.lobes.models.dual_path.Dual_Path_Model
num_spks: !ref <num_spks>
in_channels: !ref <N_encoder_out>
out_channels: !ref <out_channels>
num_layers: 2
K: 25
intra_model: !ref <SBtfintra>
inter_model: !ref <SBtfinter>
norm: ln
linear_layer_after_inter_intra: False
skip_around_intra: True
Decoder: !new:speechbrain.nnet.linear.Linear
input_shape: [!ref <batch_size>, !ref <training_signal_len>, !ref <N_encoder_out>]
n_neurons: 1
linear_1: !new:speechbrain.nnet.linear.Linear
input_size: !ref <training_signal_len>
n_neurons: 512
relu: !new:torch.nn.ReLU
linear_2: !new:speechbrain.nnet.linear.Linear
input_size: 512
n_neurons: 1
optimizer: !name:torch.optim.Adam
lr: !ref <lr>
weight_decay: 0
loss: !name:speechbrain.nnet.losses.mse_loss
# loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper
loss2: !name:speechbrain.nnet.losses.bce_loss
lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau
factor: 0.5
patience: 2
dont_halve_until_epoch: 35
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <N_epochs>
modules:
encoder: !ref <Encoder>
decoder: !ref <Decoder>
masknet: !ref <MaskNet>
linear_1: !ref <linear_1>
linear_2: !ref <linear_2>
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
encoder: !ref <Encoder>
decoder: !ref <Decoder>
masknet: !ref <MaskNet>
linear_1: !ref <linear_1>
linear_2: !ref <linear_2>
counter: !ref <epoch_counter>
lr_scheduler: !ref <lr_scheduler>
# mlp: !ref <MLP>
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>