Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Add support for metrics aggregation #22099

Merged
merged 23 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions python/ray/train/callbacks/print.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import json
from typing import Dict, List
from typing import Dict, List, Optional

from ray.train.callbacks import TrainingCallback
from ray.train.callbacks.results_preprocessors import ResultsPreprocessor
from ray.train.callbacks.results_preprocessors.preprocessor import (
SequentialResultsPreprocessor,
)
from ray.train.utils import ResultsList


class PrintCallback(TrainingCallback):
Expand Down Expand Up @@ -52,6 +57,14 @@ class PrintCallback(TrainingCallback):
]
"""

def __init__(
self, results_preprocessors: Optional[List[ResultsPreprocessor]] = None
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
):
if results_preprocessors:
self.results_preprocessor = SequentialResultsPreprocessor(
results_preprocessors
)

def handle_result(self, results: List[Dict], **info):
"""Prints results to STDOUT.

Expand All @@ -61,4 +74,10 @@ def handle_result(self, results: List[Dict], **info):
the training function from each worker.
**info: kwargs dict for forward compatibility.
"""
print(json.dumps(results, indent=4))

output = list(results)

if isinstance(results, ResultsList):
output.append(results.aggregated_results)

print(json.dumps(output, indent=4))
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from ray.train.callbacks.results_preprocessors.keys import (
ExcludedKeysResultsPreprocessor,
)
from ray.train.callbacks.results_preprocessors.aggregate import (
AverageResultsPreprocessor,
MaxResultsPreprocessor,
WeightedAverageResultsPreprocessor,
)
from ray.train.callbacks.results_preprocessors.preprocessor import (
SequentialResultsPreprocessor,
ResultsPreprocessor,
Expand All @@ -12,4 +17,7 @@
"IndexedResultsPreprocessor",
"ResultsPreprocessor",
"SequentialResultsPreprocessor",
"AverageResultsPreprocessor",
"MaxResultsPreprocessor",
"WeightedAverageResultsPreprocessor",
]
95 changes: 95 additions & 0 deletions python/ray/train/callbacks/results_preprocessors/aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import logging
from typing import Dict, List, Optional

import numpy as np

from ray.train.callbacks.results_preprocessors.preprocessor import ResultsPreprocessor
from ray.train.callbacks.results_preprocessors.aggregate_fn import (
AggregateFn,
Average,
Max,
WeightedAverage,
VALID_AGGREGATE_TYPES,
_get_values_from_results,
)
from ray.train.utils import ResultsList

logger = logging.getLogger(__name__)


class AggregateResultsPreprocessor(ResultsPreprocessor):
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
"""A preprocessor that aggregates training metrics from all workers.

Args:
aggregation_fn (AggregateFn):
An aggregation method that performs the aggregation on results.
keys (Optional[List[str]]):
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
A list of keys reported in results to be aggregated. Keys should be saved
using `train.report()`.
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, aggregation_fn: AggregateFn, keys: Optional[List[str]] = None):
self.aggregate_fn = aggregation_fn
self.keys = keys

def preprocess(self, results: List[Dict]) -> List[Dict]:
"""Average results before sending them to callbacks.

Args:
results List[Dict]: A list of results from all workers. The metrics
specified in `keys` will be averaged according by `aggregation_fn`.
Non-numerical values will be ignored.
Returns:
A ResultsList (a subclass of Python list) instance. The aggregated
results is stored as the property `aggregated_results`.
"""
results = [] if results is None else results
if len(results) == 0:
return results

self.aggregate_fn.prepare(results)

reported_metrics = set(results[0].keys())
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
if self.keys is None:
valid_keys = []
for metric in reported_metrics:
if all(
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
isinstance(result.get(metric, np.nan), VALID_AGGREGATE_TYPES)
for result in results
):
valid_keys.append(metric)
self.keys = valid_keys

aggregated_results = {}

for key in self.keys:
values = _get_values_from_results(key, reported_metrics, results)
if values is None:
continue
aggregated_results[self.aggregate_fn.wrap_key(key)] = self.aggregate_fn(
values
)

if not isinstance(results, ResultsList):
results = ResultsList(results)

results.update_aggregated_results(aggregated_results)

return results


class AverageResultsPreprocessor(AggregateResultsPreprocessor):
def __init__(self, keys: Optional[List[str]] = None):
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(Average(), keys)


class MaxResultsPreprocessor(AggregateResultsPreprocessor):
def __init__(self, keys: Optional[List[str]] = None):
super().__init__(Max(), keys)


class WeightedAverageResultsPreprocessor(AggregateResultsPreprocessor):
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self, keys: Optional[List[str]] = None, weight_key: Optional[str] = None
):
super().__init__(WeightedAverage(weight_key), keys)
104 changes: 104 additions & 0 deletions python/ray/train/callbacks/results_preprocessors/aggregate_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import logging
from typing import Dict, List, Union, Tuple, Optional

import numpy as np

VALID_AGGREGATE_TYPES: Tuple[type] = (
int,
float,
np.float32,
np.float64,
np.int32,
np.int64,
)

