Skip to content

Commit

Permalink
new impl
Browse files Browse the repository at this point in the history
  • Loading branch information
jwyyy committed Feb 12, 2022
1 parent 4f33310 commit be279dd
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 44 deletions.
2 changes: 2 additions & 0 deletions python/ray/train/callbacks/results_preprocessors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ray.train.callbacks.results_preprocessors.keys import (
ExcludedKeysResultsPreprocessor,
)
from ray.train.callbacks.results_preprocessors.average import AverageResultsPreprocessor
from ray.train.callbacks.results_preprocessors.preprocessor import (
SequentialResultsPreprocessor,
ResultsPreprocessor,
Expand All @@ -12,4 +13,5 @@
"IndexedResultsPreprocessor",
"ResultsPreprocessor",
"SequentialResultsPreprocessor",
"AverageResultsPreprocessor",
]
74 changes: 74 additions & 0 deletions python/ray/train/callbacks/results_preprocessors/average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Dict, List, Tuple

import numpy as np

from ray.train.callbacks.results_preprocessors.preprocessor import ResultsPreprocessor


class AverageResultsPreprocessor(ResultsPreprocessor):
"""A preprocessor that average training metrics from all workers.
Args:
metrics_to_average (Dict): A Dict of metrics in results to average.
The key is the metric to be averaged across all workers. The value
is the magic key in the results that will be used as weights. If
the value is None, the weight will be taken to be equal for all
workers.
"""

VALID_SUMMARY_TYPES: Tuple[type] = (
int,
float,
np.float32,
np.float64,
np.int32,
np.int64,
)

def __init__(self, metrics_to_average: Dict = {}):
self.metrics_to_average = metrics_to_average

def preprocess(self, results: List[Dict] = []) -> List[Dict]:
"""Average results before sending them to callbacks.
Args:
results List[Dict]: A list of results from all workers. The metrics
specified in `metrics_to_average` will be averaged according to
their weights. Non-numerical values will be ignored.
Returns:
A updated list of results.
"""
if len(results) == 0 or len(self.metrics_to_average) == 0:
return results

average_metrics = {}
for metrics, weight in self.metrics_to_average.items():

if not isinstance(results[0][metrics], self.VALID_SUMMARY_TYPES):
continue

metrics_from_workers = np.array(
[result[metrics] for result in results if not np.isnan(result[metrics])]
)
if weight:
weights_from_workers = np.array(
[
result[weight]
for result in results
if not np.isnan(result[metrics])
]
)
else:
# if no weight is provided, equal weight will be used.
weights_from_workers = np.array([1] * len(metrics_from_workers))

average_metrics["_average_" + metrics] = np.nanmean(
metrics_from_workers
* weights_from_workers
/ np.sum(weights_from_workers)
)

for result in results:
result.update(average_metrics)

return results
28 changes: 8 additions & 20 deletions python/ray/train/examples/train_linear_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch.nn as nn
import ray.train as train
from ray.train import Trainer
from ray.train.callbacks import JsonLoggerCallback, TBXLoggerCallback
from ray.train.constants import TIME_THIS_ITER_S
from ray.train.callbacks import JsonLoggerCallback, TBXLoggerCallback, PrintCallback
from ray.train.callbacks.results_preprocessors.average import AverageResultsPreprocessor


class LinearDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -45,10 +45,11 @@ def validate_epoch(dataloader, model, loss_fn):
pred = model(X)
loss += loss_fn(pred, y).item()
loss /= num_batches
import copy

model_copy = copy.deepcopy(model)
result = {"model": model_copy.cpu().state_dict(), "loss": loss}
result = {
"loss": loss,
"batch_size": num_batches,
}
return result


Expand Down Expand Up @@ -86,32 +87,19 @@ def train_func(config):
return results


def average_validation_loss(intermediate_results):
worker_results = [worker_result["loss"] for worker_result in intermediate_results]
return np.mean(worker_results)


def average_iter_time(intermediate_results):
worker_results = [
worker_result[TIME_THIS_ITER_S] for worker_result in intermediate_results
]
return np.mean(worker_results)


