From 28e84b8f493f54c5a079f47b72a3d3f9d99fe806 Mon Sep 17 00:00:00 2001 From: ilee300a <112217335+ilee300a@users.noreply.github.com> Date: Thu, 27 Oct 2022 16:43:16 -0700 Subject: [PATCH] [AIR] Added Ray Logging to MosaicTrainer (#29620) Added RayLogger to MosaicTrainer to relay all reported information. RayLogger is a subclass of LoggerDestination, just like all other native composer loggers. The information to be logged is given via log_metrics call, which is saved in the RayLogger object. The logger reports the logged information every batch checkpoint and epoch checkpoint. All other composer loggers besides RayLogger loggers are removed from the trainer. Note that because at the moment, the result metrics_dataframe will only include the keys that are reported in the very first report call, to have metrics that are not reported every batch in the final metrics dataframe, the keys should be passed in via 'log_keys' in the trainer_init_config. Co-authored-by: Amog Kamsetty Signed-off-by: ilee300a <112217335+ilee300a@users.noreply.github.com> --- doc/source/custom_directives.py | 2 + .../train/examples/mosaic_cifar10_example.py | 8 +- python/ray/train/mosaic/_mosaic_utils.py | 68 +++++++++ python/ray/train/mosaic/mosaic_trainer.py | 12 +- python/ray/train/tests/test_mosaic_trainer.py | 129 +++++++++++++++++- 5 files changed, 210 insertions(+), 9 deletions(-) create mode 100644 python/ray/train/mosaic/_mosaic_utils.py diff --git a/doc/source/custom_directives.py b/doc/source/custom_directives.py index d95e5c92f5c3..d3f3bb6f547c 100644 --- a/doc/source/custom_directives.py +++ b/doc/source/custom_directives.py @@ -116,6 +116,8 @@ def update_context(app, pagename, templatename, context, doctree): "composer.trainer", "composer.loggers", "composer.loggers.logger_destination", + "composer.core", + "composer.core.state", ] diff --git a/python/ray/train/examples/mosaic_cifar10_example.py b/python/ray/train/examples/mosaic_cifar10_example.py index 71c31cd2ec37..162100b3ed3e 100644 --- a/python/ray/train/examples/mosaic_cifar10_example.py +++ b/python/ray/train/examples/mosaic_cifar10_example.py @@ -19,7 +19,7 @@ def trainer_init_per_worker(config): from composer.models.tasks import ComposerClassifier import composer.optim - BATCH_SIZE = 32 + BATCH_SIZE = 64 # prepare the model for distributed training and wrap with ComposerClassifier for # Composer Trainer compatibility model = torchvision.models.resnet18(num_classes=10) @@ -37,13 +37,13 @@ def trainer_init_per_worker(config): datasets.CIFAR10( data_directory, train=True, download=True, transform=cifar10_transforms ), - list(range(64)), + list(range(BATCH_SIZE * 10)), ) test_dataset = torch.utils.data.Subset( datasets.CIFAR10( data_directory, train=False, download=True, transform=cifar10_transforms ), - list(range(64)), + list(range(BATCH_SIZE * 10)), ) batch_size_per_worker = BATCH_SIZE // session.get_world_size() @@ -82,7 +82,7 @@ def train_mosaic_cifar10(num_workers=2, use_gpu=False): from ray.train.mosaic import MosaicTrainer trainer_init_config = { - "max_duration": "1ep", + "max_duration": "2ep", "algorithms": [LabelSmoothing()], "should_eval": False, } diff --git a/python/ray/train/mosaic/_mosaic_utils.py b/python/ray/train/mosaic/_mosaic_utils.py new file mode 100644 index 000000000000..27bec274a179 --- /dev/null +++ b/python/ray/train/mosaic/_mosaic_utils.py @@ -0,0 +1,68 @@ +from typing import Any, Dict, Optional, List +import torch + +from composer.loggers import Logger +from composer.loggers.logger_destination import LoggerDestination +from composer.core.state import State + +from ray.air import session + + +class RayLogger(LoggerDestination): + """A logger to relay information logged by composer models to ray. + + This logger allows utilizing all necessary logging and logged data handling provided + by the Composer library. All the logged information is saved in the data dictionary + every time a new information is logged, but to reduce unnecessary reporting, the + most up-to-date logged information is reported as metrics every batch checkpoint and + epoch checkpoint (see Composer's Event module for more details). + + Because ray's metric dataframe will not include new keys that is reported after the + very first report call, any logged information with the keys not included in the + first batch checkpoint would not be retrievable after training. In other words, if + the log level is greater than `LogLevel.BATCH` for some data, they would not be + present in `Result.metrics_dataframe`. To allow preserving those information, the + user can provide keys to be always included in the reported data by using `keys` + argument in the constructor. For `MosaicTrainer`, use + `trainer_init_config['log_keys']` to populate these keys. + + Note that in the Event callback functions, we remove unused variables, as this is + practiced in Mosaic's composer library. + + Args: + keys: the key values that will be included in the reported metrics. + """ + + def __init__(self, keys: Optional[List[str]] = None) -> None: + self.data = {} + # report at fit end only if there are additional training batches run after the + # last epoch checkpoint report + self.should_report_fit_end = False + if keys: + for key in keys: + self.data[key] = None + + def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: + self.data.update(metrics.items()) + for key, val in self.data.items(): + if isinstance(val, torch.Tensor): + self.data[key] = val.item() + + def batch_checkpoint(self, state: State, logger: Logger) -> None: + del logger # unused + self.should_report_fit_end = True + + def epoch_checkpoint(self, state: State, logger: Logger) -> None: + del logger # unused + self.should_report_fit_end = False + session.report(self.data) + + # flush the data + self.data = {} + + def fit_end(self, state: State, logger: Logger) -> None: + # report at close in case the trainer stops in the middle of an epoch. + # this may be double counted with epoch checkpoint. + del logger # unused + if self.should_report_fit_end: + session.report(self.data) diff --git a/python/ray/train/mosaic/mosaic_trainer.py b/python/ray/train/mosaic/mosaic_trainer.py index bf4ffdd2c630..8d8481004d61 100644 --- a/python/ray/train/mosaic/mosaic_trainer.py +++ b/python/ray/train/mosaic/mosaic_trainer.py @@ -9,6 +9,7 @@ from ray.air import session from ray.air.checkpoint import Checkpoint from ray.air.config import DatasetConfig, RunConfig, ScalingConfig +from ray.train.mosaic._mosaic_utils import RayLogger from ray.train.torch import TorchConfig, TorchTrainer from ray.train.trainer import GenDataset from ray.util import PublicAPI @@ -207,16 +208,23 @@ def _mosaic_train_loop_per_worker(config): os.environ["WORLD_SIZE"] = str(session.get_world_size()) os.environ["LOCAL_RANK"] = str(session.get_local_rank()) + # Replace Composer's Loggers with RayLogger + ray_logger = RayLogger(keys=config.pop("log_keys", [])) + # initialize Composer trainer - config["progress_bar"] = False trainer: Trainer = trainer_init_per_worker(config) - # Remove Composer's Loggers + # Remove Composer's Loggers if there are any added in the trainer_init_per_worker + # this removes the logging part of the loggers filtered_callbacks = list() for callback in trainer.state.callbacks: if not isinstance(callback, LoggerDestination): filtered_callbacks.append(callback) + filtered_callbacks.append(ray_logger) trainer.state.callbacks = filtered_callbacks + # this prevents data to be routed to all the Composer Loggers + trainer.logger.destinations = (ray_logger,) + # call the trainer trainer.fit() diff --git a/python/ray/train/tests/test_mosaic_trainer.py b/python/ray/train/tests/test_mosaic_trainer.py index 6110e3c49acc..df0f5c6b0c39 100644 --- a/python/ray/train/tests/test_mosaic_trainer.py +++ b/python/ray/train/tests/test_mosaic_trainer.py @@ -71,7 +71,7 @@ def trainer_init_per_worker(config): weight_decay=2.0e-3, ) - if config.pop("eval", False): + if config.pop("should_eval", False): config["eval_dataloader"] = evaluator return composer.trainer.Trainer( @@ -85,9 +85,17 @@ def trainer_init_per_worker(config): def test_mosaic_cifar10(ray_start_4_cpus): from ray.train.examples.mosaic_cifar10_example import train_mosaic_cifar10 - _ = train_mosaic_cifar10() + result = train_mosaic_cifar10().metrics_dataframe - # TODO : add asserts once reporting has been integrated + # check the max epoch value + assert result["epoch"][result.index[-1]] == 1 + + # check train_iterations + assert result["_training_iteration"][result.index[-1]] == 2 + + # check metrics/train/Accuracy has increased + acc = list(result["metrics/train/Accuracy"]) + assert acc[-1] > acc[0] def test_init_errors(ray_start_4_cpus): @@ -149,6 +157,10 @@ class DummyCallback(Callback): def fit_start(self, state: State, logger: Logger) -> None: raise ValueError("Composer Callback object exists.") + class DummyMonitorCallback(Callback): + def fit_start(self, state: State, logger: Logger) -> None: + logger.log_metrics({"dummy_callback": "test"}) + # DummyLogger should not throw an error since it should be removed before `fit` call trainer_init_config = { "max_duration": "1ep", @@ -175,6 +187,117 @@ def fit_start(self, state: State, logger: Logger) -> None: trainer.fit() assert e == "Composer Callback object exists." + trainer_init_config["callbacks"] = DummyMonitorCallback() + trainer = MosaicTrainer( + trainer_init_per_worker=trainer_init_per_worker, + trainer_init_config=trainer_init_config, + scaling_config=scaling_config, + ) + + result = trainer.fit() + + assert "dummy_callback" in result.metrics + assert result.metrics["dummy_callback"] == "test" + + +def test_log_count(ray_start_4_cpus): + from ray.train.mosaic import MosaicTrainer + + trainer_init_config = { + "max_duration": "1ep", + "should_eval": False, + } + + trainer = MosaicTrainer( + trainer_init_per_worker=trainer_init_per_worker, + trainer_init_config=trainer_init_config, + scaling_config=scaling_config, + ) + + result = trainer.fit() + + assert len(result.metrics_dataframe) == 1 + + trainer_init_config["max_duration"] = "1ba" + + trainer = MosaicTrainer( + trainer_init_per_worker=trainer_init_per_worker, + trainer_init_config=trainer_init_config, + scaling_config=scaling_config, + ) + + result = trainer.fit() + + assert len(result.metrics_dataframe) == 1 + + +def test_metrics_key(ray_start_4_cpus): + from ray.train.mosaic import MosaicTrainer + + """Tests if `log_keys` defined in `trianer_init_config` appears in result + metrics_dataframe. + """ + trainer_init_config = { + "max_duration": "1ep", + "should_eval": True, + "log_keys": ["metrics/my_evaluator/Accuracy"], + } + + trainer = MosaicTrainer( + trainer_init_per_worker=trainer_init_per_worker, + trainer_init_config=trainer_init_config, + scaling_config=scaling_config, + ) + + result = trainer.fit() + + # check if the passed in log key exists + assert "metrics/my_evaluator/Accuracy" in result.metrics_dataframe.columns + + +def test_monitor_callbacks(ray_start_4_cpus): + from ray.train.mosaic import MosaicTrainer + + # Test Callbacks involving logging (SpeedMonitor, LRMonitor) + from composer.callbacks import SpeedMonitor, LRMonitor, GradMonitor + + trainer_init_config = { + "max_duration": "1ep", + "should_eval": True, + } + trainer_init_config["log_keys"] = [ + "grad_l2_norm/step", + ] + trainer_init_config["callbacks"] = [ + SpeedMonitor(window_size=3), + LRMonitor(), + GradMonitor(), + ] + + trainer = MosaicTrainer( + trainer_init_per_worker=trainer_init_per_worker, + trainer_init_config=trainer_init_config, + scaling_config=scaling_config, + ) + + result = trainer.fit() + + assert len(result.metrics_dataframe) == 1 + + metrics_columns = result.metrics_dataframe.columns + columns_to_check = [ + "wall_clock/train", + "wall_clock/val", + "wall_clock/total", + "lr-DecoupledSGDW/group0", + "grad_l2_norm/step", + ] + for column in columns_to_check: + assert column in metrics_columns, column + " is not found" + assert result.metrics_dataframe[column].isnull().sum() == 0, ( + column + " column has a null value" + ) + if __name__ == "__main__": import sys