Skip to content

Commit

Permalink
Add mean row aggregation to HELM summarize (#2997)
Browse files Browse the repository at this point in the history
  • Loading branch information
farzaank authored Sep 25, 2024
1 parent 9bae33b commit fce1e5f
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 15 deletions.
3 changes: 3 additions & 0 deletions src/helm/benchmark/presentation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ class MetricGroup(Field):
hide_win_rates: Optional[bool] = None
"""If set to true, do not compute win rates."""

aggregation_strategies: Optional[List[str]] = None
"""List with values in {'win_rate','mean'} that correspond to aggregations"""


BY_METRIC = "by_metric"
BY_GROUP = "by_group"
Expand Down
99 changes: 84 additions & 15 deletions src/helm/benchmark/presentation/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,39 @@ def compute_aggregate_row_win_rates(table: Table, aggregation: str = "mean") ->
return aggregate_win_rates


def compute_aggregate_row_means(table: Table) -> List[Optional[float]]:
"""
Computes the aggregate mean of each row across columns.
Returns a list of means, one per row, with None if a row was never meaningfully comparable (i.e., all
non-null values of the row are in columns we skip).
"""

row_means: List[Optional[float]] = []

# check for all header cells where specified, that lower_is_better is consistent
orderings = []
for elem in table.header:
orderings.append(elem.lower_is_better)
if len(set(orderings)) != 1:
raise Exception("Cannot mean columns with different values for lower_is_better")

for row in table.rows:
total = 0.0
count = 0
for cell in row:
if cell.value is not None:
total += float(cell.value)
count += 1
if count == 0:
row_means.append(None)
else:
row_means.append(total / count)

return row_means


AGGREGATE_WIN_RATE_COLUMN = 1
AGGREGATION_STRATEGIES = ["mean", "win_rate"]


class Summarizer:
Expand Down Expand Up @@ -891,6 +923,7 @@ def create_group_table(
sub_split: Optional[str] = None,
bold_columns: bool = True,
add_win_rate: bool = False,
aggregation_strategies: List[str] = [],
) -> Table:
"""
Create a table for where each row is an adapter (for which we have a set of runs) and columns are pairs of
Expand Down Expand Up @@ -1073,21 +1106,53 @@ def _adapter_spec_sort_key(spec):

table = Table(title=title, header=header, rows=rows, links=links, name=name)

if add_win_rate:
# add overall win rate as the second column
WIN_RATE_AGGREGATION = "mean"
win_rates = compute_aggregate_row_win_rates(table, aggregation=WIN_RATE_AGGREGATION)
description = "How many models this model outperform on average (over columns)."
table.header.insert(
AGGREGATE_WIN_RATE_COLUMN,
HeaderCell(
f"{WIN_RATE_AGGREGATION.capitalize()} win rate",
description=description,
lower_is_better=False,
),
)
for row, win_rate in zip(table.rows, win_rates):
row.insert(AGGREGATE_WIN_RATE_COLUMN, Cell(win_rate))
if aggregation_strategies is None:
aggregation_strategies = ["win_rate"]

# this preserves backwards compatibility for self.schema.name_to_metric_group[metric_group].hide_win_rates
# hide_win_rate is the inverse of add_win_rate here (see the function call for create_group_table)
hide_aggregation = not add_win_rate
if hide_aggregation:
aggregation_strategies = []

aggregate_header_cells: List[HeaderCell] = []
aggregate_row_values: List[List[Optional[float]]] = []

for strategy in aggregation_strategies:
if strategy == "win_rate":
WIN_RATE_AGGREGATION = "mean"
win_rates = compute_aggregate_row_win_rates(table, aggregation=WIN_RATE_AGGREGATION)
description = "How many models this model outperforms on average (over columns)."
aggregate_header_cells.append(
HeaderCell(
f"{WIN_RATE_AGGREGATION.capitalize()} win rate",
description=description,
lower_is_better=False,
)
)
aggregate_row_values.append(win_rates)
elif strategy == "mean":
means = compute_aggregate_row_means(table)
description = "An average over columns representing the mean performance."
aggregate_header_cells.append(
HeaderCell(
"Mean performance",
description=description,
lower_is_better=table.header[0].lower_is_better,
)
)
aggregate_row_values.append(means)
else:
raise Exception(
f"Unknown aggregation strategy found: {strategy}. Please use one of: {AGGREGATION_STRATEGIES}"
)

for i in range(len(aggregate_header_cells)):
aggregate_header_cell = aggregate_header_cells[i]
aggregate_rows = aggregate_row_values[i]
table.header.insert(i + 1, aggregate_header_cell)
for row, row_val in zip(table.rows, aggregate_rows):
row.insert(i + 1, Cell(row_val))

if bold_columns:
for i, header_cell in enumerate(table.header):
Expand Down Expand Up @@ -1136,13 +1201,17 @@ def create_group_tables_by_metric_group(self, group: RunGroup) -> List[Table]:
if len(adapter_to_runs) > 0:
for metric_group in all_metric_groups:
display_name = self.schema.name_to_metric_group[metric_group].get_short_display_name()
aggregate_strategies: List[str] = (
self.schema.name_to_metric_group[metric_group].aggregation_strategies or []
)
table = self.create_group_table(
name=metric_group,
title=display_name,
adapter_to_runs=adapter_to_runs,
columns=[(subgroup, metric_group) for subgroup in subgroups],
is_scenario_table=False,
add_win_rate=not self.schema.name_to_metric_group[metric_group].hide_win_rates,
aggregation_strategies=aggregate_strategies,
)
tables.append(table)
return tables
Expand Down
3 changes: 3 additions & 0 deletions src/helm/benchmark/static/schema_safety.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ perturbations: []
metric_groups:
- name: accuracy
display_name: Accuracy
aggregation_strategies:
- win_rate
- mean
metrics:
- name: ${main_name}
split: ${main_split}
Expand Down

0 comments on commit fce1e5f

Please sign in to comment.