-
Notifications
You must be signed in to change notification settings - Fork 0
/
output_aggregate_metrics_both.py
118 lines (99 loc) · 5.11 KB
/
output_aggregate_metrics_both.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import json
import cattrs
import pandas as pd
import numpy
from nltk import ngrams
from collections import defaultdict
from typing import List, Tuple, Any
from dataclasses import dataclass
from data_overlap_spec import AggregateOverlapMetric, AggregateDataOverlapKey, MetricProtocolSpec, PartialOverlapSpec, FrequencySpec, EntryOverlapMetric
from compute_data_overlap_metrics import load_light_scenarios_from_jsonl
from common.util import get_tokenizer
from common.general import asdict_without_nones
def scenario_spec_to_class(scenario_spec) -> str:
return f"{'.'.join(scenario_spec.class_name.split('.')[-1:])}"
PART_INPUT: str = "input"
PART_REF: str = "reference"
metric_protocol_specs_list = [
MetricProtocolSpec(PartialOverlapSpec.binary, FrequencySpec(0, False)),
MetricProtocolSpec(PartialOverlapSpec.jaccard, FrequencySpec(0, False)),
MetricProtocolSpec(PartialOverlapSpec.jaccard, FrequencySpec(0, True)),
MetricProtocolSpec(PartialOverlapSpec.token, FrequencySpec(0, False)),
MetricProtocolSpec(PartialOverlapSpec.token, FrequencySpec(0, True)),
MetricProtocolSpec(PartialOverlapSpec.binary, FrequencySpec(10, False)),
MetricProtocolSpec(PartialOverlapSpec.jaccard, FrequencySpec(10, False)),
MetricProtocolSpec(PartialOverlapSpec.jaccard, FrequencySpec(10, True)),
MetricProtocolSpec(PartialOverlapSpec.token, FrequencySpec(10, False)),
MetricProtocolSpec(PartialOverlapSpec.token, FrequencySpec(10, True))
]
non_weighted_metrics = [
MetricProtocolSpec(PartialOverlapSpec.binary, FrequencySpec(0, False)),
MetricProtocolSpec(PartialOverlapSpec.jaccard, FrequencySpec(0, False)),
MetricProtocolSpec(PartialOverlapSpec.token, FrequencySpec(0, False)),
]
def aggregate_metrics(path, out_path):
overlap_metrics_jsons = open(path, "r").readlines()
entry_overlap_metric_list = []
for entry_overlap_metric_json in overlap_metrics_jsons:
entry_overlap_metric_dict = json.loads(entry_overlap_metric_json)
entry_overlap_metric_list.append(cattrs.structure(entry_overlap_metric_dict, EntryOverlapMetric))
# Initialize a new dictionary for aggregated scores
check_score_dict = {}
aggregate_score_dict = {}
for entry_overlap_metric in entry_overlap_metric_list:
# Extract necessary information
stats_key = entry_overlap_metric.entry_data_overlap_key.stats_key
part = entry_overlap_metric.entry_data_overlap_key.part
instance_id = entry_overlap_metric.entry_data_overlap_key.instance_id
metric_protocol_spec = entry_overlap_metric.overlap_metric.metric_protocol_spec
metric_score = entry_overlap_metric.overlap_metric.metric_score
if metric_protocol_spec not in non_weighted_metrics:
continue
other_part = 'input'
if part == 'input':
other_part = 'references'
check_key= (stats_key, other_part, instance_id, metric_protocol_spec)
if check_key not in check_score_dict:
if stats_key.light_scenario_key.scenario_spec.class_name == 'helm.benchmark.scenarios.copyright_scenario.CopyrightScenario':
continue
curr_key =(stats_key, part, instance_id, metric_protocol_spec)
check_score_dict[curr_key] = metric_score
else:
agg_key = (stats_key, part, metric_protocol_spec)
other_key = (stats_key, other_part, metric_protocol_spec)
other_metric_score = check_score_dict[check_key]
# Initialize or append the metric score
if agg_key not in aggregate_score_dict:
aggregate_score_dict[agg_key] = [metric_score]
aggregate_score_dict[other_key] = [other_metric_score]
else:
aggregate_score_dict[agg_key].append(metric_score)
aggregate_score_dict[other_key].append(other_metric_score)
# Convert the aggregated data to AggregateOverlapMetric objects
aggregate_overlap_metrics = []
for (stats_key, part, metric_protocol_spec), scores in aggregate_score_dict.items():
aggregate_key = AggregateDataOverlapKey(
stats_key=stats_key,
part=part
)
aggregate_overlap_metrics.append(
AggregateOverlapMetric(
aggregate_data_overlap_key=aggregate_key,
metric_scores=scores,
metric_protocol_spec=metric_protocol_spec
)
)
def save_metrics_to_jsonl(overlap_metrics: List[AggregateOverlapMetric], filename: str):
with open(filename, "w") as f:
for overlap_metric in overlap_metrics:
f.write(json.dumps(asdict_without_nones(overlap_metric), ensure_ascii=False) + "\n")
save_metrics_to_jsonl(aggregate_overlap_metrics, out_path)
def get_args() -> Any:
parser = argparse.ArgumentParser()
parser.add_argument("--metrics-path", type=str, required=True, help="Path to your metrics")
parser.add_argument("--out-path", type=str, required=True, help="Path to the output metrics file")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
aggregate_metrics(args.metrics_path, args.out_path)