Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air] pyarrow.fs persistence (9/n): ray.train.Checkpoint restore: Manual restore #38128

Merged
Merged
Show file tree
Hide file tree
Changes from 142 commits
Commits
Show all changes
144 commits
Select commit Hold shift + click to select a range
abb1307
Pipe storage context to Trainable (used now for Trainable syncing)
justinvyu Jul 23, 2023
f6ff90a
Don't use the storage context in the trial/trainable
justinvyu Jul 27, 2023
562369f
Disable all trainable syncing in new codepath
justinvyu Jul 27, 2023
95a3d20
Pipe storage context to Train workers (not actually used yet)
justinvyu Jul 23, 2023
484e67f
Fix race condition for setting checkpoint_uri
justinvyu Jul 24, 2023
2148669
Fix cyclical import
justinvyu Jul 27, 2023
8c856b8
Add simple trainer test
justinvyu Jul 27, 2023
78c525f
Add legacy prefix to train session checkpoint uri
justinvyu Jul 27, 2023
e97f471
Add new checkpoint class
justinvyu Jul 27, 2023
64945be
New train session report implementation using new checkpoint
justinvyu Jul 28, 2023
c6480c9
Simplify checkpoint propagation from user code (in worker) -> trainer…
justinvyu Jul 28, 2023
c681ccb
New tune session.report
justinvyu Jul 28, 2023
795bafe
Save direction works with new checkpoint API
justinvyu Jul 28, 2023
8a084bc
Update test with e2e trainer test
justinvyu Jul 28, 2023
725d802
Make callback supporting new checkpoint a todo for now
justinvyu Jul 28, 2023
877acb9
Remove unnecessary comment
justinvyu Jul 28, 2023
ee4ccbd
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Jul 28, 2023
88042b3
Separate out the new set checkpoint id from the old set checkpoint uri
justinvyu Jul 28, 2023
a5eeab2
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Jul 31, 2023
a6cd9dc
Update id -> index
justinvyu Jul 31, 2023
01f34bb
Address comments on error to raise with old ckpt type
justinvyu Jul 31, 2023
65e7a27
Move checkpoint upload logic to a helper fn of storage ctx
justinvyu Jul 31, 2023
f2a4c36
Drop a checkpoint marker after uploading
justinvyu Jul 31, 2023
49ee126
Add a simplified checkpoint manager
justinvyu Aug 1, 2023
ffa0dd4
Fixes to checkpoint manager
justinvyu Aug 1, 2023
15553f7
Add unit test for simplified checkpoint manager
justinvyu Aug 1, 2023
00cc9d7
Full test coverage
justinvyu Aug 1, 2023
cb5990e
Add a simplified checkpoint manager
justinvyu Aug 1, 2023
2db9aae
Fixes to checkpoint manager
justinvyu Aug 1, 2023
a2067b7
Add unit test for simplified checkpoint manager
justinvyu Aug 1, 2023
f1216f2
Full test coverage
justinvyu Aug 1, 2023
d4243e6
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 1, 2023
6699d81
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 1, 2023
9b9ff34
Simplify even more
justinvyu Aug 1, 2023
83aecd9
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 1, 2023
913af10
Patch fix for circular imports
justinvyu Aug 1, 2023
6b5d34e
Use new checkpoint manager in Tune ckpt book-keeping
justinvyu Aug 2, 2023
24f441a
Update result to return a train.Checkpoint to the user
justinvyu Aug 2, 2023
504ed54
Update e2e test to try multiple ckpt configs for trainer test
justinvyu Aug 2, 2023
1992161
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
b9eb88f
Fix lint for trial.py
justinvyu Aug 2, 2023
a6115b3
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
7cc74d9
Rename _TrackedCheckpoint -> _TrainingResult
justinvyu Aug 2, 2023
6a0e1fb
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
4662789
Merge branch 'air/persistence/simplified_ckpt_manager' into air/persi…
justinvyu Aug 2, 2023
8da0477
Fixes after merging latest ckpt manager changes
justinvyu Aug 2, 2023
255b149
Remove prints / convert to logger.debug
justinvyu Aug 2, 2023
0971aca
Don't set training iteration as the default checkpoint_score_attr
justinvyu Aug 2, 2023
6e7a873
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
6f6a341
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
d9804b0
Fix test to reflect working dir change
justinvyu Aug 2, 2023
318158f
Don't upload a .is_checkpoint marker
justinvyu Aug 2, 2023
a54664c
Add back cwd check
justinvyu Aug 2, 2023
c4263ec
Update the dir trees + better naming for ckpt shards and artifacts
justinvyu Aug 2, 2023
0cd7e47
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
3a6eba6
A different fix for the circular dep
justinvyu Aug 2, 2023
b65e9fe
Update checkpoint -> _checkpoint imports
justinvyu Aug 2, 2023
b89bd1c
fix lint
justinvyu Aug 2, 2023
cada06a
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
3b784d7
Revert all changes to ckpt manager
justinvyu Aug 2, 2023
49c1ead
Don't set checkpoint user metadata
justinvyu Aug 2, 2023
7177940
Remove remaining print
justinvyu Aug 2, 2023
ae8a9ec
Add trial_path property to storage ctx
justinvyu Aug 3, 2023
c1c8441
Use storage context for all experiment/trial path properties
justinvyu Aug 3, 2023
5d2ca07
Don't skip trainer test cases for custom_fs
justinvyu Aug 3, 2023
1fcfb3f
Split some utilities into helper methods + test for ResultGrid paths
justinvyu Aug 3, 2023
5e2a933
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 3, 2023
61fdadf
Prepend legacy to old path attributes in trial
justinvyu Aug 3, 2023
d38cd87
Remove todo
justinvyu Aug 3, 2023
3ba944a
Bump the test size
justinvyu Aug 3, 2023
5f71608
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 3, 2023
52d6c14
Merge branch 'air/persistence/new_checkpoint' into air/persistence/fi…
justinvyu Aug 3, 2023
9c16120
Clean up experiment path handling
justinvyu Aug 3, 2023
76468d9
Fix for base trainer
justinvyu Aug 3, 2023
b17c17e
Fix for base trainer pt 2
justinvyu Aug 3, 2023
30e3328
Add in missing legacy property
justinvyu Aug 3, 2023
f25ad39
Prepend legacy to old path attributes in experiment
justinvyu Aug 3, 2023
e1846ec
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 3, 2023
d11ede8
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 3, 2023
c99d30c
too much space
justinvyu Aug 3, 2023
9d11d2d
remove unused var
justinvyu Aug 3, 2023
950b991
Fix lint
justinvyu Aug 3, 2023
ed86255
restore mostly works
justinvyu Aug 4, 2023
de4b924
hacky way of getting checkpoint folders to increment correctly
justinvyu Aug 4, 2023
e060476
Fix for xgboost trainer
justinvyu Aug 4, 2023
bd5c846
Fix race as_directory / download file lock race condition
justinvyu Aug 4, 2023
e51fb17
Update test with auto-recovery fault tolerance
justinvyu Aug 4, 2023
6368075
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 4, 2023
eaa26c5
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 4, 2023
0c3c5c8
compute storage_prefix
justinvyu Aug 4, 2023
217af77
Remove '_path' properties from storage
justinvyu Aug 4, 2023
8e2330c
Move exp dir name helper to storage ctx
justinvyu Aug 4, 2023
7502cca
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 4, 2023
f3f22fd
Fix bugs causing broken CI
justinvyu Aug 4, 2023
2c9adf4
Merge branch 'air/persistence/fix_custom_fs_path_expansion' into air/…
justinvyu Aug 4, 2023
dc1c7c3
Fix syncing needed logic to handle storage path == local path case
justinvyu Aug 4, 2023
a262cb3
working for manual Trainer.restore
justinvyu Aug 4, 2023
e59b408
Add manual restore to e2e test
justinvyu Aug 4, 2023
c6c3dfe
Fix renamed attribute in mock test class
justinvyu Aug 4, 2023
4e56bd0
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 4, 2023
314e8bd
Merge branch 'air/persistence/fix_custom_fs_path_expansion' into air/…
justinvyu Aug 4, 2023
36464af
fix storage attr setting to only happen if ff enabled
justinvyu Aug 5, 2023
6e73f6e
cleanup on errors in as_directory
justinvyu Aug 5, 2023
c7de72a
Support + test resume_from_checkpoint
justinvyu Aug 5, 2023
db016da
Fix result grid bug when no checkpoints saved
justinvyu Aug 5, 2023
bcbcec9
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 5, 2023
d7497e1
fix merge conflict remainder
justinvyu Aug 5, 2023
aeb89ba
Recover trainable metadata from last_result rather than .tune_metadata
justinvyu Aug 7, 2023
9556371
Fix restore info log
justinvyu Aug 7, 2023
e897eaa
Keep current checkpoint index synchronized on the driver
justinvyu Aug 7, 2023
3eef417
Remove checkpoint dirname parsing
justinvyu Aug 7, 2023
0e52384
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 7, 2023
89631ab
Update todo comment
justinvyu Aug 7, 2023
d4e20f2
Fix lint
justinvyu Aug 7, 2023
877c4ae
Merge branch 'air/persistence/restore_new_checkpoint_autoft' into air…
justinvyu Aug 7, 2023
194dc37
Merge branch 'air/persistence/restore_new_checkpoint_autoft' into air…
justinvyu Aug 7, 2023
5c50dd4
Some small imports cleanup
justinvyu Aug 7, 2023
d367e8d
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 8, 2023
d329e1b
Fix e2e test for storage_path=None case
justinvyu Aug 8, 2023
e382e29
Remove unused code
justinvyu Aug 8, 2023
550575e
Merge branch 'air/persistence/restore_new_checkpoint_rfc' into air/pe…
justinvyu Aug 8, 2023
86367ef
Guard new codepath correctly
justinvyu Aug 8, 2023
3317e3b
Separate out fs resolution into a helper
justinvyu Aug 8, 2023
6fb9064
Add custom filesystem arg on restore
justinvyu Aug 8, 2023
f388dcc
Don't skip the custom fs test case for restore
justinvyu Aug 8, 2023
d881c0d
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 8, 2023
ade1078
clean up some imports
justinvyu Aug 8, 2023
8b78dee
Fix the test fixtures
justinvyu Aug 8, 2023
075c6a6
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 8, 2023
cf963fd
Remove done todo
justinvyu Aug 8, 2023
1d87dde
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 9, 2023
ca5f5bf
Fix optional can_restore argument
justinvyu Aug 9, 2023
75e947c
Remove duplicate in test
justinvyu Aug 9, 2023
5c1c282
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 9, 2023
a13d88d
read file directly from fs for trainer restore
justinvyu Aug 9, 2023
7e468c6
check for existence rather than list
justinvyu Aug 9, 2023
13c224f
Update tuner restore
justinvyu Aug 9, 2023
1085666
Mark experiment_checkpoint_dir as legacy
justinvyu Aug 9, 2023
0957d68
Revert changes to sync down logic in trainer
justinvyu Aug 9, 2023
81c307f
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 9, 2023
1716fa7
Fix lint
justinvyu Aug 9, 2023
5b6be43
minor fixes
justinvyu Aug 9, 2023
413ed38
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 9, 2023
ca0a3a0
Remove backwards compatibility test
justinvyu Aug 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 51 additions & 28 deletions python/ray/train/_internal/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,44 @@ def _create_directory(fs: pyarrow.fs.FileSystem, fs_path: str) -> None:
)


