diff --git a/python/ray/train/_checkpoint.py b/python/ray/train/_checkpoint.py index 6505c70f877e..d7214fb7fb77 100644 --- a/python/ray/train/_checkpoint.py +++ b/python/ray/train/_checkpoint.py @@ -211,35 +211,38 @@ def as_directory(self) -> Iterator[str]: if isinstance(self.filesystem, pyarrow.fs.LocalFileSystem): yield self.path else: - temp_dir = self.to_directory() - del_lock_path = _get_del_lock_path(temp_dir) + del_lock_path = _get_del_lock_path(self._get_temporary_checkpoint_dir()) open(del_lock_path, "a").close() - yield temp_dir - - # Cleanup try: - os.remove(del_lock_path) - except Exception: - logger.warning( - f"Could not remove {del_lock_path} deletion file lock. " - f"Traceback:\n{traceback.format_exc()}" - ) - - # In the edge case (process crash before del lock file is removed), - # we do not remove the directory at all. - # Since it's in /tmp, this is not that big of a deal. - # check if any lock files are remaining - remaining_locks = _list_existing_del_locks(temp_dir) - if not remaining_locks: + temp_dir = self.to_directory() + yield temp_dir + finally: + # Always cleanup the del lock after we're done with the directory. + # This avoids leaving a lock file behind in the case of an exception + # in the user code. try: - # Timeout 0 means there will be only one attempt to acquire - # the file lock. If it cannot be acquired, a TimeoutError - # will be thrown. - with TempFileLock(f"{temp_dir}.lock", timeout=0): - shutil.rmtree(temp_dir, ignore_errors=True) - except TimeoutError: - pass + os.remove(del_lock_path) + except Exception: + logger.warning( + f"Could not remove {del_lock_path} deletion file lock. " + f"Traceback:\n{traceback.format_exc()}" + ) + + # In the edge case (process crash before del lock file is removed), + # we do not remove the directory at all. + # Since it's in /tmp, this is not that big of a deal. + # check if any lock files are remaining + remaining_locks = _list_existing_del_locks(temp_dir) + if not remaining_locks: + try: + # Timeout 0 means there will be only one attempt to acquire + # the file lock. If it cannot be acquired, a TimeoutError + # will be thrown. + with TempFileLock(temp_dir, timeout=0): + shutil.rmtree(temp_dir, ignore_errors=True) + except TimeoutError: + pass def _get_temporary_checkpoint_dir(self) -> str: """Return the name for the temporary checkpoint dir that this checkpoint diff --git a/python/ray/train/_internal/storage.py b/python/ray/train/_internal/storage.py index eddb142bb504..7e92ecd01af0 100644 --- a/python/ray/train/_internal/storage.py +++ b/python/ray/train/_internal/storage.py @@ -400,7 +400,7 @@ def __init__( experiment_dir_name: str, storage_filesystem: Optional[pyarrow.fs.FileSystem] = None, trial_dir_name: Optional[str] = None, - current_checkpoint_index: Optional[int] = None, + current_checkpoint_index: int = 0, ): storage_path_provided = storage_path is not None @@ -591,18 +591,13 @@ def trial_fs_path(self) -> str: @property def checkpoint_fs_path(self) -> str: - """The trial directory path on the `storage_filesystem`. + """The current checkpoint directory path on the `storage_filesystem`. - Raises a ValueError if `current_checkpoint_index` is not set beforehand. + "Current" refers to the checkpoint that is currently being created/persisted. + The user of this class is responsible for setting the `current_checkpoint_index` + (e.g., incrementing when needed). """ - from ray.tune.trainable.util import TrainableUtil - - if self.current_checkpoint_index is None: - raise RuntimeError( - "Should not access `checkpoint_fs_path` without setting " - "`current_checkpoint_index`" - ) - checkpoint_dir_name = TrainableUtil._make_checkpoint_dir_name( + checkpoint_dir_name = StorageContext._make_checkpoint_dir_name( self.current_checkpoint_index ) return os.path.join(self.trial_fs_path, checkpoint_dir_name) @@ -620,6 +615,11 @@ def get_experiment_dir_name(run_obj: Union[str, Callable, Type]) -> str: dir_name = "{}_{}".format(run_identifier, date_str()) return dir_name + @staticmethod + def _make_checkpoint_dir_name(index: int): + """Get the name of the checkpoint directory, given an index.""" + return f"checkpoint_{index:06d}" + _storage_context: Optional[StorageContext] = None diff --git a/python/ray/train/base_trainer.py b/python/ray/train/base_trainer.py index 1c8cd9514465..2b9e247749b2 100644 --- a/python/ray/train/base_trainer.py +++ b/python/ray/train/base_trainer.py @@ -22,6 +22,7 @@ from ray.air.config import RunConfig, ScalingConfig from ray.air.result import Result from ray.train._internal import session +from ray.train._internal.storage import _use_storage_context from ray.train.constants import TRAIN_DATASET_KEY from ray.util import PublicAPI from ray.util.annotations import DeveloperAPI @@ -191,7 +192,7 @@ def __init__( self.run_config = run_config if run_config is not None else RunConfig() self.datasets = datasets if datasets is not None else {} self.preprocessor = preprocessor - self.resume_from_checkpoint = resume_from_checkpoint + self.starting_checkpoint = resume_from_checkpoint # This path should only be set through restore self._restore_path = None @@ -377,7 +378,7 @@ def __repr__(self): "run_config": RunConfig(), "datasets": {}, "preprocessor": None, - "resume_from_checkpoint": None, + "starting_checkpoint": None, } non_default_arguments = [] @@ -452,13 +453,13 @@ def _validate_attributes(self): f"found {type(self.preprocessor)} with value `{self.preprocessor}`." ) - if self.resume_from_checkpoint is not None and not isinstance( - self.resume_from_checkpoint, ray.air.Checkpoint + if self.starting_checkpoint is not None and not isinstance( + self.starting_checkpoint, ray.air.Checkpoint ): raise ValueError( f"`resume_from_checkpoint` should be an instance of " - f"`ray.train.Checkpoint`, found {type(self.resume_from_checkpoint)} " - f"with value `{self.resume_from_checkpoint}`." + f"`ray.train.Checkpoint`, found {type(self.starting_checkpoint)} " + f"with value `{self.starting_checkpoint}`." ) @classmethod @@ -700,18 +701,22 @@ def train_func(config): # Instantiate new Trainer in Trainable. trainer = trainer_cls(**config) - # Get the checkpoint from the train context, and use it to initialize - # the restored trainer. - # This handles both worker-level and cluster-level restoration - # of the Train experiment. + # Get the checkpoint from Tune and pass it to workers later on. checkpoint = session.get_checkpoint() if checkpoint: - trainer.resume_from_checkpoint = checkpoint - # Always load the preprocessor from an available checkpoint - # Unless we are restoring the experiment and have explicitly - # passed in a new preprocessor - if not (restored and trainer.preprocessor): - trainer.preprocessor = checkpoint.get_preprocessor() + # Set `starting_checkpoint` for auto-recovery fault-tolerance + # as well as manual restoration. + trainer.starting_checkpoint = checkpoint + + # TODO(justinvyu): Remove this when Preprocessor is removed from Trainer + if not _use_storage_context(): + # Always load the preprocessor from an available checkpoint + # Unless we are restoring the experiment and have explicitly + # passed in a new preprocessor + if not (restored and trainer.preprocessor): + trainer.preprocessor = checkpoint.get_preprocessor() + # Else: Train will restore from the user-provided + # `resume_from_checkpoint` == `starting_checkpoint`. trainer.setup() trainer.preprocess_datasets() diff --git a/python/ray/train/data_parallel_trainer.py b/python/ray/train/data_parallel_trainer.py index 80077fac2309..b17b02c5054e 100644 --- a/python/ray/train/data_parallel_trainer.py +++ b/python/ray/train/data_parallel_trainer.py @@ -527,7 +527,7 @@ def clear_lazy_checkpoint_marker(): datasets=self.datasets, data_config=self._data_config, checkpoint_manager=checkpoint_manager, - checkpoint=self.resume_from_checkpoint, + checkpoint=self.starting_checkpoint, checkpoint_strategy=checkpoint_strategy, storage_path=self.run_config.storage_path, ) diff --git a/python/ray/train/gbdt_trainer.py b/python/ray/train/gbdt_trainer.py index 997be4230008..df9564fed25f 100644 --- a/python/ray/train/gbdt_trainer.py +++ b/python/ray/train/gbdt_trainer.py @@ -280,8 +280,8 @@ def training_loop(self) -> None: evals_result = {} init_model = None - if self.resume_from_checkpoint: - init_model, _ = self._load_checkpoint(self.resume_from_checkpoint) + if self.starting_checkpoint: + init_model, _ = self._load_checkpoint(self.starting_checkpoint) config.setdefault("verbose_eval", False) config.setdefault("callbacks", []) diff --git a/python/ray/train/tests/test_base_trainer.py b/python/ray/train/tests/test_base_trainer.py index f8eaf64c108a..5ef729f96caa 100644 --- a/python/ray/train/tests/test_base_trainer.py +++ b/python/ray/train/tests/test_base_trainer.py @@ -398,7 +398,7 @@ def test_large_params(ray_start_4_cpus): array_size = int(1e8) def training_loop(self): - checkpoint = self.resume_from_checkpoint.to_dict()["ckpt"] + checkpoint = self.starting_checkpoint.to_dict()["ckpt"] assert len(checkpoint) == array_size checkpoint = Checkpoint.from_dict({"ckpt": np.zeros(shape=array_size)}) diff --git a/python/ray/train/tests/test_checkpoint.py b/python/ray/train/tests/test_checkpoint.py index 91f192e3cf23..f01352829e0f 100644 --- a/python/ray/train/tests/test_checkpoint.py +++ b/python/ray/train/tests/test_checkpoint.py @@ -156,6 +156,21 @@ def test_multiprocess_as_directory(checkpoint: Checkpoint, monkeypatch): assert not Path(checkpoint_dir_1).exists() +def test_as_directory_lock_cleanup(checkpoint: Checkpoint): + """Errors when accessing a checkpoint with `as_directory` + shouldn't leave behind lock files. + """ + with pytest.raises(RuntimeError): + with checkpoint.as_directory() as checkpoint_dir: + raise RuntimeError + + assert not _list_existing_del_locks(checkpoint_dir) + + is_local_checkpoint = isinstance(checkpoint.filesystem, pyarrow.fs.LocalFileSystem) + if not is_local_checkpoint: + assert not Path(checkpoint_dir).exists() + + def test_metadata(checkpoint: Checkpoint): assert checkpoint.get_metadata() == {} diff --git a/python/ray/train/tests/test_new_persistence.py b/python/ray/train/tests/test_new_persistence.py index b01a73d550ed..8de2a3641eff 100644 --- a/python/ray/train/tests/test_new_persistence.py +++ b/python/ray/train/tests/test_new_persistence.py @@ -143,24 +143,28 @@ def train_fn(config): for i in range(start, config.get("num_iterations", 5)): time.sleep(0.25) - checkpoint_file_name = "checkpoint.pkl" + temp_dir = tempfile.mkdtemp() + with open(os.path.join(temp_dir, "checkpoint.pkl"), "wb") as f: + pickle.dump({"iter": i}, f) + artifact_file_name = f"artifact-iter={i}.txt" if in_trainer: rank = train.get_context().get_world_rank() - checkpoint_file_name = f"checkpoint_shard-rank={rank}.pkl" artifact_file_name = f"artifact-rank={rank}-iter={i}.txt" + checkpoint_file_name = f"checkpoint_shard-rank={rank}.pkl" + with open(os.path.join(temp_dir, checkpoint_file_name), "wb") as f: + pickle.dump({"iter": i}, f) + with open(artifact_file_name, "w") as f: f.write(f"{i}") - temp_dir = tempfile.mkdtemp() - with open(os.path.join(temp_dir, checkpoint_file_name), "wb") as f: - pickle.dump({"iter": i}, f) - train.report( {"iter": i, _SCORE_KEY: i}, checkpoint=NewCheckpoint.from_directory(temp_dir), ) + if i in config.get("fail_iters", []): + raise RuntimeError(f"Failing on iter={i}!!") @pytest.mark.parametrize("storage_path_type", [None, "nfs", "cloud", "custom_fs"]) @@ -287,6 +291,7 @@ def test_trainer( ├── progress.csv ├── result.json ├── checkpoint_000000 + │ ├── checkpoint.pkl <- Shared checkpoint file │ ├── checkpoint_shard-rank=0.pkl <- Worker checkpoint shards │ └── checkpoint_shard-rank=1.pkl ├── ... @@ -309,7 +314,11 @@ def test_trainer( NUM_WORKERS = 2 trainer = DataParallelTrainer( train_fn, - train_loop_config={"in_trainer": True, "num_iterations": NUM_ITERATIONS}, + train_loop_config={ + "in_trainer": True, + "num_iterations": NUM_ITERATIONS, + "fail_iters": [2, 4], + }, scaling_config=train.ScalingConfig(num_workers=2), run_config=train.RunConfig( storage_path=storage_path, @@ -317,6 +326,7 @@ def test_trainer( name=exp_name, verbose=0, checkpoint_config=checkpoint_config, + failure_config=train.FailureConfig(max_failures=2), ), ) result = trainer.fit() @@ -352,6 +362,8 @@ def test_trainer( assert len(list(trial_dir.glob("checkpoint_*"))) == expected_num_checkpoints for checkpoint_dir in trial_dir.glob("checkpoint_*"): + # 1 shared checkpoint.pkl file, written by all workers. + assert len(list(checkpoint_dir.glob("checkpoint.pkl"))) == 1 # 1 checkpoint shard per worker. assert ( len(list(checkpoint_dir.glob("checkpoint_shard-*.pkl"))) == NUM_WORKERS diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index 833a9537acc5..d64c66b7e33d 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -71,9 +71,9 @@ def __init__( # TrainingResult event. There's no need to do these one at a time. self._checkpoint_to_report = None - # TODO(justinvyu): Is this the best way to do this? Need to save this - # as part of checkpoint metadata and load it back on restore. - self._latest_checkpoint_index = 0 + self._storage = None + if _use_storage_context(): + self._storage = get_storage_context() self._start_training( train_func=train_func, @@ -103,7 +103,10 @@ def _start_training( run_dir=run_dir, latest_checkpoint_id=latest_checkpoint_id, ) - checkpoint = self._checkpoint_manager._load_checkpoint(checkpoint) + + if not _use_storage_context(): + checkpoint = self._checkpoint_manager._load_checkpoint(checkpoint) + self._run_with_error_handling( lambda: self._backend_executor.start_training( train_func=train_func, @@ -119,18 +122,10 @@ def _send_next_checkpoint_path_to_workers(self): # NOTE: Always upload to storage from workers in the new persistence path # (no need to check for the `checkpoint_upload_from_workers` flag) if _use_storage_context(): - storage = get_storage_context() - - # NOTE: Idea: this checkpoint dir name should be customizable - # and created on the fly when the checkpoint is reported with metrics. - # Ex: lambda metrics: f"checkpoint_iter={metrics['training_iteration']}" - storage.current_checkpoint_index = self._latest_checkpoint_index - self._backend_executor._set_checkpoint_index( - storage.current_checkpoint_index + self._storage.current_checkpoint_index ) - - self._latest_checkpoint_index += 1 + self._storage.current_checkpoint_index += 1 elif self._checkpoint_strategy._checkpoint_upload_from_workers: self._backend_executor._set_legacy_checkpoint_uri( diff --git a/python/ray/tune/execution/tune_controller.py b/python/ray/tune/execution/tune_controller.py index b880257bab69..61e675baef7b 100644 --- a/python/ray/tune/execution/tune_controller.py +++ b/python/ray/tune/execution/tune_controller.py @@ -1956,11 +1956,29 @@ def _checkpoint_trial_if_needed(self, trial, force=False): ### # RESTORE def _schedule_trial_restore(self, trial: Trial) -> bool: - checkpoint = trial.checkpoint - if _use_storage_context(): - # TODO(justinvyu): Skipping restoration altogether for now. - return False + checkpoint_result = trial.checkpoint_manager.latest_checkpoint_result + + if not checkpoint_result: + logger.debug(f"Not restoring trial {trial}: No checkpoint found.") + return False + + # TODO(justinvyu): Is this really needed? + trial.restoring_from = checkpoint_result + + method_name = "restore" + args = (checkpoint_result,) + self._schedule_trial_task( + trial=trial, + method_name=method_name, + args=args, + kwargs={}, + on_result=self._on_restoring_result, + on_error=self._trial_task_failure, + ) + return True + + checkpoint = trial.checkpoint if checkpoint.dir_or_data is None: logger.debug(f"Not restoring trial {trial}: No checkpoint found.") diff --git a/python/ray/tune/experiment/trial.py b/python/ray/tune/experiment/trial.py index 267918050f9f..b0f2713934ae 100644 --- a/python/ray/tune/experiment/trial.py +++ b/python/ray/tune/experiment/trial.py @@ -384,8 +384,6 @@ def __init__( self.storage = copy.copy(storage) if _use_storage_context(): - assert self.storage - self._legacy_orig_experiment_path = None self._legacy_orig_experiment_dir_name = None self._legacy_local_experiment_path = None @@ -1079,8 +1077,13 @@ def on_checkpoint(self, checkpoint: _TrackedCheckpoint): if _use_storage_context(): from ray.train._internal.checkpoint_manager import _TrainingResult - assert isinstance(checkpoint, _TrainingResult) - self.checkpoint_manager.register_checkpoint(checkpoint) + checkpoint_result = checkpoint + assert isinstance(checkpoint_result, _TrainingResult) + self.checkpoint_manager.register_checkpoint(checkpoint_result) + # Increment the checkpoint index to keep the checkpoint index in sync. + # This index will get restored when the trial is restored and will + # be passed to the Trainable as the starting checkpoint index. + self.storage.current_checkpoint_index += 1 else: self.checkpoint_manager.on_checkpoint(checkpoint) self.invalidate_json_state() @@ -1088,6 +1091,12 @@ def on_checkpoint(self, checkpoint: _TrackedCheckpoint): def on_restore(self): """Handles restoration completion.""" assert self.is_restoring + + if _use_storage_context(): + from ray.train._internal.checkpoint_manager import _TrainingResult + + assert isinstance(self.restoring_from, _TrainingResult) + self.last_result = self.restoring_from.metrics self.last_result.setdefault("config", self.config) self.restoring_from = None diff --git a/python/ray/tune/trainable/function_trainable.py b/python/ray/tune/trainable/function_trainable.py index 569f2f9731d1..d5dd7b0aff59 100644 --- a/python/ray/tune/trainable/function_trainable.py +++ b/python/ray/tune/trainable/function_trainable.py @@ -9,7 +9,7 @@ import warnings from functools import partial from numbers import Number -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Optional, Type, TYPE_CHECKING from ray.air._internal.util import StartTraceback, RunnerThread import queue @@ -37,6 +37,10 @@ from ray.util.annotations import DeveloperAPI from ray.util.debug import log_once +if TYPE_CHECKING: + from ray.train._internal.checkpoint_manager import _TrainingResult + + logger = logging.getLogger(__name__) # Time between FunctionTrainable checks when fetching @@ -149,6 +153,7 @@ def __init__( self._trial_id = trial_id self._logdir = logdir self._last_checkpoint = None + self._latest_checkpoint_result: Optional["_TrainingResult"] = None self._fresh_checkpoint = False self._trial_resources = trial_resources # Mark whether the `ray.air.session.report()` API is being used, @@ -160,6 +165,7 @@ def reset(self, trial_name=None, trial_id=None, logdir=None, trial_resources=Non self._trial_id = trial_id self._logdir = logdir self._last_checkpoint = None + self._latest_checkpoint_result = None self._fresh_checkpoint = False self._trial_resources = trial_resources self._air_session_has_reported = False @@ -242,6 +248,14 @@ def set_checkpoint(self, checkpoint, is_new=True): def has_new_checkpoint(self): return self._fresh_checkpoint + def get_checkpoint_result(self) -> Optional["_TrainingResult"]: + from ray.train._internal.storage import _use_storage_context + + assert _use_storage_context() + # The checkpoint is no longer fresh after it's been handed off to Tune. + self._fresh_checkpoint = False + return self._latest_checkpoint_result + def get_checkpoint(self): # NOTE: This is not the same as `train.get_checkpoint`. # This is used internally by `FunctionTrainable.save_checkpoint`. @@ -263,15 +277,18 @@ def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> N # TODO(justinvyu): With a unified session, we'll still run into this doubled # report problem. This should be fixed by checking if the checkpoint has been # uploaded already (via some marker), then skipping the repeat upload. - if _use_storage_context() and isinstance(checkpoint, NewCheckpoint): + if _use_storage_context(): + assert isinstance(checkpoint, NewCheckpoint) logger.debug(f"Checkpoint received by the Tune session: {checkpoint}") self._fresh_checkpoint = True # TODO(justinvyu): `metrics` doesn't include the autofilled metrics # like `training_iteration` and `time_total_s`. # Should the session be the source of truth for these metrics? - self._last_checkpoint = _TrainingResult( + self._latest_checkpoint_result = _TrainingResult( checkpoint=checkpoint, metrics=metrics ) + + self._last_checkpoint = None else: if checkpoint: training_iteration = self._get_training_iteration() @@ -284,15 +301,17 @@ def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> N @property def loaded_checkpoint(self) -> Optional[Checkpoint]: - if self._last_checkpoint: - from ray.train._internal.storage import _use_storage_context - from ray.train._internal.checkpoint_manager import _TrainingResult + from ray.train._internal.storage import _use_storage_context + from ray.train._internal.checkpoint_manager import _TrainingResult - if _use_storage_context() and isinstance( - self._last_checkpoint, _TrainingResult - ): - return self._last_checkpoint.checkpoint + if _use_storage_context(): + if not self._latest_checkpoint_result: + return None + assert isinstance(self._latest_checkpoint_result, _TrainingResult) + return self._latest_checkpoint_result.checkpoint + + if self._last_checkpoint: assert isinstance(self._last_checkpoint, str) return Checkpoint.from_directory(self._last_checkpoint) return None @@ -483,6 +502,14 @@ def execute(self, fn): def get_state(self): state = super().get_state() + from ray.train._internal.storage import _use_storage_context + + if _use_storage_context(): + # TODO(justinvyu): This is only used to populate the tune metadata + # file within the checkpoint, so can be removed after if remove + # the metadata file. + return state + checkpoint = self._status_reporter.get_checkpoint() if not checkpoint: state.update(iteration=0, timesteps_total=0, episodes_total=0) @@ -492,13 +519,15 @@ def save_checkpoint(self, checkpoint_dir: str = ""): if checkpoint_dir: raise ValueError("Checkpoint dir should not be used with function API.") - checkpoint = self._status_reporter.get_checkpoint() - from ray.train._internal.storage import _use_storage_context from ray.train._internal.checkpoint_manager import _TrainingResult - if _use_storage_context() and isinstance(checkpoint, _TrainingResult): - return checkpoint + if _use_storage_context(): + checkpoint_result = self._status_reporter.get_checkpoint_result() + assert isinstance(checkpoint_result, _TrainingResult) + return checkpoint_result + + checkpoint = self._status_reporter.get_checkpoint() if not checkpoint: # We drop a marker here to indicate that the checkpoint is empty @@ -537,6 +566,16 @@ def save_to_object(self): return checkpoint.to_bytes() def load_checkpoint(self, checkpoint): + from ray.train._internal.storage import _use_storage_context + from ray.train._internal.checkpoint_manager import _TrainingResult + + if _use_storage_context(): + checkpoint_result = checkpoint + assert isinstance(checkpoint_result, _TrainingResult) + self._status_reporter._latest_checkpoint_result = checkpoint_result + self._status_reporter._fresh_checkpoint = False + return + # This should be removed once Trainables are refactored. if "tune_checkpoint_path" in checkpoint: del checkpoint["tune_checkpoint_path"] diff --git a/python/ray/tune/trainable/trainable.py b/python/ray/tune/trainable/trainable.py index d5488b52a8e9..8d65618def26 100644 --- a/python/ray/tune/trainable/trainable.py +++ b/python/ray/tune/trainable/trainable.py @@ -506,7 +506,11 @@ def save( if _use_storage_context() and isinstance( checkpoint_dict_or_path, _TrainingResult ): - return checkpoint_dict_or_path + checkpoint_result = checkpoint_dict_or_path + assert self._last_result + # Update the checkpoint result to include auto-filled metrics. + checkpoint_result.metrics.update(self._last_result) + return checkpoint_result if checkpoint_dict_or_path is None: # checkpoint_dict_or_path can only be None in class trainables. @@ -864,6 +868,35 @@ def restore( could not be found. """ + if _use_storage_context(): + from ray.train._internal.checkpoint_manager import _TrainingResult + + checkpoint_result = checkpoint_path + assert isinstance(checkpoint_result, _TrainingResult) + + checkpoint_metrics = checkpoint_result.metrics + self._iteration = checkpoint_metrics[TRAINING_ITERATION] + self._time_total = checkpoint_metrics[TIME_TOTAL_S] + self._time_since_restore = 0.0 + self._iterations_since_restore = 0 + + # TODO(justinvyu): This stuff should be moved to rllib. + self._timesteps_total = checkpoint_metrics.get(TIMESTEPS_TOTAL) + self._timesteps_since_restore = 0 + self._episodes_total = checkpoint_metrics.get(EPISODES_TOTAL) + + # TODO(justinvyu): The Trainable `load_checkpoint` interface + # should be updated to take in a `_TrainingResult` / Checkpoint + self.load_checkpoint(checkpoint_result) + + self._restored = True + + logger.info( + f"Restored on {self._local_ip} from checkpoint: " + f"{checkpoint_result.checkpoint}" + ) + return True + # Ensure Checkpoints are converted if isinstance(checkpoint_path, Checkpoint): return self._restore_from_checkpoint_obj(checkpoint_path)