Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air] pyarrow.fs persistence (5/n): ray.train.Checkpoint save direction #37888

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
abb1307
Pipe storage context to Trainable (used now for Trainable syncing)
justinvyu Jul 23, 2023
f6ff90a
Don't use the storage context in the trial/trainable
justinvyu Jul 27, 2023
562369f
Disable all trainable syncing in new codepath
justinvyu Jul 27, 2023
95a3d20
Pipe storage context to Train workers (not actually used yet)
justinvyu Jul 23, 2023
484e67f
Fix race condition for setting checkpoint_uri
justinvyu Jul 24, 2023
2148669
Fix cyclical import
justinvyu Jul 27, 2023
8c856b8
Add simple trainer test
justinvyu Jul 27, 2023
78c525f
Add legacy prefix to train session checkpoint uri
justinvyu Jul 27, 2023
e97f471
Add new checkpoint class
justinvyu Jul 27, 2023
64945be
New train session report implementation using new checkpoint
justinvyu Jul 28, 2023
c6480c9
Simplify checkpoint propagation from user code (in worker) -> trainer…
justinvyu Jul 28, 2023
c681ccb
New tune session.report
justinvyu Jul 28, 2023
795bafe
Save direction works with new checkpoint API
justinvyu Jul 28, 2023
8a084bc
Update test with e2e trainer test
justinvyu Jul 28, 2023
725d802
Make callback supporting new checkpoint a todo for now
justinvyu Jul 28, 2023
877acb9
Remove unnecessary comment
justinvyu Jul 28, 2023
ee4ccbd
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Jul 28, 2023
88042b3
Separate out the new set checkpoint id from the old set checkpoint uri
justinvyu Jul 28, 2023
a5eeab2
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Jul 31, 2023
a6cd9dc
Update id -> index
justinvyu Jul 31, 2023
01f34bb
Address comments on error to raise with old ckpt type
justinvyu Jul 31, 2023
65e7a27
Move checkpoint upload logic to a helper fn of storage ctx
justinvyu Jul 31, 2023
f2a4c36
Drop a checkpoint marker after uploading
justinvyu Jul 31, 2023
49ee126
Add a simplified checkpoint manager
justinvyu Aug 1, 2023
ffa0dd4
Fixes to checkpoint manager
justinvyu Aug 1, 2023
15553f7
Add unit test for simplified checkpoint manager
justinvyu Aug 1, 2023
00cc9d7
Full test coverage
justinvyu Aug 1, 2023
cb5990e
Add a simplified checkpoint manager
justinvyu Aug 1, 2023
2db9aae
Fixes to checkpoint manager
justinvyu Aug 1, 2023
a2067b7
Add unit test for simplified checkpoint manager
justinvyu Aug 1, 2023
f1216f2
Full test coverage
justinvyu Aug 1, 2023
d4243e6
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 1, 2023
6699d81
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 1, 2023
9b9ff34
Simplify even more
justinvyu Aug 1, 2023
83aecd9
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 1, 2023
913af10
Patch fix for circular imports
justinvyu Aug 1, 2023
6b5d34e
Use new checkpoint manager in Tune ckpt book-keeping
justinvyu Aug 2, 2023
24f441a
Update result to return a train.Checkpoint to the user
justinvyu Aug 2, 2023
504ed54
Update e2e test to try multiple ckpt configs for trainer test
justinvyu Aug 2, 2023
1992161
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
b9eb88f
Fix lint for trial.py
justinvyu Aug 2, 2023
a6115b3
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
7cc74d9
Rename _TrackedCheckpoint -> _TrainingResult
justinvyu Aug 2, 2023
6a0e1fb
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
4662789
Merge branch 'air/persistence/simplified_ckpt_manager' into air/persi…
justinvyu Aug 2, 2023
8da0477
Fixes after merging latest ckpt manager changes
justinvyu Aug 2, 2023
255b149
Remove prints / convert to logger.debug
justinvyu Aug 2, 2023
0971aca
Don't set training iteration as the default checkpoint_score_attr
justinvyu Aug 2, 2023
6e7a873
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
6f6a341
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
d9804b0
Fix test to reflect working dir change
justinvyu Aug 2, 2023
318158f
Don't upload a .is_checkpoint marker
justinvyu Aug 2, 2023
a54664c
Add back cwd check
justinvyu Aug 2, 2023
c4263ec
Update the dir trees + better naming for ckpt shards and artifacts
justinvyu Aug 2, 2023
0cd7e47
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
3a6eba6
A different fix for the circular dep
justinvyu Aug 2, 2023
b65e9fe
Update checkpoint -> _checkpoint imports
justinvyu Aug 2, 2023
b89bd1c
fix lint
justinvyu Aug 2, 2023
cada06a
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
3b784d7
Revert all changes to ckpt manager
justinvyu Aug 2, 2023
49c1ead
Don't set checkpoint user metadata
justinvyu Aug 2, 2023
7177940
Remove remaining print
justinvyu Aug 2, 2023
3ba944a
Bump the test size
justinvyu Aug 3, 2023
5f71608
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ py_test(

py_test(
name = "test_new_persistence",
size = "small",
size = "medium",
srcs = ["tests/test_new_persistence.py"],
tags = ["team:ml", "exclusive"],
deps = [":train_lib", ":conftest"]
Expand Down
43 changes: 43 additions & 0 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,46 @@ def _set_legacy_checkpoint_uri(self, uri: str):
"""
self.legacy_checkpoint_uri = uri

def new_checkpoint(self, checkpoint):
from ray.train._checkpoint import Checkpoint as NewCheckpoint

if not isinstance(checkpoint, NewCheckpoint):
raise ValueError(
"You must pass a `ray.train.checkpoint.Checkpoint` "
"object to `train.report`. `ray.air.Checkpoint` is deprecated."
)

# Persist the reported checkpoint files to storage.
persisted_checkpoint = self.storage.persist_current_checkpoint(checkpoint)

self.loaded_checkpoint = persisted_checkpoint

metadata = self._auto_fill_checkpoint_metrics({})

# Save the rank of the worker that created this checkpoint.
metadata.update({CHECKPOINT_RANK_KEY: self.world_rank})

result = TrainingResult(
type=TrainingResultType.CHECKPOINT,
data=persisted_checkpoint,
metadata=metadata,
)

# Add result to a thread-safe queue.
self.result_queue.put(result, block=True)

# Acquire lock to stop the training thread until
# checkpoint has been processed.
self.continue_lock.acquire()
justinvyu marked this conversation as resolved.
Show resolved Hide resolved

def new_report(self, metrics: Dict, checkpoint=None) -> None:
if checkpoint:
self.new_checkpoint(checkpoint)

# TODO(justinvyu): Unify checkpoint / report logic to just report a single
# (metrics, Checkpoint) result for the consumer to handle.
self._report_legacy(**metrics)

def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
# TODO(xwjiang): tons of optimizations.

Expand All @@ -457,6 +497,9 @@ def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None
"store your Torch objects."
)

if _use_storage_context():
return self.new_report(metrics, checkpoint=checkpoint)

if checkpoint:
self.checkpoint(checkpoint)
self._report_legacy(**metrics)
Expand Down
54 changes: 53 additions & 1 deletion python/ray/train/_internal/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from pathlib import Path
import shutil
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING

try:
import fsspec
Expand All @@ -30,6 +30,9 @@
from ray.tune.syncer import Syncer, SyncConfig, _BackgroundSyncer
from ray.tune.result import _get_defaults_results_dir

if TYPE_CHECKING:
from ray.train._checkpoint import Checkpoint
Comment on lines +33 to +34
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This circular dependency is:

  • ray.train._internal.storage ->ray.train._checkpoint -> ray.train._internal.storage._download_fs_path

We can solve this by moving the filesystem utils to a different file.



logger = logging.getLogger(__file__)

Expand Down Expand Up @@ -472,6 +475,55 @@ def _check_validation_file(self):
"to the configured storage path."
)

def persist_current_checkpoint(self, checkpoint: "Checkpoint") -> "Checkpoint":
"""Persists a given checkpoint to the current checkpoint path on the filesystem.

"Current" is defined by the `current_checkpoint_index` attribute of the
storage context.

This method copies the checkpoint files to the storage location,
drops a marker at the storage path to indicate that the checkpoint
is completely uploaded, then deletes the original checkpoint directory.
For example, the original directory is typically a local temp directory.

Args:
checkpoint: The checkpoint to persist to (fs, checkpoint_fs_path).

Returns:
Checkpoint: A Checkpoint pointing to the persisted checkpoint location.
"""
# TODO(justinvyu): Fix this cyclical import.
from ray.train._checkpoint import Checkpoint

logger.debug(
"Copying checkpoint files to storage path:\n"
"({source_fs}, {source}) -> ({dest_fs}, {destination})".format(
source=checkpoint.path,
destination=self.checkpoint_fs_path,
source_fs=checkpoint.filesystem,
dest_fs=self.storage_filesystem,
)
)
self.storage_filesystem.create_dir(self.checkpoint_fs_path)
_pyarrow_fs_copy_files(
source=checkpoint.path,
destination=self.checkpoint_fs_path,
source_filesystem=checkpoint.filesystem,
destination_filesystem=self.storage_filesystem,
)

# Delete local checkpoint files.
# TODO(justinvyu): What if checkpoint.path == self.checkpoint_fs_path?
# TODO(justinvyu): What if users don't want to delete the local checkpoint?
checkpoint.filesystem.delete_dir(checkpoint.path)

uploaded_checkpoint = Checkpoint(
filesystem=self.storage_filesystem,
path=self.checkpoint_fs_path,
)
logger.debug(f"Checkpoint successfully created at: {uploaded_checkpoint}")
return uploaded_checkpoint

@property
def experiment_path(self) -> str:
"""The path the experiment directory, where the format matches the
Expand Down
17 changes: 15 additions & 2 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ray._private.thirdparty.tabulate.tabulate import tabulate

import ray
from ray import tune
from ray import train, tune
from ray.air.checkpoint import Checkpoint
from ray.air._internal.checkpointing import add_preprocessor_to_checkpoint
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig, CheckpointConfig
Expand All @@ -17,6 +17,7 @@
from ray.train._internal.backend_executor import BackendExecutor, TrialInfo
from ray.train._internal.checkpoint import TuneCheckpointManager
from ray.train._internal.data_config import DataConfig, _LegacyDataConfigWrapper
from ray.train._internal.storage import _use_storage_context
from ray.train._internal.utils import construct_train_func
from ray.train.constants import TRAIN_DATASET_KEY, WILDCARD_KEY
from ray.train.trainer import BaseTrainer, GenDataset
Expand Down Expand Up @@ -429,7 +430,19 @@ def _report(self, training_iterator: TrainingIterator) -> None:
for results in training_iterator:
# TODO(ml-team): add ability to report results from multiple workers.
first_worker_results = results[0]
tune.report(**first_worker_results)
if _use_storage_context():
assert (
isinstance(first_worker_results, tuple)
and len(first_worker_results) == 2
)
metrics, checkpoint = first_worker_results
logger.debug(
"Report (metrics, checkpoint) to the Tune session:\n"
f" metrics={metrics}\n checkpoint={checkpoint}"
)
train.report(metrics, checkpoint=checkpoint)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, a way to make this perhaps easier to read is to not use the public report API for this inner reporting, but call a private _report that doesn't do the repeated checkpoint uploading.

However, this is a minor comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericl I think that makes sense -- I think this will be easier to do when we unify the 2 sessions -- just introduce a private method on the session that does the Train -> Tune communication.

Let's merge with what we have for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, hold off on that, need to rebase and up the test timeout limit.

else:
tune.report(**first_worker_results)

def training_loop(self) -> None:
scaling_config = self._validate_scaling_config(self.scaling_config)
Expand Down
Loading
Loading