def get_fs_and_path(
storage_path: Union[str, os.PathLike],
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
) -> Tuple[pyarrow.fs.FileSystem, str]:
"""Returns the fs and path from a storage path and an optional custom fs.
Args:
storage_path: A storage path or URI. (ex: s3://bucket/path or /tmp/ray_results)
storage_filesystem: A custom filesystem to use. If not provided,
this will be auto-resolved by pyarrow. If provided, the storage_path
is assumed to be prefix-stripped already, and must be a valid path
on the filesystem.
Raises:
ValueError: if the storage path is a URI and a custom filesystem is given.
"""
storage_path = str(storage_path)

if storage_filesystem:
if is_uri(storage_path):
raise ValueError(
"If you specify a custom `storage_filesystem`, the corresponding "
"`storage_path` must be a *path* on that filesystem, not a URI.\n"
"For example: "
"(storage_filesystem=CustomS3FileSystem(), "
"storage_path='s3://bucket/path') should be changed to "
"(storage_filesystem=CustomS3FileSystem(), "
"storage_path='bucket/path')\n"
"This is what you provided: "
f"(storage_filesystem={storage_filesystem}, "
f"storage_path={storage_path})\n"
"Note that this may depend on the custom filesystem you use."
)
return storage_filesystem, storage_path

