diff --git a/python/ray/train/base_trainer.py b/python/ray/train/base_trainer.py index c646a7704339..6be7eb249cdf 100644 --- a/python/ray/train/base_trainer.py +++ b/python/ray/train/base_trainer.py @@ -362,12 +362,16 @@ def fit(self) -> Result: raise TrainingFailedError from e return result - def as_trainable(self) -> Type["Trainable"]: - """Convert self to a ``tune.Trainable`` class.""" + def _generate_trainable_cls(self) -> Type["Trainable"]: + """Generate the base Trainable class. + + Returns: + A Trainable class to use for training. + """ + from ray.tune.execution.placement_groups import PlacementGroupFactory from ray.tune.trainable import wrap_function - base_config = self._param_dict trainer_cls = self.__class__ scaling_config = self.scaling_config @@ -414,9 +418,8 @@ def base_scaling_config(cls) -> ScalingConfig: """Returns the unchanged scaling config provided through the Trainer.""" return scaling_config - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - + def setup(self, config, **kwargs): + base_config = dict(kwargs) # Create a new config by merging the dicts. # run_config is not a tunable hyperparameter so it does not need to be # merged. @@ -431,6 +434,7 @@ def __init__(self, *args, **kwargs): ] = self._reconcile_scaling_config_with_trial_resources( merged_scaling_config ) + super(TrainTrainable, self).setup(config) def _reconcile_scaling_config_with_trial_resources( self, scaling_config: ScalingConfig @@ -487,3 +491,13 @@ def default_resource_request(cls, config): return validated_scaling_config.as_placement_group_factory() return TrainTrainable + + def as_trainable(self) -> Type["Trainable"]: + """Convert self to a ``tune.Trainable`` class.""" + from ray import tune + + base_config = self._param_dict + trainable_cls = self._generate_trainable_cls() + + # Wrap with `tune.with_parameters` to handle very large values in base_config + return tune.with_parameters(trainable_cls, **base_config) diff --git a/python/ray/train/gbdt_trainer.py b/python/ray/train/gbdt_trainer.py index 8375f1919cb7..9cf6cb8074f0 100644 --- a/python/ray/train/gbdt_trainer.py +++ b/python/ray/train/gbdt_trainer.py @@ -270,8 +270,8 @@ def training_loop(self) -> None: if checkpoint_at_end: self._checkpoint_at_end(model, evals_result) - def as_trainable(self) -> Type[Trainable]: - trainable_cls = super().as_trainable() + def _generate_trainable_cls(self) -> Type["Trainable"]: + trainable_cls = super()._generate_trainable_cls() trainer_cls = self.__class__ scaling_config = self.scaling_config ray_params_cls = self._ray_params_cls diff --git a/python/ray/train/huggingface/huggingface_trainer.py b/python/ray/train/huggingface/huggingface_trainer.py index 2df3b32f6f70..400505684a28 100644 --- a/python/ray/train/huggingface/huggingface_trainer.py +++ b/python/ray/train/huggingface/huggingface_trainer.py @@ -432,7 +432,7 @@ def setup(self) -> None: ) ) - def as_trainable(self) -> Type[Trainable]: + def _generate_trainable_cls(self) -> Type["Trainable"]: original_param_dict = self._param_dict.copy() resume_from_checkpoint: Optional[Checkpoint] = self._param_dict.get( "resume_from_checkpoint", None @@ -444,7 +444,7 @@ def as_trainable(self) -> Type[Trainable]: resume_from_checkpoint ) try: - ret = super().as_trainable() + ret = super()._generate_trainable_cls() finally: self._param_dict = original_param_dict return ret diff --git a/python/ray/train/tests/test_base_trainer.py b/python/ray/train/tests/test_base_trainer.py index 0009487de585..0a7126747a6c 100644 --- a/python/ray/train/tests/test_base_trainer.py +++ b/python/ray/train/tests/test_base_trainer.py @@ -5,11 +5,13 @@ from contextlib import redirect_stderr from unittest.mock import patch +import numpy as np import pytest import ray from ray import tune from ray.air import session +from ray.air.checkpoint import Checkpoint from ray.air.constants import MAX_REPR_LENGTH from ray.data.preprocessor import Preprocessor from ray.tune.impl import tuner_internal @@ -359,7 +361,7 @@ def test_trainable_name_is_overriden_gbdt_trainer(ray_start_4_cpus): _is_trainable_name_overriden(trainer) -def test_repr(): +def test_repr(ray_start_4_cpus): def training_loop(self): pass @@ -376,6 +378,19 @@ def training_loop(self): assert len(representation) < MAX_REPR_LENGTH +def test_large_params(ray_start_4_cpus): + """Tests if large arguments are can be serialized by the Trainer.""" + array_size = int(1e8) + + def training_loop(self): + checkpoint = self.resume_from_checkpoint.to_dict()["ckpt"] + assert len(checkpoint) == array_size + + checkpoint = Checkpoint.from_dict({"ckpt": np.zeros(shape=array_size)}) + trainer = DummyTrainer(training_loop, resume_from_checkpoint=checkpoint) + trainer.fit() + + if __name__ == "__main__": import sys diff --git a/python/ray/tune/trainable/util.py b/python/ray/tune/trainable/util.py index b591f95f0caf..6bcb7269ae5d 100644 --- a/python/ray/tune/trainable/util.py +++ b/python/ray/tune/trainable/util.py @@ -339,6 +339,11 @@ def setup(self, config): setup_kwargs[k] = parameter_registry.get(prefix + k) super(_Inner, self).setup(config, **setup_kwargs) + # Workaround for actor name not being logged correctly + # if __repr__ is not directly defined in a class. + def __repr__(self): + return super().__repr__() + _Inner.__name__ = trainable_name return _Inner else: