Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Add support for metrics aggregation #22099

Merged
merged 23 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from ray.train.callbacks.results_preprocessors.keys import (
ExcludedKeysResultsPreprocessor,
)
from ray.train.callbacks.results_preprocessors.aggregate import (
AverageResultsPreprocessor,
MaxResultsPreprocessor,
WeightedAverageResultsPreprocessor,
)
from ray.train.callbacks.results_preprocessors.preprocessor import (
SequentialResultsPreprocessor,
ResultsPreprocessor,
Expand All @@ -12,4 +17,7 @@
"IndexedResultsPreprocessor",
"ResultsPreprocessor",
"SequentialResultsPreprocessor",
"AverageResultsPreprocessor",
"MaxResultsPreprocessor",
"WeightedAverageResultsPreprocessor",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from ray.train.callbacks.results_preprocessors.aggregate.aggregate_preprocessor import (
AverageResultsPreprocessor,
MaxResultsPreprocessor,
WeightedAverageResultsPreprocessor,
)

__all__ = [
"AverageResultsPreprocessor",
"MaxResultsPreprocessor",
"WeightedAverageResultsPreprocessor",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Dict, List, Union, Optional

import numpy as np

from ray.train.callbacks.results_preprocessors.aggregate.aggregate_utils import (
VALID_AGGREGATE_TYPES,
_get_weights_from_results,
)


class AggregateFn:
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
"""An abstract class for aggregation function."""

def __call__(
self, values: List[Union[VALID_AGGREGATE_TYPES]]
) -> Union[VALID_AGGREGATE_TYPES]:
"""
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
Args:
values (List[Union[VALID_AGGREGATE_TYPES]]): A list of
values returned from workers. The length of the list
is expected to be equal to the number of workers.

Returns:
An aggregated value.
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
"""
raise NotImplementedError

def prepare(self, results: List[Dict]) -> None:
"""Perform some preparation work before aggregation."""
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
pass

def wrap_key(self, key) -> str:
"""Get a string representation of the aggregation."""
return str(self) + f"({key})"


class Average(AggregateFn):
"""Average aggregation class."""

def __call__(
self, values: List[Union[VALID_AGGREGATE_TYPES]]
) -> Union[VALID_AGGREGATE_TYPES]:
return np.nanmean(values)

def __repr__(self) -> str:
return "Average"
jwyyy marked this conversation as resolved.
Show resolved Hide resolved


class Max(AggregateFn):
"""Maximum aggregation class."""

def __call__(
self, values: List[Union[VALID_AGGREGATE_TYPES]]
) -> Union[VALID_AGGREGATE_TYPES]:
return np.nanmax(values)

def __repr__(self) -> str:
return "Max"
jwyyy marked this conversation as resolved.
Show resolved Hide resolved


class WeightedAverage(AggregateFn):
"""Weighted average aggregation class.

Args:
weight_key (Optional[str]): A string that specifies
the average weight to be used. If it is None, then
equal weight will be used.
"""

def __init__(self, weight_key: Optional[str] = None):
self.weight_key = weight_key
self.weights = None

def __call__(
self, values: List[Union[VALID_AGGREGATE_TYPES]]
) -> Union[VALID_AGGREGATE_TYPES]:
return np.nansum(np.array(values) * self.weights / np.nansum(self.weights))

def __repr__(self) -> str:
return f"Weighted average [by {self.weight_key}]"
jwyyy marked this conversation as resolved.
Show resolved Hide resolved

def prepare(self, results: List[Dict]):
self.weights = _get_weights_from_results(self.weight_key, results)
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import logging
from typing import Dict, List, Optional

import numpy as np

from ray.util.debug import log_once
from ray.util.annotations import DeveloperAPI
from ray.train.callbacks.results_preprocessors.preprocessor import ResultsPreprocessor
from ray.train.callbacks.results_preprocessors.aggregate.aggregate_fn import (
AggregateFn,
Average,
Max,
WeightedAverage,
)
from ray.train.callbacks.results_preprocessors.aggregate.aggregate_utils import (
VALID_AGGREGATE_TYPES,
_get_metrics_from_results,
)

logger = logging.getLogger(__name__)


@DeveloperAPI
class AggregateResultsPreprocessor(ResultsPreprocessor):
"""A preprocessor that aggregates training metrics from all workers.

Args:
aggregation_fn (AggregateFn):
An aggregation method that performs the aggregation on results.
keys (Optional[List[str]]):
A list of keys reported in results to be aggregated. Keys should
be saved using ``train.report()``. If a key is invalid or not
reported, it will be ignored.
"""

def __init__(self, aggregation_fn: AggregateFn, keys: Optional[List[str]] = None):
self.aggregate_fn = aggregation_fn
self.keys = keys

def preprocess(self, results: List[Dict]) -> List[Dict]:
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
"""Aggregate results before sending them to callbacks.

Args:
results (List[Dict]): A list of results from all workers. The metrics
specified in ``keys`` will be averaged according by ``aggregation_fn``.
Non-numerical values will be ignored.
Returns:
An updated results list with aggregated results.
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
"""
results = [] if results is None else results
if len(results) == 0:
return results
jwyyy marked this conversation as resolved.
Show resolved Hide resolved

self.aggregate_fn.prepare(results)
reported_metrics = {key for result in results for key in result.keys()}

if self.keys is None:
valid_keys = []
for metric in reported_metrics:
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
if all(
isinstance(result.get(metric, np.nan), VALID_AGGREGATE_TYPES)
for result in results
):
valid_keys.append(metric)
elif log_once(metric):
logger.warning(
f"`{metric}` value type is not "
f"one of {VALID_AGGREGATE_TYPES}, so it is ignored. "
)

self.keys = valid_keys
jwyyy marked this conversation as resolved.
Show resolved Hide resolved

aggregated_results = {}

for key in self.keys:
values = _get_metrics_from_results(key, results)
if values is None:
continue
aggregated_results[self.aggregate_fn.wrap_key(key)] = self.aggregate_fn(
values
)

# Currently we directly update each result dict with aggregated results.
for result in results:
result.update(aggregated_results)

return results


class AverageResultsPreprocessor(AggregateResultsPreprocessor):
"""A preprocessor that averages results with equal weight.

Args:
keys (Optional[List[str]]): A list of metrics to be averaged.
If None is specified, then the list will be populated by
reported keys whose value type is valid, that is, one of
``VALID_AGGREGATE_TYPES``.

Returns:
An updated results list with average values.
"""

def __init__(self, keys: Optional[List[str]] = None):
super().__init__(Average(), keys)


class MaxResultsPreprocessor(AggregateResultsPreprocessor):
"""A preprocessor that computes the maximum values.

Args:
keys (Optional[List[str]]): A list of metrics upon which the
maximum value will be taken. If None is specified, then
the list will be populated by reported keys whose value type
is valid, that is, one of ``VALID_AGGREGATE_TYPES``.

Returns:
An updated results list with maximum values.
"""

def __init__(self, keys: Optional[List[str]] = None):
super().__init__(Max(), keys)


class WeightedAverageResultsPreprocessor(AggregateResultsPreprocessor):
"""A preprocessor that performs weighted average over metrics.


.. code-block:: python

preprocessor = WeightedAverageResultsPreprocessor(keys=["loss", "accuracy"],
weight_key="batch_size")
update_results = preprocessor.preprocess(results)

Args:
keys (Optional[List[str]]): A list of metrics to be averaged.
If None is specified, then the list will be populated by
reported keys whose value type is valid, that is, one of
``VALID_AGGREGATE_TYPES``.
weight_key (Optional[str]): A a key from reported metrics that
will be used as the weight in averaging. If None is specified,
then equal weight will be used.

Returns:
An updated results list with weighted average results.
"""

def __init__(
self, keys: Optional[List[str]] = None, weight_key: Optional[str] = None
):
super().__init__(WeightedAverage(weight_key), keys)
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import logging
from typing import Dict, List, Tuple, Union

import numpy as np

VALID_AGGREGATE_TYPES: Tuple[type] = (
int,
float,
np.float32,
np.float64,
np.int32,
np.int64,
)

logger = logging.getLogger(__name__)


def _get_metrics_from_results(
key: str, results: List[Dict]
) -> List[Union[VALID_AGGREGATE_TYPES]]:
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
"""Return metrics values in the results list from all workers.
jwyyy marked this conversation as resolved.
Show resolved Hide resolved

Args:
key (str): A key string. If it doesn't exist in results,
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
i.e. it is not reported, the None will be returned.
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
results (List[Dict]): The results list returned from workers.

Returns:
A list of valid key values from each worker, if key exists.
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
Otherwise, None.
"""
reported_metrics = {key for result in results for key in result.keys()}
values = [result.get(key, np.nan) for result in results]
warning_message = None
if key not in reported_metrics:
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
warning_message = (
f"`{key}` is not reported from workers, so it is ignored. "
"Please make sure that it is saved using `train.report()`."
)
elif not all(isinstance(value, VALID_AGGREGATE_TYPES) for value in values):
warning_message = (
f"`{key}` value type (`{type(values[0])}`) is not valid, "
"so it is ignored. "
f"Make sure that its type is one of {VALID_AGGREGATE_TYPES}. "
)

if warning_message:
logger.warning(warning_message)
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
return None

return values


def _get_weights_from_results(
key: str, results: List[Dict]
) -> List[Union[VALID_AGGREGATE_TYPES]]:
"""Return weight values in the results list from all workers.

Args:
key (str): A key string specifies the weight metric.
If it doesn't exist in results, then equal weight
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
will be used.
results (List[Dict]): The results list returned from workers.

Returns:
A list of valid weight values from each worker, if key exists.
Otherwise, a list of all ones, that is, equal weight.
"""
reported_metrics = {key for result in results for key in result.keys()}
weights = [result.get(key, np.nan) for result in results]
warning_message = None
if key not in reported_metrics:
warning_message = (
f"Averaging weight `{key}` is not reported in `train.report()`. "
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
)
elif not all(isinstance(value, VALID_AGGREGATE_TYPES) for value in weights):
warning_message = (
f"Averaging weight `{key}` value type (`{type(weights[0])}`) is not valid. "
f"Make sure that its type is one of {VALID_AGGREGATE_TYPES}. "
)

if warning_message:
logger.warning(warning_message + "Use equal weight instead.")
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
weights = np.array([1] * len(results))

return weights
Loading