From 3b1466420548d01cdb309729240f1e234048a8f7 Mon Sep 17 00:00:00 2001 From: Junwen Date: Sat, 12 Feb 2022 20:54:29 -0800 Subject: [PATCH] handle nan+missing metrics --- .../results_preprocessors/average.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/python/ray/train/callbacks/results_preprocessors/average.py b/python/ray/train/callbacks/results_preprocessors/average.py index 449c9e1f64c4..68b8176c51d8 100644 --- a/python/ray/train/callbacks/results_preprocessors/average.py +++ b/python/ray/train/callbacks/results_preprocessors/average.py @@ -1,3 +1,4 @@ +import logging from typing import Dict, List, Tuple import numpy as np @@ -5,6 +6,9 @@ from ray.train.callbacks.results_preprocessors.preprocessor import ResultsPreprocessor +logger = logging.getLogger(__name__) + + class AverageResultsPreprocessor(ResultsPreprocessor): """A preprocessor that average training metrics from all workers. @@ -13,7 +17,7 @@ class AverageResultsPreprocessor(ResultsPreprocessor): The key is the metric to be averaged across all workers. The value is the magic key in the results that will be used as weights. If the value is None, the weight will be taken to be equal for all - workers. + workers. Both key and value should be reported using `train.report()`. """ VALID_SUMMARY_TYPES: Tuple[type] = ( @@ -44,31 +48,35 @@ def preprocess(self, results: List[Dict] = None) -> List[Dict]: if len(results) == 0 or len(self.metrics_to_average) == 0: return results + reported_metrics = set(results[0].keys()) average_metrics = {} for metrics, weight in self.metrics_to_average.items(): - if not isinstance(results[0][metrics], self.VALID_SUMMARY_TYPES): + if metrics not in reported_metrics: + logger.warning( + f"`{metrics}` is not reported from workers, so it is ignored. " + "Please make sure that it is saved using `train.report()`." + ) continue metrics_from_workers = np.array( - [result[metrics] for result in results if not np.isnan(result[metrics])] + [result.get(metrics, np.nan) for result in results] ) if weight: weights_from_workers = np.array( - [ - result[weight] - for result in results - if not np.isnan(result[metrics]) - ] + [result.get(weight, np.nan) for result in results] ) else: # if no weight is provided, equal weight will be used. + logger.warning( + f"No weight is provided for `{metrics}`. Use equal weight instead." + ) weights_from_workers = np.array([1] * len(metrics_from_workers)) average_metrics["_average_" + metrics] = np.nanmean( metrics_from_workers * weights_from_workers - / np.sum(weights_from_workers) + / np.nansum(weights_from_workers) ) for result in results: