Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air] pyarrow.fs persistence (7/n): ray.train.Checkpoint restore: Auto-recovery fault tolerance #38141

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
abb1307
Pipe storage context to Trainable (used now for Trainable syncing)
justinvyu Jul 23, 2023
f6ff90a
Don't use the storage context in the trial/trainable
justinvyu Jul 27, 2023
562369f
Disable all trainable syncing in new codepath
justinvyu Jul 27, 2023
95a3d20
Pipe storage context to Train workers (not actually used yet)
justinvyu Jul 23, 2023
484e67f
Fix race condition for setting checkpoint_uri
justinvyu Jul 24, 2023
2148669
Fix cyclical import
justinvyu Jul 27, 2023
8c856b8
Add simple trainer test
justinvyu Jul 27, 2023
78c525f
Add legacy prefix to train session checkpoint uri
justinvyu Jul 27, 2023
e97f471
Add new checkpoint class
justinvyu Jul 27, 2023
64945be
New train session report implementation using new checkpoint
justinvyu Jul 28, 2023
c6480c9
Simplify checkpoint propagation from user code (in worker) -> trainer…
justinvyu Jul 28, 2023
c681ccb
New tune session.report
justinvyu Jul 28, 2023
795bafe
Save direction works with new checkpoint API
justinvyu Jul 28, 2023
8a084bc
Update test with e2e trainer test
justinvyu Jul 28, 2023
725d802
Make callback supporting new checkpoint a todo for now
justinvyu Jul 28, 2023
877acb9
Remove unnecessary comment
justinvyu Jul 28, 2023
ee4ccbd
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Jul 28, 2023
88042b3
Separate out the new set checkpoint id from the old set checkpoint uri
justinvyu Jul 28, 2023
a5eeab2
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Jul 31, 2023
a6cd9dc
Update id -> index
justinvyu Jul 31, 2023
01f34bb
Address comments on error to raise with old ckpt type
justinvyu Jul 31, 2023
65e7a27
Move checkpoint upload logic to a helper fn of storage ctx
justinvyu Jul 31, 2023
f2a4c36
Drop a checkpoint marker after uploading
justinvyu Jul 31, 2023
49ee126
Add a simplified checkpoint manager
justinvyu Aug 1, 2023
ffa0dd4
Fixes to checkpoint manager
justinvyu Aug 1, 2023
15553f7
Add unit test for simplified checkpoint manager
justinvyu Aug 1, 2023
00cc9d7
Full test coverage
justinvyu Aug 1, 2023
cb5990e
Add a simplified checkpoint manager
justinvyu Aug 1, 2023
2db9aae
Fixes to checkpoint manager
justinvyu Aug 1, 2023
a2067b7
Add unit test for simplified checkpoint manager
justinvyu Aug 1, 2023
f1216f2
Full test coverage
justinvyu Aug 1, 2023
d4243e6
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 1, 2023
6699d81
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 1, 2023
9b9ff34
Simplify even more
justinvyu Aug 1, 2023
83aecd9
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 1, 2023
913af10
Patch fix for circular imports
justinvyu Aug 1, 2023
6b5d34e
Use new checkpoint manager in Tune ckpt book-keeping
justinvyu Aug 2, 2023
24f441a
Update result to return a train.Checkpoint to the user
justinvyu Aug 2, 2023
504ed54
Update e2e test to try multiple ckpt configs for trainer test
justinvyu Aug 2, 2023
1992161
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
b9eb88f
Fix lint for trial.py
justinvyu Aug 2, 2023
a6115b3
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
7cc74d9
Rename _TrackedCheckpoint -> _TrainingResult
justinvyu Aug 2, 2023
6a0e1fb
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
4662789
Merge branch 'air/persistence/simplified_ckpt_manager' into air/persi…
justinvyu Aug 2, 2023
8da0477
Fixes after merging latest ckpt manager changes
justinvyu Aug 2, 2023
255b149
Remove prints / convert to logger.debug
justinvyu Aug 2, 2023
0971aca
Don't set training iteration as the default checkpoint_score_attr
justinvyu Aug 2, 2023
6e7a873
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
6f6a341
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
d9804b0
Fix test to reflect working dir change
justinvyu Aug 2, 2023
318158f
Don't upload a .is_checkpoint marker
justinvyu Aug 2, 2023
a54664c
Add back cwd check
justinvyu Aug 2, 2023
c4263ec
Update the dir trees + better naming for ckpt shards and artifacts
justinvyu Aug 2, 2023
0cd7e47
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
3a6eba6
A different fix for the circular dep
justinvyu Aug 2, 2023
b65e9fe
Update checkpoint -> _checkpoint imports
justinvyu Aug 2, 2023
b89bd1c
fix lint
justinvyu Aug 2, 2023
cada06a
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 2, 2023
3b784d7
Revert all changes to ckpt manager
justinvyu Aug 2, 2023
49c1ead
Don't set checkpoint user metadata
justinvyu Aug 2, 2023
7177940
Remove remaining print
justinvyu Aug 2, 2023
ae8a9ec
Add trial_path property to storage ctx
justinvyu Aug 3, 2023
c1c8441
Use storage context for all experiment/trial path properties
justinvyu Aug 3, 2023
5d2ca07
Don't skip trainer test cases for custom_fs
justinvyu Aug 3, 2023
1fcfb3f
Split some utilities into helper methods + test for ResultGrid paths
justinvyu Aug 3, 2023
5e2a933
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 3, 2023
61fdadf
Prepend legacy to old path attributes in trial
justinvyu Aug 3, 2023
d38cd87
Remove todo
justinvyu Aug 3, 2023
3ba944a
Bump the test size
justinvyu Aug 3, 2023
5f71608
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 3, 2023
52d6c14
Merge branch 'air/persistence/new_checkpoint' into air/persistence/fi…
justinvyu Aug 3, 2023
9c16120
Clean up experiment path handling
justinvyu Aug 3, 2023
76468d9
Fix for base trainer
justinvyu Aug 3, 2023
b17c17e
Fix for base trainer pt 2
justinvyu Aug 3, 2023
30e3328
Add in missing legacy property
justinvyu Aug 3, 2023
f25ad39
Prepend legacy to old path attributes in experiment
justinvyu Aug 3, 2023
e1846ec
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 3, 2023
d11ede8
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 3, 2023
c99d30c
too much space
justinvyu Aug 3, 2023
9d11d2d
remove unused var
justinvyu Aug 3, 2023
950b991
Fix lint
justinvyu Aug 3, 2023
ed86255
restore mostly works
justinvyu Aug 4, 2023
de4b924
hacky way of getting checkpoint folders to increment correctly
justinvyu Aug 4, 2023
e060476
Fix for xgboost trainer
justinvyu Aug 4, 2023
bd5c846
Fix race as_directory / download file lock race condition
justinvyu Aug 4, 2023
e51fb17
Update test with auto-recovery fault tolerance
justinvyu Aug 4, 2023
eaa26c5
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 4, 2023
0c3c5c8
compute storage_prefix
justinvyu Aug 4, 2023
217af77
Remove '_path' properties from storage
justinvyu Aug 4, 2023
8e2330c
Move exp dir name helper to storage ctx
justinvyu Aug 4, 2023
7502cca
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 4, 2023
f3f22fd
Fix bugs causing broken CI
justinvyu Aug 4, 2023
c6c3dfe
Fix renamed attribute in mock test class
justinvyu Aug 4, 2023
4e56bd0
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 4, 2023
314e8bd
Merge branch 'air/persistence/fix_custom_fs_path_expansion' into air/…
justinvyu Aug 4, 2023
36464af
fix storage attr setting to only happen if ff enabled
justinvyu Aug 5, 2023
6e73f6e
cleanup on errors in as_directory
justinvyu Aug 5, 2023
bcbcec9
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 5, 2023
d7497e1
fix merge conflict remainder
justinvyu Aug 5, 2023
aeb89ba
Recover trainable metadata from last_result rather than .tune_metadata
justinvyu Aug 7, 2023
9556371
Fix restore info log
justinvyu Aug 7, 2023
e897eaa
Keep current checkpoint index synchronized on the driver
justinvyu Aug 7, 2023
3eef417
Remove checkpoint dirname parsing
justinvyu Aug 7, 2023
0e52384
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 7, 2023
89631ab
Update todo comment
justinvyu Aug 7, 2023
d4e20f2
Fix lint
justinvyu Aug 7, 2023
72aa1fb
Rename to starting_checkpoint
justinvyu Aug 7, 2023
fb056f8
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 7, 2023
ca0df9f
Fix lint
justinvyu Aug 7, 2023
3636a21
fix typo
justinvyu Aug 8, 2023
8743a99
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 8, 2023
a0c5a26
Fix repr
justinvyu Aug 8, 2023
0a40c47
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Aug 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 28 additions & 25 deletions python/ray/train/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,35 +211,38 @@ def as_directory(self) -> Iterator[str]:
if isinstance(self.filesystem, pyarrow.fs.LocalFileSystem):
yield self.path
else:
temp_dir = self.to_directory()
del_lock_path = _get_del_lock_path(temp_dir)
del_lock_path = _get_del_lock_path(self._get_temporary_checkpoint_dir())
open(del_lock_path, "a").close()

