diff --git a/python/ray/air/_internal/checkpoint_manager.py b/python/ray/air/_internal/checkpoint_manager.py index 07573ed57334..cc008aa199ca 100644 --- a/python/ray/air/_internal/checkpoint_manager.py +++ b/python/ray/air/_internal/checkpoint_manager.py @@ -55,6 +55,11 @@ class _TrackedCheckpoint: into `"evaluation/episode_reward_mean"`. node_ip: IP of the node where the checkpoint was generated. Defaults to the current node. + local_dir_to_remote_uri_fn: Function that takes in this checkpoint's local + directory path and returns the corresponding remote URI in the cloud. + This should only be specified if the data was synced to cloud. + Only applied during conversion to AIR checkpoint and only + if ``dir_or_data`` is or resolves to a directory path. """ def __init__( @@ -64,12 +69,16 @@ def __init__( checkpoint_id: Optional[int] = None, metrics: Optional[Dict] = None, node_ip: Optional[str] = None, + local_to_remote_path_fn: Optional[Callable[[str], str]] = None, ): from ray.tune.result import NODE_IP self.dir_or_data = dir_or_data self.id = checkpoint_id self.storage_mode = storage_mode + # This is a function because dir_or_data may be an object ref + # and we need to wait until its resolved first. + self.local_to_remote_path_fn = local_to_remote_path_fn self.metrics = flatten_dict(metrics) if metrics else {} self.node_ip = node_ip or self.metrics.get(NODE_IP, None) @@ -144,22 +153,32 @@ def to_air_checkpoint(self) -> Optional[Checkpoint]: if isinstance(checkpoint_data, ray.ObjectRef): checkpoint_data = ray.get(checkpoint_data) + if isinstance(checkpoint_data, Checkpoint): + return checkpoint_data + if isinstance(checkpoint_data, str): - try: - checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data) - except FileNotFoundError: - if log_once("checkpoint_not_available"): - logger.error( - f"The requested checkpoint is not available on this node, " - f"most likely because you are using Ray client or disabled " - f"checkpoint synchronization. To avoid this, enable checkpoint " - f"synchronization to cloud storage by specifying a " - f"`SyncConfig`. The checkpoint may be available on a different " - f"node - please check this location on worker nodes: " - f"{checkpoint_data}" - ) - return None - checkpoint = Checkpoint.from_directory(checkpoint_dir) + # Prefer cloud checkpoints. + if self.local_to_remote_path_fn: + checkpoint = Checkpoint.from_uri( + self.local_to_remote_path_fn(checkpoint_data) + ) + else: + try: + checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data) + except FileNotFoundError: + if log_once("checkpoint_not_available"): + logger.error( + f"The requested checkpoint is not available on this node, " + f"most likely because you are using Ray client or disabled " + f"checkpoint synchronization. To avoid this, enable " + f"checkpoint synchronization to cloud storage by " + f"specifying a `SyncConfig`. The checkpoint may be " + f"available on a different node - please check this " + f"location on worker nodes: " + f"{checkpoint_data}" + ) + return None + checkpoint = Checkpoint.from_directory(checkpoint_dir) elif isinstance(checkpoint_data, bytes): checkpoint = Checkpoint.from_bytes(checkpoint_data) elif isinstance(checkpoint_data, dict): diff --git a/python/ray/tune/execution/ray_trial_executor.py b/python/ray/tune/execution/ray_trial_executor.py index b7522a00362e..a3e8f3f56af8 100644 --- a/python/ray/tune/execution/ray_trial_executor.py +++ b/python/ray/tune/execution/ray_trial_executor.py @@ -986,7 +986,16 @@ def save( else: value = trial.runner.save.remote() checkpoint = _TrackedCheckpoint( - dir_or_data=value, storage_mode=storage, metrics=result + dir_or_data=value, + storage_mode=storage, + metrics=result, + local_to_remote_path_fn=partial( + TrainableUtil.get_remote_storage_path, + logdir=trial.logdir, + remote_checkpoint_dir=trial.remote_checkpoint_dir, + ) + if trial.uses_cloud_checkpointing + else None, ) trial.saving_to = checkpoint self._futures[value] = (_ExecutorEventType.SAVING_RESULT, trial) diff --git a/python/ray/tune/tests/test_result_grid.py b/python/ray/tune/tests/test_result_grid.py index 914e17a91505..d2873573ef96 100644 --- a/python/ray/tune/tests/test_result_grid.py +++ b/python/ray/tune/tests/test_result_grid.py @@ -1,8 +1,9 @@ import json import os import pickle -from pathlib import Path import shutil +from pathlib import Path +from typing import Optional, List import pytest import pandas as pd @@ -14,6 +15,7 @@ from ray.tune.registry import get_trainable_cls from ray.tune.result_grid import ResultGrid from ray.tune.experiment import Trial +from ray.tune.syncer import Syncer from ray.tune.tests.tune_test_util import create_tune_experiment_checkpoint @@ -25,6 +27,21 @@ def ray_start_2_cpus(): ray.shutdown() +class MockSyncer(Syncer): + def sync_up( + self, local_dir: str, remote_dir: str, exclude: Optional[List] = None + ) -> bool: + return True + + def sync_down( + self, remote_dir: str, local_dir: str, exclude: Optional[List] = None + ) -> bool: + return True + + def delete(self, remote_dir: str) -> bool: + return True + + def test_result_grid(ray_start_2_cpus): def f(config): # simulating the case that no report is called in train. @@ -344,6 +361,35 @@ def train_func(config): assert set(checkpoint_data) == {5, 6} +def test_result_grid_cloud_path(ray_start_2_cpus, tmpdir): + # Test that checkpoints returned by ResultGrid point to URI + # if upload_dir is specified in SyncConfig. + local_dir = Path(tmpdir) / "local_dir" + sync_config = tune.SyncConfig(upload_dir="s3://bucket", syncer=MockSyncer()) + + def trainable(config): + for i in range(5): + checkpoint = Checkpoint.from_dict({"model": i}) + session.report(metrics={"metric": i}, checkpoint=checkpoint) + + tuner = tune.Tuner( + trainable, + run_config=air.RunConfig(sync_config=sync_config, local_dir=local_dir), + tune_config=tune.TuneConfig( + metric="metric", + mode="max", + ), + ) + results = tuner.fit() + shutil.rmtree(local_dir) + best_checkpoint = results.get_best_result().checkpoint + assert not best_checkpoint.uri.startswith("file://") + assert ( + best_checkpoint.get_internal_representation() + == results._experiment_analysis.best_checkpoint.get_internal_representation() + ) + + if __name__ == "__main__": import sys diff --git a/python/ray/tune/trainable/trainable.py b/python/ray/tune/trainable/trainable.py index 508997f0e6db..fb1558d8366a 100644 --- a/python/ray/tune/trainable/trainable.py +++ b/python/ray/tune/trainable/trainable.py @@ -195,8 +195,9 @@ def uses_cloud_checkpointing(self): def _storage_path(self, local_path): """Converts a `local_path` to be based off of `self.remote_checkpoint_dir`.""" - rel_local_path = os.path.relpath(local_path, self.logdir) - return os.path.join(self.remote_checkpoint_dir, rel_local_path) + return TrainableUtil.get_remote_storage_path( + local_path, self.logdir, self.remote_checkpoint_dir + ) @classmethod def default_resource_request( diff --git a/python/ray/tune/trainable/util.py b/python/ray/tune/trainable/util.py index 52095d31d92a..0f696e7871c9 100644 --- a/python/ray/tune/trainable/util.py +++ b/python/ray/tune/trainable/util.py @@ -185,6 +185,17 @@ def get_checkpoints_paths(logdir): ) return chkpt_df + @staticmethod + def get_remote_storage_path( + local_path: str, logdir: str, remote_checkpoint_dir: str + ) -> str: + """Converts a ``local_path`` to be based off of + ``remote_checkpoint_dir`` instead of ``logdir``. + + ``logdir`` is assumed to be a prefix of ``local_path``.""" + rel_local_path = os.path.relpath(local_path, logdir) + return os.path.join(remote_checkpoint_dir, rel_local_path) + @DeveloperAPI class PlacementGroupUtil: