-
Notifications
You must be signed in to change notification settings - Fork 0
/
output_aggregate_metrics.py
105 lines (87 loc) · 4.46 KB
/
output_aggregate_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import json
import cattrs
import pandas as pd
import numpy
from nltk import ngrams
from collections import defaultdict
from typing import List, Tuple, Any
from dataclasses import dataclass
from data_overlap_spec import AggregateOverlapMetric, AggregateDataOverlapKey, MetricProtocolSpec, PartialOverlapSpec, FrequencySpec, EntryOverlapMetric
from compute_data_overlap_metrics import load_light_scenarios_from_jsonl
from common.util import get_tokenizer
from common.general import asdict_without_nones
def scenario_spec_to_class(scenario_spec) -> str:
return f"{'.'.join(scenario_spec.class_name.split('.')[-1:])}"
PART_INPUT: str = "input"
PART_REF: str = "reference"
metric_protocol_specs_list = [
MetricProtocolSpec(PartialOverlapSpec.binary, FrequencySpec(0, False)),
MetricProtocolSpec(PartialOverlapSpec.jaccard, FrequencySpec(0, False)),
MetricProtocolSpec(PartialOverlapSpec.jaccard, FrequencySpec(0, True)),
MetricProtocolSpec(PartialOverlapSpec.token, FrequencySpec(0, False)),
MetricProtocolSpec(PartialOverlapSpec.token, FrequencySpec(0, True)),
MetricProtocolSpec(PartialOverlapSpec.binary, FrequencySpec(10, False)),
MetricProtocolSpec(PartialOverlapSpec.jaccard, FrequencySpec(10, False)),
MetricProtocolSpec(PartialOverlapSpec.jaccard, FrequencySpec(10, True)),
MetricProtocolSpec(PartialOverlapSpec.token, FrequencySpec(10, False)),
MetricProtocolSpec(PartialOverlapSpec.token, FrequencySpec(10, True))
]
non_weighted_metrics = [
MetricProtocolSpec(PartialOverlapSpec.binary, FrequencySpec(0, False)),
MetricProtocolSpec(PartialOverlapSpec.jaccard, FrequencySpec(0, False)),
MetricProtocolSpec(PartialOverlapSpec.token, FrequencySpec(0, False)),
]
def aggregate_metrics(path, out_path):
overlap_metrics_jsons = open(path, "r").readlines()
entry_overlap_metric_list = []
for entry_overlap_metric_json in overlap_metrics_jsons:
entry_overlap_metric_dict = json.loads(entry_overlap_metric_json)
entry_overlap_metric_list.append(cattrs.structure(entry_overlap_metric_dict, EntryOverlapMetric))
# Initialize a new dictionary for aggregated scores
aggregate_score_dict = {}
for entry_overlap_metric in entry_overlap_metric_list:
# Extract necessary information
stats_key = entry_overlap_metric.entry_data_overlap_key.stats_key
part = entry_overlap_metric.entry_data_overlap_key.part
metric_protocol_spec = entry_overlap_metric.overlap_metric.metric_protocol_spec
if metric_protocol_spec not in non_weighted_metrics:
continue
metric_score = entry_overlap_metric.overlap_metric.metric_score
# Define the aggregate key
agg_key = (stats_key, part, metric_protocol_spec)
# Initialize or append the metric score
if agg_key not in aggregate_score_dict:
if stats_key.light_scenario_key.scenario_spec.class_name == 'helm.benchmark.scenarios.copyright_scenario.CopyrightScenario':
continue
aggregate_score_dict[agg_key] = [metric_score]
else:
aggregate_score_dict[agg_key].append(metric_score)
# Convert the aggregated data to AggregateOverlapMetric objects
aggregate_overlap_metrics = []
for (stats_key, part, metric_protocol_spec), scores in aggregate_score_dict.items():
aggregate_key = AggregateDataOverlapKey(
stats_key=stats_key,
part=part
)
# if aggregate_key.stats_key.light_scenario_key.split == 'test':
aggregate_overlap_metrics.append(
AggregateOverlapMetric(
aggregate_data_overlap_key=aggregate_key,
metric_scores=scores,
metric_protocol_spec=metric_protocol_spec
)
)
def save_metrics_to_jsonl(overlap_metrics: List[AggregateOverlapMetric], filename: str):
with open(filename, "w") as f:
for overlap_metric in overlap_metrics:
f.write(json.dumps(asdict_without_nones(overlap_metric), ensure_ascii=False) + "\n")
save_metrics_to_jsonl(aggregate_overlap_metrics, out_path)
def get_args() -> Any:
parser = argparse.ArgumentParser()
parser.add_argument("--metrics-path", type=str, required=True, help="Path to your metrics")
parser.add_argument("--out-path", type=str, required=True, help="Path to the output metrics file")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
aggregate_metrics(args.metrics_path, args.out_path)