Skip to content

Commit

Permalink
handle nan+missing metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
jwyyy committed Feb 13, 2022
1 parent b3fa427 commit 3b14664
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions python/ray/train/callbacks/results_preprocessors/average.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import logging
from typing import Dict, List, Tuple

import numpy as np

from ray.train.callbacks.results_preprocessors.preprocessor import ResultsPreprocessor


logger = logging.getLogger(__name__)


class AverageResultsPreprocessor(ResultsPreprocessor):
"""A preprocessor that average training metrics from all workers.
Expand All @@ -13,7 +17,7 @@ class AverageResultsPreprocessor(ResultsPreprocessor):
The key is the metric to be averaged across all workers. The value
is the magic key in the results that will be used as weights. If
the value is None, the weight will be taken to be equal for all
workers.
workers. Both key and value should be reported using `train.report()`.
"""

VALID_SUMMARY_TYPES: Tuple[type] = (
Expand Down Expand Up @@ -44,31 +48,35 @@ def preprocess(self, results: List[Dict] = None) -> List[Dict]:
if len(results) == 0 or len(self.metrics_to_average) == 0:
return results

reported_metrics = set(results[0].keys())
average_metrics = {}
for metrics, weight in self.metrics_to_average.items():

if not isinstance(results[0][metrics], self.VALID_SUMMARY_TYPES):
if metrics not in reported_metrics:
logger.warning(
f"`{metrics}` is not reported from workers, so it is ignored. "
"Please make sure that it is saved using `train.report()`."
)
continue

metrics_from_workers = np.array(
[result[metrics] for result in results if not np.isnan(result[metrics])]
[result.get(metrics, np.nan) for result in results]
)
if weight:
weights_from_workers = np.array(
[
result[weight]
for result in results
if not np.isnan(result[metrics])
]
[result.get(weight, np.nan) for result in results]
)
else:
# if no weight is provided, equal weight will be used.
logger.warning(
f"No weight is provided for `{metrics}`. Use equal weight instead."
)
weights_from_workers = np.array([1] * len(metrics_from_workers))

average_metrics["_average_" + metrics] = np.nanmean(
metrics_from_workers
* weights_from_workers
/ np.sum(weights_from_workers)
/ np.nansum(weights_from_workers)
)

for result in results:
Expand Down

0 comments on commit 3b14664

Please sign in to comment.