Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] Fix regression where large Trainer attributes get serialized along with actor class #43234

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 40 additions & 28 deletions python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import warnings
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union

Expand Down Expand Up @@ -68,6 +69,40 @@ class TrainingFailedError(RuntimeError):
)


def _train_coordinator_fn(
config: dict, trainer_cls: Type["BaseTrainer"], metadata: dict
):
"""This is the function that defines the logic of the Ray Train coordinator.
This is responsible for setting up a remote instance of the `trainer_cls`
(a different instance than the one calling `trainer.fit` on the driver!)
and running the training loop.
"""
assert metadata is not None, metadata
# Propagate user metadata from the Trainer constructor.
_get_session().metadata = metadata

# config already contains merged values.
# Instantiate new Trainer in Trainable.
trainer = trainer_cls(**config)

# Get the checkpoint from Tune and pass it to workers later on.
checkpoint = ray.train.get_checkpoint()
if checkpoint:
# Set `starting_checkpoint` for auto-recovery fault-tolerance
# as well as manual restoration.
trainer.starting_checkpoint = checkpoint
# else: Train will restore from the user-provided
# `resume_from_checkpoint` == `starting_checkpoint`.

# Evaluate datasets if they are wrapped in a factory.
trainer.datasets = {
k: d() if callable(d) else d for k, d in trainer.datasets.items()
}

trainer.setup()
trainer.training_loop()


@DeveloperAPI
class BaseTrainer(abc.ABC):
"""Defines interface for distributed training on Ray.
Expand Down Expand Up @@ -656,38 +691,15 @@ def _generate_trainable_cls(self) -> Type["Trainable"]:
scaling_config = self.scaling_config
metadata = self.metadata

def train_func(config):
assert metadata is not None, metadata
# Propagate user metadata from the Trainer constructor.
_get_session().metadata = metadata

# config already contains merged values.
# Instantiate new Trainer in Trainable.
trainer = trainer_cls(**config)

# Get the checkpoint from Tune and pass it to workers later on.
checkpoint = ray.train.get_checkpoint()
if checkpoint:
# Set `starting_checkpoint` for auto-recovery fault-tolerance
# as well as manual restoration.
trainer.starting_checkpoint = checkpoint
# else: Train will restore from the user-provided
# `resume_from_checkpoint` == `starting_checkpoint`.

# Evaluate datasets if they are wrapped in a factory.
trainer.datasets = {
k: d() if callable(d) else d for k, d in self.datasets.items()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the self reference that caused serialization of Huge BaseTrainer object?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that's it!

}

trainer.setup()
trainer.training_loop()

train_coordinator_fn = partial(
_train_coordinator_fn, trainer_cls=trainer_cls, metadata=metadata
)
# Change the name of the training function to match the name of the Trainer
# class. This will mean the Tune trial name will match the name of Trainer on
# stdout messages and the results directory.
train_func.__name__ = trainer_cls.__name__
train_coordinator_fn.__name__ = trainer_cls.__name__

trainable_cls = wrap_function(train_func)
trainable_cls = wrap_function(train_coordinator_fn)
has_base_dataset = bool(self.datasets)
if has_base_dataset:
from ray.data.context import DataContext
Expand Down
13 changes: 13 additions & 0 deletions python/ray/train/tests/test_base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import tempfile

import numpy as np
import pytest

import ray
Expand Down Expand Up @@ -187,6 +188,18 @@ def training_loop(self):
trainer.fit()


def test_large_params(ray_start_4_cpus):
"""Tests that large params are not serialized with the trainer actor
and are instead put into the object store separately."""
huge_array = np.zeros(shape=int(1e8))

def training_loop(self):
huge_array

trainer = DummyTrainer(training_loop)
trainer.fit()


if __name__ == "__main__":
import sys

Expand Down
Loading