diff --git a/python/ray/train/callbacks/results_preprocessors/__init__.py b/python/ray/train/callbacks/results_preprocessors/__init__.py index 88fd5738c6a3..81bbfb3fb936 100644 --- a/python/ray/train/callbacks/results_preprocessors/__init__.py +++ b/python/ray/train/callbacks/results_preprocessors/__init__.py @@ -2,6 +2,11 @@ from ray.train.callbacks.results_preprocessors.keys import ( ExcludedKeysResultsPreprocessor, ) +from ray.train.callbacks.results_preprocessors.aggregate import ( + AverageResultsPreprocessor, + MaxResultsPreprocessor, + WeightedAverageResultsPreprocessor, +) from ray.train.callbacks.results_preprocessors.preprocessor import ( SequentialResultsPreprocessor, ResultsPreprocessor, @@ -12,4 +17,7 @@ "IndexedResultsPreprocessor", "ResultsPreprocessor", "SequentialResultsPreprocessor", + "AverageResultsPreprocessor", + "MaxResultsPreprocessor", + "WeightedAverageResultsPreprocessor", ] diff --git a/python/ray/train/callbacks/results_preprocessors/aggregate/__init__.py b/python/ray/train/callbacks/results_preprocessors/aggregate/__init__.py new file mode 100644 index 000000000000..e3e125b2fd75 --- /dev/null +++ b/python/ray/train/callbacks/results_preprocessors/aggregate/__init__.py @@ -0,0 +1,11 @@ +from ray.train.callbacks.results_preprocessors.aggregate.aggregate_preprocessor import ( + AverageResultsPreprocessor, + MaxResultsPreprocessor, + WeightedAverageResultsPreprocessor, +) + +__all__ = [ + "AverageResultsPreprocessor", + "MaxResultsPreprocessor", + "WeightedAverageResultsPreprocessor", +] diff --git a/python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_fn.py b/python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_fn.py new file mode 100644 index 000000000000..f22de03c161e --- /dev/null +++ b/python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_fn.py @@ -0,0 +1,100 @@ +import abc +from typing import Dict, List, Union, Optional + +import numpy as np + +from ray.train.callbacks.results_preprocessors.aggregate.aggregate_utils import ( + VALID_AGGREGATE_TYPES, + _get_weights_from_results, +) + + +class AggregateFn(abc.ABC): + """An abstract class for aggregation function.""" + + def __call__( + self, values: List[Union[VALID_AGGREGATE_TYPES]] + ) -> Union[VALID_AGGREGATE_TYPES]: + """Perform the aggregation of values when being called. + + Args: + values (List[Union[VALID_AGGREGATE_TYPES]]): A list of + values returned from workers. The length of the list + is expected to be equal to the number of workers. + + Returns: + A single value that should logically be some form of aggregation + of the values from each worker in the ``values`` list. + """ + raise NotImplementedError + + def prepare(self, results: List[Dict]) -> None: + """Perform some preparation work before aggregation. + + Unlike ``__call__``, this method is not called separately + for each metric, but is only called once for preparation + before aggregation begins. Any logic that does not need to + be called for each metric should be placed in this method. + """ + pass + + def wrap_key(self, key) -> str: + """Get a string representation of the aggregation.""" + return str(self) + f"({key})" + + +class Average(AggregateFn): + """Average aggregation class.""" + + def __call__( + self, values: List[Union[VALID_AGGREGATE_TYPES]] + ) -> Union[VALID_AGGREGATE_TYPES]: + # A numpy runtime warning will be thrown if values + # is a list of all ``np.nan``. + return np.nanmean(values) + + def __repr__(self) -> str: + return "avg" + + +class Max(AggregateFn): + """Maximum aggregation class.""" + + def __call__( + self, values: List[Union[VALID_AGGREGATE_TYPES]] + ) -> Union[VALID_AGGREGATE_TYPES]: + # A numpy runtime warning will be thrown if values + # is a list of all ``np.nan``. + return np.nanmax(values) + + def __repr__(self) -> str: + return "max" + + +class WeightedAverage(AggregateFn): + """Weighted average aggregation class. + + Args: + weight_key (Optional[str]): A key string that specifies + the average weight to be used. If it is None, then + equal weight will be used. + """ + + def __init__(self, weight_key: Optional[str] = None): + self.weight_key = weight_key + self.weights = None + + def __call__( + self, values: List[Union[VALID_AGGREGATE_TYPES]] + ) -> Union[VALID_AGGREGATE_TYPES]: + return np.nansum( + np.array(values) + * self.weights + / np.nansum(self.weights * (1 - np.isnan(values))) + ) + + def __repr__(self) -> str: + return f"weight_avg_{self.weight_key}" + + def prepare(self, results: List[Dict]): + self.weights = _get_weights_from_results(self.weight_key, results) diff --git a/python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_preprocessor.py b/python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_preprocessor.py new file mode 100644 index 000000000000..bd95a2100a1b --- /dev/null +++ b/python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_preprocessor.py @@ -0,0 +1,154 @@ +import logging +from typing import Dict, List, Optional + +from ray.util.annotations import DeveloperAPI +from ray.train.callbacks.results_preprocessors.preprocessor import ResultsPreprocessor +from ray.train.callbacks.results_preprocessors.aggregate.aggregate_fn import ( + AggregateFn, + Average, + Max, + WeightedAverage, +) +from ray.train.callbacks.results_preprocessors.aggregate.aggregate_utils import ( + _get_metrics_from_results, +) + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class AggregateResultsPreprocessor(ResultsPreprocessor): + """A preprocessor that aggregates training metrics from all workers. + + Args: + aggregation_fn (AggregateFn): + An aggregation method that performs the aggregation on results. + keys (Optional[List[str]]): + A list of keys reported in results to be aggregated. Keys should + be saved using ``train.report()``. If a key is invalid or not + reported, it will be ignored. + """ + + def __init__(self, aggregation_fn: AggregateFn, keys: Optional[List[str]] = None): + self.aggregate_fn = aggregation_fn + self.keys = keys + + def preprocess(self, results: Optional[List[Dict]] = None) -> Optional[List[Dict]]: + """Aggregate results before sending them to callbacks. + + A key will be ignored if one of the following occurs: + 1. No worker reports it. + 2. The values returned from all workers are invalid. + The aggregation WILL be performed even if some but not all + workers report the key with valid values. The aggregation + will only applied to those that report the key. + + Args: + results (Optional[List[Dict]]): A list of results + from all workers. The metrics specified in ``keys`` + will be averaged according to ``aggregation_fn``. + + Returns: + An updated results list that has aggregated results and + is of the same length as the input list. + """ + if results is None or len(results) == 0: + return results + + self.aggregate_fn.prepare(results) + + keys_to_aggregate = ( + self.keys + if self.keys + else {key for result in results for key in result.keys()} + ) + + aggregated_results = {} + + for key in keys_to_aggregate: + values = _get_metrics_from_results(key, results) + if values: + aggregated_results[self.aggregate_fn.wrap_key(key)] = self.aggregate_fn( + values + ) + + # Currently we directly update each result dict with aggregated results. + for result in results: + result.update(aggregated_results) + + return results + + +class AverageResultsPreprocessor(AggregateResultsPreprocessor): + """A preprocessor that averages results with equal weight. + + .. code-block:: python + + preprocessor = AverageResultsPreprocessor(keys=["loss", "accuracy"]) + update_results = preprocessor.preprocess(results) + + + Args: + keys (Optional[List[str]]): A list of metrics to be averaged. + If None is specified, then the list will be populated by + reported keys whose value type is valid, that is, one of + ``VALID_AGGREGATE_TYPES``. + + Returns: + An updated results list with average values. + """ + + def __init__(self, keys: Optional[List[str]] = None): + super().__init__(Average(), keys) + + +class MaxResultsPreprocessor(AggregateResultsPreprocessor): + """A preprocessor that computes maximum values of specified keys. + + .. code-block:: python + + preprocessor = MaxResultsPreprocessor(keys=["loss", "accuracy"]) + update_results = preprocessor.preprocess(results) + + + Args: + keys (Optional[List[str]]): A list of metrics upon which the + maximum value will be taken. If None is specified, then + the list will be populated by reported keys whose value type + is valid, that is, one of ``VALID_AGGREGATE_TYPES``. + + Returns: + An updated results list with maximum values. + """ + + def __init__(self, keys: Optional[List[str]] = None): + super().__init__(Max(), keys) + + +class WeightedAverageResultsPreprocessor(AggregateResultsPreprocessor): + """A preprocessor that performs weighted average over metrics. + + + .. code-block:: python + + preprocessor = WeightedAverageResultsPreprocessor(keys=["loss", "accuracy"], + weight_key="batch_size") + update_results = preprocessor.preprocess(results) + + Args: + keys (Optional[List[str]]): A list of metrics to be averaged. + If None is specified, then the list will be populated by + reported keys whose value type is valid, that is, one of + ``VALID_AGGREGATE_TYPES``. + weight_key (Optional[str]): A a key from reported metrics that + will be used as the weight in averaging. If None is specified, + then equal weight will be used. + + Returns: + An updated results list with weighted average results. + """ + + def __init__( + self, keys: Optional[List[str]] = None, weight_key: Optional[str] = None + ): + super().__init__(WeightedAverage(weight_key), keys) diff --git a/python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_utils.py b/python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_utils.py new file mode 100644 index 000000000000..25b1c4abf0a1 --- /dev/null +++ b/python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_utils.py @@ -0,0 +1,144 @@ +import logging +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + +from ray.util.debug import log_once + +VALID_AGGREGATE_TYPES: Tuple[type] = ( + int, + float, + np.float32, + np.float64, + np.int32, + np.int64, +) + +logger = logging.getLogger(__name__) + + +def _check_if_key_is_reported(key: str, results: List[Dict]) -> bool: + """Check if a particular key is reported by some workers. + + Args: + key (str): A key string. + results (List[Dict]): The results list returned from workers. + + Returns: + A boolean value. True if ``key`` exists in some worker's result dict. + Otherwise, False. + """ + return key in {key for result in results for key in result.keys()} + + +def _check_if_any_value_is_valid(key: str, results: List[Dict]) -> bool: + """Check if some values of ``key`` are valid types. + + Args: + key (str): A key string. + results (List[Dict]): The results list returned from workers. + + Returns: + A boolean value. True if some values of ``key`` are one of + ``VALID_AGGREGATE_TYPES``. Otherwise, False. + """ + values = [result.get(key, np.nan) for result in results] + return any(isinstance(value, VALID_AGGREGATE_TYPES) for value in values) + + +def _get_valid_values_from_results( + key: str, results: List[Dict] +) -> List[Union[VALID_AGGREGATE_TYPES]]: + """Get the list of values specified by ``key``. + + Args: + key (str): A key string. + results (List[Dict]): The results list returned from workers. + + Returns: + A list of values specified by ``key``. Invalid values are + replaced by ``np.nan``. This should be called after + ``_check_if_any_value_is_valid()`` returns True. + """ + return [ + ( + result.get(key, np.nan) + if isinstance(result.get(key, np.nan), VALID_AGGREGATE_TYPES) + else np.nan + ) + for result in results + ] + + +def _get_metrics_from_results( + key: str, results: List[Dict] +) -> Optional[List[Union[VALID_AGGREGATE_TYPES]]]: + """Return the metric values specified by ``key`` from each worker's result dict. + + Args: + key (str): A key string specifies the metric. + results (List[Dict]): The results list returned from workers. + + Returns: + A list of values for ``key`` from each worker. If ``key`` is + missing in every result dict, or if ``key`` is not a valid + type in each result dict, then it will return None. If some + workers report valid ``key`` values but other don't, a list + of values will still be returned and invalid values are + replaced by ``np.nan``. + """ + warning_message = None + if not _check_if_key_is_reported(key, results): + warning_message = ( + f"`{key}` is not reported from workers, so it is ignored. " + "Please make sure that it is saved using `train.report()`." + ) + elif not _check_if_any_value_is_valid(key, results): + warning_message = ( + f"`{key}` value type is not valid, so it is ignored. " + f"Make sure that its type is one of {VALID_AGGREGATE_TYPES}. " + ) + + if warning_message: + if log_once(key): + logger.warning(warning_message) + return None + + return _get_valid_values_from_results(key, results) + + +def _get_weights_from_results( + key: str, results: List[Dict] +) -> List[Union[VALID_AGGREGATE_TYPES]]: + """Return weight values specified by ``key`` from all workers. + + Args: + key (str): A key string specifies the weight metric. + results (List[Dict]): The results list returned from workers. + + Returns: + A list of valid weight values from each worker, if key exists. + Invalid values are replaced by ``np.nan`` and will be ignored + in the subsequent weighted average aggregation. If ``key`` + doesn't exist in every single result or its value from every + single worker is invalid, then equal weight will be used. + That is, a list of all ones. + """ + warning_message = None + if not _check_if_key_is_reported(key, results): + warning_message = ( + f"Averaging weight `{key}` is not reported " + "by all workers in `train.report()`. " + ) + elif not _check_if_any_value_is_valid(key, results): + warning_message = ( + f"Averaging weight `{key}` value type is not valid. " + f"Make sure that its type is one of {VALID_AGGREGATE_TYPES}. " + ) + + if warning_message: + if log_once(key): + logger.warning(warning_message + "Use equal weight instead.") + return [1] * len(results) + + return _get_valid_values_from_results(key, results) diff --git a/python/ray/train/tests/test_results_preprocessors.py b/python/ray/train/tests/test_results_preprocessors.py index 7cef2155b90b..77736ec9682c 100644 --- a/python/ray/train/tests/test_results_preprocessors.py +++ b/python/ray/train/tests/test_results_preprocessors.py @@ -1,7 +1,12 @@ +import pytest + from ray.train.callbacks.results_preprocessors import ( ExcludedKeysResultsPreprocessor, IndexedResultsPreprocessor, SequentialResultsPreprocessor, + AverageResultsPreprocessor, + MaxResultsPreprocessor, + WeightedAverageResultsPreprocessor, ) @@ -39,6 +44,172 @@ def test_sequential_results_preprocessor(): assert preprocessed_results == expected +def test_average_results_preprocessor(): + from copy import deepcopy + import numpy as np + + results = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}, {"a": 7, "b": 8}] + expected = deepcopy(results) + for res in expected: + res.update( + { + "avg(a)": np.mean([result["a"] for result in results]), + "avg(b)": np.mean([result["b"] for result in results]), + } + ) + + preprocessor = AverageResultsPreprocessor(["a", "b"]) + preprocessed_results = preprocessor.preprocess(results) + + assert preprocessed_results == expected + + +def test_max_results_preprocessor(): + from copy import deepcopy + import numpy as np + + results = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}, {"a": 7, "b": 8}] + expected = deepcopy(results) + for res in expected: + res.update( + { + "max(a)": np.max([result["a"] for result in results]), + "max(b)": np.max([result["b"] for result in results]), + } + ) + + preprocessor = MaxResultsPreprocessor(["a", "b"]) + preprocessed_results = preprocessor.preprocess(results) + + assert preprocessed_results == expected + + +def test_weighted_average_results_preprocessor(): + from copy import deepcopy + import numpy as np + + results = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}, {"a": 7, "b": 8}] + expected = deepcopy(results) + total_weight = np.sum([result["b"] for result in results]) + for res in expected: + res.update( + { + "weight_avg_b(a)": np.sum( + [result["a"] * result["b"] / total_weight for result in results] + ) + } + ) + + preprocessor = WeightedAverageResultsPreprocessor(["a"], "b") + preprocessed_results = preprocessor.preprocess(results) + + assert preprocessed_results == expected + + +@pytest.mark.parametrize( + ("results_preprocessor", "expected_value"), + [(AverageResultsPreprocessor, 2.0), (MaxResultsPreprocessor, 3.0)], +) +def test_warning_in_aggregate_results_preprocessors( + caplog, results_preprocessor, expected_value +): + import logging + from copy import deepcopy + from ray.util import debug + + caplog.at_level(logging.WARNING) + + results1 = [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}] + results2 = [{"a": 1}, {"a": "invalid"}, {"a": 3}, {"a": "invalid"}] + results3 = [{"a": "invalid"}, {"a": "invalid"}, {"a": "invalid"}, {"a": "invalid"}] + results4 = [{"a": 1}, {"a": 2}, {"a": 3}, {"c": 4}] + + # test case 1: metric key `b` is missing from all workers + results_preprocessor1 = results_preprocessor(["b"]) + results_preprocessor1.preprocess(results1) + assert "`b` is not reported from workers, so it is ignored." in caplog.text + + # test case 2: some values of key `a` have invalid data type + results_preprocessor2 = results_preprocessor(["a"]) + expected2 = deepcopy(results2) + aggregation_key = results_preprocessor2.aggregate_fn.wrap_key("a") + for res in expected2: + res.update({aggregation_key: expected_value}) + assert results_preprocessor2.preprocess(results2) == expected2 + + # test case 3: all key `a` values are invalid + results_preprocessor2.preprocess(results3) + assert "`a` value type is not valid, so it is ignored." in caplog.text + + # test case 4: some workers don't report key `a` + expected4 = deepcopy(results4) + aggregation_key = results_preprocessor2.aggregate_fn.wrap_key("a") + for res in expected4: + res.update({aggregation_key: expected_value}) + assert results_preprocessor2.preprocess(results4) == expected4 + + for record in caplog.records: + assert record.levelname == "WARNING" + + debug.reset_log_once("b") + debug.reset_log_once("a") + + +def test_warning_in_weighted_average_results_preprocessors(caplog): + import logging + from copy import deepcopy + + caplog.at_level(logging.WARNING) + + results1 = [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}] + results2 = [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}] + results3 = [ + {"a": 1, "c": 3}, + {"a": 2, "c": "invalid"}, + {"a": "invalid", "c": 1}, + {"a": 4, "c": "invalid"}, + ] + results4 = [ + {"a": 1, "c": "invalid"}, + {"a": 2, "c": "invalid"}, + {"a": 3, "c": "invalid"}, + {"a": 4, "c": "invalid"}, + ] + + # test case 1: weight key `b` is not reported from all workers + results_preprocessor1 = WeightedAverageResultsPreprocessor(["a"], "b") + expected1 = deepcopy(results1) + for res in expected1: + res.update({"weight_avg_b(a)": 2.5}) + assert results_preprocessor1.preprocess(results1) == expected1 + assert ( + "Averaging weight `b` is not reported by all workers in `train.report()`." + in caplog.text + ) + assert "Use equal weight instead." in caplog.text + + # test case 2: metric key `a` (to be averaged) is not reported from all workers + results_preprocessor1.preprocess(results2) + assert "`a` is not reported from workers, so it is ignored." in caplog.text + + # test case 3: both metric and weight keys have invalid data type + results_preprocessor2 = WeightedAverageResultsPreprocessor(["a"], "c") + expected3 = deepcopy(results3) + for res in expected3: + res.update({"weight_avg_c(a)": 1.0}) + assert results_preprocessor2.preprocess(results3) == expected3 + + # test case 4: all weight values are invalid + expected4 = deepcopy(results4) + for res in expected4: + res.update({"weight_avg_c(a)": 2.5}) + assert results_preprocessor2.preprocess(results4) == expected4 + assert "Averaging weight `c` value type is not valid." in caplog.text + + for record in caplog.records: + assert record.levelname == "WARNING" + + if __name__ == "__main__": import pytest import sys