From 3361d4e491ccfa7869950cbe75de1edecb6f67d1 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 16 Feb 2024 11:28:26 -0800 Subject: [PATCH 1/6] FIX: don't pickle self along with train trainable Signed-off-by: Justin Yu --- python/ray/train/base_trainer.py | 78 +++++++++++++++++++------------- 1 file changed, 46 insertions(+), 32 deletions(-) diff --git a/python/ray/train/base_trainer.py b/python/ray/train/base_trainer.py index 0760208d72a9..c5f8dc8ee490 100644 --- a/python/ray/train/base_trainer.py +++ b/python/ray/train/base_trainer.py @@ -5,6 +5,7 @@ import logging import os import warnings +from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union @@ -68,6 +69,40 @@ class TrainingFailedError(RuntimeError): ) +def train_coordinator_fn( + config: dict, trainer_cls: Type["BaseTrainer"], metadata: dict +): + """This is the function that defines the logic of the Ray Train coordinator. + This is responsible for setting up a remote instance of the `trainer_cls` + (a different instance than the one calling `trainer.fit` on the driver!) + and running the training loop. + """ + assert metadata is not None, metadata + # Propagate user metadata from the Trainer constructor. + _get_session().metadata = metadata + + # config already contains merged values. + # Instantiate new Trainer in Trainable. + trainer = trainer_cls(**config) + + # Get the checkpoint from Tune and pass it to workers later on. + checkpoint = ray.train.get_checkpoint() + if checkpoint: + # Set `starting_checkpoint` for auto-recovery fault-tolerance + # as well as manual restoration. + trainer.starting_checkpoint = checkpoint + # else: Train will restore from the user-provided + # `resume_from_checkpoint` == `starting_checkpoint`. + + # Evaluate datasets if they are wrapped in a factory. + trainer.datasets = { + k: d() if callable(d) else d for k, d in trainer.datasets.items() + } + + trainer.setup() + trainer.training_loop() + + @DeveloperAPI class BaseTrainer(abc.ABC): """Defines interface for distributed training on Ray. @@ -656,38 +691,17 @@ def _generate_trainable_cls(self) -> Type["Trainable"]: scaling_config = self.scaling_config metadata = self.metadata - def train_func(config): - assert metadata is not None, metadata - # Propagate user metadata from the Trainer constructor. - _get_session().metadata = metadata - - # config already contains merged values. - # Instantiate new Trainer in Trainable. - trainer = trainer_cls(**config) - - # Get the checkpoint from Tune and pass it to workers later on. - checkpoint = ray.train.get_checkpoint() - if checkpoint: - # Set `starting_checkpoint` for auto-recovery fault-tolerance - # as well as manual restoration. - trainer.starting_checkpoint = checkpoint - # else: Train will restore from the user-provided - # `resume_from_checkpoint` == `starting_checkpoint`. - - # Evaluate datasets if they are wrapped in a factory. - trainer.datasets = { - k: d() if callable(d) else d for k, d in self.datasets.items() - } - - trainer.setup() - trainer.training_loop() - + # Create a local copy of the training function to avoid modifying attributes + # of the globally accessible function. + _train_coordinator_fn = copy.copy(train_coordinator_fn) # Change the name of the training function to match the name of the Trainer # class. This will mean the Tune trial name will match the name of Trainer on # stdout messages and the results directory. - train_func.__name__ = trainer_cls.__name__ + _train_coordinator_fn.__name__ = trainer_cls.__name__ - trainable_cls = wrap_function(train_func) + trainable_cls = wrap_function( + partial(_train_coordinator_fn, trainer_cls=trainer_cls, metadata=metadata) + ) has_base_dataset = bool(self.datasets) if has_base_dataset: from ray.data.context import DataContext @@ -723,10 +737,10 @@ def setup(self, config, **kwargs): merged_scaling_config = self._merged_config.get("scaling_config") if isinstance(merged_scaling_config, dict): merged_scaling_config = ScalingConfig(**merged_scaling_config) - self._merged_config[ - "scaling_config" - ] = self._reconcile_scaling_config_with_trial_resources( - merged_scaling_config + self._merged_config["scaling_config"] = ( + self._reconcile_scaling_config_with_trial_resources( + merged_scaling_config + ) ) if self.has_base_dataset(): # Set the DataContext on the Trainer actor to the DataContext From ab4e8283f3d4ef326492c13a9358cb9a6c685bda Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 16 Feb 2024 11:33:59 -0800 Subject: [PATCH 2/6] add test Signed-off-by: Justin Yu --- python/ray/train/tests/test_base_trainer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/ray/train/tests/test_base_trainer.py b/python/ray/train/tests/test_base_trainer.py index 03a61db7629b..900e7c337b5c 100644 --- a/python/ray/train/tests/test_base_trainer.py +++ b/python/ray/train/tests/test_base_trainer.py @@ -1,6 +1,7 @@ import logging import tempfile +import numpy as np import pytest import ray @@ -187,6 +188,18 @@ def training_loop(self): trainer.fit() +def test_large_params(ray_start_4_cpus): + """Tests that large params are not serialized with the trainer actor + and are instead put into the object store separately.""" + huge_array = np.zeros(shape=int(1e8)) + + def training_loop(self): + huge_array + + trainer = DummyTrainer(training_loop) + trainer.fit() + + if __name__ == "__main__": import sys From fb1a3c31d2da2a9f7cec3ef6f15a13cc5a588d23 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 16 Feb 2024 11:34:55 -0800 Subject: [PATCH 3/6] fix lint Signed-off-by: Justin Yu --- python/ray/train/base_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/ray/train/base_trainer.py b/python/ray/train/base_trainer.py index c5f8dc8ee490..3704dbbf98b2 100644 --- a/python/ray/train/base_trainer.py +++ b/python/ray/train/base_trainer.py @@ -737,10 +737,10 @@ def setup(self, config, **kwargs): merged_scaling_config = self._merged_config.get("scaling_config") if isinstance(merged_scaling_config, dict): merged_scaling_config = ScalingConfig(**merged_scaling_config) - self._merged_config["scaling_config"] = ( - self._reconcile_scaling_config_with_trial_resources( - merged_scaling_config - ) + self._merged_config[ + "scaling_config" + ] = self._reconcile_scaling_config_with_trial_resources( + merged_scaling_config ) if self.has_base_dataset(): # Set the DataContext on the Trainer actor to the DataContext From fd75ea3c5d50e5333bfc2ed4d484111056ada8a7 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 16 Feb 2024 13:20:09 -0800 Subject: [PATCH 4/6] fix annotation lint error Signed-off-by: Justin Yu --- python/ray/train/base_trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/ray/train/base_trainer.py b/python/ray/train/base_trainer.py index 3704dbbf98b2..4cffa05abc60 100644 --- a/python/ray/train/base_trainer.py +++ b/python/ray/train/base_trainer.py @@ -69,7 +69,7 @@ class TrainingFailedError(RuntimeError): ) -def train_coordinator_fn( +def _train_coordinator_fn( config: dict, trainer_cls: Type["BaseTrainer"], metadata: dict ): """This is the function that defines the logic of the Ray Train coordinator. @@ -693,14 +693,14 @@ def _generate_trainable_cls(self) -> Type["Trainable"]: # Create a local copy of the training function to avoid modifying attributes # of the globally accessible function. - _train_coordinator_fn = copy.copy(train_coordinator_fn) + train_coordinator_fn = copy.copy(_train_coordinator_fn) # Change the name of the training function to match the name of the Trainer # class. This will mean the Tune trial name will match the name of Trainer on # stdout messages and the results directory. - _train_coordinator_fn.__name__ = trainer_cls.__name__ + train_coordinator_fn.__name__ = trainer_cls.__name__ trainable_cls = wrap_function( - partial(_train_coordinator_fn, trainer_cls=trainer_cls, metadata=metadata) + partial(train_coordinator_fn, trainer_cls=trainer_cls, metadata=metadata) ) has_base_dataset = bool(self.datasets) if has_base_dataset: @@ -737,10 +737,10 @@ def setup(self, config, **kwargs): merged_scaling_config = self._merged_config.get("scaling_config") if isinstance(merged_scaling_config, dict): merged_scaling_config = ScalingConfig(**merged_scaling_config) - self._merged_config[ - "scaling_config" - ] = self._reconcile_scaling_config_with_trial_resources( - merged_scaling_config + self._merged_config["scaling_config"] = ( + self._reconcile_scaling_config_with_trial_resources( + merged_scaling_config + ) ) if self.has_base_dataset(): # Set the DataContext on the Trainer actor to the DataContext From d73e35e1d9fece1ecc63e084e31d011ebe60516d Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 16 Feb 2024 13:21:44 -0800 Subject: [PATCH 5/6] fix lint Signed-off-by: Justin Yu --- python/ray/train/base_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/ray/train/base_trainer.py b/python/ray/train/base_trainer.py index 4cffa05abc60..bbd8a26c5674 100644 --- a/python/ray/train/base_trainer.py +++ b/python/ray/train/base_trainer.py @@ -737,10 +737,10 @@ def setup(self, config, **kwargs): merged_scaling_config = self._merged_config.get("scaling_config") if isinstance(merged_scaling_config, dict): merged_scaling_config = ScalingConfig(**merged_scaling_config) - self._merged_config["scaling_config"] = ( - self._reconcile_scaling_config_with_trial_resources( - merged_scaling_config - ) + self._merged_config[ + "scaling_config" + ] = self._reconcile_scaling_config_with_trial_resources( + merged_scaling_config ) if self.has_base_dataset(): # Set the DataContext on the Trainer actor to the DataContext From 4cb6d3648d009e285fccb3aa77299d27016d9bd8 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 16 Feb 2024 14:10:40 -0800 Subject: [PATCH 6/6] make sure the partial doesn't mess up the __name__ Signed-off-by: Justin Yu --- python/ray/train/base_trainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/ray/train/base_trainer.py b/python/ray/train/base_trainer.py index bbd8a26c5674..b011a39f8293 100644 --- a/python/ray/train/base_trainer.py +++ b/python/ray/train/base_trainer.py @@ -691,17 +691,15 @@ def _generate_trainable_cls(self) -> Type["Trainable"]: scaling_config = self.scaling_config metadata = self.metadata - # Create a local copy of the training function to avoid modifying attributes - # of the globally accessible function. - train_coordinator_fn = copy.copy(_train_coordinator_fn) + train_coordinator_fn = partial( + _train_coordinator_fn, trainer_cls=trainer_cls, metadata=metadata + ) # Change the name of the training function to match the name of the Trainer # class. This will mean the Tune trial name will match the name of Trainer on # stdout messages and the results directory. train_coordinator_fn.__name__ = trainer_cls.__name__ - trainable_cls = wrap_function( - partial(train_coordinator_fn, trainer_cls=trainer_cls, metadata=metadata) - ) + trainable_cls = wrap_function(train_coordinator_fn) has_base_dataset = bool(self.datasets) if has_base_dataset: from ray.data.context import DataContext