yield temp_dir

# Cleanup
try:
os.remove(del_lock_path)
except Exception:
logger.warning(
f"Could not remove {del_lock_path} deletion file lock. "
f"Traceback:\n{traceback.format_exc()}"
)

# In the edge case (process crash before del lock file is removed),
# we do not remove the directory at all.
# Since it's in /tmp, this is not that big of a deal.
# check if any lock files are remaining
remaining_locks = _list_existing_del_locks(temp_dir)
if not remaining_locks:
temp_dir = self.to_directory()
yield temp_dir
finally:
# Always cleanup the del lock after we're done with the directory.
# This avoids leaving a lock file behind in the case of an exception
# in the user code.
try:
# Timeout 0 means there will be only one attempt to acquire
# the file lock. If it cannot be acquired, a TimeoutError
# will be thrown.
with TempFileLock(f"{temp_dir}.lock", timeout=0):
shutil.rmtree(temp_dir, ignore_errors=True)
except TimeoutError:
pass
os.remove(del_lock_path)
except Exception:
logger.warning(
f"Could not remove {del_lock_path} deletion file lock. "
f"Traceback:\n{traceback.format_exc()}"
)

# In the edge case (process crash before del lock file is removed),
# we do not remove the directory at all.
# Since it's in /tmp, this is not that big of a deal.
# check if any lock files are remaining
remaining_locks = _list_existing_del_locks(temp_dir)
if not remaining_locks:
try:
# Timeout 0 means there will be only one attempt to acquire
# the file lock. If it cannot be acquired, a TimeoutError
# will be thrown.
with TempFileLock(temp_dir, timeout=0):
shutil.rmtree(temp_dir, ignore_errors=True)
except TimeoutError:
pass