logger = logging.getLogger(__name__)


def _get_values_from_results(key, reported_metrics, results, get_weights=False):
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
"""Return values in the results list from all workers."""
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
values = [result.get(key, np.nan) for result in results]
warning_message = None
if key not in reported_metrics:
warning_message = (
f"Averaging weight `{key}` is not reported in `train.report()`. "
if get_weights
else (
f"`{key}` is not reported from workers, so it is ignored. "
"Please make sure that it is saved using `train.report()`."
)
)
else:
if not all(isinstance(value, VALID_AGGREGATE_TYPES) for value in values):
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
warning_message = (
(
f"Averaging weight `{key}` value type "
f"(`{type(values[0])}`) is not valid. "
)
if get_weights
else (
f"`{key}` value type (`{type(values[0])}`) is not valid, "
f"so it is ignored. "
)
)

warning_message += (
f"Make sure that its type is one of {VALID_AGGREGATE_TYPES}. "
)

if warning_message:
if get_weights:
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(warning_message + "Use equal weight instead.")
values = np.array([1] * len(results))
else:
logger.warning(warning_message)
return None

return values


class AggregateFn:
"""An abstract class for aggregation function."""

def __call__(self, *args, **kwargs):
raise NotImplementedError

def prepare(self, *args, **kwargs):
pass

def wrap_key(self, key):
return str(self) + f"({key})"


class Average(AggregateFn):
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
def __call__(self, values: List[Union[VALID_AGGREGATE_TYPES]]):
amogkam marked this conversation as resolved.
Show resolved Hide resolved
return np.nanmean(values)

def __repr__(self):
return "Average"


class Max(AggregateFn):
def __call__(self, values: List[Union[VALID_AGGREGATE_TYPES]]):
return np.nanmax(values)

def __repr__(self):
return "Max"


class WeightedAverage(AggregateFn):
def __init__(self, weight_key: Optional[str] = None):
self.weight_key = weight_key
self.weights = None

def __call__(self, values: List[Union[VALID_AGGREGATE_TYPES]]):
return np.nansum(np.array(values) * self.weights / np.nansum(self.weights))

def __repr__(self):
return f"Weighted average [by {self.weight_key}]"

def prepare(self, results: List[Dict]):
reported_metrics = set(results[0].keys())
self.weights = _get_values_from_results(
self.weight_key, reported_metrics, results, get_weights=True
)
36 changes: 28 additions & 8 deletions python/ray/train/examples/train_linear_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import torch.nn as nn
import ray.train as train
from ray.train import Trainer
from ray.train.callbacks import JsonLoggerCallback, TBXLoggerCallback
from ray.train.callbacks import JsonLoggerCallback, TBXLoggerCallback, PrintCallback
from ray.train.callbacks.results_preprocessors.aggregate import (
AverageResultsPreprocessor,
MaxResultsPreprocessor,
WeightedAverageResultsPreprocessor,
)


class LinearDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -35,7 +40,7 @@ def train_epoch(dataloader, model, loss_fn, optimizer):
optimizer.step()


def validate_epoch(dataloader, model, loss_fn):
def validate_epoch(epoch, dataloader, model, loss_fn):
num_batches = len(dataloader)
model.eval()
loss = 0
Expand All @@ -44,10 +49,13 @@ def validate_epoch(dataloader, model, loss_fn):
pred = model(X)
loss += loss_fn(pred, y).item()
loss /= num_batches
import copy

model_copy = copy.deepcopy(model)
result = {"model": model_copy.cpu().state_dict(), "loss": loss}
result = {
"worker_idx": train.world_rank(),
"epoch": epoch,
"loss": loss,
"batch_size": num_batches,
}
return result


Expand Down Expand Up @@ -76,9 +84,9 @@ def train_func(config):

results = []

for _ in range(epochs):
for i in range(epochs):
train_epoch(train_loader, model, loss_fn, optimizer)
result = validate_epoch(validation_loader, model, loss_fn)
result = validate_epoch(i, validation_loader, model, loss_fn)
train.report(**result)
results.append(result)

Expand All @@ -90,7 +98,19 @@ def train_linear(num_workers=2, use_gpu=False, epochs=3):
config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": epochs}
trainer.start()
results = trainer.run(
train_func, config, callbacks=[JsonLoggerCallback(), TBXLoggerCallback()]
train_func,
config,
callbacks=[
PrintCallback(
results_preprocessors=[
AverageResultsPreprocessor(["loss", "batch_size"]),
MaxResultsPreprocessor(["time", "loss"]),
WeightedAverageResultsPreprocessor(["loss"], "batch_size"),
]
),
JsonLoggerCallback(),
TBXLoggerCallback(),
],
)
trainer.shutdown()

Expand Down
15 changes: 15 additions & 0 deletions python/ray/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,18 @@ def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]


class ResultsList(list):
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
"""A list subclass that stores aggregated results"""

def __init__(self, *arg):
super().__init__(*arg)
self._aggregate_results = {}

def update_aggregated_results(self, s):
self._aggregate_results.update(s)

@property
def aggregated_results(self):
return self._aggregate_results