Skip to content

Commit

Permalink
[air] pyarrow.fs persistence (9/n): ray.train.Checkpoint restore:…
Browse files Browse the repository at this point in the history
… Manual restore (ray-project#38128)

Signed-off-by: Shreyas Krishnaswamy <[email protected]>
  • Loading branch information
justinvyu authored and shrekris-anyscale committed Aug 10, 2023
1 parent 3f4ed33 commit c5c313a
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 168 deletions.
79 changes: 51 additions & 28 deletions python/ray/train/_internal/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,44 @@ def _create_directory(fs: pyarrow.fs.FileSystem, fs_path: str) -> None:
)


def get_fs_and_path(
storage_path: Union[str, os.PathLike],
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
) -> Tuple[pyarrow.fs.FileSystem, str]:
"""Returns the fs and path from a storage path and an optional custom fs.
Args:
storage_path: A storage path or URI. (ex: s3://bucket/path or /tmp/ray_results)
storage_filesystem: A custom filesystem to use. If not provided,
this will be auto-resolved by pyarrow. If provided, the storage_path
is assumed to be prefix-stripped already, and must be a valid path
on the filesystem.
Raises:
ValueError: if the storage path is a URI and a custom filesystem is given.
"""
storage_path = str(storage_path)

if storage_filesystem:
if is_uri(storage_path):
raise ValueError(
"If you specify a custom `storage_filesystem`, the corresponding "
"`storage_path` must be a *path* on that filesystem, not a URI.\n"
"For example: "
"(storage_filesystem=CustomS3FileSystem(), "
"storage_path='s3://bucket/path') should be changed to "
"(storage_filesystem=CustomS3FileSystem(), "
"storage_path='bucket/path')\n"
"This is what you provided: "
f"(storage_filesystem={storage_filesystem}, "
f"storage_path={storage_path})\n"
"Note that this may depend on the custom filesystem you use."
)
return storage_filesystem, storage_path

return pyarrow.fs.FileSystem.from_uri(storage_path)


class _FilesystemSyncer(_BackgroundSyncer):
"""Syncer between local filesystem and a `storage_filesystem`."""

Expand Down Expand Up @@ -402,7 +440,7 @@ def __init__(
trial_dir_name: Optional[str] = None,
current_checkpoint_index: int = 0,
):
storage_path_provided = storage_path is not None
custom_fs_provided = storage_filesystem is not None

