Skip to content

Commit

Permalink
[Train] Add support for metrics aggregation (#22099)
Browse files Browse the repository at this point in the history
This PR allows users to aggregate metrics returned from all workers.
  • Loading branch information
jwyyy authored Mar 8, 2022
1 parent c8aa6cd commit d1009c8
Show file tree
Hide file tree
Showing 6 changed files with 588 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/ray/train/callbacks/results_preprocessors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from ray.train.callbacks.results_preprocessors.keys import (
ExcludedKeysResultsPreprocessor,
)
from ray.train.callbacks.results_preprocessors.aggregate import (
AverageResultsPreprocessor,
MaxResultsPreprocessor,
WeightedAverageResultsPreprocessor,
)
from ray.train.callbacks.results_preprocessors.preprocessor import (
SequentialResultsPreprocessor,
ResultsPreprocessor,
Expand All @@ -12,4 +17,7 @@
"IndexedResultsPreprocessor",
"ResultsPreprocessor",
"SequentialResultsPreprocessor",
"AverageResultsPreprocessor",
"MaxResultsPreprocessor",
"WeightedAverageResultsPreprocessor",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from ray.train.callbacks.results_preprocessors.aggregate.aggregate_preprocessor import (
AverageResultsPreprocessor,
MaxResultsPreprocessor,
WeightedAverageResultsPreprocessor,
)

__all__ = [
"AverageResultsPreprocessor",
"MaxResultsPreprocessor",
"WeightedAverageResultsPreprocessor",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import abc
from typing import Dict, List, Union, Optional

import numpy as np

from ray.train.callbacks.results_preprocessors.aggregate.aggregate_utils import (
VALID_AGGREGATE_TYPES,
_get_weights_from_results,
)


class AggregateFn(abc.ABC):
"""An abstract class for aggregation function."""

def __call__(
self, values: List[Union[VALID_AGGREGATE_TYPES]]
) -> Union[VALID_AGGREGATE_TYPES]:
"""Perform the aggregation of values when being called.
Args:
values (List[Union[VALID_AGGREGATE_TYPES]]): A list of
values returned from workers. The length of the list
is expected to be equal to the number of workers.
Returns:
A single value that should logically be some form of aggregation
of the values from each worker in the ``values`` list.
"""
raise NotImplementedError

def prepare(self, results: List[Dict]) -> None:
"""Perform some preparation work before aggregation.
Unlike ``__call__``, this method is not called separately
for each metric, but is only called once for preparation
before aggregation begins. Any logic that does not need to
be called for each metric should be placed in this method.
"""
pass

def wrap_key(self, key) -> str:
"""Get a string representation of the aggregation."""
return str(self) + f"({key})"


class Average(AggregateFn):
"""Average aggregation class."""

def __call__(
self, values: List[Union[VALID_AGGREGATE_TYPES]]
) -> Union[VALID_AGGREGATE_TYPES]:
# A numpy runtime warning will be thrown if values
# is a list of all ``np.nan``.
return np.nanmean(values)

def __repr__(self) -> str:
return "avg"


class Max(AggregateFn):
"""Maximum aggregation class."""

def __call__(
self, values: List[Union[VALID_AGGREGATE_TYPES]]
) -> Union[VALID_AGGREGATE_TYPES]:
# A numpy runtime warning will be thrown if values
# is a list of all ``np.nan``.
return np.nanmax(values)

def __repr__(self) -> str:
return "max"


class WeightedAverage(AggregateFn):
"""Weighted average aggregation class.
Args:
weight_key (Optional[str]): A key string that specifies
the average weight to be used. If it is None, then
equal weight will be used.
"""

def __init__(self, weight_key: Optional[str] = None):
self.weight_key = weight_key
self.weights = None

def __call__(
self, values: List[Union[VALID_AGGREGATE_TYPES]]
) -> Union[VALID_AGGREGATE_TYPES]:
return np.nansum(
np.array(values)
* self.weights
/ np.nansum(self.weights * (1 - np.isnan(values)))
)

def __repr__(self) -> str:
return f"weight_avg_{self.weight_key}"

def prepare(self, results: List[Dict]):
self.weights = _get_weights_from_results(self.weight_key, results)
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import logging
from typing import Dict, List, Optional

from ray.util.annotations import DeveloperAPI
from ray.train.callbacks.results_preprocessors.preprocessor import ResultsPreprocessor
from ray.train.callbacks.results_preprocessors.aggregate.aggregate_fn import (
AggregateFn,
Average,
Max,
WeightedAverage,
)
from ray.train.callbacks.results_preprocessors.aggregate.aggregate_utils import (
_get_metrics_from_results,
)

logger = logging.getLogger(__name__)


@DeveloperAPI
class AggregateResultsPreprocessor(ResultsPreprocessor):
"""A preprocessor that aggregates training metrics from all workers.
Args:
aggregation_fn (AggregateFn):
An aggregation method that performs the aggregation on results.
keys (Optional[List[str]]):
A list of keys reported in results to be aggregated. Keys should
be saved using ``train.report()``. If a key is invalid or not
reported, it will be ignored.
"""

def __init__(self, aggregation_fn: AggregateFn, keys: Optional[List[str]] = None):
self.aggregate_fn = aggregation_fn
self.keys = keys

def preprocess(self, results: Optional[List[Dict]] = None) -> Optional[List[Dict]]:
"""Aggregate results before sending them to callbacks.
A key will be ignored if one of the following occurs:
1. No worker reports it.
2. The values returned from all workers are invalid.
The aggregation WILL be performed even if some but not all
workers report the key with valid values. The aggregation
will only applied to those that report the key.
Args:
results (Optional[List[Dict]]): A list of results
from all workers. The metrics specified in ``keys``
will be averaged according to ``aggregation_fn``.
Returns:
An updated results list that has aggregated results and
is of the same length as the input list.
"""
if results is None or len(results) == 0:
return results

self.aggregate_fn.prepare(results)

keys_to_aggregate = (
self.keys
if self.keys
else {key for result in results for key in result.keys()}
)

aggregated_results = {}

for key in keys_to_aggregate:
values = _get_metrics_from_results(key, results)
if values:
aggregated_results[self.aggregate_fn.wrap_key(key)] = self.aggregate_fn(
values
)

# Currently we directly update each result dict with aggregated results.
for result in results:
result.update(aggregated_results)

return results


class AverageResultsPreprocessor(AggregateResultsPreprocessor):
"""A preprocessor that averages results with equal weight.
.. code-block:: python
preprocessor = AverageResultsPreprocessor(keys=["loss", "accuracy"])
update_results = preprocessor.preprocess(results)
Args:
keys (Optional[List[str]]): A list of metrics to be averaged.
If None is specified, then the list will be populated by
reported keys whose value type is valid, that is, one of
``VALID_AGGREGATE_TYPES``.
Returns:
An updated results list with average values.
"""

def __init__(self, keys: Optional[List[str]] = None):
super().__init__(Average(), keys)


class MaxResultsPreprocessor(AggregateResultsPreprocessor):
"""A preprocessor that computes maximum values of specified keys.
.. code-block:: python
preprocessor = MaxResultsPreprocessor(keys=["loss", "accuracy"])
update_results = preprocessor.preprocess(results)
Args:
keys (Optional[List[str]]): A list of metrics upon which the
maximum value will be taken. If None is specified, then
the list will be populated by reported keys whose value type
is valid, that is, one of ``VALID_AGGREGATE_TYPES``.
Returns:
An updated results list with maximum values.
"""

def __init__(self, keys: Optional[List[str]] = None):
super().__init__(Max(), keys)


class WeightedAverageResultsPreprocessor(AggregateResultsPreprocessor):
"""A preprocessor that performs weighted average over metrics.
.. code-block:: python
preprocessor = WeightedAverageResultsPreprocessor(keys=["loss", "accuracy"],
weight_key="batch_size")
update_results = preprocessor.preprocess(results)
Args:
keys (Optional[List[str]]): A list of metrics to be averaged.
If None is specified, then the list will be populated by
reported keys whose value type is valid, that is, one of
``VALID_AGGREGATE_TYPES``.
weight_key (Optional[str]): A a key from reported metrics that
will be used as the weight in averaging. If None is specified,
then equal weight will be used.
Returns:
An updated results list with weighted average results.
"""

def __init__(
self, keys: Optional[List[str]] = None, weight_key: Optional[str] = None
):
super().__init__(WeightedAverage(weight_key), keys)
Loading

0 comments on commit d1009c8

Please sign in to comment.