diff --git a/python/ray/air/_internal/tensorflow_utils.py b/python/ray/air/_internal/tensorflow_utils.py index bbb08efc2b83..61d38292e6d9 100644 --- a/python/ray/air/_internal/tensorflow_utils.py +++ b/python/ray/air/_internal/tensorflow_utils.py @@ -64,6 +64,26 @@ def convert_ndarray_batch_to_tf_tensor_batch( return batch +# This is not foolproof, but it's better than nothing +# The place it is used in will be deprecated soon +def contains_tensorflow_object(obj): + if hasattr(obj, "__module__") and ( + "keras" in obj.__module__ or "tensorflow" in obj.__module__ + ): + return True + elif isinstance(obj, dict): + for k, v in obj.items(): + if contains_tensorflow_object(k): + return True + if contains_tensorflow_object(v): + return True + elif isinstance(obj, (list, tuple)): + for v in obj: + if contains_tensorflow_object(v): + return True + return False + + def get_type_spec( schema: Union["pyarrow.lib.Schema", "PandasBlockSchema"], columns: Union[str, List[str]], diff --git a/python/ray/air/checkpoint.py b/python/ray/air/checkpoint.py index 9ef8a88780a4..01d62f04b3dc 100644 --- a/python/ray/air/checkpoint.py +++ b/python/ray/air/checkpoint.py @@ -501,6 +501,11 @@ def _get_temporary_checkpoint_dir(self) -> str: ) return os.path.join(tmp_dir_path, checkpoint_dir_name) + def _save_checkpoint_metadata_in_directory(self, path: str) -> None: + checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME) + with open(checkpoint_metadata_path, "wb") as file: + pickle.dump(self._metadata, file) + def _to_directory(self, path: str) -> None: if self._data_dict or self._obj_ref: # This is a object ref or dict @@ -547,9 +552,7 @@ def _to_directory(self, path: str) -> None: f"No valid location found for checkpoint {self}: {self._uri}" ) - checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME) - with open(checkpoint_metadata_path, "wb") as file: - pickle.dump(self._metadata, file) + self._save_checkpoint_metadata_in_directory(path) def to_directory(self, path: Optional[str] = None) -> str: """Write checkpoint data to directory. diff --git a/python/ray/train/_internal/backend_executor.py b/python/ray/train/_internal/backend_executor.py index 855ce43d0f60..e8d82388a920 100644 --- a/python/ray/train/_internal/backend_executor.py +++ b/python/ray/train/_internal/backend_executor.py @@ -346,7 +346,7 @@ def initialize_session( train_func=train_func, dataset_shard=self.dataset_shards[index], checkpoint=checkpoint, - encode_data_fn=self._backend.encode_data, + encode_data_fn=self._backend._encode_data, ) ) diff --git a/python/ray/train/_internal/checkpoint.py b/python/ray/train/_internal/checkpoint.py index 55a11abdd400..6e18fedbff9c 100644 --- a/python/ray/train/_internal/checkpoint.py +++ b/python/ray/train/_internal/checkpoint.py @@ -101,13 +101,15 @@ def _process_checkpoint( """Ray Train entrypoint. Perform all processing for a checkpoint.""" # Get checkpoint from first worker. checkpoint_data = checkpoint_results[0].data + checkpoint_metadata = checkpoint_results[0].metadata or {} - # Decode checkpoint. - checkpoint_data = decode_checkpoint_fn(checkpoint_data) + # TODO(ml-team): Remove once we remove Backend.decode_data + checkpoint_data = decode_checkpoint_fn(checkpoint_data).to_dict() score_attr = self._checkpoint_strategy.checkpoint_score_attribute if ( self._checkpoint_strategy.num_to_keep != 0 + and score_attr not in checkpoint_metadata and score_attr not in checkpoint_data ): raise ValueError( @@ -122,7 +124,11 @@ def _process_checkpoint( dir_or_data=checkpoint_data, checkpoint_id=self._latest_checkpoint_id, storage_mode=CheckpointStorage.MEMORY, - metrics={score_attr: checkpoint_data.get(score_attr, 0.0)}, + metrics={ + score_attr: checkpoint_metadata.get( + score_attr, checkpoint_data.get(score_attr, 0.0) + ) + }, ) self.register_checkpoint(checkpoint=tracked_checkpoint) diff --git a/python/ray/train/_internal/session.py b/python/ray/train/_internal/session.py index 9ab1717615a3..a8a86a594083 100644 --- a/python/ray/train/_internal/session.py +++ b/python/ray/train/_internal/session.py @@ -2,6 +2,7 @@ import logging import platform import queue +import sys import threading import time from dataclasses import dataclass @@ -51,7 +52,8 @@ class TrialInfo: @dataclass class TrainingResult: type: TrainingResultType - data: Dict + data: Union[Dict, Checkpoint] + metadata: Optional[Dict] = None # TODO(xwjiang): This needs a better name. @@ -68,8 +70,9 @@ def __init__( trial_info: Optional[TrialInfo] = None, dataset_shard: Optional[Union[Dataset, DatasetPipeline]] = None, # TODO(xwjiang): Legacy Ray Train trainer clean up! - checkpoint: Optional[Union[Dict, Checkpoint]] = None, - encode_data_fn: Callable = None, + checkpoint: Optional[Checkpoint] = None, + # Deprecated + encode_data_fn: Optional[Callable] = None, detailed_autofilled_metrics: bool = False, ): @@ -80,7 +83,7 @@ def __init__( self.world_size = world_size self.trial_info = trial_info # TODO(xwjiang): Legacy Ray Train trainer clean up! - self.loaded_checkpoint: Optional[Union[Dict, Checkpoint]] = checkpoint + self.loaded_checkpoint = checkpoint # Function to encode checkpoint dict before sending to the driver. if not encode_data_fn: @@ -240,9 +243,9 @@ def _report_legacy(self, **kwargs): if self.ignore_report: return - kwargs = self._encode_data_fn(self._auto_fill_metrics(kwargs)) + kwargs = self._auto_fill_metrics(kwargs) - result = TrainingResult(TrainingResultType.REPORT, kwargs) + result = TrainingResult(type=TrainingResultType.REPORT, data=kwargs) # Add result to a thread-safe queue. self.result_queue.put(result, block=True) @@ -269,22 +272,26 @@ def _report_thread_runner_error(self, block=False): except queue.Empty: pass - def checkpoint(self, **kwargs): + def checkpoint(self, checkpoint: Checkpoint): """Adds kwargs to the queue to be consumed by main thread. Also stores the checkpoint in ``self.loaded_checkpoint``. """ # Update session checkpoint to latest checkpoint. - self.loaded_checkpoint = kwargs + self.loaded_checkpoint = checkpoint # Only store checkpoints on worker with rank 0. if self.world_rank != 0: - kwargs = {} - else: - kwargs = self._encode_data_fn(self._auto_fill_checkpoint_metrics(kwargs)) - - result = TrainingResult(TrainingResultType.CHECKPOINT, kwargs) + checkpoint = None + elif checkpoint: + checkpoint = self._encode_data_fn(checkpoint) + + result = TrainingResult( + type=TrainingResultType.CHECKPOINT, + data=checkpoint, + metadata=self._auto_fill_checkpoint_metrics({}), + ) # Add result to a thread-safe queue. self.result_queue.put(result, block=True) @@ -294,9 +301,23 @@ def checkpoint(self, **kwargs): def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None: # TODO(xwjiang): tons of optimizations. + + # Special case: early fail for Torch tensors + if "torch" in sys.modules: + from ray.air._internal.torch_utils import contains_tensor + + if contains_tensor(metrics): + raise ValueError( + "Passing objects containg Torch tensors as metrics " + "is not supported as it will throw an exception on " + "deserialization. You can either convert the tensors " + "to Python objects or use a `TorchCheckpoint` as the " + "`checkpoint` argument of `ray.air.session.report` to " + "store your Torch objects." + ) + if checkpoint: - checkpoint_dict = checkpoint.to_dict() - self.checkpoint(**checkpoint_dict) + self.checkpoint(checkpoint) self._report_legacy(**metrics) diff --git a/python/ray/train/backend.py b/python/ray/train/backend.py index 78dde0c36836..7d1ae6848d6f 100644 --- a/python/ray/train/backend.py +++ b/python/ray/train/backend.py @@ -1,16 +1,43 @@ import logging -from typing import TypeVar, Dict +import warnings +from typing import Type, TypeVar, Dict +from ray.air.checkpoint import Checkpoint from ray.train._internal.utils import Singleton from ray.train._internal.worker_group import WorkerGroup -from ray.util.annotations import DeveloperAPI - +from ray.util.annotations import Deprecated, DeveloperAPI from ray.widgets import make_table_html_repr EncodedData = TypeVar("EncodedData") logger = logging.getLogger(__name__) +# This is used in several places to print a warning. +_encode_decode_deprecation_message = ( + "``encode_data`` and ``decode_data`` are deprecated in favor of " + "framework-specific ``ray.air.Checkpoint`` subclasses (reported " + "using ``ray.air.session.report()``) which can implement " + "encoding and decoding logic. In the future, ``encode_data`` and " + "``decode_data`` will throw an exception if overriden." +) + + +def _warn_about_bad_checkpoint_type(expected_checkpoint_cls: Type[Checkpoint]): + return + # Do not print warnings in 2.1 yet. + # TODO(ml-team): Change this once we have full API parity with framework + # checkpoints. Also turn on test_torch_trainer::test_torch_bad_checkpoint_warning + # warnings.warn( + # f"You have reported a checkpoint with the `{Checkpoint}` " + # "type, but the intended checkpoint type for the Trainer " + # f"you are using is `{expected_checkpoint_cls}`. " + # "Not using the intended checkpoint type may cause " + # "exceptions or other issues, especially during " + # "serialization and deserialization. The checkpoint " + # "type will be changed automatically. " + # "This behavior may change in the future." + # ) + @DeveloperAPI class BackendConfig: @@ -46,6 +73,37 @@ def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig): """Logic for shutting down the backend.""" pass + @classmethod + def _encode_data(cls, checkpoint: Checkpoint) -> Checkpoint: + """Temporary method until ``encode_data`` is deprecated.""" + if cls.encode_data != Backend.encode_data: + warnings.warn( + _encode_decode_deprecation_message, DeprecationWarning, stacklevel=2 + ) + # We wrap the return of encode_data in dict in case it is + # not a dict itself. + checkpoint = checkpoint.from_dict( + {"encoded_data": cls.encode_data(checkpoint.to_dict())} + ) + return checkpoint + + @classmethod + def _decode_data(cls, checkpoint: Checkpoint) -> Checkpoint: + """Temporary method until ``decode_data`` is deprecated.""" + if cls.decode_data != Backend.decode_data: + warnings.warn( + _encode_decode_deprecation_message, DeprecationWarning, stacklevel=2 + ) + checkpoint_dict = checkpoint.to_dict() + # If "encoded_data" is not in the dict, then the data was + # not encoded, but the user may want to just do decoding + # anyway. + checkpoint = checkpoint.from_dict( + cls.decode_data(checkpoint_dict.get("encoded_data", checkpoint_dict)) + ) + return checkpoint + + @Deprecated(message=_encode_decode_deprecation_message) @staticmethod def encode_data(data_dict: Dict) -> EncodedData: """Logic to encode a data dict before sending to the driver. @@ -56,6 +114,7 @@ def encode_data(data_dict: Dict) -> EncodedData: return data_dict + @Deprecated(message=_encode_decode_deprecation_message) @staticmethod def decode_data(encoded_data: EncodedData) -> Dict: """Logic to decode an encoded data dict. diff --git a/python/ray/train/data_parallel_trainer.py b/python/ray/train/data_parallel_trainer.py index ca2c595e4190..06a828605853 100644 --- a/python/ray/train/data_parallel_trainer.py +++ b/python/ray/train/data_parallel_trainer.py @@ -8,6 +8,7 @@ from ray import tune from ray.air import session from ray.air.checkpoint import Checkpoint +from ray.air._internal.checkpointing import save_preprocessor_to_dir from ray.air.config import DatasetConfig, RunConfig, ScalingConfig, CheckpointConfig from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY from ray.air._internal.checkpoint_manager import _TrackedCheckpoint @@ -41,7 +42,10 @@ def __init__( ) def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint): - checkpoint.dir_or_data[PREPROCESSOR_KEY] = self.preprocessor + if isinstance(checkpoint.dir_or_data, dict): + checkpoint.dir_or_data[PREPROCESSOR_KEY] = self.preprocessor + else: + save_preprocessor_to_dir(self.preprocessor, checkpoint.dir_or_data) super(_DataParallelCheckpointManager, self)._process_persistent_checkpoint( checkpoint=checkpoint ) diff --git a/python/ray/train/examples/horovod/horovod_pytorch_example.py b/python/ray/train/examples/horovod/horovod_pytorch_example.py index 5197d900aba6..f4d15ae0515b 100644 --- a/python/ray/train/examples/horovod/horovod_pytorch_example.py +++ b/python/ray/train/examples/horovod/horovod_pytorch_example.py @@ -9,9 +9,9 @@ from torchvision import datasets, transforms from ray.air import session -from ray.air.checkpoint import Checkpoint from ray.air.config import ScalingConfig from ray.train.horovod import HorovodTrainer +from ray.train.torch.torch_checkpoint import TorchCheckpoint import ray.train.torch @@ -152,12 +152,11 @@ def train_func(config): model, optimizer, train_sampler, train_loader, epoch, log_interval, use_cuda ) if save_model_as_dict: - checkpoint_dict = dict(model=model.state_dict()) + checkpoint = TorchCheckpoint.from_state_dict(model.state_dict()) else: - checkpoint_dict = dict(model=model) - checkpoint_dict = Checkpoint.from_dict(checkpoint_dict) + checkpoint = TorchCheckpoint.from_model(model) results.append(loss) - session.report(dict(loss=loss), checkpoint=checkpoint_dict) + session.report(dict(loss=loss), checkpoint=checkpoint) # Only used for testing. return results diff --git a/python/ray/train/examples/pytorch/torch_linear_example.py b/python/ray/train/examples/pytorch/torch_linear_example.py index 03cedc5e0751..647b51a0db6b 100644 --- a/python/ray/train/examples/pytorch/torch_linear_example.py +++ b/python/ray/train/examples/pytorch/torch_linear_example.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import ray.train as train -from ray.train.torch import TorchTrainer +from ray.train.torch import TorchTrainer, TorchCheckpoint from ray.air.config import ScalingConfig @@ -48,8 +48,7 @@ def validate_epoch(dataloader, model, loss_fn): import copy model_copy = copy.deepcopy(model) - result = {"model": model_copy.cpu().state_dict(), "loss": loss} - return result + return model_copy.cpu().state_dict(), loss def train_func(config): @@ -76,12 +75,12 @@ def train_func(config): optimizer = torch.optim.SGD(model.parameters(), lr=lr) results = [] - for _ in range(epochs): train_epoch(train_loader, model, loss_fn, optimizer) - result = validate_epoch(validation_loader, model, loss_fn) + state_dict, loss = validate_epoch(validation_loader, model, loss_fn) + result = dict(loss=loss) results.append(result) - session.report(result) + session.report(result, checkpoint=TorchCheckpoint.from_state_dict(state_dict)) return results diff --git a/python/ray/train/horovod/config.py b/python/ray/train/horovod/config.py index 989c8f06f6d6..5960137f45fa 100644 --- a/python/ray/train/horovod/config.py +++ b/python/ray/train/horovod/config.py @@ -1,18 +1,20 @@ import sys -from typing import Optional, Set, Dict +from typing import Optional, Set import os from dataclasses import dataclass import ray -from ray.air._internal.torch_utils import contains_tensor -from ray.train.backend import BackendConfig, Backend, EncodedData +from ray.air.checkpoint import Checkpoint +from ray.train.backend import BackendConfig, Backend, _warn_about_bad_checkpoint_type from ray.train._internal.utils import update_env_vars from ray.train._internal.worker_group import WorkerGroup, Worker from horovod.ray.runner import Coordinator from horovod.ray.utils import detect_nics, nics_to_env_var from horovod.runner.common.util import secret, timeout +from ray.train.tensorflow.tensorflow_checkpoint import TensorflowCheckpoint +from ray.train.torch.torch_checkpoint import TorchCheckpoint from ray.util import PublicAPI @@ -131,36 +133,26 @@ def on_start(self, worker_group: WorkerGroup, backend_config: HorovodConfig): worker_group.execute(update_env_vars, coordinator_envs) - @staticmethod - def encode_data(data_dict: Dict) -> EncodedData: - """Logic to encode a data dict before sending to the driver. - - This function will be called on the workers for any data that is - sent to the driver via ``session.report()``. - """ - # If torch is imported, we can use it to serialize the data dict - # into bytes. This will prevent e.g. GPU deserialization errors. - if "torch" in sys.modules and contains_tensor(data_dict): - from ray.train.torch.config import _TorchBackend - - return _TorchBackend.encode_data(data_dict) - - return data_dict - - @staticmethod - def decode_data(encoded_data: EncodedData) -> Dict: - """Logic to decode an encoded data dict. - - This function will be called on the driver after receiving the - encoded data dict from the worker. - """ - # See encode_data - if "torch" in sys.modules and isinstance(encoded_data, bytes): - from ray.train.torch.config import _TorchBackend - - return _TorchBackend.decode_data(encoded_data) - - return encoded_data + @classmethod + def _encode_data(cls, checkpoint: Checkpoint): + checkpoint = super()._encode_data(checkpoint) + if type(checkpoint) is Checkpoint: + if checkpoint.get_internal_representation()[0] == "data_dict": + if "tensorflow" in sys.modules: + from ray.air._internal.tensorflow_utils import ( + contains_tensorflow_object, + ) + + if contains_tensorflow_object(checkpoint.to_dict()): + _warn_about_bad_checkpoint_type(TensorflowCheckpoint) + checkpoint = TensorflowCheckpoint.from_checkpoint(checkpoint) + if "torch" in sys.modules: + from ray.air._internal.torch_utils import contains_tensor + + if contains_tensor(checkpoint.to_dict()): + _warn_about_bad_checkpoint_type(TorchCheckpoint) + checkpoint = TorchCheckpoint.from_checkpoint(checkpoint) + return checkpoint def _init_env_vars(world_rank: int, world_size: int, node_id: str): diff --git a/python/ray/train/horovod/horovod_trainer.py b/python/ray/train/horovod/horovod_trainer.py index beded4a43bb8..c4a4c957cc6b 100644 --- a/python/ray/train/horovod/horovod_trainer.py +++ b/python/ray/train/horovod/horovod_trainer.py @@ -87,8 +87,9 @@ def train_loop_per_worker(): import horovod.torch as hvd import torch import torch.nn as nn - from ray.air import session, Checkpoint + from ray.air import session from ray.train.horovod import HorovodTrainer + from ray.train.torch import TorchCheckpoint from ray.air.config import ScalingConfig input_size = 1 @@ -136,8 +137,8 @@ def train_loop_per_worker(): print(f"epoch: {epoch}, loss: {loss.item()}") session.report( {}, - checkpoint=Checkpoint.from_dict( - dict(model=model.state_dict()) + checkpoint=TorchCheckpoint.from_state_dict( + model.state_dict() ), ) train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) diff --git a/python/ray/train/huggingface/_huggingface_utils.py b/python/ray/train/huggingface/_huggingface_utils.py index 623b7ca00ef4..8b61590147a8 100644 --- a/python/ray/train/huggingface/_huggingface_utils.py +++ b/python/ray/train/huggingface/_huggingface_utils.py @@ -7,9 +7,9 @@ from transformers.trainer_utils import IntervalStrategy from ray.air import session -from ray.air.checkpoint import Checkpoint from ray.util import get_node_ip_address from ray.data.dataset import Dataset +from ray.train.huggingface.huggingface_checkpoint import HuggingFaceCheckpoint if TYPE_CHECKING: from torch.utils.data import IterableDataset @@ -152,7 +152,8 @@ def on_save(self, args, state, control, **kwargs): transformers.trainer.get_last_checkpoint(args.output_dir) ).absolute() if checkpoint_path: - self.delayed_report["checkpoint"] = Checkpoint.from_dict( + # Use HuggingFaceCheckpoint here to avoid a warning in _TrainSession + self.delayed_report["checkpoint"] = HuggingFaceCheckpoint.from_dict( { NODE_IP_KEY: get_node_ip_address(), CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), diff --git a/python/ray/train/huggingface/huggingface_trainer.py b/python/ray/train/huggingface/huggingface_trainer.py index 400505684a28..9ca24ce3ea9e 100644 --- a/python/ray/train/huggingface/huggingface_trainer.py +++ b/python/ray/train/huggingface/huggingface_trainer.py @@ -7,6 +7,7 @@ import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type +from ray.train.huggingface.huggingface_checkpoint import HuggingFaceCheckpoint try: from packaging.version import Version @@ -129,6 +130,13 @@ def commit(self, path: Optional[Path] = None) -> None: with open(path.joinpath(TUNE_CHECKPOINT_ID), "w") as f: f.write(str(self.id)) + # Add checkpoint class metadata + # A bit of a hack but this will be removed with the rest + # of this special case eventually + # TODO(ml-team): remove this when HF checkpointing is refactored + checkpoint = HuggingFaceCheckpoint.from_directory(path) + checkpoint._save_checkpoint_metadata_in_directory(path) + class _DataParallelSyncingCheckpointManager(_DataParallelCheckpointManager): def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint): diff --git a/python/ray/train/tensorflow/config.py b/python/ray/train/tensorflow/config.py index 88ccafbb839b..f16caf88b420 100644 --- a/python/ray/train/tensorflow/config.py +++ b/python/ray/train/tensorflow/config.py @@ -5,9 +5,11 @@ from typing import List import ray -from ray.train.backend import BackendConfig, Backend +from ray.air.checkpoint import Checkpoint +from ray.train.backend import BackendConfig, Backend, _warn_about_bad_checkpoint_type from ray.train._internal.utils import get_address_and_port from ray.train._internal.worker_group import WorkerGroup +from ray.train.tensorflow.tensorflow_checkpoint import TensorflowCheckpoint from ray.util import PublicAPI @@ -56,3 +58,11 @@ def get_url(): ) ) ray.get(setup_futures) + + @classmethod + def _encode_data(cls, checkpoint: Checkpoint): + checkpoint = super()._encode_data(checkpoint) + if type(checkpoint) is Checkpoint: + _warn_about_bad_checkpoint_type(TensorflowCheckpoint) + checkpoint = TensorflowCheckpoint.from_checkpoint(checkpoint) + return checkpoint diff --git a/python/ray/train/tests/test_gpu_amp.py b/python/ray/train/tests/test_gpu_amp.py index 455f65db106b..2e8f9522b248 100644 --- a/python/ray/train/tests/test_gpu_amp.py +++ b/python/ray/train/tests/test_gpu_amp.py @@ -65,7 +65,7 @@ def train_func(): model = torchvision.models.resnet101() model = train.torch.prepare_model(model) - session.report({"model": model}, checkpoint=TorchCheckpoint.from_model(model)) + session.report({}, checkpoint=TorchCheckpoint.from_model(model)) trainer = TorchTrainer( train_func, scaling_config=ScalingConfig(num_workers=2, use_gpu=True) diff --git a/python/ray/train/tests/test_gpu_auto_transfer.py b/python/ray/train/tests/test_gpu_auto_transfer.py index 242dbe7644c1..3cea67266a49 100644 --- a/python/ray/train/tests/test_gpu_auto_transfer.py +++ b/python/ray/train/tests/test_gpu_auto_transfer.py @@ -8,8 +8,7 @@ from ray.air import session from ray.air.constants import MODEL_KEY from ray.air.config import ScalingConfig -from ray.train.torch.torch_checkpoint import TorchCheckpoint -from ray.train.torch.torch_trainer import TorchTrainer +from ray.train.torch import TorchTrainer, TorchCheckpoint import ray.train.torch.train_loop_utils @@ -106,7 +105,7 @@ def train_func(): assert next(model.parameters()).is_cuda - session.report({"model": model}, checkpoint=TorchCheckpoint.from_model(model)) + session.report({}, checkpoint=TorchCheckpoint.from_model(model)) trainer = TorchTrainer( train_func, scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=True) @@ -114,9 +113,7 @@ def train_func(): results = trainer.fit() model_checkpoint = results.checkpoint.get_model() - model_report = results.metrics["model"] assert not next(model_checkpoint.parameters()).is_cuda - assert not next(model_report.parameters()).is_cuda # Test the same thing for state dict. @@ -134,7 +131,7 @@ def train_func(): assert tensor.is_cuda session.report( - {"state_dict": state_dict}, + {}, checkpoint=TorchCheckpoint.from_state_dict(state_dict), ) @@ -144,10 +141,6 @@ def train_func(): results = trainer.fit() state_dict_checkpoint = results.checkpoint.to_dict()[MODEL_KEY] - state_dict_report = results.metrics["state_dict"] - - for tensor in state_dict_report.values(): - assert not tensor.is_cuda for tensor in state_dict_checkpoint.values(): assert not tensor.is_cuda diff --git a/python/ray/train/tests/test_huggingface_trainer.py b/python/ray/train/tests/test_huggingface_trainer.py index b741d3475574..4951f6444370 100644 --- a/python/ray/train/tests/test_huggingface_trainer.py +++ b/python/ray/train/tests/test_huggingface_trainer.py @@ -11,7 +11,11 @@ import ray.data from ray.exceptions import RayTaskError from ray.train.batch_predictor import BatchPredictor -from ray.train.huggingface import HuggingFacePredictor, HuggingFaceTrainer +from ray.train.huggingface import ( + HuggingFacePredictor, + HuggingFaceTrainer, + HuggingFaceCheckpoint, +) from ray.air.config import ScalingConfig from ray.train.tests._huggingface_data import train_data, validation_data from ray import tune @@ -91,6 +95,7 @@ def test_e2e(ray_start_4_cpus, save_strategy): assert result.metrics["epoch"] == 4 assert result.metrics["training_iteration"] == 4 assert result.checkpoint + assert isinstance(result.checkpoint, HuggingFaceCheckpoint) assert "eval_loss" in result.metrics trainer2 = HuggingFaceTrainer( @@ -108,6 +113,7 @@ def test_e2e(ray_start_4_cpus, save_strategy): assert result2.metrics["epoch"] == 5 assert result2.metrics["training_iteration"] == 1 assert result2.checkpoint + assert isinstance(result2.checkpoint, HuggingFaceCheckpoint) assert "eval_loss" in result2.metrics predictor = BatchPredictor.from_checkpoint( diff --git a/python/ray/train/tests/test_session.py b/python/ray/train/tests/test_session.py index 9eb5dd51985b..545d611f4882 100644 --- a/python/ray/train/tests/test_session.py +++ b/python/ray/train/tests/test_session.py @@ -139,7 +139,7 @@ def validate_zero(expected): next = session.get_next() assert next is not None assert next.type == TrainingResultType.CHECKPOINT - assert next.data["epoch"] == expected + assert next.data.to_dict()["epoch"] == expected init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1) session = get_session() @@ -155,7 +155,7 @@ def validate_nonzero(): next = session.get_next() assert next is not None assert next.type == TrainingResultType.CHECKPOINT - assert next.data == {} + assert not next.data init_session(training_func=train_func, world_rank=1, local_rank=1, world_size=1) session = get_session() @@ -173,13 +173,17 @@ def train_func(): report(dict(epoch=0), checkpoint=Checkpoint.from_dict(dict(epoch=0))) def encode_checkpoint(checkpoint): - checkpoint.update({"encoded": True}) - return checkpoint + data = checkpoint.to_dict() + data["encoded"] = True + return checkpoint.from_dict(data) def validate_encoded(result_type: TrainingResultType): next = session.get_next() assert next.type is result_type - assert next.data["encoded"] is True + data = next.data + if isinstance(data, Checkpoint): + data = data.to_dict() + assert data["encoded"] is True init_session( training_func=train_func, @@ -193,8 +197,7 @@ def validate_encoded(result_type: TrainingResultType): session.start() # Validate checkpoint is encoded. validate_encoded(TrainingResultType.CHECKPOINT) - # Validate report is encoded. - validate_encoded(TrainingResultType.REPORT) + session.get_next() session.finish() shutdown_session() @@ -211,6 +214,7 @@ def train_func(): session.start() for i in range(2): session.get_next() + session.get_next() session.finish() shutdown_session() diff --git a/python/ray/train/tests/test_tensorflow_trainer.py b/python/ray/train/tests/test_tensorflow_trainer.py index fc7ee1f97a9f..df87b665b5c6 100644 --- a/python/ray/train/tests/test_tensorflow_trainer.py +++ b/python/ray/train/tests/test_tensorflow_trainer.py @@ -1,6 +1,4 @@ import os - -import numpy as np import pytest import ray @@ -10,12 +8,14 @@ train_func as tensorflow_linear_train_func, ) from ray.air.config import ScalingConfig +from ray.train.batch_predictor import BatchPredictor from ray.train.constants import TRAIN_DATASET_KEY from ray.train.tensorflow import ( - TensorflowCheckpoint, TensorflowPredictor, TensorflowTrainer, + TensorflowCheckpoint, ) +from ray.train.tests.dummy_preprocessor import DummyPreprocessor @pytest.fixture @@ -74,23 +74,19 @@ def train_func(): scaling_config = ScalingConfig(num_workers=2) trainer = TensorflowTrainer( - train_loop_per_worker=train_func, scaling_config=scaling_config + train_loop_per_worker=train_func, + scaling_config=scaling_config, + preprocessor=DummyPreprocessor(), ) result = trainer.fit() + assert isinstance(result.checkpoint.get_preprocessor(), DummyPreprocessor) - class TensorflowScorer: - def __init__(self): - self.pred = TensorflowPredictor.from_checkpoint( - result.checkpoint, build_model - ) - - def __call__(self, x): - return self.pred.predict(x, dtype=np.float) + batch_predictor = BatchPredictor.from_checkpoint( + result.checkpoint, TensorflowPredictor, model_definition=build_model + ) predict_dataset = ray.data.range(3) - predictions = predict_dataset.map_batches( - TensorflowScorer, batch_format="pandas", compute="actors" - ) + predictions = batch_predictor.predict(predict_dataset) assert predictions.count() == 3 @@ -112,17 +108,23 @@ def train_func(): scaling_config = ScalingConfig(num_workers=2) trainer = TensorflowTrainer( - train_loop_per_worker=train_func, scaling_config=scaling_config + train_loop_per_worker=train_func, + scaling_config=scaling_config, + preprocessor=DummyPreprocessor(), ) result = trainer.fit() + checkpoint = result.checkpoint + assert isinstance(checkpoint.get_preprocessor(), DummyPreprocessor) trainer2 = TensorflowTrainer( train_loop_per_worker=train_func, scaling_config=scaling_config, - resume_from_checkpoint=result.checkpoint, + resume_from_checkpoint=checkpoint, + preprocessor=DummyPreprocessor(), ) result = trainer2.fit() checkpoint = result.checkpoint + assert isinstance(checkpoint.get_preprocessor(), DummyPreprocessor) with checkpoint.as_directory() as ckpt_dir: assert os.path.exists(os.path.join(ckpt_dir, "saved_model.pb")) assert result.metrics["iter"] == 1 diff --git a/python/ray/train/tests/test_torch_trainer.py b/python/ray/train/tests/test_torch_trainer.py index 2dbec40eb1d5..905717961fc7 100644 --- a/python/ray/train/tests/test_torch_trainer.py +++ b/python/ray/train/tests/test_torch_trainer.py @@ -1,14 +1,13 @@ import contextlib import pytest -from ray.air import session -from ray.air.checkpoint import Checkpoint -from ray.train.torch.torch_checkpoint import TorchCheckpoint import torch +import os import ray from ray.train.examples.pytorch.torch_linear_example import ( train_func as linear_train_func, ) +from ray.train.batch_predictor import BatchPredictor from ray.train.torch import TorchPredictor, TorchTrainer from ray.tune import TuneError from ray.air.config import ScalingConfig @@ -16,6 +15,9 @@ import ray.train as train from unittest.mock import patch from ray.cluster_utils import Cluster +from ray.air import session +from ray.train.tests.dummy_preprocessor import DummyPreprocessor +from ray.train.torch.torch_checkpoint import TorchCheckpoint @pytest.fixture @@ -62,25 +64,21 @@ def train_func(config): def test_torch_e2e(ray_start_4_cpus): def train_func(): model = torch.nn.Linear(3, 1) - session.report({}, checkpoint=Checkpoint.from_dict(dict(model=model))) + session.report({}, checkpoint=TorchCheckpoint.from_model(model)) scaling_config = ScalingConfig(num_workers=2) trainer = TorchTrainer( - train_loop_per_worker=train_func, scaling_config=scaling_config + train_loop_per_worker=train_func, + scaling_config=scaling_config, + preprocessor=DummyPreprocessor(), ) result = trainer.fit() + assert isinstance(result.checkpoint.get_preprocessor(), DummyPreprocessor) predict_dataset = ray.data.range(9) - - class TorchScorer: - def __init__(self): - self.pred = TorchPredictor.from_checkpoint(result.checkpoint) - - def __call__(self, x): - return self.pred.predict(x, dtype=torch.float) - - predictions = predict_dataset.map_batches( - TorchScorer, batch_size=3, batch_format="pandas", compute="actors" + batch_predictor = BatchPredictor.from_checkpoint(result.checkpoint, TorchPredictor) + predictions = batch_predictor.predict( + predict_dataset, batch_size=3, dtype=torch.float ) assert predictions.count() == 3 @@ -88,22 +86,55 @@ def __call__(self, x): def test_torch_e2e_state_dict(ray_start_4_cpus): def train_func(): model = torch.nn.Linear(3, 1).state_dict() - session.report({}, checkpoint=Checkpoint.from_dict(dict(model=model))) + session.report({}, checkpoint=TorchCheckpoint.from_state_dict(model)) scaling_config = ScalingConfig(num_workers=2) trainer = TorchTrainer( - train_loop_per_worker=train_func, scaling_config=scaling_config + train_loop_per_worker=train_func, + scaling_config=scaling_config, + preprocessor=DummyPreprocessor(), ) result = trainer.fit() + isinstance(result.checkpoint.get_preprocessor(), DummyPreprocessor) # If loading from a state dict, a model definition must be passed in. with pytest.raises(ValueError): TorchPredictor.from_checkpoint(result.checkpoint) + predict_dataset = ray.data.range(9) + batch_predictor = BatchPredictor.from_checkpoint( + result.checkpoint, TorchPredictor, model=torch.nn.Linear(3, 1) + ) + predictions = batch_predictor.predict( + predict_dataset, batch_size=3, dtype=torch.float + ) + assert predictions.count() == 3 + + +def test_torch_e2e_dir(ray_start_4_cpus, tmpdir): + def train_func(): + model = torch.nn.Linear(3, 1) + torch.save(model, os.path.join(tmpdir, "model")) + session.report({}, checkpoint=TorchCheckpoint.from_directory(tmpdir)) + + scaling_config = ScalingConfig(num_workers=2) + trainer = TorchTrainer( + train_loop_per_worker=train_func, + scaling_config=scaling_config, + preprocessor=DummyPreprocessor(), + ) + result = trainer.fit() + isinstance(result.checkpoint.get_preprocessor(), DummyPreprocessor) + + # TODO(ml-team): Add a way for TorchCheckpoint to natively support + # models from files class TorchScorer: def __init__(self): + with result.checkpoint.as_directory() as checkpoint_path: + model = torch.load(os.path.join(checkpoint_path, "model")) + preprocessor = result.checkpoint.get_preprocessor() self.pred = TorchPredictor.from_checkpoint( - result.checkpoint, model=torch.nn.Linear(3, 1) + TorchCheckpoint.from_model(model, preprocessor=preprocessor) ) def __call__(self, x): @@ -131,6 +162,58 @@ def test_checkpoint_freq(ray_start_4_cpus): trainer.fit() +def test_torch_session_errors(ray_start_4_cpus): + """Test fail-fast behavior when reporting dicts with Torch tensors""" + + def train_func(): + model = torch.nn.Linear(1, 1).state_dict() + with pytest.raises(ValueError): + session.report(model) + + scaling_config = ScalingConfig(num_workers=2) + trainer = TorchTrainer( + train_loop_per_worker=train_func, + scaling_config=scaling_config, + ) + trainer.fit() + + +# See comment in backend.py::_warn_about_bad_checkpoint_type +# for why test_torch_bad_checkpoint_warning is commented out + +# def test_torch_bad_checkpoint_warning(ray_start_4_cpus): +# """Test that a warning is printed if bad checkpoint type is used.""" + +# def train_func(): +# model = torch.nn.Linear(1, 1).state_dict() +# session.report({}, checkpoint=TorchCheckpoint.from_dict({"model": model})) + +# scaling_config = ScalingConfig(num_workers=2) +# trainer = TorchTrainer( +# train_loop_per_worker=train_func, +# scaling_config=scaling_config, +# ) +# output = io.StringIO() +# with redirect_stdout(output), redirect_stderr(output): +# trainer.fit() +# output = output.getvalue() +# assert "You have reported a checkpoint" not in output + +# def train_func(): +# model = torch.nn.Linear(1, 1).state_dict() +# session.report({}, checkpoint=Checkpoint.from_dict({"model": model})) + +# trainer = TorchTrainer( +# train_loop_per_worker=train_func, +# scaling_config=scaling_config, +# ) +# output = io.StringIO() +# with redirect_stdout(output), redirect_stderr(output): +# trainer.fit() +# output = output.getvalue() +# assert "You have reported a checkpoint" in output + + @pytest.mark.parametrize( "num_gpus_per_worker,expected_devices", [(0.5, [0]), (1, [0]), (2, [0, 1])] ) @@ -204,7 +287,7 @@ def train_fn(): model = train.torch.prepare_model(model) # Save DDP wrapped model. - session.report({"model": model}, checkpoint=TorchCheckpoint.from_model(model)) + session.report({}, checkpoint=TorchCheckpoint.from_model(model)) trainer = TorchTrainer( train_loop_per_worker=train_fn, @@ -218,11 +301,6 @@ def train_fn(): model, torch.nn.parallel.DistributedDataParallel ) - model_report = results.metrics["model"] - assert isinstance(model_report, torch.nn.Module) and not isinstance( - model_report, torch.nn.parallel.DistributedDataParallel - ) - def test_torch_amp(ray_start_4_cpus): def train_fn(): diff --git a/python/ray/train/torch/config.py b/python/ray/train/torch/config.py index 346de18d5815..fca7ec745169 100644 --- a/python/ray/train/torch/config.py +++ b/python/ray/train/torch/config.py @@ -1,21 +1,20 @@ from dataclasses import dataclass -import io import logging import os from datetime import timedelta -from typing import Dict, Optional +from typing import Optional import ray -import ray.cloudpickle -from ray.train.backend import BackendConfig, Backend, EncodedData +from ray.air.checkpoint import Checkpoint +from ray.train.backend import BackendConfig, Backend, _warn_about_bad_checkpoint_type from ray.train.constants import DEFAULT_NCCL_SOCKET_IFNAME from ray.train._internal.worker_group import WorkerGroup from ray.train._internal.utils import get_address_and_port +from ray.train.torch.torch_checkpoint import TorchCheckpoint from ray.util import PublicAPI import torch import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel try: from torch.profiler import profile @@ -178,32 +177,10 @@ def on_shutdown(self, worker_group: WorkerGroup, backend_config: TorchConfig): _shutdown_torch, destroy_process_group=len(worker_group) > 1 ) - @staticmethod - def encode_data(data_dict: Dict) -> EncodedData: - """Special handling for moving model from worker to driver.""" - - # If model is being checkpointed and is wrapped in DDP, then extract - # out the underlying module. If not, then deserialization will fail - # since the torch process group is not initialized on the driver. - - for k, v in data_dict.items(): - if isinstance(v, DistributedDataParallel) and hasattr(v, "module"): - data_dict[k] = v.module - - # Convert the checkpoint dict to bytes, so that any GPU tensors that - # are in the checkpoint dict can be properly deserialized on the - # driver side, even if the driver does not have access to a GPU device. - _buffer = io.BytesIO() - # If a custom torch model contains a function that cannot be pickled normally, - # we need to use ray.cloudpickle. This is also consistent with how Ray - # serialization works in general and has no downsides - # (this can still be unpickled without ray using normal pickle). - torch.save(data_dict, _buffer, pickle_module=ray.cloudpickle) - return _buffer.getvalue() - - @staticmethod - def decode_data(encoded_data: EncodedData) -> Dict: - # When decoding the bytes on the driver side, always map to CPU. - _buffer = io.BytesIO(encoded_data) - checkpoint_dict = torch.load(_buffer, map_location="cpu") - return checkpoint_dict + @classmethod + def _encode_data(cls, checkpoint: Checkpoint): + checkpoint = super()._encode_data(checkpoint) + if type(checkpoint) is Checkpoint: + _warn_about_bad_checkpoint_type(TorchCheckpoint) + checkpoint = TorchCheckpoint.from_checkpoint(checkpoint) + return checkpoint diff --git a/python/ray/train/torch/torch_checkpoint.py b/python/ray/train/torch/torch_checkpoint.py index 774256965c8b..9d040d99060c 100644 --- a/python/ray/train/torch/torch_checkpoint.py +++ b/python/ray/train/torch/torch_checkpoint.py @@ -1,8 +1,10 @@ from typing import TYPE_CHECKING, Any, Dict, Optional +import io import torch import warnings +import ray.cloudpickle from ray.air.checkpoint import Checkpoint from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY from ray.train.data_parallel_trainer import _load_checkpoint_dict @@ -12,6 +14,8 @@ if TYPE_CHECKING: from ray.data.preprocessor import Preprocessor +ENCODED_DATA_KEY = "torch_encoded_data" + @PublicAPI(stability="beta") class TorchCheckpoint(Checkpoint): @@ -21,6 +25,49 @@ class TorchCheckpoint(Checkpoint): ``TorchCheckpoint.from_checkpoint(ckpt)``. """ + # Special encoding logic to avoid serialization errors with torch. + def _encode_data_dict(self, data_dict: dict) -> dict: + """Encode data_dict using torch.save.""" + from torch.nn.parallel import DistributedDataParallel + + for k, v in data_dict.items(): + if isinstance(v, DistributedDataParallel) and hasattr(v, "module"): + data_dict[k] = v.module + + # Convert the checkpoint dict to bytes, so that any GPU tensors that + # are in the checkpoint dict can be properly deserialized on the + # driver side, even if the driver does not have access to a GPU device. + _buffer = io.BytesIO() + torch.save(data_dict, _buffer, pickle_module=ray.cloudpickle) + return {ENCODED_DATA_KEY: _buffer.getvalue()} + + def _decode_data_dict(self, data_dict: dict) -> dict: + """Decode data_dict using torch_load if needed.""" + if ENCODED_DATA_KEY not in data_dict: + return data_dict + encoded_data = data_dict[ENCODED_DATA_KEY] + _buffer = io.BytesIO(encoded_data) + data_dict = torch.load( + _buffer, + map_location="cpu" + # Not using ray.cloudpickle here as it doesn't + # define an Unpickler (as it is not necessary). + ) + return data_dict + + def __getstate__(self) -> dict: + if self._data_dict: + state = self.__dict__.copy() + state["_data_dict"] = self._encode_data_dict(self._data_dict) + return state + return super().__getstate__() + + def __setstate__(self, state: dict): + if "_data_dict" in state: + state = state.copy() + state["_data_dict"] = self._decode_data_dict(state["_data_dict"]) + super().__setstate__(state) + @classmethod def from_state_dict( cls, diff --git a/python/ray/train/torch/torch_trainer.py b/python/ray/train/torch/torch_trainer.py index bbd5848c79b9..270d130badb9 100644 --- a/python/ray/train/torch/torch_trainer.py +++ b/python/ray/train/torch/torch_trainer.py @@ -142,7 +142,7 @@ def train_loop_per_worker(): session.report( {}, checkpoint=Checkpoint.from_dict( - dict(epoch=epoch, model=model.state_dict()) + dict(epoch=epoch, model=model.state_dict() ), ) diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index 5ede3472a331..9aadf2930f09 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -260,11 +260,11 @@ def _fetch_next_result(self) -> Optional[List[Dict]]: first_result = results[0] result_type = first_result.type if result_type is TrainingResultType.REPORT: - result_data = [self._backend.decode_data(r.data) for r in results] + result_data = [r.data for r in results] return result_data elif result_type is TrainingResultType.CHECKPOINT: self._checkpoint_manager._process_checkpoint( - results, decode_checkpoint_fn=self._backend.decode_data + results, decode_checkpoint_fn=self._backend._decode_data ) # Iterate until next REPORT call or training has finished. else: @@ -284,7 +284,7 @@ def _finish_checkpointing(self): # Process checkpoints and ignore other result types. if result_type is TrainingResultType.CHECKPOINT: self._checkpoint_manager._process_checkpoint( - results, decode_checkpoint_fn=self._backend.decode_data + results, decode_checkpoint_fn=self._backend._decode_data ) def _finish_training(self):