Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Add support for metrics aggregation #22099

Merged
merged 23 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from ray.train.callbacks.results_preprocessors.keys import (
ExcludedKeysResultsPreprocessor,
)
from ray.train.callbacks.results_preprocessors.aggregate import (
AverageResultsPreprocessor,
MaxResultsPreprocessor,
WeightedAverageResultsPreprocessor,
)
from ray.train.callbacks.results_preprocessors.preprocessor import (
SequentialResultsPreprocessor,
ResultsPreprocessor,
Expand All @@ -12,4 +17,7 @@
"IndexedResultsPreprocessor",
"ResultsPreprocessor",
"SequentialResultsPreprocessor",
"AverageResultsPreprocessor",
"MaxResultsPreprocessor",
"WeightedAverageResultsPreprocessor",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from ray.train.callbacks.results_preprocessors.aggregate.aggregate_preprocessor import (
AverageResultsPreprocessor,
MaxResultsPreprocessor,
WeightedAverageResultsPreprocessor,
)

__all__ = [
"AverageResultsPreprocessor",
"MaxResultsPreprocessor",
"WeightedAverageResultsPreprocessor",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import abc
from typing import Dict, List, Union, Optional

import numpy as np

from ray.train.callbacks.results_preprocessors.aggregate.aggregate_utils import (
VALID_AGGREGATE_TYPES,
_get_weights_from_results,
)


class AggregateFn(abc.ABC):
"""An abstract class for aggregation function."""

def __call__(
self, values: List[Union[VALID_AGGREGATE_TYPES]]
) -> Union[VALID_AGGREGATE_TYPES]:
"""Perform the aggregation of values when being called.

Args:
values (List[Union[VALID_AGGREGATE_TYPES]]): A list of
values returned from workers. The length of the list
is expected to be equal to the number of workers.

Returns:
A single value that should logically be some form of aggregation
of the values from each worker in the ``values`` list.
"""
raise NotImplementedError

def prepare(self, results: List[Dict]) -> None:
"""Perform some preparation work before aggregation.

Unlike ``__call__``, this method is not called separately
for each metric, but is only called once for preparation
before aggregation begins. Any logic that does not need to
be called for each metric should be placed in this method.
"""
pass

def wrap_key(self, key) -> str:
"""Get a string representation of the aggregation."""
return str(self) + f"({key})"


class Average(AggregateFn):
"""Average aggregation class."""

def __call__(
self, values: List[Union[VALID_AGGREGATE_TYPES]]
) -> Union[VALID_AGGREGATE_TYPES]:
# A numpy runtime warning will be thrown if values
# is a list of all ``np.nan``.
return np.nanmean(values)

def __repr__(self) -> str:
return "avg"


class Max(AggregateFn):
"""Maximum aggregation class."""

def __call__(
self, values: List[Union[VALID_AGGREGATE_TYPES]]
) -> Union[VALID_AGGREGATE_TYPES]:
# A numpy runtime warning will be thrown if values
# is a list of all ``np.nan``.
return np.nanmax(values)

def __repr__(self) -> str:
return "max"


class WeightedAverage(AggregateFn):
"""Weighted average aggregation class.

Args:
weight_key (Optional[str]): A key string that specifies
the average weight to be used. If it is None, then
equal weight will be used.
"""

def __init__(self, weight_key: Optional[str] = None):
self.weight_key = weight_key
self.weights = None

def __call__(
self, values: List[Union[VALID_AGGREGATE_TYPES]]
) -> Union[VALID_AGGREGATE_TYPES]:
return np.nansum(
np.array(values)
* self.weights
/ np.nansum(self.weights * (1 - np.isnan(values)))
)

def __repr__(self) -> str:
return f"weight_avg_{self.weight_key}"

def prepare(self, results: List[Dict]):
self.weights = _get_weights_from_results(self.weight_key, results)
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import logging
from typing import Dict, List, Optional

from ray.util.annotations import DeveloperAPI
from ray.train.callbacks.results_preprocessors.preprocessor import ResultsPreprocessor
from ray.train.callbacks.results_preprocessors.aggregate.aggregate_fn import (
AggregateFn,
Average,
Max,
WeightedAverage,
)
from ray.train.callbacks.results_preprocessors.aggregate.aggregate_utils import (
_get_metrics_from_results,
)

logger = logging.getLogger(__name__)


@DeveloperAPI
class AggregateResultsPreprocessor(ResultsPreprocessor):
"""A preprocessor that aggregates training metrics from all workers.

Args:
aggregation_fn (AggregateFn):
An aggregation method that performs the aggregation on results.
keys (Optional[List[str]]):
A list of keys reported in results to be aggregated. Keys should
be saved using ``train.report()``. If a key is invalid or not
reported, it will be ignored.
"""

def __init__(self, aggregation_fn: AggregateFn, keys: Optional[List[str]] = None):
self.aggregate_fn = aggregation_fn
self.keys = keys

def preprocess(self, results: Optional[List[Dict]] = None) -> Optional[List[Dict]]:
"""Aggregate results before sending them to callbacks.

A key will be ignored if one of the following occurs:
1. No worker reports it.
2. The values returned from all workers are invalid.
The aggregation WILL be performed even if some but not all
workers report the key with valid values. The aggregation
will only applied to those that report the key.

Args:
results (Optional[List[Dict]]): A list of results
from all workers. The metrics specified in ``keys``
will be averaged according to ``aggregation_fn``.

Returns:
An updated results list that has aggregated results and
is of the same length as the input list.
"""
if results is None or len(results) == 0:
return results
jwyyy marked this conversation as resolved.
Show resolved Hide resolved

self.aggregate_fn.prepare(results)

keys_to_aggregate = (
self.keys
if self.keys
else {key for result in results for key in result.keys()}
)

aggregated_results = {}

for key in keys_to_aggregate:
values = _get_metrics_from_results(key, results)
if values:
aggregated_results[self.aggregate_fn.wrap_key(key)] = self.aggregate_fn(
values
)

# Currently we directly update each result dict with aggregated results.
for result in results:
result.update(aggregated_results)

return results


class AverageResultsPreprocessor(AggregateResultsPreprocessor):
"""A preprocessor that averages results with equal weight.

.. code-block:: python

preprocessor = AverageResultsPreprocessor(keys=["loss", "accuracy"])
update_results = preprocessor.preprocess(results)


Args:
keys (Optional[List[str]]): A list of metrics to be averaged.
If None is specified, then the list will be populated by
reported keys whose value type is valid, that is, one of
``VALID_AGGREGATE_TYPES``.

Returns:
An updated results list with average values.
"""

def __init__(self, keys: Optional[List[str]] = None):
super().__init__(Average(), keys)


class MaxResultsPreprocessor(AggregateResultsPreprocessor):
"""A preprocessor that computes maximum values of specified keys.

.. code-block:: python

preprocessor = MaxResultsPreprocessor(keys=["loss", "accuracy"])
update_results = preprocessor.preprocess(results)


Args:
keys (Optional[List[str]]): A list of metrics upon which the
maximum value will be taken. If None is specified, then
the list will be populated by reported keys whose value type
is valid, that is, one of ``VALID_AGGREGATE_TYPES``.

Returns:
An updated results list with maximum values.
"""

def __init__(self, keys: Optional[List[str]] = None):
super().__init__(Max(), keys)


class WeightedAverageResultsPreprocessor(AggregateResultsPreprocessor):
"""A preprocessor that performs weighted average over metrics.


.. code-block:: python

preprocessor = WeightedAverageResultsPreprocessor(keys=["loss", "accuracy"],
weight_key="batch_size")
update_results = preprocessor.preprocess(results)

Args:
keys (Optional[List[str]]): A list of metrics to be averaged.
If None is specified, then the list will be populated by
reported keys whose value type is valid, that is, one of
``VALID_AGGREGATE_TYPES``.
weight_key (Optional[str]): A a key from reported metrics that
will be used as the weight in averaging. If None is specified,
then equal weight will be used.

Returns:
An updated results list with weighted average results.
"""

def __init__(
self, keys: Optional[List[str]] = None, weight_key: Optional[str] = None
):
super().__init__(WeightedAverage(weight_key), keys)
Loading