From cbce66de2a8b61eddb3e62580acccb92c8169cd8 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 27 Oct 2022 20:24:06 +0200 Subject: [PATCH] [AIR] Avoid checkpoint conversion, move encoding logic to checkpoints (#28794) This PR avoids always converting to dictionary when reporting a checkpoint in Train and uses Checkpoints instead of dicts to transfer data. This is a low hanging fruit change for better consistency and performance with non-dict checkpoints. In order to facilitate that, the data encoding logic in Ray Train has been modified. Encoding and decoding is now done in the checkpoint classes. I believe this is the cleanest solution as it is both generic and inherently tied to the checkpoint itself - however, this has the downside of requiring users to use correct checkpoint classes for torch and horovod. In order to maintain backwards compatibility, the checkpoint class is automatically changed in session.py if a torch checkpoint is required (which has extra encoding and decoding logic to deal with serialization issues). Warnings are printed where necessary. Signed-off-by: Antoni Baum Co-authored-by: Kai Fricke Signed-off-by: Weichen Xu --- python/ray/air/_internal/tensorflow_utils.py | 20 +++ python/ray/air/checkpoint.py | 9 +- .../ray/train/_internal/backend_executor.py | 2 +- python/ray/train/_internal/checkpoint.py | 12 +- python/ray/train/_internal/session.py | 51 ++++--- python/ray/train/backend.py | 65 ++++++++- python/ray/train/data_parallel_trainer.py | 6 +- .../horovod/horovod_pytorch_example.py | 9 +- .../examples/pytorch/torch_linear_example.py | 11 +- python/ray/train/horovod/config.py | 58 ++++---- python/ray/train/horovod/horovod_trainer.py | 7 +- .../train/huggingface/_huggingface_utils.py | 5 +- .../train/huggingface/huggingface_trainer.py | 8 ++ python/ray/train/tensorflow/config.py | 12 +- python/ray/train/tests/test_gpu_amp.py | 2 +- .../ray/train/tests/test_gpu_auto_transfer.py | 13 +- .../train/tests/test_huggingface_trainer.py | 8 +- python/ray/train/tests/test_session.py | 18 ++- .../train/tests/test_tensorflow_trainer.py | 36 ++--- python/ray/train/tests/test_torch_trainer.py | 126 ++++++++++++++---- python/ray/train/torch/config.py | 45 ++----- python/ray/train/torch/torch_checkpoint.py | 47 +++++++ python/ray/train/torch/torch_trainer.py | 2 +- python/ray/train/trainer.py | 6 +- 24 files changed, 404 insertions(+), 174 deletions(-) diff --git a/python/ray/air/_internal/tensorflow_utils.py b/python/ray/air/_internal/tensorflow_utils.py index bbb08efc2b83..61d38292e6d9 100644 --- a/python/ray/air/_internal/tensorflow_utils.py +++ b/python/ray/air/_internal/tensorflow_utils.py @@ -64,6 +64,26 @@ def convert_ndarray_batch_to_tf_tensor_batch( return batch +# This is not foolproof, but it's better than nothing +# The place it is used in will be deprecated soon +def contains_tensorflow_object(obj): + if hasattr(obj, "__module__") and ( + "keras" in obj.__module__ or "tensorflow" in obj.__module__ + ): + return True + elif isinstance(obj, dict): + for k, v in obj.items(): + if contains_tensorflow_object(k): + return True + if contains_tensorflow_object(v): + return True + elif isinstance(obj, (list, tuple)): + for v in obj: + if contains_tensorflow_object(v): + return True + return False + + def get_type_spec( schema: Union["pyarrow.lib.Schema", "PandasBlockSchema"], columns: Union[str, List[str]], diff --git a/python/ray/air/checkpoint.py b/python/ray/air/checkpoint.py index 9ef8a88780a4..01d62f04b3dc 100644 --- a/python/ray/air/checkpoint.py +++ b/python/ray/air/checkpoint.py @@ -501,6 +501,11 @@ def _get_temporary_checkpoint_dir(self) -> str: ) return os.path.join(tmp_dir_path, checkpoint_dir_name) + def _save_checkpoint_metadata_in_directory(self, path: str) -> None: + checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME) + with open(checkpoint_metadata_path, "wb") as file: + pickle.dump(self._metadata, file) + def _to_directory(self, path: str) -> None: if self._data_dict or self._obj_ref: # This is a object ref or dict @@ -547,9 +552,7 @@ def _to_directory(self, path: str) -> None: f"No valid location found for checkpoint {self}: {self._uri}" ) - checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME) - with open(checkpoint_metadata_path, "wb") as file: - pickle.dump(self._metadata, file) + self._save_checkpoint_metadata_in_directory(path) def to_directory(self, path: Optional[str] = None) -> str: """Write checkpoint data to directory. diff --git a/python/ray/train/_internal/backend_executor.py b/python/ray/train/_internal/backend_executor.py index 855ce43d0f60..e8d82388a920 100644 --- a/python/ray/train/_internal/backend_executor.py +++ b/python/ray/train/_internal/backend_executor.py @@ -346,7 +346,7 @@ def initialize_session( train_func=train_func, dataset_shard=self.dataset_shards[index], checkpoint=checkpoint, - encode_data_fn=self._backend.encode_data, + encode_data_fn=self._backend._encode_data, ) ) diff --git a/python/ray/train/_internal/checkpoint.py b/python/ray/train/_internal/checkpoint.py index 55a11abdd400..6e18fedbff9c 100644 --- a/python/ray/train/_internal/checkpoint.py +++ b/python/ray/train/_internal/checkpoint.py @@ -101,13 +101,15 @@ def _process_checkpoint( """Ray Train entrypoint. Perform all processing for a checkpoint.""" # Get checkpoint from first worker. checkpoint_data = checkpoint_results[0].data + checkpoint_metadata = checkpoint_results[0].metadata or {} - # Decode checkpoint. - checkpoint_data = decode_checkpoint_fn(checkpoint_data) + # TODO(ml-team): Remove once we remove Backend.decode_data + checkpoint_data = decode_checkpoint_fn(checkpoint_data).to_dict() score_attr = self._checkpoint_strategy.checkpoint_score_attribute if ( self._checkpoint_strategy.num_to_keep != 0 + and score_attr not in checkpoint_metadata and score_attr not in checkpoint_data ): raise ValueError( @@ -122,7 +124,11 @@ def _process_checkpoint( dir_or_data=checkpoint_data, checkpoint_id=self._latest_checkpoint_id, storage_mode=CheckpointStorage.MEMORY, - metrics={score_attr: checkpoint_data.get(score_attr, 0.0)}, + metrics={ + score_attr: checkpoint_metadata.get( + score_attr, checkpoint_data.get(score_attr, 0.0) + ) + }, ) self.register_checkpoint(checkpoint=tracked_checkpoint) diff --git a/python/ray/train/_internal/session.py b/python/ray/train/_internal/session.py index 9ab1717615a3..a8a86a594083 100644 --- a/python/ray/train/_internal/session.py +++ b/python/ray/train/_internal/session.py @@ -2,6 +2,7 @@ import logging import platform import queue +import sys import threading import time from dataclasses import dataclass @@ -51,7 +52,8 @@ class TrialInfo: @dataclass class TrainingResult: type: TrainingResultType - data: Dict + data: Union[Dict, Checkpoint] + metadata: Optional[Dict] = None # TODO(xwjiang): This needs a better name. @@ -68,8 +70,9 @@ def __init__( trial_info: Optional[TrialInfo] = None, dataset_shard: Optional[Union[Dataset, DatasetPipeline]] = None, # TODO(xwjiang): Legacy Ray Train trainer clean up! - checkpoint: Optional[Union[Dict, Checkpoint]] = None, - encode_data_fn: Callable = None, + checkpoint: Optional[Checkpoint] = None, + # Deprecated + encode_data_fn: Optional[Callable] = None, detailed_autofilled_metrics: bool = False, ): @@ -80,7 +83,7 @@ def __init__( self.world_size = world_size self.trial_info = trial_info # TODO(xwjiang): Legacy Ray Train trainer clean up! - self.loaded_checkpoint: Optional[Union[Dict, Checkpoint]] = checkpoint + self.loaded_checkpoint = checkpoint # Function to encode checkpoint dict before sending to the driver. if not encode_data_fn: @@ -240,9 +243,9 @@ def _report_legacy(self, **kwargs): if self.ignore_report: return - kwargs = self._encode_data_fn(self._auto_fill_metrics(kwargs)) + kwargs = self._auto_fill_metrics(kwargs) - result = TrainingResult(TrainingResultType.REPORT, kwargs) + result = TrainingResult(type=TrainingResultType.REPORT, data=kwargs) # Add result to a thread-safe queue. self.result_queue.put(result, block=True) @@ -269,22 +272,26 @@ def _report_thread_runner_error(self, block=False): except queue.Empty: pass - def checkpoint(self, **kwargs): + def checkpoint(self, checkpoint: Checkpoint): """Adds kwargs to the queue to be consumed by main thread. Also stores the checkpoint in ``self.loaded_checkpoint``. """ # Update session checkpoint to latest checkpoint. - self.loaded_checkpoint = kwargs + self.loaded_checkpoint = checkpoint # Only store checkpoints on worker with rank 0. if self.world_rank != 0: - kwargs = {} - else: - kwargs = self._encode_data_fn(self._auto_fill_checkpoint_metrics(kwargs)) - - result = TrainingResult(TrainingResultType.CHECKPOINT, kwargs) + checkpoint = None + elif checkpoint: + checkpoint = self._encode_data_fn(checkpoint) + + result = TrainingResult( + type=TrainingResultType.CHECKPOINT, + data=checkpoint, + metadata=self._auto_fill_checkpoint_metrics({}), + ) # Add result to a thread-safe queue. self.result_queue.put(result, block=True) @@ -294,9 +301,23 @@ def checkpoint(self, **kwargs): def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None: # TODO(xwjiang): tons of optimizations. + + # Special case: early fail for Torch tensors + if "torch" in sys.modules: + from ray.air._internal.torch_utils import contains_tensor + + if contains_tensor(metrics): + raise ValueError( + "Passing objects containg Torch tensors as metrics " + "is not supported as it will throw an exception on " + "deserialization. You can either convert the tensors " + "to Python objects or use a `TorchCheckpoint` as the " + "`checkpoint` argument of `ray.air.session.report` to " + "store your Torch objects." + ) + if checkpoint: - checkpoint_dict = checkpoint.to_dict() - self.checkpoint(**checkpoint_dict) + self.checkpoint(checkpoint) self._report_legacy(**metrics) diff --git a/python/ray/train/backend.py b/python/ray/train/backend.py index 78dde0c36836..7d1ae6848d6f 100644 --- a/python/ray/train/backend.py +++ b/python/ray/train/backend.py @@ -1,16 +1,43 @@ import logging -from typing import TypeVar, Dict +import warnings +from typing import Type, TypeVar, Dict +from ray.air.checkpoint import Checkpoint from ray.train._internal.utils import Singleton from ray.train._internal.worker_group import WorkerGroup -from ray.util.annotations import DeveloperAPI - +from ray.util.annotations import Deprecated, DeveloperAPI from ray.widgets import make_table_html_repr EncodedData = TypeVar("EncodedData") logger = logging.getLogger(__name__) +# This is used in several places to print a warning. +_encode_decode_deprecation_message = ( + "``encode_data`` and ``decode_data`` are deprecated in favor of " + "framework-specific ``ray.air.Checkpoint`` subclasses (reported " + "using ``ray.air.session.report()``) which can implement " + "encoding and decoding logic. In the future, ``encode_data`` and " + "``decode_data`` will throw an exception if overriden." +) + + +def _warn_about_bad_checkpoint_type(expected_checkpoint_cls: Type[Checkpoint]): + return + # Do not print warnings in 2.1 yet. + # TODO(ml-team): Change this once we have full API parity with framework + # checkpoints. Also turn on test_torch_trainer::test_torch_bad_checkpoint_warning + # warnings.warn( + # f"You have reported a checkpoint with the `{Checkpoint}` " + # "type, but the intended checkpoint type for the Trainer " + # f"you are using is `{expected_checkpoint_cls}`. " + # "Not using the intended checkpoint type may cause " + # "exceptions or other issues, especially during " + # "serialization and deserialization. The checkpoint " + # "type will be changed automatically. " + # "This behavior may change in the future." + # ) + @DeveloperAPI class BackendConfig: @@ -46,6 +73,37 @@ def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig): """Logic for shutting down the backend.""" pass + @classmethod + def _encode_data(cls, checkpoint: Checkpoint) -> Checkpoint: + """Temporary method until ``encode_data`` is deprecated.""" + if cls.encode_data != Backend.encode_data: + warnings.warn( + _encode_decode_deprecation_message, DeprecationWarning, stacklevel=2 + ) + # We wrap the return of encode_data in dict in case it is + # not a dict itself. + checkpoint = checkpoint.from_dict( + {"encoded_data": cls.encode_data(checkpoint.to_dict())} + ) + return checkpoint + + @classmethod + def _decode_data(cls, checkpoint: Checkpoint) -> Checkpoint: + """Temporary method until ``decode_data`` is deprecated.""" + if cls.decode_data != Backend.decode_data: + warnings.warn( + _encode_decode_deprecation_message, DeprecationWarning, stacklevel=2 + ) + checkpoint_dict = checkpoint.to_dict() + # If "encoded_data" is not in the dict, then the data was + # not encoded, but the user may want to just do decoding + # anyway. + checkpoint = checkpoint.from_dict( + cls.decode_data(checkpoint_dict.get("encoded_data", checkpoint_dict)) + ) + return checkpoint + + @Deprecated(message=_encode_decode_deprecation_message) @staticmethod def encode_data(data_dict: Dict) -> EncodedData: """Logic to encode a data dict before sending to the driver. @@ -56,6 +114,7 @@ def encode_data(data_dict: Dict) -> EncodedData: return data_dict + @Deprecated(message=_encode_decode_deprecation_message) @staticmethod def decode_data(encoded_data: EncodedData) -> Dict: """Logic to decode an encoded data dict. diff --git a/python/ray/train/data_parallel_trainer.py b/python/ray/train/data_parallel_trainer.py index ca2c595e4190..06a828605853 100644 --- a/python/ray/train/data_parallel_trainer.py +++ b/python/ray/train/data_parallel_trainer.py @@ -8,6 +8,7 @@ from ray import tune from ray.air import session from ray.air.checkpoint import Checkpoint +from ray.air._internal.checkpointing import save_preprocessor_to_dir from ray.air.config import DatasetConfig, RunConfig, ScalingConfig, CheckpointConfig from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY from ray.air._internal.checkpoint_manager import _TrackedCheckpoint @@ -41,7 +42,10 @@ def __init__( ) def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint): - checkpoint.dir_or_data[PREPROCESSOR_KEY] = self.preprocessor + if isinstance(checkpoint.dir_or_data, dict): + checkpoint.dir_or_data[PREPROCESSOR_KEY] = self.preprocessor + else: + save_preprocessor_to_dir(self.preprocessor, checkpoint.dir_or_data) super(_DataParallelCheckpointManager, self)._process_persistent_checkpoint( checkpoint=checkpoint ) diff --git a/python/ray/train/examples/horovod/horovod_pytorch_example.py b/python/ray/train/examples/horovod/horovod_pytorch_example.py index 5197d900aba6..f4d15ae0515b 100644 --- a/python/ray/train/examples/horovod/horovod_pytorch_example.py +++ b/python/ray/train/examples/horovod/horovod_pytorch_example.py @@ -9,9 +9,9 @@ from torchvision import datasets, transforms from ray.air import session -from ray.air.checkpoint import Checkpoint from ray.air.config import ScalingConfig from ray.train.horovod import HorovodTrainer +from ray.train.torch.torch_checkpoint import TorchCheckpoint import ray.train.torch @@ -152,12 +152,11 @@ def train_func(config): model, optimizer, train_sampler, train_loader, epoch, log_interval, use_cuda ) if save_model_as_dict: - checkpoint_dict = dict(model=model.state_dict()) + checkpoint = TorchCheckpoint.from_state_dict(model.state_dict()) else: - checkpoint_dict = dict(model=model) - checkpoint_dict = Checkpoint.from_dict(checkpoint_dict) + checkpoint = TorchCheckpoint.from_model(model) results.append(loss) - session.report(dict(loss=loss), checkpoint=checkpoint_dict) + session.report(dict(loss=loss), checkpoint=checkpoint) # Only used for testing. return results diff --git a/python/ray/train/examples/pytorch/torch_linear_example.py b/python/ray/train/examples/pytorch/torch_linear_example.py index 03cedc5e0751..647b51a0db6b 100644 --- a/python/ray/train/examples/pytorch/torch_linear_example.py +++ b/python/ray/train/examples/pytorch/torch_linear_example.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import ray.train as train -from ray.train.torch import TorchTrainer +from ray.train.torch import TorchTrainer, TorchCheckpoint from ray.air.config import ScalingConfig @@ -48,8 +48,7 @@ def validate_epoch(dataloader, model, loss_fn): import copy model_copy = copy.deepcopy(model) - result = {"model": model_copy.cpu().state_dict(), "loss": loss} - return result + return model_copy.cpu().state_dict(), loss def train_func(config): @@ -76,12 +75,12 @@ def train_func(config): optimizer = torch.optim.SGD(model.parameters(), lr=lr) results = [] - for _ in range(epochs): train_epoch(train_loader, model, loss_fn, optimizer) - result = validate_epoch(validation_loader, model, loss_fn) + state_dict, loss = validate_epoch(validation_loader, model, loss_fn) + result = dict(loss=loss) results.append(result) - session.report(result) + session.report(result, checkpoint=TorchCheckpoint.from_state_dict(state_dict)) return results diff --git a/python/ray/train/horovod/config.py b/python/ray/train/horovod/config.py index 989c8f06f6d6..5960137f45fa 100644 --- a/python/ray/train/horovod/config.py +++ b/python/ray/train/horovod/config.py @@ -1,18 +1,20 @@ import sys -from typing import Optional, Set, Dict +from typing import Optional, Set import os from dataclasses import dataclass import ray -from ray.air._internal.torch_utils import contains_tensor -from ray.train.backend import BackendConfig, Backend, EncodedData +from ray.air.checkpoint import Checkpoint +from ray.train.backend import BackendConfig, Backend, _warn_about_bad_checkpoint_type from ray.train._internal.utils import update_env_vars from ray.train._internal.worker_group import WorkerGroup, Worker from horovod.ray.runner import Coordinator from horovod.ray.utils import detect_nics, nics_to_env_var from horovod.runner.common.util import secret, timeout +from ray.train.tensorflow.tensorflow_checkpoint import TensorflowCheckpoint +from ray.train.torch.torch_checkpoint import TorchCheckpoint from ray.util import PublicAPI @@ -131,36 +133,26 @@ def on_start(self, worker_group: WorkerGroup, backend_config: HorovodConfig): worker_group.execute(update_env_vars, coordinator_envs) - @staticmethod - def encode_data(data_dict: Dict) -> EncodedData: - """Logic to encode a data dict before sending to the driver. - - This function will be called on the workers for any data that is - sent to the driver via ``session.report()``. - """ - # If torch is imported, we can use it to serialize the data dict - # into bytes. This will prevent e.g. GPU deserialization errors. - if "torch" in sys.modules and contains_tensor(data_dict): - from ray.train.torch.config import _TorchBackend - - return _TorchBackend.encode_data(data_dict) - - return data_dict - - @staticmethod - def decode_data(encoded_data: EncodedData) -> Dict: - """Logic to decode an encoded data dict. - - This function will be called on the driver after receiving the - encoded data dict from the worker. - """ - # See encode_data - if "torch" in sys.modules and isinstance(encoded_data, bytes): - from ray.train.torch.config import _TorchBackend - - return _TorchBackend.decode_data(encoded_data) - - return encoded_data + @classmethod + def _encode_data(cls, checkpoint: Checkpoint): + checkpoint = super()._encode_data(checkpoint) + if type(checkpoint) is Checkpoint: + if checkpoint.get_internal_representation()[0] == "data_dict": + if "tensorflow" in sys.modules: + from ray.air._internal.tensorflow_utils import ( + contains_tensorflow_object, + ) + + if contains_tensorflow_object(checkpoint.to_dict()): + _warn_about_bad_checkpoint_type(TensorflowCheckpoint) + checkpoint = TensorflowCheckpoint.from_checkpoint(checkpoint) + if "torch" in sys.modules: + from ray.air._internal.torch_utils import contains_tensor + + if contains_tensor(checkpoint.to_dict()): + _warn_about_bad_checkpoint_type(TorchCheckpoint) + checkpoint = TorchCheckpoint.from_checkpoint(checkpoint) + return checkpoint def _init_env_vars(world_rank: int, world_size: int, node_id: str): diff --git a/python/ray/train/horovod/horovod_trainer.py b/python/ray/train/horovod/horovod_trainer.py index beded4a43bb8..c4a4c957cc6b 100644 --- a/python/ray/train/horovod/horovod_trainer.py +++ b/python/ray/train/horovod/horovod_trainer.py @@ -87,8 +87,9 @@ def train_loop_per_worker(): import horovod.torch as hvd import torch import torch.nn as nn - from ray.air import session, Checkpoint + from ray.air import session from ray.train.horovod import HorovodTrainer + from ray.train.torch import TorchCheckpoint from ray.air.config import ScalingConfig input_size = 1 @@ -136,8 +137,8 @@ def train_loop_per_worker(): print(f"epoch: {epoch}, loss: {loss.item()}") session.report( {}, - checkpoint=Checkpoint.from_dict( - dict(model=model.state_dict()) + checkpoint=TorchCheckpoint.from_state_dict( + model.state_dict() ), ) train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) diff --git a/python/ray/train/huggingface/_huggingface_utils.py b/python/ray/train/huggingface/_huggingface_utils.py index 623b7ca00ef4..8b61590147a8 100644 --- a/python/ray/train/huggingface/_huggingface_utils.py +++ b/python/ray/train/huggingface/_huggingface_utils.py @@ -7,9 +7,9 @@ from transformers.trainer_utils import IntervalStrategy from ray.air import session -from ray.air.checkpoint import Checkpoint from ray.util import get_node_ip_address from ray.data.dataset import Dataset +from ray.train.huggingface.huggingface_checkpoint import HuggingFaceCheckpoint if TYPE_CHECKING: from torch.utils.data import IterableDataset @@ -152,7 +152,8 @@ def on_save(self, args, state, control, **kwargs): transformers.trainer.get_last_checkpoint(args.output_dir) ).absolute() if checkpoint_path: - self.delayed_report["checkpoint"] = Checkpoint.from_dict( + # Use HuggingFaceCheckpoint here to avoid a warning in _TrainSession + self.delayed_report["checkpoint"] = HuggingFaceCheckpoint.from_dict( { NODE_IP_KEY: get_node_ip_address(), CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), diff --git a/python/ray/train/huggingface/huggingface_trainer.py b/python/ray/train/huggingface/huggingface_trainer.py index 400505684a28..9ca24ce3ea9e 100644 --- a/python/ray/train/huggingface/huggingface_trainer.py +++ b/python/ray/train/huggingface/huggingface_trainer.py @@ -7,6 +7,7 @@ import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type +from ray.train.huggingface.huggingface_checkpoint import HuggingFaceCheckpoint try: from packaging.version import Version @@ -129,6 +130,13 @@ def commit(self, path: Optional[Path] = None) -> None: with open(path.joinpath(TUNE_CHECKPOINT_ID), "w") as f: f.write(str(self.id)) + # Add checkpoint class metadata + # A bit of a hack but this will be removed with the rest + # of this special case eventually + # TODO(ml-team): remove this when HF checkpointing is refactored + checkpoint = HuggingFaceCheckpoint.from_directory(path) + checkpoint._save_checkpoint_metadata_in_directory(path) + class _DataParallelSyncingCheckpointManager(_DataParallelCheckpointManager): def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint): diff --git a/python/ray/train/tensorflow/config.py b/python/ray/train/tensorflow/config.py index 88ccafbb839b..f16caf88b420 100644 --- a/python/ray/train/tensorflow/config.py +++ b/python/ray/train/tensorflow/config.py @@ -5,9 +5,11 @@ from typing import List import ray -from ray.train.backend import BackendConfig, Backend +from ray.air.checkpoint import Checkpoint +from ray.train.backend import BackendConfig, Backend, _warn_about_bad_checkpoint_type from ray.train._internal.utils import get_address_and_port from ray.train._internal.worker_group import WorkerGroup +from ray.train.tensorflow.tensorflow_checkpoint import TensorflowCheckpoint from ray.util import PublicAPI @@ -56,3 +58,11 @@ def get_url(): ) ) ray.get(setup_futures) + + @classmethod + def _encode_data(cls, checkpoint: Checkpoint): + checkpoint = super()._encode_data(checkpoint) + if type(checkpoint) is Checkpoint: + _warn_about_bad_checkpoint_type(TensorflowCheckpoint) + checkpoint = TensorflowCheckpoint.from_checkpoint(checkpoint) + return checkpoint diff --git a/python/ray/train/tests/test_gpu_amp.py b/python/ray/train/tests/test_gpu_amp.py index 455f65db106b..2e8f9522b248 100644 --- a/python/ray/train/tests/test_gpu_amp.py +++ b/python/ray/train/tests/test_gpu_amp.py @@ -65,7 +65,7 @@ def train_func(): model = torchvision.models.resnet101() model = train.torch.prepare_model(model) - session.report({"model": model}, checkpoint=TorchCheckpoint.from_model(model)) + session.report({}, checkpoint=TorchCheckpoint.from_model(model)) trainer = TorchTrainer( train_func, scaling_config=ScalingConfig(num_workers=2, use_gpu=True) diff --git a/python/ray/train/tests/test_gpu_auto_transfer.py b/python/ray/train/tests/test_gpu_auto_transfer.py index 242dbe7644c1..3cea67266a49 100644 --- a/python/ray/train/tests/test_gpu_auto_transfer.py +++ b/python/ray/train/tests/test_gpu_auto_transfer.py @@ -8,8 +8,7 @@ from ray.air import session from ray.air.constants import MODEL_KEY from ray.air.config import ScalingConfig -from ray.train.torch.torch_checkpoint import TorchCheckpoint -from ray.train.torch.torch_trainer import TorchTrainer +from ray.train.torch import TorchTrainer, TorchCheckpoint import ray.train.torch.train_loop_utils @@ -106,7 +105,7 @@ def train_func(): assert next(model.parameters()).is_cuda - session.report({"model": model}, checkpoint=TorchCheckpoint.from_model(model)) + session.report({}, checkpoint=TorchCheckpoint.from_model(model)) trainer = TorchTrainer( train_func, scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=True) @@ -114,9 +113,7 @@ def train_func(): results = trainer.fit() model_checkpoint = results.checkpoint.get_model() - model_report = results.metrics["model"] assert not next(model_checkpoint.parameters()).is_cuda - assert not next(model_report.parameters()).is_cuda # Test the same thing for state dict. @@ -134,7 +131,7 @@ def train_func(): assert tensor.is_cuda session.report( - {"state_dict": state_dict}, + {}, checkpoint=TorchCheckpoint.from_state_dict(state_dict), ) @@ -144,10 +141,6 @@ def train_func(): results = trainer.fit() state_dict_checkpoint = results.checkpoint.to_dict()[MODEL_KEY] - state_dict_report = results.metrics["state_dict"] - - for tensor in state_dict_report.values(): - assert not tensor.is_cuda for tensor in state_dict_checkpoint.values(): assert not tensor.is_cuda diff --git a/python/ray/train/tests/test_huggingface_trainer.py b/python/ray/train/tests/test_huggingface_trainer.py index b741d3475574..4951f6444370 100644 --- a/python/ray/train/tests/test_huggingface_trainer.py +++ b/python/ray/train/tests/test_huggingface_trainer.py @@ -11,7 +11,11 @@ import ray.data from ray.exceptions import RayTaskError from ray.train.batch_predictor import BatchPredictor -from ray.train.huggingface import HuggingFacePredictor, HuggingFaceTrainer +from ray.train.huggingface import ( + HuggingFacePredictor, + HuggingFaceTrainer, + HuggingFaceCheckpoint, +) from ray.air.config import ScalingConfig from ray.train.tests._huggingface_data import train_data, validation_data from ray import tune @@ -91,6 +95,7 @@ def test_e2e(ray_start_4_cpus, save_strategy): assert result.metrics["epoch"] == 4 assert result.metrics["training_iteration"] == 4 assert result.checkpoint + assert isinstance(result.checkpoint, HuggingFaceCheckpoint) assert "eval_loss" in result.metrics trainer2 = HuggingFaceTrainer( @@ -108,6 +113,7 @@ def test_e2e(ray_start_4_cpus, save_strategy): assert result2.metrics["epoch"] == 5 assert result2.metrics["training_iteration"] == 1 assert result2.checkpoint + assert isinstance(result2.checkpoint, HuggingFaceCheckpoint) assert "eval_loss" in result2.metrics predictor = BatchPredictor.from_checkpoint( diff --git a/python/ray/train/tests/test_session.py b/python/ray/train/tests/test_session.py index 9eb5dd51985b..545d611f4882 100644 --- a/python/ray/train/tests/test_session.py +++ b/python/ray/train/tests/test_session.py @@ -139,7 +139,7 @@ def validate_zero(expected): next = session.get_next() assert next is not None assert next.type == TrainingResultType.CHECKPOINT - assert next.data["epoch"] == expected + assert next.data.to_dict()["epoch"] == expected init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1) session = get_session() @@ -155,7 +155,7 @@ def validate_nonzero(): next = session.get_next() assert next is not None assert next.type == TrainingResultType.CHECKPOINT - assert next.data == {} + assert not next.data init_session(training_func=train_func, world_rank=1, local_rank=1, world_size=1) session = get_session() @@ -173,13 +173,17 @@ def train_func(): report(dict(epoch=0), checkpoint=Checkpoint.from_dict(dict(epoch=0))) def encode_checkpoint(checkpoint): - checkpoint.update({"encoded": True}) - return checkpoint + data = checkpoint.to_dict() + data["encoded"] = True + return checkpoint.from_dict(data) def validate_encoded(result_type: TrainingResultType): next = session.get_next() assert next.type is result_type - assert next.data["encoded"] is True + data = next.data + if isinstance(data, Checkpoint): + data = data.to_dict() + assert data["encoded"] is True init_session( training_func=train_func, @@ -193,8 +197,7 @@ def validate_encoded(result_type: TrainingResultType): session.start() # Validate checkpoint is encoded. validate_encoded(TrainingResultType.CHECKPOINT) - # Validate report is encoded. - validate_encoded(TrainingResultType.REPORT) + session.get_next() session.finish() shutdown_session() @@ -211,6 +214,7 @@ def train_func(): session.start() for i in range(2): session.get_next() + session.get_next() session.finish() shutdown_session() diff --git a/python/ray/train/tests/test_tensorflow_trainer.py b/python/ray/train/tests/test_tensorflow_trainer.py index fc7ee1f97a9f..df87b665b5c6 100644 --- a/python/ray/train/tests/test_tensorflow_trainer.py +++ b/python/ray/train/tests/test_tensorflow_trainer.py @@ -1,6 +1,4 @@ import os - -import numpy as np import pytest import ray @@ -10,12 +8,14 @@ train_func as tensorflow_linear_train_func, ) from ray.air.config import ScalingConfig +from ray.train.batch_predictor import BatchPredictor from ray.train.constants import TRAIN_DATASET_KEY from ray.train.tensorflow import ( - TensorflowCheckpoint, TensorflowPredictor, TensorflowTrainer, + TensorflowCheckpoint, ) +from ray.train.tests.dummy_preprocessor import DummyPreprocessor @pytest.fixture @@ -74,23 +74,19 @@ def train_func(): scaling_config = ScalingConfig(num_workers=2) trainer = TensorflowTrainer( - train_loop_per_worker=train_func, scaling_config=scaling_config + train_loop_per_worker=train_func, + scaling_config=scaling_config, + preprocessor=DummyPreprocessor(), ) result = trainer.fit() + assert isinstance(result.checkpoint.get_preprocessor(), DummyPreprocessor) - class TensorflowScorer: - def __init__(self): - self.pred = TensorflowPredictor.from_checkpoint( - result.checkpoint, build_model - ) - - def __call__(self, x): - return self.pred.predict(x, dtype=np.float) + batch_predictor = BatchPredictor.from_checkpoint( + result.checkpoint, TensorflowPredictor, model_definition=build_model + ) predict_dataset = ray.data.range(3) - predictions = predict_dataset.map_batches( - TensorflowScorer, batch_format="pandas", compute="actors" - ) + predictions = batch_predictor.predict(predict_dataset) assert predictions.count() == 3 @@ -112,17 +108,23 @@ def train_func(): scaling_config = ScalingConfig(num_workers=2) trainer = TensorflowTrainer( - train_loop_per_worker=train_func, scaling_config=scaling_config + train_loop_per_worker=train_func, + scaling_config=scaling_config, + preprocessor=DummyPreprocessor(), ) result = trainer.fit() + checkpoint = result.checkpoint + assert isinstance(checkpoint.get_preprocessor(), DummyPreprocessor) trainer2 = TensorflowTrainer( train_loop_per_worker=train_func, scaling_config=scaling_config, - resume_from_checkpoint=result.checkpoint, + resume_from_checkpoint=checkpoint, + preprocessor=DummyPreprocessor(), ) result = trainer2.fit() checkpoint = result.checkpoint + assert isinstance(checkpoint.get_preprocessor(), DummyPreprocessor) with checkpoint.as_directory() as ckpt_dir: assert os.path.exists(os.path.join(ckpt_dir, "saved_model.pb")) assert result.metrics["iter"] == 1 diff --git a/python/ray/train/tests/test_torch_trainer.py b/python/ray/train/tests/test_torch_trainer.py index 2dbec40eb1d5..905717961fc7 100644 --- a/python/ray/train/tests/test_torch_trainer.py +++ b/python/ray/train/tests/test_torch_trainer.py @@ -1,14 +1,13 @@ import contextlib import pytest -from ray.air import session -from ray.air.checkpoint import Checkpoint -from ray.train.torch.torch_checkpoint import TorchCheckpoint import torch +import os import ray from ray.train.examples.pytorch.torch_linear_example import ( train_func as linear_train_func, ) +from ray.train.batch_predictor import BatchPredictor from ray.train.torch import TorchPredictor, TorchTrainer from ray.tune import TuneError from ray.air.config import ScalingConfig @@ -16,6 +15,9 @@ import ray.train as train from unittest.mock import patch from ray.cluster_utils import Cluster +from ray.air import session +from ray.train.tests.dummy_preprocessor import DummyPreprocessor +from ray.train.torch.torch_checkpoint import TorchCheckpoint @pytest.fixture @@ -62,25 +64,21 @@ def train_func(config): def test_torch_e2e(ray_start_4_cpus): def train_func(): model = torch.nn.Linear(3, 1) - session.report({}, checkpoint=Checkpoint.from_dict(dict(model=model))) + session.report({}, checkpoint=TorchCheckpoint.from_model(model)) scaling_config = ScalingConfig(num_workers=2) trainer = TorchTrainer( - train_loop_per_worker=train_func, scaling_config=scaling_config + train_loop_per_worker=train_func, + scaling_config=scaling_config, + preprocessor=DummyPreprocessor(), ) result = trainer.fit() + assert isinstance(result.checkpoint.get_preprocessor(), DummyPreprocessor) predict_dataset = ray.data.range(9) - - class TorchScorer: - def __init__(self): - self.pred = TorchPredictor.from_checkpoint(result.checkpoint) - - def __call__(self, x): - return self.pred.predict(x, dtype=torch.float) - - predictions = predict_dataset.map_batches( - TorchScorer, batch_size=3, batch_format="pandas", compute="actors" + batch_predictor = BatchPredictor.from_checkpoint(result.checkpoint, TorchPredictor) + predictions = batch_predictor.predict( + predict_dataset, batch_size=3, dtype=torch.float ) assert predictions.count() == 3 @@ -88,22 +86,55 @@ def __call__(self, x): def test_torch_e2e_state_dict(ray_start_4_cpus): def train_func(): model = torch.nn.Linear(3, 1).state_dict() - session.report({}, checkpoint=Checkpoint.from_dict(dict(model=model))) + session.report({}, checkpoint=TorchCheckpoint.from_state_dict(model)) scaling_config = ScalingConfig(num_workers=2) trainer = TorchTrainer( - train_loop_per_worker=train_func, scaling_config=scaling_config + train_loop_per_worker=train_func, + scaling_config=scaling_config, + preprocessor=DummyPreprocessor(), ) result = trainer.fit() + isinstance(result.checkpoint.get_preprocessor(), DummyPreprocessor) # If loading from a state dict, a model definition must be passed in. with pytest.raises(ValueError): TorchPredictor.from_checkpoint(result.checkpoint) + predict_dataset = ray.data.range(9) + batch_predictor = BatchPredictor.from_checkpoint( + result.checkpoint, TorchPredictor, model=torch.nn.Linear(3, 1) + ) + predictions = batch_predictor.predict( + predict_dataset, batch_size=3, dtype=torch.float + ) + assert predictions.count() == 3 + + +def test_torch_e2e_dir(ray_start_4_cpus, tmpdir): + def train_func(): + model = torch.nn.Linear(3, 1) + torch.save(model, os.path.join(tmpdir, "model")) + session.report({}, checkpoint=TorchCheckpoint.from_directory(tmpdir)) + + scaling_config = ScalingConfig(num_workers=2) + trainer = TorchTrainer( + train_loop_per_worker=train_func, + scaling_config=scaling_config, + preprocessor=DummyPreprocessor(), + ) + result = trainer.fit() + isinstance(result.checkpoint.get_preprocessor(), DummyPreprocessor) + + # TODO(ml-team): Add a way for TorchCheckpoint to natively support + # models from files class TorchScorer: def __init__(self): + with result.checkpoint.as_directory() as checkpoint_path: + model = torch.load(os.path.join(checkpoint_path, "model")) + preprocessor = result.checkpoint.get_preprocessor() self.pred = TorchPredictor.from_checkpoint( - result.checkpoint, model=torch.nn.Linear(3, 1) + TorchCheckpoint.from_model(model, preprocessor=preprocessor) ) def __call__(self, x): @@ -131,6 +162,58 @@ def test_checkpoint_freq(ray_start_4_cpus): trainer.fit() +def test_torch_session_errors(ray_start_4_cpus): + """Test fail-fast behavior when reporting dicts with Torch tensors""" + + def train_func(): + model = torch.nn.Linear(1, 1).state_dict() + with pytest.raises(ValueError): + session.report(model) + + scaling_config = ScalingConfig(num_workers=2) + trainer = TorchTrainer( + train_loop_per_worker=train_func, + scaling_config=scaling_config, + ) + trainer.fit() + + +# See comment in backend.py::_warn_about_bad_checkpoint_type +# for why test_torch_bad_checkpoint_warning is commented out + +# def test_torch_bad_checkpoint_warning(ray_start_4_cpus): +# """Test that a warning is printed if bad checkpoint type is used.""" + +# def train_func(): +# model = torch.nn.Linear(1, 1).state_dict() +# session.report({}, checkpoint=TorchCheckpoint.from_dict({"model": model})) + +# scaling_config = ScalingConfig(num_workers=2) +# trainer = TorchTrainer( +# train_loop_per_worker=train_func, +# scaling_config=scaling_config, +# ) +# output = io.StringIO() +# with redirect_stdout(output), redirect_stderr(output): +# trainer.fit() +# output = output.getvalue() +# assert "You have reported a checkpoint" not in output + +# def train_func(): +# model = torch.nn.Linear(1, 1).state_dict() +# session.report({}, checkpoint=Checkpoint.from_dict({"model": model})) + +# trainer = TorchTrainer( +# train_loop_per_worker=train_func, +# scaling_config=scaling_config, +# ) +# output = io.StringIO() +# with redirect_stdout(output), redirect_stderr(output): +# trainer.fit() +# output = output.getvalue() +# assert "You have reported a checkpoint" in output + + @pytest.mark.parametrize( "num_gpus_per_worker,expected_devices", [(0.5, [0]), (1, [0]), (2, [0, 1])] ) @@ -204,7 +287,7 @@ def train_fn(): model = train.torch.prepare_model(model) # Save DDP wrapped model. - session.report({"model": model}, checkpoint=TorchCheckpoint.from_model(model)) + session.report({}, checkpoint=TorchCheckpoint.from_model(model)) trainer = TorchTrainer( train_loop_per_worker=train_fn, @@ -218,11 +301,6 @@ def train_fn(): model, torch.nn.parallel.DistributedDataParallel ) - model_report = results.metrics["model"] - assert isinstance(model_report, torch.nn.Module) and not isinstance( - model_report, torch.nn.parallel.DistributedDataParallel - ) - def test_torch_amp(ray_start_4_cpus): def train_fn(): diff --git a/python/ray/train/torch/config.py b/python/ray/train/torch/config.py index 346de18d5815..fca7ec745169 100644 --- a/python/ray/train/torch/config.py +++ b/python/ray/train/torch/config.py @@ -1,21 +1,20 @@ from dataclasses import dataclass -import io import logging import os from datetime import timedelta -from typing import Dict, Optional +from typing import Optional import ray -import ray.cloudpickle -from ray.train.backend import BackendConfig, Backend, EncodedData +from ray.air.checkpoint import Checkpoint +from ray.train.backend import BackendConfig, Backend, _warn_about_bad_checkpoint_type from ray.train.constants import DEFAULT_NCCL_SOCKET_IFNAME from ray.train._internal.worker_group import WorkerGroup from ray.train._internal.utils import get_address_and_port +from ray.train.torch.torch_checkpoint import TorchCheckpoint from ray.util import PublicAPI import torch import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel try: from torch.profiler import profile @@ -178,32 +177,10 @@ def on_shutdown(self, worker_group: WorkerGroup, backend_config: TorchConfig): _shutdown_torch, destroy_process_group=len(worker_group) > 1 ) - @staticmethod - def encode_data(data_dict: Dict) -> EncodedData: - """Special handling for moving model from worker to driver.""" - - # If model is being checkpointed and is wrapped in DDP, then extract - # out the underlying module. If not, then deserialization will fail - # since the torch process group is not initialized on the driver. - - for k, v in data_dict.items(): - if isinstance(v, DistributedDataParallel) and hasattr(v, "module"): - data_dict[k] = v.module - - # Convert the checkpoint dict to bytes, so that any GPU tensors that - # are in the checkpoint dict can be properly deserialized on the - # driver side, even if the driver does not have access to a GPU device. - _buffer = io.BytesIO() - # If a custom torch model contains a function that cannot be pickled normally, - # we need to use ray.cloudpickle. This is also consistent with how Ray - # serialization works in general and has no downsides - # (this can still be unpickled without ray using normal pickle). - torch.save(data_dict, _buffer, pickle_module=ray.cloudpickle) - return _buffer.getvalue() - - @staticmethod - def decode_data(encoded_data: EncodedData) -> Dict: - # When decoding the bytes on the driver side, always map to CPU. - _buffer = io.BytesIO(encoded_data) - checkpoint_dict = torch.load(_buffer, map_location="cpu") - return checkpoint_dict + @classmethod + def _encode_data(cls, checkpoint: Checkpoint): + checkpoint = super()._encode_data(checkpoint) + if type(checkpoint) is Checkpoint: + _warn_about_bad_checkpoint_type(TorchCheckpoint) + checkpoint = TorchCheckpoint.from_checkpoint(checkpoint) + return checkpoint diff --git a/python/ray/train/torch/torch_checkpoint.py b/python/ray/train/torch/torch_checkpoint.py index 774256965c8b..9d040d99060c 100644 --- a/python/ray/train/torch/torch_checkpoint.py +++ b/python/ray/train/torch/torch_checkpoint.py @@ -1,8 +1,10 @@ from typing import TYPE_CHECKING, Any, Dict, Optional +import io import torch import warnings +import ray.cloudpickle from ray.air.checkpoint import Checkpoint from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY from ray.train.data_parallel_trainer import _load_checkpoint_dict @@ -12,6 +14,8 @@ if TYPE_CHECKING: from ray.data.preprocessor import Preprocessor +ENCODED_DATA_KEY = "torch_encoded_data" + @PublicAPI(stability="beta") class TorchCheckpoint(Checkpoint): @@ -21,6 +25,49 @@ class TorchCheckpoint(Checkpoint): ``TorchCheckpoint.from_checkpoint(ckpt)``. """ + # Special encoding logic to avoid serialization errors with torch. + def _encode_data_dict(self, data_dict: dict) -> dict: + """Encode data_dict using torch.save.""" + from torch.nn.parallel import DistributedDataParallel + + for k, v in data_dict.items(): + if isinstance(v, DistributedDataParallel) and hasattr(v, "module"): + data_dict[k] = v.module + + # Convert the checkpoint dict to bytes, so that any GPU tensors that + # are in the checkpoint dict can be properly deserialized on the + # driver side, even if the driver does not have access to a GPU device. + _buffer = io.BytesIO() + torch.save(data_dict, _buffer, pickle_module=ray.cloudpickle) + return {ENCODED_DATA_KEY: _buffer.getvalue()} + + def _decode_data_dict(self, data_dict: dict) -> dict: + """Decode data_dict using torch_load if needed.""" + if ENCODED_DATA_KEY not in data_dict: + return data_dict + encoded_data = data_dict[ENCODED_DATA_KEY] + _buffer = io.BytesIO(encoded_data) + data_dict = torch.load( + _buffer, + map_location="cpu" + # Not using ray.cloudpickle here as it doesn't + # define an Unpickler (as it is not necessary). + ) + return data_dict + + def __getstate__(self) -> dict: + if self._data_dict: + state = self.__dict__.copy() + state["_data_dict"] = self._encode_data_dict(self._data_dict) + return state + return super().__getstate__() + + def __setstate__(self, state: dict): + if "_data_dict" in state: + state = state.copy() + state["_data_dict"] = self._decode_data_dict(state["_data_dict"]) + super().__setstate__(state) + @classmethod def from_state_dict( cls, diff --git a/python/ray/train/torch/torch_trainer.py b/python/ray/train/torch/torch_trainer.py index bbd5848c79b9..270d130badb9 100644 --- a/python/ray/train/torch/torch_trainer.py +++ b/python/ray/train/torch/torch_trainer.py @@ -142,7 +142,7 @@ def train_loop_per_worker(): session.report( {}, checkpoint=Checkpoint.from_dict( - dict(epoch=epoch, model=model.state_dict()) + dict(epoch=epoch, model=model.state_dict() ), ) diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index 5ede3472a331..9aadf2930f09 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -260,11 +260,11 @@ def _fetch_next_result(self) -> Optional[List[Dict]]: first_result = results[0] result_type = first_result.type if result_type is TrainingResultType.REPORT: - result_data = [self._backend.decode_data(r.data) for r in results] + result_data = [r.data for r in results] return result_data elif result_type is TrainingResultType.CHECKPOINT: self._checkpoint_manager._process_checkpoint( - results, decode_checkpoint_fn=self._backend.decode_data + results, decode_checkpoint_fn=self._backend._decode_data ) # Iterate until next REPORT call or training has finished. else: @@ -284,7 +284,7 @@ def _finish_checkpointing(self): # Process checkpoints and ignore other result types. if result_type is TrainingResultType.CHECKPOINT: self._checkpoint_manager._process_checkpoint( - results, decode_checkpoint_fn=self._backend.decode_data + results, decode_checkpoint_fn=self._backend._decode_data ) def _finish_training(self):