forked from MasterprojectRK/Hi-cGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
282 lines (268 loc) · 14.7 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import csv
import os
import click
import numpy as np
import tensorflow as tf
import hicGAN
import dataContainer
import records
@click.option("--trainMatrices", "-tm", required=True,
type=click.Path(exists=True, dir_okay=False, readable=True), multiple=True,
help="Cooler matrices for training. Use this option multiple times to specify more than one matrix. First matrix belongs to first trainChromPath")
@click.option("--trainChroms", "-tchroms", required=True,
type=str,
help="Train chromosomes. Must be present in all train matrices. Specify multiple chroms separated by spaces, e.g. '10 11 12'.")
@click.option("--trainChromPaths", "-tcp", required=True,
type=click.Path(exists=True, file_okay=False, readable=True), multiple=True,
help="Path where chromatin factors for training reside (bigwig files). Use this option multiple times to specify more than one path. First path belongs to first train matrix")
@click.option("--valMatrices", "-vm", required=True,
type=click.Path(exists=True, dir_okay=False, readable=True), multiple=True,
help="Cooler matrices for validation. Use this option multiple times to specify more than one matrix")
@click.option("--valChroms", "-vchroms", required=True,
type=str,
help="Validation chromosomes. Must be present in all validation matrices. Specify multiple chroms separated by spaces, e.g. '1 2 3'.")
@click.option("--valChromPaths", "-vcp", required=True,
type=click.Path(exists=True, file_okay=False, readable=True), multiple=True,
help="Path where chromatin factors for validation reside (bigwig files). Use this option multiple times to specify more than one path. First path belongs to first validation matrix etc.")
@click.option("--windowsize", "-ws", required=True,
type=click.Choice(["64", "128", "256"]),
default="64", show_default=True,
help="Windowsize for submatrices. 64, 128 and 256 are supported")
@click.option("--outfolder", "-o", required=True,
type=click.Path(exists=True, writable=True, file_okay=False),
help="Folder where trained model and diverse outputs will be stored")
@click.option("--epochs", "-ep", required=True,
type=click.IntRange(min=1),
default=2, show_default=True)
@click.option("--batchsize", "-bs", required=True,
type=click.IntRange(min=1, max=256),
default=32, show_default=True,
help="Batch size for training, choose integer in [1, 256]")
@click.option("--lossWeightPixel", "-lwp", required=False,
type=click.FloatRange(min=1e-10),
default=100.0, show_default=True,
help="loss weight for L1/L2 error of generator")
@click.option("--lossWeightDisc", "-lwd", required=False,
type=click.FloatRange(min=1e-10),
default=0.5, show_default=True,
help="loss weight (multiplicator) for the discriminator loss")
@click.option("--lossTypePixel", "-ltp", required=False,
type=click.Choice(["L1", "L2"]),
default="L1", show_default=True,
help="Type of per-pixel loss to use for the generator; choose from L1 (mean abs. error) or L2 (mean squared error)")
@click.option("--lossWeightTv", "-lwt", required=False,
type=click.FloatRange(min=0.0),
default=1e-10, show_default=True,
help="loss weight for Total-Variation-loss of generator; higher value - more smoothing")
@click.option("--lossWeightAdv", "-lwa", required=False,
type=click.FloatRange(min=1e-10),
default=1.0, show_default=True,
help="loss weight for adversarial loss in generator")
@click.option("--learningRateGen", "-lrg", required=False,
type=click.FloatRange(min=1e-10, max=1.0),
default=2e-5, show_default=True,
help="learning rate for Adam optimizer of generator")
@click.option("--learningRateDisc", "-lrd", required=False,
type=click.FloatRange(min=1e-10, max=1.0),
default=1e-6, show_default=True,
help="learning rate for Adam optimizer of discriminator")
@click.option("--beta1", "-b1", required=False,
type=click.FloatRange(min=1e-2, max=1.0),
default=0.5, show_default=True,
help="beta1 parameter for Adam optimizers (gen. and disc.)")
@click.option("--flipsamples", "-fs", required=False,
type=bool, default=False, show_default=True,
help="Flip training matrices and chromatin features (data augmentation)")
@click.option("--embeddingType", "-emb", required=False,
type=click.Choice(["CNN", "DNN", "mixed"]),
default="CNN", show_default=True,
help="Type of embedding to use for generator and discriminator. CNN, DNN, or mixed (Gen: CNN, Disc: DNN)")
@click.option("--pretrainedIntroModel", "-ptm", required=False,
type=click.Path(exists=True, dir_okay=False, readable=True),
help="pretrained model for 1D-2D conversion of inputs")
@click.option("--figuretype", "-ft", required=False,
type=click.Choice(["png", "pdf", "svg"]),
default="png", show_default=True,
help="Figure type for all plots")
@click.option("--recordsize", "-rs", required=False,
type=click.IntRange(min=10),
default=2000, show_default=True,
help="Approx. size (number of samples) of the tfRecords used in the data pipeline for training. Lower values = less memory consumption, but maybe longer runtime")
@click.option("--plotFrequency", "-pfreq", required=False,
type=click.IntRange(min=1),
default=10, show_default=True,
help="Update loss over epoch plots after this number of epochs")
@click.command()
def training(trainmatrices,
trainchroms,
trainchrompaths,
valmatrices,
valchroms,
valchrompaths,
windowsize,
outfolder,
epochs,
batchsize,
lossweightpixel,
lossweightdisc,
lossweightadv,
losstypepixel,
lossweighttv,
learningrategen,
learningratedisc,
beta1,
flipsamples,
embeddingtype,
pretrainedintromodel,
figuretype,
recordsize,
plotfrequency):
#few constants
windowsize = int(windowsize)
debugstate = None
paramDict = locals().copy()
#remove spaces, commas and "chr" from the train and val chromosome lists
#ensure each chrom name is used only once, but allow the same chrom for train and validation
#sort the lists and write to param dict
trainChromNameList = trainchroms.replace(",","")
trainChromNameList = trainChromNameList.rstrip().split(" ")
trainChromNameList = [x.lstrip("chr") for x in trainChromNameList]
trainChromNameList = sorted(list(set(trainChromNameList)))
paramDict["trainChromNameList"] = trainChromNameList
valChromNameList = valchroms.replace(",","")
valChromNameList = valChromNameList.rstrip().split(" ")
valChromNameList = [x.lstrip("chr") for x in valChromNameList]
valChromNameList = sorted(list(set(valChromNameList)))
paramDict["valChromNameList"] = valChromNameList
#ensure there are as many matrices as chromatin paths
if len(trainmatrices) != len(trainchrompaths):
msg = "Number of train matrices and chromatin paths must match\n"
msg += "Current numbers: Matrices: {:d}; Chromatin Paths: {:d}"
msg = msg.format(len(trainmatrices), len(trainchrompaths))
raise SystemExit(msg)
if len(valmatrices) != len(valchrompaths):
msg = "Number of validation matrices and chromatin paths must match\n"
msg += "Current numbers: Matrices: {:d}; Chromatin Paths: {:d}"
msg = msg.format(len(valmatrices), len(valchrompaths))
raise SystemExit(msg)
#prepare the training data containers. No data is loaded yet.
traindataContainerList = []
for chrom in trainChromNameList:
for matrix, chromatinpath in zip(trainmatrices, trainchrompaths):
container = dataContainer.DataContainer(chromosome=chrom,
matrixfilepath=matrix,
chromatinFolder=chromatinpath)
traindataContainerList.append(container)
#prepare the validation data containers. No data is loaded yet.
valdataContainerList = []
for chrom in valChromNameList:
for matrix, chromatinpath in zip(valmatrices, valchrompaths):
container = dataContainer.DataContainer(chromosome=chrom,
matrixfilepath=matrix,
chromatinFolder=chromatinpath)
valdataContainerList.append(container)
#define the load params for the containers
loadParams = {"scaleFeatures": True,
"clampFeatures": False,
"scaleTargets": True,
"windowsize": windowsize,
"flankingsize": windowsize,
"maxdist": None}
#now load the data and write TFRecords, one container at a time.
if len(traindataContainerList) == 0:
msg = "Exiting. No data found"
print(msg)
return #nothing to do
container0 = traindataContainerList[0]
tfRecordFilenames = []
nr_samples_list = []
for container in traindataContainerList + valdataContainerList:
container.loadData(**loadParams)
if not container0.checkCompatibility(container):
msg = "Aborting. Incompatible data"
raise SystemExit(msg)
tfRecordFilenames.append(container.writeTFRecord(pOutfolder=outfolder,
pRecordSize=recordsize))
if debugstate is not None:
if isinstance(debugstate, int):
idx = debugstate
else:
idx = None
container.plotFeatureAtIndex(idx=idx,
outpath=outfolder,
figuretype=figuretype)
container.saveMatrix(outputpath=outfolder, index=idx)
nr_samples_list.append(container.getNumberSamples())
#data is no longer needed
for container in traindataContainerList + valdataContainerList:
container.unloadData()
traindataRecords = [item for sublist in tfRecordFilenames[0:len(traindataContainerList)] for item in sublist]
valdataRecords = [item for sublist in tfRecordFilenames[len(traindataContainerList):] for item in sublist]
#different binsizes are ok
#not clear which binsize to use for prediction when they differ during training.
#For now, store the max.
binsize = max([container.binsize for container in traindataContainerList])
paramDict["binsize"] = binsize
#because of compatibility checks above,
#the following properties are the same with all containers,
#so just use data from first container
nr_factors = container0.nr_factors
paramDict["nr_factors"] = nr_factors
for i in range(nr_factors):
paramDict["chromFactor_" + str(i)] = container0.factorNames[i]
nr_trainingSamples = sum(nr_samples_list[0:len(traindataContainerList)])
storedFeaturesDict = container0.storedFeatures
#save the training parameters to a file before starting to train
#(allows recovering the parameters even if training is aborted
# and only intermediate models are available)
parameterFile = os.path.join(outfolder, "trainParams.csv")
with open(parameterFile, "w") as csvfile:
dictWriter = csv.DictWriter(csvfile, fieldnames=sorted(list(paramDict.keys())))
dictWriter.writeheader()
dictWriter.writerow(paramDict)
#build the input streams for training
shuffleBufferSize = 3*recordsize
trainDs = tf.data.TFRecordDataset(traindataRecords,
num_parallel_reads=tf.data.experimental.AUTOTUNE,
compression_type="GZIP")
trainDs = trainDs.map(lambda x: records.parse_function(x, storedFeaturesDict), num_parallel_calls=tf.data.experimental.AUTOTUNE)
if flipsamples:
flippedDs = trainDs.map(lambda a,b: records.mirror_function(a["factorData"], b["out_matrixData"]))
trainDs = trainDs.concatenate(flippedDs)
trainDs = trainDs.shuffle(buffer_size=shuffleBufferSize, reshuffle_each_iteration=True)
trainDs = trainDs.batch(batchsize, drop_remainder=True)
trainDs = trainDs.prefetch(tf.data.experimental.AUTOTUNE)
#build the input streams for validation
validationDs = tf.data.TFRecordDataset(valdataRecords,
num_parallel_reads=tf.data.experimental.AUTOTUNE,
compression_type="GZIP")
validationDs = validationDs.map(lambda x: records.parse_function(x, storedFeaturesDict) , num_parallel_calls=tf.data.experimental.AUTOTUNE)
validationDs = validationDs.batch(batchsize)
validationDs = validationDs.prefetch(tf.data.experimental.AUTOTUNE)
steps_per_epoch = int( np.floor(nr_trainingSamples / batchsize) )
if flipsamples:
steps_per_epoch *= 2
if pretrainedintromodel is None:
pretrainedintromodel = ""
hicGanModel = hicGAN.HiCGAN(log_dir=outfolder,
number_factors=nr_factors,
loss_weight_pixel=lossweightpixel,
loss_weight_adversarial=lossweightadv,
loss_weight_discriminator=lossweightdisc,
loss_type_pixel=losstypepixel,
loss_weight_tv=lossweighttv,
input_size=windowsize,
learning_rate_generator=learningrategen,
learning_rate_discriminator=learningratedisc,
adam_beta_1=beta1,
plot_type=figuretype,
plot_frequency=plotfrequency,
embedding_model_type=embeddingtype,
pretrained_model_path=pretrainedintromodel)
hicGanModel.plotModels(outputpath=outfolder, figuretype=figuretype)
hicGanModel.fit(train_ds=trainDs, epochs=epochs, test_ds=validationDs, steps_per_epoch=steps_per_epoch)
for tfRecordfile in traindataRecords + valdataRecords:
if os.path.exists(tfRecordfile):
os.remove(tfRecordfile)
if __name__ == "__main__":
training() #pylint: disable=no-value-for-parameter