def train_linear(num_workers=2, use_gpu=False, epochs=3):
trainer = Trainer(backend="torch", num_workers=num_workers, use_gpu=use_gpu)
config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": epochs}
trainer.start()
results = trainer.run(
train_func,
config,
callbacks=[JsonLoggerCallback(), TBXLoggerCallback()],
aggregate_funcs=[average_validation_loss, average_iter_time],
preprocessors=[AverageResultsPreprocessor({"loss": "batch_size"})],
callbacks=[JsonLoggerCallback(), TBXLoggerCallback(), PrintCallback()],
)
trainer.shutdown()

print(results)
print(trainer.aggregated_metrics)
return results


Expand Down
31 changes: 7 additions & 24 deletions python/ray/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from datetime import datetime
import collections
import inspect
import logging
import os
Expand All @@ -16,6 +15,7 @@
TrainingWorkerError,
)
from ray.train.callbacks.callback import TrainingCallback
from ray.train.callbacks.results_preprocessors import ResultsPreprocessor
from ray.train.session import TrainingResultType
from ray.train.utils import RayDataset
from ray.train.checkpoint import (
Expand Down Expand Up @@ -268,11 +268,11 @@ def run(
self,
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
config: Optional[Dict[str, Any]] = None,
preprocessors: Optional[List[ResultsPreprocessor]] = None,
callbacks: Optional[List[TrainingCallback]] = None,
dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]] = None,
checkpoint: Optional[Union[Dict, str, Path]] = None,
checkpoint_strategy: Optional[CheckpointStrategy] = None,
aggregate_funcs: Optional[Union[Dict, List]] = None,
) -> List[T]:
"""Runs a training function in a distributed manner.
Expand All @@ -281,6 +281,9 @@ def run(
This can either take in no arguments or a ``config`` dict.
config (Optional[Dict]): Configurations to pass into
``train_func``. If None then an empty Dict will be created.
preprocessors (Optional[List[ResultsPreprocessor]]): A list of
Preprocessors which will be called before results passed to
callbacks. Currently there are NO default Callbacks.
callbacks (Optional[List[TrainingCallback]]): A list of Callbacks
which will be executed during training. If this is not set,
currently there are NO default Callbacks.
Expand All @@ -302,9 +305,6 @@ def run(
``None`` then no checkpoint will be loaded.
checkpoint_strategy (Optional[CheckpointStrategy]): The
configurations for saving checkpoints.
aggregate_funcs (Optional[Union[Dict, List]]): The methods
used to aggregate intermediate results returned
by `train.report()` on each worker.
Returns:
A list of results from the training function. Each value in the
Expand Down Expand Up @@ -337,20 +337,12 @@ def run(
checkpoint_strategy=checkpoint_strategy,
run_dir=self.latest_run_dir,
)
aggregated_results = collections.defaultdict(list)
if aggregate_funcs is None or len(aggregate_funcs) == 0:
aggregate_funcs = {}
elif isinstance(aggregate_funcs, list):
aggregate_funcs = {e.__name__: e for e in aggregate_funcs}

for intermediate_result in iterator:
for aggregate_name, func in aggregate_funcs.items():
aggregated_results[aggregate_name].append(func(intermediate_result))
for preprocessor in preprocessors:
intermediate_result = preprocessor.preprocess(intermediate_result)
for callback in callbacks:
callback.process_results(intermediate_result)

self._aggregated_metrics = aggregated_results

assert iterator.is_finished()
return iterator.get_final_results()
finally:
Expand Down Expand Up @@ -499,15 +491,6 @@ def latest_checkpoint(self) -> Optional[Dict]:
"""
return self.checkpoint_manager.latest_checkpoint

@property
def aggregated_metrics(self) -> Optional[Dict]:
"""A ``Dict`` of aggregated metrics across all workers.
Returns ``None`` if ``run()`` has not been called or an empty
``Dict`` if ``train.report()`` has not been called from ``train_func``.
"""
return self._aggregated_metrics

def shutdown(self):
"""Shuts down the training execution service."""
ray.get(self._backend_executor_actor.shutdown.remote())
Expand Down

0 comments on commit be279dd

Please sign in to comment.