-
Notifications
You must be signed in to change notification settings - Fork 10
/
data_manager.py
395 lines (323 loc) · 14.9 KB
/
data_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
import json
import numpy as np
import pickle
import definitions
from structured_rep import *
from logical_forms import TokenTypes
from sentence_processing import preprocess_sentences, replace_rare_words_with_unk
from seq2seqModel.hyper_params import SENTENCE_DRIVEN_CONSTRAINTS_ON_BEAM_SEARCH
import sys
np.random.seed(1)
class DataSet(Enum):
TRAIN = 'train',
DEV = 'dev',
TEST = 'test',
TEST2 = 'hidden_test'
if len(sys.argv) == 2:
paths = {DataSet.TRAIN: definitions.TRAIN_JSON,
DataSet.DEV: definitions.DEV_JSON,
DataSet.TEST: definitions.TEST_JSON,
DataSet.TEST2: definitions.TEST2_JSON}
else:
paths = {DataSet.TRAIN: definitions.TRAIN_JSON,
DataSet.DEV: definitions.DEV_JSON,
DataSet.TEST: definitions.TEST_JSON}
def read_data(filename):
data = []
with open(filename) as data_file:
for line in data_file:
data.append(json.loads(line))
return data
def rewrite_data(filename, data, mapping):
with open(filename, 'w') as output_file:
for sample in data:
s_index = int(str.split(sample["identifier"], "-")[0])
sample["sentence"] = mapping[s_index]
line = json.dump(sample, output_file)
output_file.write('\n')
return
def build_data(data, preprocessing_type=None, use_unk=True):
'''
:param data: a deserialized version of a dataset json file: List[List[List[Dict[str: str]]]]
:param preprocessing_type: the type of setnece preprocessing to be used
:param use)unk : whether to replace rare word words <UNK> tokens
:return:
samples : a list of Sample objects (see structured_rep.py). Each represents in a convenient, OOP way
a single line from the data
sentences : a dictionary that maps sentence ids to the sentences, where each unique sentence appears only once
'''
samples = []
sentences = {}
for line in data:
samples.append(Sample(line))
s_index = int(str.split(line["identifier"], "-")[0])
if s_index not in sentences:
sentences[s_index] = line["sentence"]
sentences = preprocess_sentences(sentences, mode=None, processing_type=preprocessing_type)
if preprocessing_type == 'abstraction':
dicts_dict = {idx: sent[1] for idx, sent in sentences.items()}
sentences = {idx: sent[0] for idx, sent in sentences.items()}
if use_unk:
sentences = replace_rare_words_with_unk(sentences)
for s in samples:
s_index = int(str.split(s.identifier, "-")[0])
s.sentence = sentences[s_index]
if preprocessing_type == 'abstraction':
s.abstraction_dict = dicts_dict[s_index]
return samples, sentences
class CNLVRDataSet:
"""
A wrapper class for an instance of a data set from CNLVR (i.e. train, dev, or test).
This class encapsulates all the processing done for loading the data, and provides
functionality for going over the data set one batch after another, as well as some other
options.
"""
def __init__(self, dataset):
self.__dataset = dataset
self.original_sentences = {}
self.processed_sentences = {}
self.samples = {}
self.__ids = []
self._index_in_epoch = 0
self.epochs_completed = 0
self.__ids_by_complexity = []
self.pictures = {}
self.sentences_quardpled_ids = []
self.processed_sentences_singles = {}
self.get_data(paths[dataset])
@property
def name(self): # i.e. 'TRAIN'
return self.__dataset.name
@property
def num_examples(self):
return len(self.__ids)
@property
def num_single_examples(self):
return len(self.sentences_quardpled_ids)
def build_single_lists(self):
# created new ids for treating the sentences and pictures in 1:1 ratio
sents_num = len(self.original_sentences.keys())
k=0
for i in range(sents_num):
for j in range(4):
sample_original_name="{0}-{1}".format(list(self.original_sentences.keys())[i], j)
sample_new_name = "{0}".format(k)
if sample_original_name in self.samples:
self.pictures[sample_new_name] = self.samples[sample_original_name]
self.sentences_quardpled_ids.append(k)
self.processed_sentences_singles[k] = self.processed_sentences[list(self.original_sentences.keys())[i]]
k += 1
def get_samples_by_sentence_id(self, sentence_id):
samples_ids = ["{0}-{1}".format(sentence_id, i) for i in range(4)]
return [self.samples[sample_id] for sample_id in samples_ids if sample_id in self.samples]
def get_single_sample_for_sentence(self,sentence_id_q):
return [self.pictures["{0}".format(sentence_id_q)]]
def get_sentence_by_id(self, sentence_id, original=False):
if original:
return self.original_sentences[sentence_id]
return self.processed_sentences[sentence_id]
def get_data(self, path):
# this methods handles all loading and processing of the data set and thus is called
# at initialization.
data = read_data(path)
sentences = {}
for line in data:
self.samples[line["identifier"]] = Sample(line)
s_index = int(str.split(line["identifier"], "-")[0])
if s_index not in sentences:
sentences[s_index] = line["sentence"]
self.original_sentences = preprocess_sentences(sentences, processing_type='shallow')
if self.__dataset == DataSet.TRAIN:
mode = None
counts_file = None
else:
mode = 'r'
counts_file = definitions.TOKEN_COUNTS_PROCESSED
if definitions.ABSTRACTION:
self.processed_sentences = preprocess_sentences(sentences, mode=mode, processing_type='abstraction')
dicts_dict = {idx: sent[1] for idx, sent in self.processed_sentences.items()}
self.processed_sentences = {idx: sent[0] for idx, sent in self.processed_sentences.items()}
else:
self.processed_sentences = preprocess_sentences(sentences, mode=mode, processing_type='deep')
self.processed_sentences = replace_rare_words_with_unk(self.processed_sentences, counts_file)
# if self.__dataset == DataSet.TRAIN:
# self.processed_sentences = \
# preprocess_sentences(sentences, mode=None, processing_type='deep')
#
# else:
# self.processed_sentences = \
# preprocess_sentences(sentences, mode='r', processing_type='deep')
#
# if self.__dataset == DataSet.TRAIN:
# self.processed_sentences = replace_rare_words_with_unk(self.processed_sentences)
#
# else:
# self.processed_sentences = \
# replace_rare_words_with_unk(self.processed_sentences,
# definitions.TOKEN_COUNTS_PROCESSED)
for s in self.samples.values():
s_index = int(str.split(s.identifier, "-")[0])
s.sentence = self.processed_sentences[s_index]
if definitions.ABSTRACTION:
s.abstraction_dict = dicts_dict[s_index]
self.__ids = [k for k in self.original_sentences.keys()]
#self.sentences_quardpled_ids = []
self.build_single_lists()
def use_subset_by_sentnce_condition(self, f_s):
"""
limits the dataset to sentences that follow some condition only.
:param f_s: a boolean function on ids
"""
new_ids = []
for k, s in self.processed_sentences.items():
if f_s(s):
new_ids.append(k)
self.__ids = new_ids
def use_subset_by_images_condition(self, f_im):
"""
limits the dataset to sentences whose related imaes follow some rule
:param f_s: f_im is a boolean function on a set of samples
"""
new_ids = []
for k, s in self.processed_sentences.items():
related_samples = self.get_samples_by_sentence_id(k)
if f_im(related_samples):
new_ids.append(k)
self.__ids = new_ids
def ignore_all_true_samples(self):
"""
limits the dataset to sentences that are not true about all their images -
should help avoid spurious signal (there are about 10% such sentences in the training set)
"""
all_true_filter = lambda s_samples: not all([s.label == True for s in s_samples])
self.use_subset_by_images_condition(all_true_filter)
def sort_sentences_by_complexity(self, complexity_measure, n_classes):
'''
sorts the data into n_classes sets by some measure of complexity of the sentences.
can be used for curriculum learning
'''
self.__ids_by_complexity = []
ids_sorted_by_sentence_length = sorted(self.processed_sentences.keys(), key=
lambda key: complexity_measure(self.processed_sentences[key]))
class_size = len(self.processed_sentences) // n_classes
for i in range(n_classes):
self.__ids_by_complexity.append(ids_sorted_by_sentence_length[
class_size * i: min(class_size * i + class_size,
len(self.processed_sentences))])
return
def choose_levels_for_curriculum_learning(self, levels):
self.__ids = [idx for idx in set(ind for level in levels for ind in self.__ids_by_complexity[level])]
def restart(self):
'''
restart the state of the data set (the is no need to reload it from disk - just call this method)
'''
self.__ids = [k for k in self.original_sentences.keys()]
self._index_in_epoch = 0
self.epochs_completed = 0
self.sentences_quardpled_ids=[x for x in self.processed_sentences_singles.keys()]
def next_batch(self, batch_size):
'''
return the next batch of (sentence, related samples) pairs.
also habdles the logic of moving between epochs.
'''
if batch_size <= 0 or batch_size > self.num_examples:
raise ValueError("invalid argument for batch size: {}".format(batch_size))
if self._index_in_epoch == 0:
np.random.shuffle(self.__ids) # shuffle index
start = self._index_in_epoch
# go to the next batch
if start + batch_size > self.num_examples:
batch_size = self.num_examples - start
self._index_in_epoch += batch_size
end = self._index_in_epoch
indices = self.__ids[start: end]
batch = {k: (self.processed_sentences[k], self.get_samples_by_sentence_id(k)) for k in indices}
if end == self.num_examples:
self.epochs_completed += 1
self._index_in_epoch = 0
return batch, self._index_in_epoch == 0
def next_batch_singles(self, batch_size):
'''
return the next batch of (sentence, related samples) pairs.
also habdles the logic of moving between epochs.
'''
if batch_size <= 0 or batch_size > self.num_single_examples:
raise ValueError("invalid argument for batch size: {}".format(batch_size))
if self._index_in_epoch == 0:
np.random.shuffle(self.sentences_quardpled_ids) # shuffle index
start = self._index_in_epoch
# go to the next batch
if start + batch_size > self.num_single_examples:
batch_size = self.num_single_examples - start
self._index_in_epoch += batch_size
end = self._index_in_epoch
indices = self.sentences_quardpled_ids[start: end]
batch = {k: (self.processed_sentences_singles[k], self.get_single_sample_for_sentence(k)) for k in indices}
if end == self.num_single_examples:
self.epochs_completed += 1
self._index_in_epoch = 0
return batch, self._index_in_epoch == 0
class DataSetForSupervised:
"""
a simpler version of the class above, used only for the supervised learning
"""
def __init__(self, path):
self.__ids = []
self.examples = []
self._index_in_epoch = 0
self.epochs_completed = 0
self.num_examples = 0
self.get_supervised_data(path)
def get_supervised_data(self, path):
sents = pickle.load(open(path, 'rb'))
self.num_examples = len(sents)
self.__ids = [x for x in range(len(sents))]
self.examples = sents
def next_batch(self, batch_size):
start = self._index_in_epoch
if start == 0:
np.random.shuffle(self.__ids) # shuffle index
# go to the next batch
elif start + batch_size > self.num_examples:
self.epochs_completed += 1
self._index_in_epoch = 0
return self.next_batch(batch_size)
self._index_in_epoch += batch_size
end = self._index_in_epoch
indices = self.__ids[start: end]
return [(self.examples[k][0], self.examples[k][1]) for k in indices]
def load_functions(filename):
"""
loads from file a dictionary of all valid tokens in the formal language we use.
each token is is defines by its name, its return types, and its argument types.
tokens that represent known entities, like ALL_BOXES, or Color.BLUE are treated as
functions that take no arguments, and their return type is their own type, i.e.
set<set<Item>>, and set<Color>, rspectively.
"""
functions_dict = {}
with open(filename) as functions_file:
for i, line in enumerate(functions_file):
if line.isspace():
continue
line = line.strip()
if line.startswith('#'):
continue
entry = line.split()
split_idx = entry.index(':') if ':' in entry else len(entry)
entry, necessary_words = entry[:split_idx], entry[split_idx:]
if len(entry) < 3 or not entry[1].isdigit() or int(entry[1]) != len(entry) - 3:
print("could not parse function in line {0}: {1}".format(i, line))
# should use Warning instead
continue
token, return_type, args_types = entry[0], entry[-1], entry[2:-1]
functions_dict[token] = TokenTypes(return_type=return_type, args_types=args_types,
necessity=necessary_words)
if SENTENCE_DRIVEN_CONSTRAINTS_ON_BEAM_SEARCH:
functions_dict['1'] = TokenTypes(return_type='int', args_types=[], necessity=['1', 'one', 'a'])
functions_dict.update(
{str(i): TokenTypes(return_type='int', args_types=[], necessity=[str(i)]) for i in range(2, 10)})
else:
functions_dict['1'] = TokenTypes(return_type='int', args_types=[], necessity=[])
functions_dict.update(
{str(i): TokenTypes(return_type='int', args_types=[], necessity=[]) for i in range(2, 10)})
return functions_dict