def _get_temporary_checkpoint_dir(self) -> str:
"""Return the name for the temporary checkpoint dir that this checkpoint
Expand Down
22 changes: 11 additions & 11 deletions python/ray/train/_internal/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def __init__(
experiment_dir_name: str,
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
trial_dir_name: Optional[str] = None,
current_checkpoint_index: Optional[int] = None,
current_checkpoint_index: int = 0,
):
storage_path_provided = storage_path is not None

Expand Down Expand Up @@ -591,18 +591,13 @@ def trial_fs_path(self) -> str:

@property
def checkpoint_fs_path(self) -> str:
"""The trial directory path on the `storage_filesystem`.
"""The current checkpoint directory path on the `storage_filesystem`.

Raises a ValueError if `current_checkpoint_index` is not set beforehand.
"Current" refers to the checkpoint that is currently being created/persisted.
The user of this class is responsible for setting the `current_checkpoint_index`
(e.g., incrementing when needed).
"""
from ray.tune.trainable.util import TrainableUtil

if self.current_checkpoint_index is None:
raise RuntimeError(
"Should not access `checkpoint_fs_path` without setting "
"`current_checkpoint_index`"
)
checkpoint_dir_name = TrainableUtil._make_checkpoint_dir_name(
checkpoint_dir_name = StorageContext._make_checkpoint_dir_name(
self.current_checkpoint_index
)
return os.path.join(self.trial_fs_path, checkpoint_dir_name)
Expand All @@ -620,6 +615,11 @@ def get_experiment_dir_name(run_obj: Union[str, Callable, Type]) -> str:
dir_name = "{}_{}".format(run_identifier, date_str())
return dir_name

@staticmethod
def _make_checkpoint_dir_name(index: int):
"""Get the name of the checkpoint directory, given an index."""
return f"checkpoint_{index:06d}"


_storage_context: Optional[StorageContext] = None

Expand Down
37 changes: 21 additions & 16 deletions python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ray.air.config import RunConfig, ScalingConfig
from ray.air.result import Result
from ray.train._internal import session
from ray.train._internal.storage import _use_storage_context
from ray.train.constants import TRAIN_DATASET_KEY
from ray.util import PublicAPI
from ray.util.annotations import DeveloperAPI
Expand Down Expand Up @@ -191,7 +192,7 @@ def __init__(
self.run_config = run_config if run_config is not None else RunConfig()
self.datasets = datasets if datasets is not None else {}
self.preprocessor = preprocessor
self.resume_from_checkpoint = resume_from_checkpoint
self.starting_checkpoint = resume_from_checkpoint

# This path should only be set through restore
self._restore_path = None
Expand Down Expand Up @@ -377,7 +378,7 @@ def __repr__(self):
"run_config": RunConfig(),
"datasets": {},
"preprocessor": None,
"resume_from_checkpoint": None,
"starting_checkpoint": None,
}

non_default_arguments = []
Expand Down Expand Up @@ -452,13 +453,13 @@ def _validate_attributes(self):
f"found {type(self.preprocessor)} with value `{self.preprocessor}`."
)

