diff --git a/python/ray/train/base_trainer.py b/python/ray/train/base_trainer.py index 0760208d72a9..b011a39f8293 100644 --- a/python/ray/train/base_trainer.py +++ b/python/ray/train/base_trainer.py @@ -5,6 +5,7 @@ import logging import os import warnings +from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union @@ -68,6 +69,40 @@ class TrainingFailedError(RuntimeError): ) +def _train_coordinator_fn( + config: dict, trainer_cls: Type["BaseTrainer"], metadata: dict +): + """This is the function that defines the logic of the Ray Train coordinator. + This is responsible for setting up a remote instance of the `trainer_cls` + (a different instance than the one calling `trainer.fit` on the driver!) + and running the training loop. + """ + assert metadata is not None, metadata + # Propagate user metadata from the Trainer constructor. + _get_session().metadata = metadata + + # config already contains merged values. + # Instantiate new Trainer in Trainable. + trainer = trainer_cls(**config) + + # Get the checkpoint from Tune and pass it to workers later on. + checkpoint = ray.train.get_checkpoint() + if checkpoint: + # Set `starting_checkpoint` for auto-recovery fault-tolerance + # as well as manual restoration. + trainer.starting_checkpoint = checkpoint + # else: Train will restore from the user-provided + # `resume_from_checkpoint` == `starting_checkpoint`. + + # Evaluate datasets if they are wrapped in a factory. + trainer.datasets = { + k: d() if callable(d) else d for k, d in trainer.datasets.items() + } + + trainer.setup() + trainer.training_loop() + + @DeveloperAPI class BaseTrainer(abc.ABC): """Defines interface for distributed training on Ray. @@ -656,38 +691,15 @@ def _generate_trainable_cls(self) -> Type["Trainable"]: scaling_config = self.scaling_config metadata = self.metadata - def train_func(config): - assert metadata is not None, metadata - # Propagate user metadata from the Trainer constructor. - _get_session().metadata = metadata - - # config already contains merged values. - # Instantiate new Trainer in Trainable. - trainer = trainer_cls(**config) - - # Get the checkpoint from Tune and pass it to workers later on. - checkpoint = ray.train.get_checkpoint() - if checkpoint: - # Set `starting_checkpoint` for auto-recovery fault-tolerance - # as well as manual restoration. - trainer.starting_checkpoint = checkpoint - # else: Train will restore from the user-provided - # `resume_from_checkpoint` == `starting_checkpoint`. - - # Evaluate datasets if they are wrapped in a factory. - trainer.datasets = { - k: d() if callable(d) else d for k, d in self.datasets.items() - } - - trainer.setup() - trainer.training_loop() - + train_coordinator_fn = partial( + _train_coordinator_fn, trainer_cls=trainer_cls, metadata=metadata + ) # Change the name of the training function to match the name of the Trainer # class. This will mean the Tune trial name will match the name of Trainer on # stdout messages and the results directory. - train_func.__name__ = trainer_cls.__name__ + train_coordinator_fn.__name__ = trainer_cls.__name__ - trainable_cls = wrap_function(train_func) + trainable_cls = wrap_function(train_coordinator_fn) has_base_dataset = bool(self.datasets) if has_base_dataset: from ray.data.context import DataContext diff --git a/python/ray/train/tests/test_base_trainer.py b/python/ray/train/tests/test_base_trainer.py index 03a61db7629b..900e7c337b5c 100644 --- a/python/ray/train/tests/test_base_trainer.py +++ b/python/ray/train/tests/test_base_trainer.py @@ -1,6 +1,7 @@ import logging import tempfile +import numpy as np import pytest import ray @@ -187,6 +188,18 @@ def training_loop(self): trainer.fit() +def test_large_params(ray_start_4_cpus): + """Tests that large params are not serialized with the trainer actor + and are instead put into the object store separately.""" + huge_array = np.zeros(shape=int(1e8)) + + def training_loop(self): + huge_array + + trainer = DummyTrainer(training_loop) + trainer.fit() + + if __name__ == "__main__": import sys