-
Notifications
You must be signed in to change notification settings - Fork 0
/
compute_metrics_from_ngrams.py
314 lines (264 loc) · 13.9 KB
/
compute_metrics_from_ngrams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import argparse
from typing import Any
import ast
import json
import cattrs
from nltk import ngrams
from collections import defaultdict
from typing import List, Tuple
from dataclasses import dataclass
from data_overlap_spec import EntryOverlapNgrams, EntryDataOverlapKey, FrequencySpec, PartialOverlapSpec, MetricProtocolSpec, OverlapMetric, EntryOverlapMetric
from light_scenario import ScenarioSpecInstanceIds
from compute_data_overlap_metrics import load_light_scenarios_from_jsonl
from common.util import get_tokenizer
from common.general import asdict_without_nones
def get_metrics(ngrams_path, scenario_path, out_path, filter_path, N):
scenario_spec_instance_id_dict = dict()
if filter_path:
scenario_spec_instance_ids_json = filter_path
scenario_spec_instance_ids_jsons = open(scenario_spec_instance_ids_json, "r").readlines()
for scenario_spec_instance_ids_json in scenario_spec_instance_ids_jsons:
scenario_spec_instance_ids_dict = json.loads(scenario_spec_instance_ids_json)
scenario_spec_instance_ids = cattrs.structure(scenario_spec_instance_ids_dict, ScenarioSpecInstanceIds)
scenario_spec_instance_id_dict[
scenario_spec_instance_ids.scenario_spec
] = scenario_spec_instance_ids.instance_ids
# Read Ngrams
ngram_jsons = open(ngrams_path, "r").readlines()
entry_overlap_ngrams_list = []
for ngram_json in ngram_jsons:
entry_overlap_ngrams = json.loads(ngram_json)
entry_overlap_ngrams = cattrs.structure(entry_overlap_ngrams, EntryOverlapNgrams)
scenario_spec = entry_overlap_ngrams.entry_data_overlap_key.stats_key.light_scenario_key.scenario_spec
if scenario_spec_instance_id_dict:
if scenario_spec not in scenario_spec_instance_id_dict:
continue
instance_ids = scenario_spec_instance_id_dict[scenario_spec]
if entry_overlap_ngrams.entry_data_overlap_key.instance_id not in instance_ids:
continue
else:
entry_overlap_ngrams_list.append(entry_overlap_ngrams)
else:
entry_overlap_ngrams_list.append(entry_overlap_ngrams)
def merge_entries(entry_overlap_ngrams_list):
overlapping_counts = defaultdict(int)
for entry_overlap_ngrams in entry_overlap_ngrams_list:
entry_data_overlap_key = entry_overlap_ngrams.entry_data_overlap_key
overlapping_ngram_counts = entry_overlap_ngrams.overlapping_ngram_counts
for ngram, count in overlapping_ngram_counts:
overlapping_counts[ngram] += count
overlapping_ngram_counts_list = []
for ngram, count in overlapping_counts.items():
overlapping_ngram_counts_list.append((ngram, count))
return [EntryOverlapNgrams(
entry_data_overlap_key=entry_data_overlap_key, overlapping_ngram_counts=overlapping_ngram_counts_list
)]
# create entry_overlap_ngrams_dict, a dict of entry_data_overlap_key -> EntryOverlapNgrams
entry_overlap_ngrams_dict = defaultdict(list)
for entry_overlap_ngrams in entry_overlap_ngrams_list:
entry_data_overlap_key = entry_overlap_ngrams.entry_data_overlap_key
overlapping_ngram_counts = entry_overlap_ngrams.overlapping_ngram_counts
ngram_count = entry_data_overlap_key.stats_key.overlap_protocol_spec.n
if ngram_count not in [N]:
continue
entry_overlap_ngrams_dict[entry_data_overlap_key].append(entry_overlap_ngrams)
# We need to merge entries if sharded by training data, since there'll be redundancy
# Can refactor to no list later
if len(entry_overlap_ngrams_dict[entry_data_overlap_key]) > 1:
entry_overlap_ngrams_dict[entry_data_overlap_key] = merge_entries(entry_overlap_ngrams_dict[entry_data_overlap_key])
# Read Scenarios
light_scenarios = load_light_scenarios_from_jsonl(scenario_path)
light_scenario_instance_dict = dict()
for light_scenario in light_scenarios:
instances = light_scenario.instances
instance_dict = dict()
for instance in instances:
instance_dict[instance.id] = instance
light_scenario_instance_dict[light_scenario.scenario_key] = instance_dict
def compute_binary_overlap(instance_str, overlapping_ngram_counts, tokenizer, frequency = 0):
"""
Compute binary overlap
If pass in frequency, include only the ngrams with count <= frequency
"""
tokens = tokenizer.tokenize(instance_str)
ngram_counts_dict = defaultdict(int)
# construct a dict of ngram -> count
for ngram, count in overlapping_ngram_counts:
ngram = tuple(ast.literal_eval(ngram))
ngram_counts_dict[ngram] = count
metric_score = 0
for ngram in ngrams(tokens, N):
count = ngram_counts_dict[ngram]
if frequency == 0 or count <= frequency:
if count != 0:
metric_score = 1
break
overlap_metric = OverlapMetric(
metric_score = metric_score,
metric_protocol_spec = MetricProtocolSpec(
partial_overlap_spec = PartialOverlapSpec.binary,
frequency_spec = FrequencySpec(
filter_value = frequency,
weighting = False
)
)
)
return overlap_metric
def compute_jaccard_overlap(instance_str, overlapping_ngram_counts, tokenizer, frequency = 0):
"""
Compute weighted and unweighted jaccard overlap
If pass in frequency, include only the ngrams with count <= frequency
"""
tokens = tokenizer.tokenize(instance_str)
ngram_counts_dict = defaultdict(int)
# construct a dict of ngram -> count
for ngram, count in overlapping_ngram_counts:
ngram = tuple(ast.literal_eval(ngram))
ngram_counts_dict[ngram] = count
total_ngram_count = 0
counts = 0
weighted_score = 0
for ngram in ngrams(tokens, N):
count = ngram_counts_dict[ngram]
if frequency == 0 or count <= frequency:
if count != 0:
counts += 1
weighted_score += 1 / count
total_ngram_count += 1
unweighted_score = counts / total_ngram_count
weighted_score = weighted_score / total_ngram_count
unweighted_overlap_metric = OverlapMetric(
metric_score = unweighted_score ,
metric_protocol_spec = MetricProtocolSpec(
partial_overlap_spec = PartialOverlapSpec.jaccard,
frequency_spec = FrequencySpec(
filter_value = frequency,
weighting = False
)
)
)
weighted_overlap_metric = OverlapMetric(
metric_score = weighted_score ,
metric_protocol_spec = MetricProtocolSpec(
partial_overlap_spec = PartialOverlapSpec.jaccard,
frequency_spec = FrequencySpec(
filter_value = frequency,
weighting = True
)
)
)
return unweighted_overlap_metric, weighted_overlap_metric
# Token overlap
def compute_token_overlap(instance_str, overlapping_ngram_counts, tokenizer, frequency = 0):
"""
Compute weighted and unweighted token overlap
If pass in frequency, include only the ngrams with count <= frequency
"""
tokens = tokenizer.tokenize(instance_str)
ngram_counts_dict = defaultdict(int)
# construct a dict of ngram -> count
for ngram, count in overlapping_ngram_counts:
ngram = tuple(ast.literal_eval(ngram))
ngram_counts_dict[ngram] = count
total_token_count = 0
counts = 0
weighted_score = 0
weight = 0
token_budget = 0
for ngram in ngrams(tokens, N):
curr_count = ngram_counts_dict[ngram]
# either no frequency, or check current count is less than frequency
# or a previous contiguous count (weight != 0) less than frequency
if frequency == 0 or curr_count <= frequency or (weight != 0 and weight <= frequency):
if curr_count != 0:
token_budget = N
if weight > 0:
weight = min(curr_count, weight)
else:
weight = curr_count
if token_budget > 0:
token_budget -= 1
counts += 1
weighted_score += 1 / weight
else:
weight = 0
total_token_count += 1
for token in ngram[1:]:
if token_budget > 0:
token_budget -= 1
counts += 1
weighted_score += 1 / weight
total_token_count += 1
unweighted_score = counts / total_token_count
weighted_score = weighted_score / total_token_count
unweighted_overlap_metric = OverlapMetric(
metric_score = unweighted_score ,
metric_protocol_spec = MetricProtocolSpec(
partial_overlap_spec = PartialOverlapSpec.token,
frequency_spec = FrequencySpec(
filter_value = frequency,
weighting = False
)
)
)
weighted_overlap_metric = OverlapMetric(
metric_score = weighted_score ,
metric_protocol_spec = MetricProtocolSpec(
partial_overlap_spec = PartialOverlapSpec.token,
frequency_spec = FrequencySpec(
filter_value = frequency,
weighting = True
)
)
)
return unweighted_overlap_metric, weighted_overlap_metric
def compute_and_add_metrics(instance_str, overlapping_ngram_counts, tokenizer, entry_data_overlap_key, entry_overlap_metric_list, frequency = 0):
overlap_metric = compute_binary_overlap(instance_str, overlapping_ngram_counts, tokenizer, frequency)
binary_metric = EntryOverlapMetric(entry_data_overlap_key=entry_data_overlap_key, overlap_metric=overlap_metric)
entry_overlap_metric_list.append(binary_metric)
unweighted_overlap_metric, weighted_overlap_metric = compute_jaccard_overlap(instance_str, overlapping_ngram_counts, tokenizer, frequency)
unweighted_jaccard = EntryOverlapMetric(entry_data_overlap_key=entry_data_overlap_key, overlap_metric=unweighted_overlap_metric)
weighted_jaccard = EntryOverlapMetric(entry_data_overlap_key=entry_data_overlap_key, overlap_metric=weighted_overlap_metric)
entry_overlap_metric_list.append(unweighted_jaccard)
entry_overlap_metric_list.append(weighted_jaccard)
unweighted_overlap_metric, weighted_overlap_metric = compute_token_overlap(instance_str, overlapping_ngram_counts, tokenizer, frequency)
unweighted_token = EntryOverlapMetric(entry_data_overlap_key=entry_data_overlap_key, overlap_metric=unweighted_overlap_metric)
weighted_token = EntryOverlapMetric(entry_data_overlap_key=entry_data_overlap_key, overlap_metric=weighted_overlap_metric)
entry_overlap_metric_list.append(unweighted_token)
entry_overlap_metric_list.append(weighted_token)
def save_metrics_to_jsonl(overlap_metrics: List[EntryOverlapMetric], filename: str):
with open(filename, "w") as f:
for overlap_metric in overlap_metrics:
f.write(json.dumps(asdict_without_nones(overlap_metric), ensure_ascii=False) + "\n")
entry_overlap_metric_list = []
tokenizer = get_tokenizer('default')
for entry_data_overlap_key, entry_overlap_ngrams_list in entry_overlap_ngrams_dict.items():
data_overlap_stats_key = entry_data_overlap_key.stats_key
light_scenario_key = data_overlap_stats_key.light_scenario_key
instance_dict = light_scenario_instance_dict[light_scenario_key]
for entry_overlap_ngrams in entry_overlap_ngrams_list:
entry_data_overlap_key = entry_overlap_ngrams.entry_data_overlap_key
instance_id = entry_data_overlap_key.instance_id
instance = instance_dict[instance_id]
part = entry_data_overlap_key.part
overlapping_ngram_counts = entry_overlap_ngrams.overlapping_ngram_counts
if part == 'input':
compute_and_add_metrics(instance.input, overlapping_ngram_counts, tokenizer, entry_data_overlap_key, entry_overlap_metric_list)
compute_and_add_metrics(instance.input, overlapping_ngram_counts, tokenizer, entry_data_overlap_key, entry_overlap_metric_list, frequency=10)
if part == 'references':
reference = ' '.join(instance.references)
compute_and_add_metrics(reference, overlapping_ngram_counts, tokenizer, entry_data_overlap_key, entry_overlap_metric_list)
compute_and_add_metrics(reference, overlapping_ngram_counts, tokenizer, entry_data_overlap_key, entry_overlap_metric_list, frequency=10)
save_metrics_to_jsonl(entry_overlap_metric_list, out_path)
def get_args() -> Any:
parser = argparse.ArgumentParser()
parser.add_argument("--ngrams-path", type=str, required=True, help="Path to your ngrams")
parser.add_argument("--scenario-path", type=str, required=True, help="Path to scenario data (benchmarking data)")
parser.add_argument("--out-path", type=str, required=True, help="Path to the output metrics file")
parser.add_argument("--filter-path", type=str, default='', help="Path to file for filtering a subset of tests")
parser.add_argument("--N", type=int, default=13, help="N of input ngrams")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
# Call the get_metrics function with the constructed arguments
get_metrics(args.ngrams_path, args.scenario_path, args.out_path, args.filter_path, args.N)