if self.resume_from_checkpoint is not None and not isinstance(
self.resume_from_checkpoint, ray.air.Checkpoint
if self.starting_checkpoint is not None and not isinstance(
self.starting_checkpoint, ray.air.Checkpoint
):
raise ValueError(
f"`resume_from_checkpoint` should be an instance of "
f"`ray.train.Checkpoint`, found {type(self.resume_from_checkpoint)} "
f"with value `{self.resume_from_checkpoint}`."
f"`ray.train.Checkpoint`, found {type(self.starting_checkpoint)} "
f"with value `{self.starting_checkpoint}`."
)

@classmethod
Expand Down Expand Up @@ -700,18 +701,22 @@ def train_func(config):
# Instantiate new Trainer in Trainable.
trainer = trainer_cls(**config)

# Get the checkpoint from the train context, and use it to initialize
# the restored trainer.
# This handles both worker-level and cluster-level restoration
# of the Train experiment.
# Get the checkpoint from Tune and pass it to workers later on.
checkpoint = session.get_checkpoint()
if checkpoint:
trainer.resume_from_checkpoint = checkpoint
# Always load the preprocessor from an available checkpoint
# Unless we are restoring the experiment and have explicitly
# passed in a new preprocessor
if not (restored and trainer.preprocessor):
trainer.preprocessor = checkpoint.get_preprocessor()
# Set `starting_checkpoint` for auto-recovery fault-tolerance
# as well as manual restoration.
trainer.starting_checkpoint = checkpoint

# TODO(justinvyu): Remove this when Preprocessor is removed from Trainer
if not _use_storage_context():
# Always load the preprocessor from an available checkpoint
# Unless we are restoring the experiment and have explicitly
# passed in a new preprocessor
if not (restored and trainer.preprocessor):
trainer.preprocessor = checkpoint.get_preprocessor()
# Else: Train will restore from the user-provided
# `resume_from_checkpoint` == `starting_checkpoint`.

trainer.setup()
trainer.preprocess_datasets()
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def clear_lazy_checkpoint_marker():
datasets=self.datasets,
data_config=self._data_config,
checkpoint_manager=checkpoint_manager,
checkpoint=self.resume_from_checkpoint,
checkpoint=self.starting_checkpoint,
checkpoint_strategy=checkpoint_strategy,
storage_path=self.run_config.storage_path,
)
Expand Down
4 changes: 2 additions & 2 deletions python/ray/train/gbdt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def training_loop(self) -> None:
evals_result = {}

