From a2f2b681cbc7b0438c6113dc29c25580dd4dcb61 Mon Sep 17 00:00:00 2001 From: Yifan Mai Date: Wed, 16 Aug 2023 10:20:01 -0700 Subject: [PATCH] Remove stats with NaN values (#1784) --- src/helm/benchmark/runner.py | 38 ++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/helm/benchmark/runner.py b/src/helm/benchmark/runner.py index 102d7b0f90..7466d7dfbc 100644 --- a/src/helm/benchmark/runner.py +++ b/src/helm/benchmark/runner.py @@ -1,9 +1,11 @@ import dacite import json +import math import os import traceback import typing from collections import Counter +import dataclasses from dataclasses import dataclass, field from typing import Any, Dict, List @@ -70,6 +72,37 @@ def __post_init__(self): object.__setattr__(self, "name", self.name.replace(os.path.sep, "_")) +def remove_stats_nans(stats: List[Stat]) -> List[Stat]: + """Return a new list of stats with stats with NaNs removed. + + Python's stdlib json.dumps() will produce invalid JSON when serializing a NaN. See: + + - https://github.com/stanford-crfm/helm/issues/1765 + - https://bugs.python.org/issue40633 + - https://docs.python.org/3/library/json.html#infinite-and-nan-number-values""" + result: List[Stat] = [] + for stat in stats: + if math.isnan(stat.sum): + hlog(f"WARNING: Removing stat {stat.name.name} because its value is NaN") + continue + result.append(stat) + return result + + +def remove_per_instance_stats_nans(per_instance_stats_list: List[PerInstanceStats]) -> List[PerInstanceStats]: + """Return a new list of PerInstanceStats with stats with NaNs removed. + + Python's stdlib json.dumps() will produce invalid JSON when serializing a NaN. See: + + - https://github.com/stanford-crfm/helm/issues/1765 + - https://bugs.python.org/issue40633 + - https://docs.python.org/3/library/json.html#infinite-and-nan-number-values""" + result: List[PerInstanceStats] = [] + for per_instance_stats in per_instance_stats_list: + result.append(dataclasses.replace(per_instance_stats, stats=remove_stats_nans(per_instance_stats.stats))) + return result + + class Runner: """ The main entry point for running the entire benchmark. Mostly just @@ -257,11 +290,12 @@ def run_one(self, run_spec: RunSpec): write(os.path.join(run_path, "scenario_state.json"), json.dumps(asdict_without_nones(scenario_state), indent=2)) write( - os.path.join(run_path, "stats.json"), json.dumps([asdict_without_nones(stat) for stat in stats], indent=2) + os.path.join(run_path, "stats.json"), + json.dumps([asdict_without_nones(stat) for stat in remove_stats_nans(stats)], indent=2), ) write( os.path.join(run_path, "per_instance_stats.json"), - json.dumps(list(map(asdict_without_nones, per_instance_stats)), indent=2), + json.dumps(list(map(asdict_without_nones, remove_per_instance_stats_nans(per_instance_stats))), indent=2), ) cache_stats.print_status()