Skip to content

Commit

Permalink
Remove stats with NaN values (#1784)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai authored Aug 16, 2023
1 parent 5f9794c commit a2f2b68
Showing 1 changed file with 36 additions and 2 deletions.
38 changes: 36 additions & 2 deletions src/helm/benchmark/runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import dacite
import json
import math
import os
import traceback
import typing
from collections import Counter
import dataclasses
from dataclasses import dataclass, field
from typing import Any, Dict, List

Expand Down Expand Up @@ -70,6 +72,37 @@ def __post_init__(self):
object.__setattr__(self, "name", self.name.replace(os.path.sep, "_"))


def remove_stats_nans(stats: List[Stat]) -> List[Stat]:
"""Return a new list of stats with stats with NaNs removed.
Python's stdlib json.dumps() will produce invalid JSON when serializing a NaN. See:
- https://github.com/stanford-crfm/helm/issues/1765
- https://bugs.python.org/issue40633
- https://docs.python.org/3/library/json.html#infinite-and-nan-number-values"""
result: List[Stat] = []
for stat in stats:
if math.isnan(stat.sum):
hlog(f"WARNING: Removing stat {stat.name.name} because its value is NaN")
continue
result.append(stat)
return result


def remove_per_instance_stats_nans(per_instance_stats_list: List[PerInstanceStats]) -> List[PerInstanceStats]:
"""Return a new list of PerInstanceStats with stats with NaNs removed.
Python's stdlib json.dumps() will produce invalid JSON when serializing a NaN. See:
- https://github.com/stanford-crfm/helm/issues/1765
- https://bugs.python.org/issue40633
- https://docs.python.org/3/library/json.html#infinite-and-nan-number-values"""
result: List[PerInstanceStats] = []
for per_instance_stats in per_instance_stats_list:
result.append(dataclasses.replace(per_instance_stats, stats=remove_stats_nans(per_instance_stats.stats)))
return result


class Runner:
"""
The main entry point for running the entire benchmark. Mostly just
Expand Down Expand Up @@ -257,11 +290,12 @@ def run_one(self, run_spec: RunSpec):
write(os.path.join(run_path, "scenario_state.json"), json.dumps(asdict_without_nones(scenario_state), indent=2))

write(
os.path.join(run_path, "stats.json"), json.dumps([asdict_without_nones(stat) for stat in stats], indent=2)
os.path.join(run_path, "stats.json"),
json.dumps([asdict_without_nones(stat) for stat in remove_stats_nans(stats)], indent=2),
)
write(
os.path.join(run_path, "per_instance_stats.json"),
json.dumps(list(map(asdict_without_nones, per_instance_stats)), indent=2),
json.dumps(list(map(asdict_without_nones, remove_per_instance_stats_nans(per_instance_stats))), indent=2),
)

cache_stats.print_status()

0 comments on commit a2f2b68

Please sign in to comment.