init_model = None
if self.resume_from_checkpoint:
init_model, _ = self._load_checkpoint(self.resume_from_checkpoint)
if self.starting_checkpoint:
init_model, _ = self._load_checkpoint(self.starting_checkpoint)

config.setdefault("verbose_eval", False)
config.setdefault("callbacks", [])
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/tests/test_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def test_large_params(ray_start_4_cpus):
array_size = int(1e8)

def training_loop(self):
checkpoint = self.resume_from_checkpoint.to_dict()["ckpt"]
checkpoint = self.starting_checkpoint.to_dict()["ckpt"]
assert len(checkpoint) == array_size

checkpoint = Checkpoint.from_dict({"ckpt": np.zeros(shape=array_size)})
Expand Down
15 changes: 15 additions & 0 deletions python/ray/train/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,21 @@ def test_multiprocess_as_directory(checkpoint: Checkpoint, monkeypatch):
assert not Path(checkpoint_dir_1).exists()


def test_as_directory_lock_cleanup(checkpoint: Checkpoint):
"""Errors when accessing a checkpoint with `as_directory`
shouldn't leave behind lock files.
"""
with pytest.raises(RuntimeError):
with checkpoint.as_directory() as checkpoint_dir:
raise RuntimeError

assert not _list_existing_del_locks(checkpoint_dir)

is_local_checkpoint = isinstance(checkpoint.filesystem, pyarrow.fs.LocalFileSystem)
if not is_local_checkpoint:
assert not Path(checkpoint_dir).exists()


def test_metadata(checkpoint: Checkpoint):
assert checkpoint.get_metadata() == {}

Expand Down
26 changes: 19 additions & 7 deletions python/ray/train/tests/test_new_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,24 +143,28 @@ def train_fn(config):
for i in range(start, config.get("num_iterations", 5)):
time.sleep(0.25)

checkpoint_file_name = "checkpoint.pkl"
temp_dir = tempfile.mkdtemp()
with open(os.path.join(temp_dir, "checkpoint.pkl"), "wb") as f:
pickle.dump({"iter": i}, f)

artifact_file_name = f"artifact-iter={i}.txt"
if in_trainer:
rank = train.get_context().get_world_rank()
checkpoint_file_name = f"checkpoint_shard-rank={rank}.pkl"
artifact_file_name = f"artifact-rank={rank}-iter={i}.txt"

checkpoint_file_name = f"checkpoint_shard-rank={rank}.pkl"
with open(os.path.join(temp_dir, checkpoint_file_name), "wb") as f:
pickle.dump({"iter": i}, f)

with open(artifact_file_name, "w") as f:
f.write(f"{i}")

temp_dir = tempfile.mkdtemp()
with open(os.path.join(temp_dir, checkpoint_file_name), "wb") as f:
pickle.dump({"iter": i}, f)

train.report(
{"iter": i, _SCORE_KEY: i},
checkpoint=NewCheckpoint.from_directory(temp_dir),
)
if i in config.get("fail_iters", []):
raise RuntimeError(f"Failing on iter={i}!!")