self.storage_local_path = _get_defaults_results_dir()
# If `storage_path=None`, then set it to the local path.
Expand All @@ -414,32 +452,12 @@ def __init__(
self.current_checkpoint_index = current_checkpoint_index
self.sync_config = dataclasses.replace(sync_config)

if storage_filesystem:
# Custom pyarrow filesystem
self.storage_filesystem = storage_filesystem
if is_uri(self.storage_path):
raise ValueError(
"If you specify a custom `storage_filesystem`, the corresponding "
"`storage_path` must be a *path* on that filesystem, not a URI.\n"
"For example: "
"(storage_filesystem=CustomS3FileSystem(), "
"storage_path='s3://bucket/path') should be changed to "
"(storage_filesystem=CustomS3FileSystem(), "
"storage_path='bucket/path')\n"
"This is what you provided: "
f"(storage_filesystem={storage_filesystem}, "
f"storage_path={storage_path})\n"
"Note that this may depend on the custom filesystem you use."
)
self.storage_fs_path = self.storage_path
else:
(
self.storage_filesystem,
self.storage_fs_path,
) = pyarrow.fs.FileSystem.from_uri(self.storage_path)
self.storage_filesystem, self.storage_fs_path = get_fs_and_path(
self.storage_path, storage_filesystem
)

# The storage prefix is the URI that remains after stripping the
# URI prefix away from the user-provided `storage_path` (using `from_uri`).
# The storage prefix is part of the URI that is stripped away
# from the user-provided `storage_path` by pyarrow's `from_uri`.
# Ex: `storage_path="s3://bucket/path?param=1`
# -> `storage_prefix=URI<s3://.?param=1>`
# See the doctests for more examples.
Expand All @@ -450,14 +468,19 @@ def __init__(
Path(self.storage_fs_path)
)

# Only initialize a syncer if a `storage_path` was provided.
# Syncing is always needed if a custom `storage_filesystem` is provided.
# Otherwise, syncing is only needed if storage_local_path
# and storage_fs_path point to different locations.
syncing_needed = (
custom_fs_provided or self.storage_fs_path != self.storage_local_path
)
self.syncer: Optional[Syncer] = (
_FilesystemSyncer(
storage_filesystem=self.storage_filesystem,
sync_period=self.sync_config.sync_period,
sync_timeout=self.sync_config.sync_timeout,
)
if storage_path_provided
if syncing_needed
else None
)

Expand Down
44 changes: 34 additions & 10 deletions python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
import warnings

import pyarrow.fs

import ray
import ray.cloudpickle as pickle
from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
Expand All @@ -23,7 +25,11 @@
from ray.air.result import Result
from ray.train._checkpoint import Checkpoint as NewCheckpoint
from ray.train._internal import session
from ray.train._internal.storage import _use_storage_context
from ray.train._internal.storage import (
_exists_at_fs_path,
_use_storage_context,
get_fs_and_path,
)
from ray.train.constants import TRAIN_DATASET_KEY
from ray.util import PublicAPI
from ray.util.annotations import DeveloperAPI
Expand Down Expand Up @@ -195,8 +201,9 @@ def __init__(
self.preprocessor = preprocessor
self.starting_checkpoint = resume_from_checkpoint

# This path should only be set through restore
# These attributes should only be set through `BaseTrainer.restore`
self._restore_path = None
self._restore_storage_filesystem = None

self._validate_attributes()

Expand All @@ -214,7 +221,8 @@ def __init__(
@classmethod
def restore(
cls: Type["BaseTrainer"],
path: str,
path: Union[str, os.PathLike],
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional["Preprocessor"] = None,
scaling_config: Optional[ScalingConfig] = None,
Expand Down Expand Up @@ -298,17 +306,23 @@ def training_loop(self):
Returns:
BaseTrainer: A restored instance of the class that is calling this method.
"""
if not cls.can_restore(path):
if not cls.can_restore(path, storage_filesystem):
raise ValueError(
f"Invalid restore path: {path}. Make sure that this path exists and "
"is the experiment directory that results from a call to "
"`trainer.fit()`."
)
trainer_state_path = cls._maybe_sync_down_trainer_state(path)
assert trainer_state_path.exists()
if _use_storage_context():
fs, fs_path = get_fs_and_path(path, storage_filesystem)
with fs.open_input_file(os.path.join(fs_path, _TRAINER_PKL)) as f:
trainer_cls, param_dict = pickle.loads(f.readall())
else:
trainer_state_path = cls._maybe_sync_down_trainer_state(path)
assert trainer_state_path.exists()

with open(trainer_state_path, "rb") as fp:
trainer_cls, param_dict = pickle.load(fp)

with open(trainer_state_path, "rb") as fp:
trainer_cls, param_dict = pickle.load(fp)
if trainer_cls is not cls:
warnings.warn(
f"Invalid trainer type. You are attempting to restore a trainer of type"
Expand Down Expand Up @@ -355,11 +369,16 @@ def training_loop(self):
f"`{cls.__name__}.restore`\n"
) from e
trainer._restore_path = path
trainer._restore_storage_filesystem = storage_filesystem
return trainer

@PublicAPI(stability="alpha")
@classmethod
def can_restore(cls: Type["BaseTrainer"], path: Union[str, Path]) -> bool:
def can_restore(
cls: Type["BaseTrainer"],
path: Union[str, os.PathLike],
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
) -> bool:
"""Checks whether a given directory contains a restorable Train experiment.
Args:
Expand All @@ -370,6 +389,10 @@ def can_restore(cls: Type["BaseTrainer"], path: Union[str, Path]) -> bool:
Returns:
bool: Whether this path exists and contains the trainer state to resume from
"""
if _use_storage_context():
fs, fs_path = get_fs_and_path(path, storage_filesystem)
return _exists_at_fs_path(fs, os.path.join(fs_path, _TRAINER_PKL))

return _TRAINER_PKL in list_at_uri(str(path))

def __repr__(self):
Expand Down Expand Up @@ -589,11 +612,12 @@ def fit(self) -> Result:

if self._restore_path:
tuner = Tuner.restore(
self._restore_path,
path=self._restore_path,
trainable=trainable,
param_space=param_space,
resume_unfinished=True,
resume_errored=True,
storage_filesystem=self._restore_storage_filesystem,
)
else:
tuner = Tuner(
Expand Down
8 changes: 2 additions & 6 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,7 @@ def restore(
Union[Callable[[], None], Callable[[Dict], None]]
] = None,
train_loop_config: Optional[Dict] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional["Preprocessor"] = None,
scaling_config: Optional[ScalingConfig] = None,
**kwargs,
) -> "DataParallelTrainer":
"""Restores a DataParallelTrainer from a previously interrupted/failed run.
Expand All @@ -366,9 +364,7 @@ def restore(
path=path,
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
datasets=datasets,
preprocessor=preprocessor,
scaling_config=scaling_config,
**kwargs,
)

def _validate_attributes(self):
Expand Down
26 changes: 0 additions & 26 deletions python/ray/train/lightning/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,32 +487,6 @@ def _unify_checkpoint_configs(
else:
return air_ckpt_config

@PublicAPI(stability="alpha")
@classmethod
def restore(
cls: Type["LightningTrainer"],
path: str,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional["Preprocessor"] = None,
scaling_config: Optional[ScalingConfig] = None,
**kwargs,
) -> "LightningTrainer":
"""Restores a LightningTrainer from a previously interrupted/failed run.
See :meth:`BaseTrainer.restore() <ray.train.trainer.BaseTrainer.restore>`
for descriptions of the arguments.
Returns:
LightningTrainer: A restored instance of `LightningTrainer`
"""
return super(LightningTrainer, cls).restore(
path=path,
datasets=datasets,
preprocessor=preprocessor,
scaling_config=scaling_config,
**kwargs,
)


def _lightning_train_loop_per_worker(config):
"""Per-worker training loop for a Lightning Trainer."""
Expand Down
43 changes: 33 additions & 10 deletions python/ray/train/tests/test_new_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

import pyarrow.fs

import ray
from ray import train, tune
from ray.air._internal.uri_utils import URI
from ray.air.constants import EXPR_RESULT_FILE
from ray.train._internal.storage import _download_from_fs_path, StorageContext
from ray.train._checkpoint import Checkpoint as NewCheckpoint
from ray.train.base_trainer import TrainingFailedError
from ray.train.data_parallel_trainer import DataParallelTrainer

from ray.air.tests.test_checkpoints import mock_s3_bucket_uri
Expand All @@ -26,11 +29,20 @@ def dummy_context_manager():
yield "dummy value"


@pytest.fixture(autouse=True)
def enable_new_persistence_mode(monkeypatch):
monkeypatch.setenv("RAY_AIR_NEW_PERSISTENCE_MODE", "1")
@pytest.fixture(scope="module")
def enable_new_persistence_mode():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("RAY_AIR_NEW_PERSISTENCE_MODE", "1")
yield
mp.setenv("RAY_AIR_NEW_PERSISTENCE_MODE", "0")


@pytest.fixture(autouse=True, scope="module")
def ray_start_4_cpus(enable_new_persistence_mode):
# Make sure to set the env var before calling ray.init()
ray.init(num_cpus=4)
yield
monkeypatch.setenv("RAY_AIR_NEW_PERSISTENCE_MODE", "0")
ray.shutdown()


def _create_mock_custom_fs(custom_fs_root_dir: Path) -> pyarrow.fs.FileSystem:
Expand Down Expand Up @@ -309,11 +321,12 @@ def test_trainer(
"""
TODO(justinvyu): Test for these once implemented:
- artifacts
- restoration, train.get_checkpoint
{storage_path}/{exp_name}
├── experiment_state-2023-07-28_10-00-38.json
├── experiment_state-2023-07-28_10-00-38.json <- Initial exp state
├── basic-variant-state-2023-07-28_10-00-38.json
├── experiment_state-2023-07-28_10-01-38.json <- Restored exp state
├── basic-variant-state-2023-07-28_10-01-38.json
├── trainer.pkl
├── tuner.pkl
└── DataParallelTrainer_46367_00000_0_...
Expand Down Expand Up @@ -358,11 +371,19 @@ def test_trainer(
name=exp_name,
verbose=0,
checkpoint_config=checkpoint_config,
failure_config=train.FailureConfig(max_failures=2),
failure_config=train.FailureConfig(max_failures=1),
),
)
print("\nStarting initial run.\n")
result = trainer.fit()
with pytest.raises(TrainingFailedError):
result = trainer.fit()

print("\nStarting manually restored run.\n")
restored_trainer = DataParallelTrainer.restore(
path=str(URI(storage_path or str(LOCAL_CACHE_DIR)) / exp_name),
storage_filesystem=storage_filesystem,
)
result = restored_trainer.fit()

with monkeypatch.context() as m:
# This is so that the `resume_from_checkpoint` run doesn't mess up the
Expand Down Expand Up @@ -390,10 +411,12 @@ def test_trainer(
exp_dir = local_inspect_dir / exp_name

# Files synced by the driver
assert len(list(exp_dir.glob("basic-variant-state-*"))) == 1
assert len(list(exp_dir.glob("experiment_state-*"))) == 1
assert len(list(exp_dir.glob("tuner.pkl"))) == 1
assert len(list(exp_dir.glob("trainer.pkl"))) == 1
# 2 copies of these files:
# 1 for the initial run, and 1 for the manually restored run.
assert len(list(exp_dir.glob("basic-variant-state-*"))) == 2
assert len(list(exp_dir.glob("experiment_state-*"))) == 2

# Files synced by the worker
assert len(list(exp_dir.glob("DataParallelTrainer_*"))) == 1
Expand Down
5 changes: 4 additions & 1 deletion python/ray/tune/analysis/experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
EXPR_PARAM_FILE,
TRAINING_ITERATION,
)
from ray.train._internal.storage import _use_storage_context
from ray.tune.syncer import SyncConfig
from ray.tune.utils import flatten_dict
from ray.tune.utils.serialization import TuneFunctionDecoder
Expand Down Expand Up @@ -980,7 +981,9 @@ def _get_trial_paths(self) -> List[str]:
for trial_json_state, path in self._checkpoints_and_paths:
try:
trial = Trial.from_json_state(trial_json_state, stub=True)
trial.local_experiment_path = str(path)
# TODO(justinvyu): [handle_moved_storage_path]
if not _use_storage_context():
trial.local_experiment_path = str(path)
except Exception:
logger.warning(
f"Could not load trials from experiment checkpoint. "
Expand Down
Loading

0 comments on commit c5c313a

Please sign in to comment.