-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Train] Add support for metrics aggregation (#22099)
This PR allows users to aggregate metrics returned from all workers.
- Loading branch information
Showing
6 changed files
with
588 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
11 changes: 11 additions & 0 deletions
11
python/ray/train/callbacks/results_preprocessors/aggregate/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from ray.train.callbacks.results_preprocessors.aggregate.aggregate_preprocessor import ( | ||
AverageResultsPreprocessor, | ||
MaxResultsPreprocessor, | ||
WeightedAverageResultsPreprocessor, | ||
) | ||
|
||
__all__ = [ | ||
"AverageResultsPreprocessor", | ||
"MaxResultsPreprocessor", | ||
"WeightedAverageResultsPreprocessor", | ||
] |
100 changes: 100 additions & 0 deletions
100
python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_fn.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import abc | ||
from typing import Dict, List, Union, Optional | ||
|
||
import numpy as np | ||
|
||
from ray.train.callbacks.results_preprocessors.aggregate.aggregate_utils import ( | ||
VALID_AGGREGATE_TYPES, | ||
_get_weights_from_results, | ||
) | ||
|
||
|
||
class AggregateFn(abc.ABC): | ||
"""An abstract class for aggregation function.""" | ||
|
||
def __call__( | ||
self, values: List[Union[VALID_AGGREGATE_TYPES]] | ||
) -> Union[VALID_AGGREGATE_TYPES]: | ||
"""Perform the aggregation of values when being called. | ||
Args: | ||
values (List[Union[VALID_AGGREGATE_TYPES]]): A list of | ||
values returned from workers. The length of the list | ||
is expected to be equal to the number of workers. | ||
Returns: | ||
A single value that should logically be some form of aggregation | ||
of the values from each worker in the ``values`` list. | ||
""" | ||
raise NotImplementedError | ||
|
||
def prepare(self, results: List[Dict]) -> None: | ||
"""Perform some preparation work before aggregation. | ||
Unlike ``__call__``, this method is not called separately | ||
for each metric, but is only called once for preparation | ||
before aggregation begins. Any logic that does not need to | ||
be called for each metric should be placed in this method. | ||
""" | ||
pass | ||
|
||
def wrap_key(self, key) -> str: | ||
"""Get a string representation of the aggregation.""" | ||
return str(self) + f"({key})" | ||
|
||
|
||
class Average(AggregateFn): | ||
"""Average aggregation class.""" | ||
|
||
def __call__( | ||
self, values: List[Union[VALID_AGGREGATE_TYPES]] | ||
) -> Union[VALID_AGGREGATE_TYPES]: | ||
# A numpy runtime warning will be thrown if values | ||
# is a list of all ``np.nan``. | ||
return np.nanmean(values) | ||
|
||
def __repr__(self) -> str: | ||
return "avg" | ||
|
||
|
||
class Max(AggregateFn): | ||
"""Maximum aggregation class.""" | ||
|
||
def __call__( | ||
self, values: List[Union[VALID_AGGREGATE_TYPES]] | ||
) -> Union[VALID_AGGREGATE_TYPES]: | ||
# A numpy runtime warning will be thrown if values | ||
# is a list of all ``np.nan``. | ||
return np.nanmax(values) | ||
|
||
def __repr__(self) -> str: | ||
return "max" | ||
|
||
|
||
class WeightedAverage(AggregateFn): | ||
"""Weighted average aggregation class. | ||
Args: | ||
weight_key (Optional[str]): A key string that specifies | ||
the average weight to be used. If it is None, then | ||
equal weight will be used. | ||
""" | ||
|
||
def __init__(self, weight_key: Optional[str] = None): | ||
self.weight_key = weight_key | ||
self.weights = None | ||
|
||
def __call__( | ||
self, values: List[Union[VALID_AGGREGATE_TYPES]] | ||
) -> Union[VALID_AGGREGATE_TYPES]: | ||
return np.nansum( | ||
np.array(values) | ||
* self.weights | ||
/ np.nansum(self.weights * (1 - np.isnan(values))) | ||
) | ||
|
||
def __repr__(self) -> str: | ||
return f"weight_avg_{self.weight_key}" | ||
|
||
def prepare(self, results: List[Dict]): | ||
self.weights = _get_weights_from_results(self.weight_key, results) |
154 changes: 154 additions & 0 deletions
154
python/ray/train/callbacks/results_preprocessors/aggregate/aggregate_preprocessor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import logging | ||
from typing import Dict, List, Optional | ||
|
||
from ray.util.annotations import DeveloperAPI | ||
from ray.train.callbacks.results_preprocessors.preprocessor import ResultsPreprocessor | ||
from ray.train.callbacks.results_preprocessors.aggregate.aggregate_fn import ( | ||
AggregateFn, | ||
Average, | ||
Max, | ||
WeightedAverage, | ||
) | ||
from ray.train.callbacks.results_preprocessors.aggregate.aggregate_utils import ( | ||
_get_metrics_from_results, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@DeveloperAPI | ||
class AggregateResultsPreprocessor(ResultsPreprocessor): | ||
"""A preprocessor that aggregates training metrics from all workers. | ||
Args: | ||
aggregation_fn (AggregateFn): | ||
An aggregation method that performs the aggregation on results. | ||
keys (Optional[List[str]]): | ||
A list of keys reported in results to be aggregated. Keys should | ||
be saved using ``train.report()``. If a key is invalid or not | ||
reported, it will be ignored. | ||
""" | ||
|
||
def __init__(self, aggregation_fn: AggregateFn, keys: Optional[List[str]] = None): | ||
self.aggregate_fn = aggregation_fn | ||
self.keys = keys | ||
|
||
def preprocess(self, results: Optional[List[Dict]] = None) -> Optional[List[Dict]]: | ||
"""Aggregate results before sending them to callbacks. | ||
A key will be ignored if one of the following occurs: | ||
1. No worker reports it. | ||
2. The values returned from all workers are invalid. | ||
The aggregation WILL be performed even if some but not all | ||
workers report the key with valid values. The aggregation | ||
will only applied to those that report the key. | ||
Args: | ||
results (Optional[List[Dict]]): A list of results | ||
from all workers. The metrics specified in ``keys`` | ||
will be averaged according to ``aggregation_fn``. | ||
Returns: | ||
An updated results list that has aggregated results and | ||
is of the same length as the input list. | ||
""" | ||
if results is None or len(results) == 0: | ||
return results | ||
|
||
self.aggregate_fn.prepare(results) | ||
|
||
keys_to_aggregate = ( | ||
self.keys | ||
if self.keys | ||
else {key for result in results for key in result.keys()} | ||
) | ||
|
||
aggregated_results = {} | ||
|
||
for key in keys_to_aggregate: | ||
values = _get_metrics_from_results(key, results) | ||
if values: | ||
aggregated_results[self.aggregate_fn.wrap_key(key)] = self.aggregate_fn( | ||
values | ||
) | ||
|
||
# Currently we directly update each result dict with aggregated results. | ||
for result in results: | ||
result.update(aggregated_results) | ||
|
||
return results | ||
|
||
|
||
class AverageResultsPreprocessor(AggregateResultsPreprocessor): | ||
"""A preprocessor that averages results with equal weight. | ||
.. code-block:: python | ||
preprocessor = AverageResultsPreprocessor(keys=["loss", "accuracy"]) | ||
update_results = preprocessor.preprocess(results) | ||
Args: | ||
keys (Optional[List[str]]): A list of metrics to be averaged. | ||
If None is specified, then the list will be populated by | ||
reported keys whose value type is valid, that is, one of | ||
``VALID_AGGREGATE_TYPES``. | ||
Returns: | ||
An updated results list with average values. | ||
""" | ||
|
||
def __init__(self, keys: Optional[List[str]] = None): | ||
super().__init__(Average(), keys) | ||
|
||
|
||
class MaxResultsPreprocessor(AggregateResultsPreprocessor): | ||
"""A preprocessor that computes maximum values of specified keys. | ||
.. code-block:: python | ||
preprocessor = MaxResultsPreprocessor(keys=["loss", "accuracy"]) | ||
update_results = preprocessor.preprocess(results) | ||
Args: | ||
keys (Optional[List[str]]): A list of metrics upon which the | ||
maximum value will be taken. If None is specified, then | ||
the list will be populated by reported keys whose value type | ||
is valid, that is, one of ``VALID_AGGREGATE_TYPES``. | ||
Returns: | ||
An updated results list with maximum values. | ||
""" | ||
|
||
def __init__(self, keys: Optional[List[str]] = None): | ||
super().__init__(Max(), keys) | ||
|
||
|
||
class WeightedAverageResultsPreprocessor(AggregateResultsPreprocessor): | ||
"""A preprocessor that performs weighted average over metrics. | ||
.. code-block:: python | ||
preprocessor = WeightedAverageResultsPreprocessor(keys=["loss", "accuracy"], | ||
weight_key="batch_size") | ||
update_results = preprocessor.preprocess(results) | ||
Args: | ||
keys (Optional[List[str]]): A list of metrics to be averaged. | ||
If None is specified, then the list will be populated by | ||
reported keys whose value type is valid, that is, one of | ||
``VALID_AGGREGATE_TYPES``. | ||
weight_key (Optional[str]): A a key from reported metrics that | ||
will be used as the weight in averaging. If None is specified, | ||
then equal weight will be used. | ||
Returns: | ||
An updated results list with weighted average results. | ||
""" | ||
|
||
def __init__( | ||
self, keys: Optional[List[str]] = None, weight_key: Optional[str] = None | ||
): | ||
super().__init__(WeightedAverage(weight_key), keys) |
Oops, something went wrong.