@pytest.mark.parametrize("storage_path_type", [None, "nfs", "cloud", "custom_fs"])
Expand Down Expand Up @@ -287,6 +291,7 @@ def test_trainer(
├── progress.csv
├── result.json
├── checkpoint_000000
│ ├── checkpoint.pkl <- Shared checkpoint file
│ ├── checkpoint_shard-rank=0.pkl <- Worker checkpoint shards
│ └── checkpoint_shard-rank=1.pkl
├── ...
Expand All @@ -309,14 +314,19 @@ def test_trainer(
NUM_WORKERS = 2
trainer = DataParallelTrainer(
train_fn,
train_loop_config={"in_trainer": True, "num_iterations": NUM_ITERATIONS},
train_loop_config={
"in_trainer": True,
"num_iterations": NUM_ITERATIONS,
"fail_iters": [2, 4],
},
scaling_config=train.ScalingConfig(num_workers=2),
run_config=train.RunConfig(
storage_path=storage_path,
storage_filesystem=storage_filesystem,
name=exp_name,
verbose=0,
checkpoint_config=checkpoint_config,
failure_config=train.FailureConfig(max_failures=2),
),
)
result = trainer.fit()
Expand Down Expand Up @@ -352,6 +362,8 @@ def test_trainer(

assert len(list(trial_dir.glob("checkpoint_*"))) == expected_num_checkpoints
for checkpoint_dir in trial_dir.glob("checkpoint_*"):
# 1 shared checkpoint.pkl file, written by all workers.
assert len(list(checkpoint_dir.glob("checkpoint.pkl"))) == 1
# 1 checkpoint shard per worker.
assert (
len(list(checkpoint_dir.glob("checkpoint_shard-*.pkl"))) == NUM_WORKERS
Expand Down
23 changes: 9 additions & 14 deletions python/ray/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def __init__(
# TrainingResult event. There's no need to do these one at a time.
self._checkpoint_to_report = None

# TODO(justinvyu): Is this the best way to do this? Need to save this
# as part of checkpoint metadata and load it back on restore.
self._latest_checkpoint_index = 0
self._storage = None
if _use_storage_context():
self._storage = get_storage_context()

self._start_training(
train_func=train_func,
Expand Down Expand Up @@ -103,7 +103,10 @@ def _start_training(
run_dir=run_dir,
latest_checkpoint_id=latest_checkpoint_id,
)
checkpoint = self._checkpoint_manager._load_checkpoint(checkpoint)

if not _use_storage_context():
checkpoint = self._checkpoint_manager._load_checkpoint(checkpoint)

self._run_with_error_handling(
lambda: self._backend_executor.start_training(
train_func=train_func,
Expand All @@ -119,18 +122,10 @@ def _send_next_checkpoint_path_to_workers(self):
# NOTE: Always upload to storage from workers in the new persistence path
# (no need to check for the `checkpoint_upload_from_workers` flag)
if _use_storage_context():
storage = get_storage_context()

# NOTE: Idea: this checkpoint dir name should be customizable
# and created on the fly when the checkpoint is reported with metrics.
# Ex: lambda metrics: f"checkpoint_iter={metrics['training_iteration']}"
storage.current_checkpoint_index = self._latest_checkpoint_index

self._backend_executor._set_checkpoint_index(
storage.current_checkpoint_index
self._storage.current_checkpoint_index
)

self._latest_checkpoint_index += 1
self._storage.current_checkpoint_index += 1

elif self._checkpoint_strategy._checkpoint_upload_from_workers:
self._backend_executor._set_legacy_checkpoint_uri(
Expand Down
26 changes: 22 additions & 4 deletions python/ray/tune/execution/tune_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,11 +1956,29 @@ def _checkpoint_trial_if_needed(self, trial, force=False):
###
# RESTORE
def _schedule_trial_restore(self, trial: Trial) -> bool:
checkpoint = trial.checkpoint

if _use_storage_context():
# TODO(justinvyu): Skipping restoration altogether for now.
return False
checkpoint_result = trial.checkpoint_manager.latest_checkpoint_result

if not checkpoint_result:
logger.debug(f"Not restoring trial {trial}: No checkpoint found.")
return False

# TODO(justinvyu): Is this really needed?
trial.restoring_from = checkpoint_result

method_name = "restore"
args = (checkpoint_result,)
self._schedule_trial_task(
trial=trial,
method_name=method_name,
args=args,
kwargs={},
on_result=self._on_restoring_result,
on_error=self._trial_task_failure,
)
return True

checkpoint = trial.checkpoint

if checkpoint.dir_or_data is None:
logger.debug(f"Not restoring trial {trial}: No checkpoint found.")
Expand Down
Loading
Loading