return pyarrow.fs.FileSystem.from_uri(storage_path)


class _FilesystemSyncer(_BackgroundSyncer):
"""Syncer between local filesystem and a `storage_filesystem`."""

Expand Down Expand Up @@ -402,7 +440,7 @@ def __init__(
trial_dir_name: Optional[str] = None,
current_checkpoint_index: int = 0,
):
storage_path_provided = storage_path is not None
custom_fs_provided = storage_filesystem is not None

self.storage_local_path = _get_defaults_results_dir()
# If `storage_path=None`, then set it to the local path.
Expand All @@ -414,32 +452,12 @@ def __init__(
self.current_checkpoint_index = current_checkpoint_index
self.sync_config = dataclasses.replace(sync_config)

if storage_filesystem:
# Custom pyarrow filesystem
self.storage_filesystem = storage_filesystem
if is_uri(self.storage_path):
raise ValueError(
"If you specify a custom `storage_filesystem`, the corresponding "
"`storage_path` must be a *path* on that filesystem, not a URI.\n"
"For example: "
"(storage_filesystem=CustomS3FileSystem(), "
"storage_path='s3://bucket/path') should be changed to "
"(storage_filesystem=CustomS3FileSystem(), "
"storage_path='bucket/path')\n"
"This is what you provided: "
f"(storage_filesystem={storage_filesystem}, "
f"storage_path={storage_path})\n"
"Note that this may depend on the custom filesystem you use."
)
self.storage_fs_path = self.storage_path
else:
(
self.storage_filesystem,
self.storage_fs_path,
) = pyarrow.fs.FileSystem.from_uri(self.storage_path)
self.storage_filesystem, self.storage_fs_path = get_fs_and_path(
self.storage_path, storage_filesystem
)

# The storage prefix is the URI that remains after stripping the
# URI prefix away from the user-provided `storage_path` (using `from_uri`).
# The storage prefix is part of the URI that is stripped away
# from the user-provided `storage_path` by pyarrow's `from_uri`.
# Ex: `storage_path="s3://bucket/path?param=1`
# -> `storage_prefix=URI<s3://.?param=1>`
# See the doctests for more examples.
Expand All @@ -450,14 +468,19 @@ def __init__(
Path(self.storage_fs_path)
)

# Only initialize a syncer if a `storage_path` was provided.
# Syncing is always needed if a custom `storage_filesystem` is provided.
# Otherwise, syncing is only needed if storage_local_path
# and storage_fs_path point to different locations.
syncing_needed = (
custom_fs_provided or self.storage_fs_path != self.storage_local_path
)
self.syncer: Optional[Syncer] = (
_FilesystemSyncer(
storage_filesystem=self.storage_filesystem,
sync_period=self.sync_config.sync_period,
sync_timeout=self.sync_config.sync_timeout,
)
if storage_path_provided
if syncing_needed
else None
)

Expand Down
44 changes: 34 additions & 10 deletions python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
import warnings

import pyarrow.fs

import ray
import ray.cloudpickle as pickle
from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
Expand All @@ -23,7 +25,11 @@
from ray.air.result import Result
from ray.train._checkpoint import Checkpoint as NewCheckpoint
from ray.train._internal import session
from ray.train._internal.storage import _use_storage_context
from ray.train._internal.storage import (
_exists_at_fs_path,
_use_storage_context,
get_fs_and_path,
)
from ray.train.constants import TRAIN_DATASET_KEY
from ray.util import PublicAPI
from ray.util.annotations import DeveloperAPI
Expand Down Expand Up @@ -195,8 +201,9 @@ def __init__(
self.preprocessor = preprocessor
self.starting_checkpoint = resume_from_checkpoint

# This path should only be set through restore
# These attributes should only be set through `BaseTrainer.restore`
self._restore_path = None
self._restore_storage_filesystem = None

self._validate_attributes()

Expand All @@ -214,7 +221,8 @@ def __init__(
@classmethod
def restore(
cls: Type["BaseTrainer"],
path: str,
path: Union[str, os.PathLike],
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional["Preprocessor"] = None,
scaling_config: Optional[ScalingConfig] = None,
Expand Down Expand Up @@ -298,17 +306,23 @@ def training_loop(self):
Returns:
BaseTrainer: A restored instance of the class that is calling this method.
"""
if not cls.can_restore(path):
if not cls.can_restore(path, storage_filesystem):
raise ValueError(
f"Invalid restore path: {path}. Make sure that this path exists and "
"is the experiment directory that results from a call to "
"`trainer.fit()`."
)
trainer_state_path = cls._maybe_sync_down_trainer_state(path)
assert trainer_state_path.exists()
if _use_storage_context():
fs, fs_path = get_fs_and_path(path, storage_filesystem)
with fs.open_input_file(os.path.join(fs_path, _TRAINER_PKL)) as f:
trainer_cls, param_dict = pickle.loads(f.readall())
else:
trainer_state_path = cls._maybe_sync_down_trainer_state(path)
assert trainer_state_path.exists()

with open(trainer_state_path, "rb") as fp:
trainer_cls, param_dict = pickle.load(fp)

with open(trainer_state_path, "rb") as fp:
trainer_cls, param_dict = pickle.load(fp)
if trainer_cls is not cls:
warnings.warn(
f"Invalid trainer type. You are attempting to restore a trainer of type"
Expand Down Expand Up @@ -355,11 +369,16 @@ def training_loop(self):
f"`{cls.__name__}.restore`\n"
) from e
trainer._restore_path = path
trainer._restore_storage_filesystem = storage_filesystem
return trainer

@PublicAPI(stability="alpha")
@classmethod
def can_restore(cls: Type["BaseTrainer"], path: Union[str, Path]) -> bool:
def can_restore(
cls: Type["BaseTrainer"],
path: Union[str, os.PathLike],
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
) -> bool:
"""Checks whether a given directory contains a restorable Train experiment.
Args:
Expand All @@ -370,6 +389,10 @@ def can_restore(cls: Type["BaseTrainer"], path: Union[str, Path]) -> bool:
Returns:
bool: Whether this path exists and contains the trainer state to resume from
"""
if _use_storage_context():
fs, fs_path = get_fs_and_path(path, storage_filesystem)
return _exists_at_fs_path(fs, os.path.join(fs_path, _TRAINER_PKL))

return _TRAINER_PKL in list_at_uri(str(path))

def __repr__(self):
Expand Down Expand Up @@ -589,11 +612,12 @@ def fit(self) -> Result:

if self._restore_path:
tuner = Tuner.restore(
self._restore_path,
path=self._restore_path,
trainable=trainable,
param_space=param_space,
resume_unfinished=True,
resume_errored=True,
storage_filesystem=self._restore_storage_filesystem,
)
else:
tuner = Tuner(
Expand Down
8 changes: 2 additions & 6 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,7 @@ def restore(
Union[Callable[[], None], Callable[[Dict], None]]
] = None,
train_loop_config: Optional[Dict] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional["Preprocessor"] = None,
scaling_config: Optional[ScalingConfig] = None,
**kwargs,
) -> "DataParallelTrainer":
"""Restores a DataParallelTrainer from a previously interrupted/failed run.
Expand All @@ -366,9 +364,7 @@ def restore(
path=path,
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
datasets=datasets,
preprocessor=preprocessor,
scaling_config=scaling_config,
**kwargs,
)

def _validate_attributes(self):
Expand Down
26 changes: 0 additions & 26 deletions python/ray/train/lightning/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,32 +488,6 @@ def _unify_checkpoint_configs(
else:
return air_ckpt_config

@PublicAPI(stability="alpha")
@classmethod
def restore(
cls: Type["LightningTrainer"],
path: str,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional["Preprocessor"] = None,
scaling_config: Optional[ScalingConfig] = None,
**kwargs,
) -> "LightningTrainer":
"""Restores a LightningTrainer from a previously interrupted/failed run.

See :meth:`BaseTrainer.restore() <ray.train.trainer.BaseTrainer.restore>`
for descriptions of the arguments.

Returns:
LightningTrainer: A restored instance of `LightningTrainer`
"""
return super(LightningTrainer, cls).restore(
path=path,
datasets=datasets,
preprocessor=preprocessor,
scaling_config=scaling_config,
**kwargs,
)
Comment on lines -509 to -515
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just a passthrough - no need for this class to implement it.



def _lightning_train_loop_per_worker(config):
"""Per-worker training loop for a Lightning Trainer."""
Expand Down
43 changes: 33 additions & 10 deletions python/ray/train/tests/test_new_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

import pyarrow.fs

import ray
from ray import train, tune
from ray.air._internal.uri_utils import URI
from ray.air.constants import EXPR_RESULT_FILE
from ray.train._internal.storage import _download_from_fs_path, StorageContext
from ray.train._checkpoint import Checkpoint as NewCheckpoint
from ray.train.base_trainer import TrainingFailedError
from ray.train.data_parallel_trainer import DataParallelTrainer

from ray.air.tests.test_checkpoints import mock_s3_bucket_uri
Expand All @@ -26,11 +29,20 @@ def dummy_context_manager():
yield "dummy value"


@pytest.fixture(autouse=True)
def enable_new_persistence_mode(monkeypatch):
monkeypatch.setenv("RAY_AIR_NEW_PERSISTENCE_MODE", "1")
@pytest.fixture(scope="module")
def enable_new_persistence_mode():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("RAY_AIR_NEW_PERSISTENCE_MODE", "1")
yield
mp.setenv("RAY_AIR_NEW_PERSISTENCE_MODE", "0")


@pytest.fixture(autouse=True, scope="module")
def ray_start_4_cpus(enable_new_persistence_mode):
# Make sure to set the env var before calling ray.init()
ray.init(num_cpus=4)
yield
monkeypatch.setenv("RAY_AIR_NEW_PERSISTENCE_MODE", "0")
ray.shutdown()


def _create_mock_custom_fs(custom_fs_root_dir: Path) -> pyarrow.fs.FileSystem:
Expand Down Expand Up @@ -309,11 +321,12 @@ def test_trainer(
"""
TODO(justinvyu): Test for these once implemented:
- artifacts
- restoration, train.get_checkpoint
{storage_path}/{exp_name}
├── experiment_state-2023-07-28_10-00-38.json
├── experiment_state-2023-07-28_10-00-38.json <- Initial exp state
├── basic-variant-state-2023-07-28_10-00-38.json
├── experiment_state-2023-07-28_10-01-38.json <- Restored exp state
├── basic-variant-state-2023-07-28_10-01-38.json
├── trainer.pkl
├── tuner.pkl
└── DataParallelTrainer_46367_00000_0_...
Expand Down Expand Up @@ -358,11 +371,19 @@ def test_trainer(
name=exp_name,
verbose=0,
checkpoint_config=checkpoint_config,
failure_config=train.FailureConfig(max_failures=2),
failure_config=train.FailureConfig(max_failures=1),
),
)
print("\nStarting initial run.\n")
result = trainer.fit()
with pytest.raises(TrainingFailedError):
result = trainer.fit()

print("\nStarting manually restored run.\n")
restored_trainer = DataParallelTrainer.restore(
path=str(URI(storage_path or str(LOCAL_CACHE_DIR)) / exp_name),
storage_filesystem=storage_filesystem,
)
result = restored_trainer.fit()

with monkeypatch.context() as m:
# This is so that the `resume_from_checkpoint` run doesn't mess up the
Expand Down Expand Up @@ -390,10 +411,12 @@ def test_trainer(
exp_dir = local_inspect_dir / exp_name

# Files synced by the driver
assert len(list(exp_dir.glob("basic-variant-state-*"))) == 1
assert len(list(exp_dir.glob("experiment_state-*"))) == 1
assert len(list(exp_dir.glob("tuner.pkl"))) == 1
assert len(list(exp_dir.glob("trainer.pkl"))) == 1
# 2 copies of these files:
# 1 for the initial run, and 1 for the manually restored run.
assert len(list(exp_dir.glob("basic-variant-state-*"))) == 2
assert len(list(exp_dir.glob("experiment_state-*"))) == 2

# Files synced by the worker
assert len(list(exp_dir.glob("DataParallelTrainer_*"))) == 1
Expand Down
5 changes: 4 additions & 1 deletion python/ray/tune/analysis/experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
EXPR_PARAM_FILE,
TRAINING_ITERATION,
)
from ray.train._internal.storage import _use_storage_context
from ray.tune.syncer import SyncConfig
from ray.tune.utils import flatten_dict
from ray.tune.utils.serialization import TuneFunctionDecoder
Expand Down Expand Up @@ -980,7 +981,9 @@ def _get_trial_paths(self) -> List[str]:
for trial_json_state, path in self._checkpoints_and_paths:
try:
trial = Trial.from_json_state(trial_json_state, stub=True)
trial.local_experiment_path = str(path)
# TODO(justinvyu): [handle_moved_storage_path]
if not _use_storage_context():
trial.local_experiment_path = str(path)
except Exception:
logger.warning(
f"Could not load trials from experiment checkpoint. "
Expand Down
Loading
Loading