Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] Make ResultGrid return cloud checkpoints #31437

Merged
48 changes: 33 additions & 15 deletions python/ray/air/_internal/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class _TrackedCheckpoint:
into `"evaluation/episode_reward_mean"`.
node_ip: IP of the node where the checkpoint was generated. Defaults
to the current node.
local_to_remote_path_fn: Function replacing the local path with a remote path.
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
If specified, it means the data was synced to cloud. Only applied during
conversion to AIR checkpoint and only if ``dir_or_data`` is or
resolves to a directory path.
"""

def __init__(
Expand All @@ -64,12 +68,16 @@ def __init__(
checkpoint_id: Optional[int] = None,
metrics: Optional[Dict] = None,
node_ip: Optional[str] = None,
local_to_remote_path_fn: Optional[Callable[[str], str]] = None,
):
from ray.tune.result import NODE_IP

self.dir_or_data = dir_or_data
self.id = checkpoint_id
self.storage_mode = storage_mode
# This is a function because dir_or_data may be an object ref
# and we need to wait until its resolved first.
self.local_to_remote_path_fn = local_to_remote_path_fn

self.metrics = flatten_dict(metrics) if metrics else {}
self.node_ip = node_ip or self.metrics.get(NODE_IP, None)
Expand Down Expand Up @@ -144,22 +152,32 @@ def to_air_checkpoint(self) -> Optional[Checkpoint]:
if isinstance(checkpoint_data, ray.ObjectRef):
checkpoint_data = ray.get(checkpoint_data)

if isinstance(checkpoint_data, Checkpoint):
return checkpoint_data

if isinstance(checkpoint_data, str):
try:
checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data)
except FileNotFoundError:
if log_once("checkpoint_not_available"):
logger.error(
f"The requested checkpoint is not available on this node, "
f"most likely because you are using Ray client or disabled "
f"checkpoint synchronization. To avoid this, enable checkpoint "
f"synchronization to cloud storage by specifying a "
f"`SyncConfig`. The checkpoint may be available on a different "
f"node - please check this location on worker nodes: "
f"{checkpoint_data}"
)
return None
checkpoint = Checkpoint.from_directory(checkpoint_dir)
# Prefer cloud checkpoints.
if self.local_to_remote_path_fn:
checkpoint = Checkpoint.from_uri(
self.local_to_remote_path_fn(checkpoint_data)
)
else:
try:
checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data)
except FileNotFoundError:
if log_once("checkpoint_not_available"):
logger.error(
f"The requested checkpoint is not available on this node, "
f"most likely because you are using Ray client or disabled "
f"checkpoint synchronization. To avoid this, enable "
f"checkpoint synchronization to cloud storage by "
f"specifying a `SyncConfig`. The checkpoint may be "
f"available on a different node - please check this "
f"location on worker nodes: "
f"{checkpoint_data}"
)
return None
checkpoint = Checkpoint.from_directory(checkpoint_dir)
elif isinstance(checkpoint_data, bytes):
checkpoint = Checkpoint.from_bytes(checkpoint_data)
elif isinstance(checkpoint_data, dict):
Expand Down
11 changes: 10 additions & 1 deletion python/ray/tune/execution/ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,16 @@ def save(
else:
value = trial.runner.save.remote()
checkpoint = _TrackedCheckpoint(
dir_or_data=value, storage_mode=storage, metrics=result
dir_or_data=value,
storage_mode=storage,
metrics=result,
local_to_remote_path_fn=partial(
TrainableUtil.get_remote_storage_path,
logdir=trial.logdir,
remote_checkpoint_dir=trial.remote_checkpoint_dir,
)
if trial.uses_cloud_checkpointing
else None,
)
trial.saving_to = checkpoint
self._futures[value] = (_ExecutorEventType.SAVING_RESULT, trial)
Expand Down
48 changes: 47 additions & 1 deletion python/ray/tune/tests/test_result_grid.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import os
import pickle
from pathlib import Path
import shutil
from pathlib import Path
from typing import Optional, List

import pytest
import pandas as pd
Expand All @@ -14,6 +15,7 @@
from ray.tune.registry import get_trainable_cls
from ray.tune.result_grid import ResultGrid
from ray.tune.experiment import Trial
from ray.tune.syncer import Syncer
from ray.tune.tests.tune_test_util import create_tune_experiment_checkpoint


Expand All @@ -25,6 +27,21 @@ def ray_start_2_cpus():
ray.shutdown()


class MockSyncer(Syncer):
def sync_up(
self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
) -> bool:
return True

def sync_down(
self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
) -> bool:
return True

def delete(self, remote_dir: str) -> bool:
return True


def test_result_grid(ray_start_2_cpus):
def f(config):
# simulating the case that no report is called in train.
Expand Down Expand Up @@ -344,6 +361,35 @@ def train_func(config):
assert set(checkpoint_data) == {5, 6}


def test_result_grid_cloud_path(ray_start_2_cpus, tmpdir):
# Test that checkpoints returned by ResultGrid point to URI
# if upload_dir is specified in SyncConfig.
local_dir = Path(tmpdir) / "local_dir"
sync_config = tune.SyncConfig(upload_dir="s3://bucket", syncer=MockSyncer())

def trainable(config):
for i in range(5):
checkpoint = Checkpoint.from_dict({"model": i})
session.report(metrics={"metric": i}, checkpoint=checkpoint)

tuner = tune.Tuner(
trainable,
run_config=air.RunConfig(sync_config=sync_config, local_dir=local_dir),
tune_config=tune.TuneConfig(
metric="metric",
mode="max",
),
)
results = tuner.fit()
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
shutil.rmtree(local_dir)
best_checkpoint = results.get_best_result().checkpoint
assert not best_checkpoint.uri.startswith("file://")
assert (
best_checkpoint.get_internal_representation()
== results._experiment_analysis.best_checkpoint.get_internal_representation()
)


if __name__ == "__main__":
import sys

Expand Down
5 changes: 3 additions & 2 deletions python/ray/tune/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,9 @@ def uses_cloud_checkpointing(self):
def _storage_path(self, local_path):
"""Converts a `local_path` to be based off of
`self.remote_checkpoint_dir`."""
rel_local_path = os.path.relpath(local_path, self.logdir)
return os.path.join(self.remote_checkpoint_dir, rel_local_path)
return TrainableUtil.get_remote_storage_path(
local_path, self.logdir, self.remote_checkpoint_dir
)

@classmethod
def default_resource_request(
Expand Down
9 changes: 9 additions & 0 deletions python/ray/tune/trainable/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ def get_checkpoints_paths(logdir):
)
return chkpt_df

@staticmethod
def get_remote_storage_path(
local_path: str, logdir: str, remote_checkpoint_dir: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename remote_checkpoint_dir to remote_logdir? Seems like two different concepts with the current naming but one is just the cloud version of the other.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wanted to use the same names as in Trial.

) -> str:
"""Converts a ``local_path`` to be based off of
``remote_checkpoint_dir`` instead of ``logdir``."""
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
rel_local_path = os.path.relpath(local_path, logdir)
return os.path.join(remote_checkpoint_dir, rel_local_path)


@DeveloperAPI
class PlacementGroupUtil:
Expand Down