Skip to content

Commit

Permalink
[Tune] Fix checkpoint directory assignment for new checkpoints create…
Browse files Browse the repository at this point in the history
…d after restoring a function trainable (#31231)

This PR fixes checkpoint directory creation for restored function trainables to use the restored iteration instead of starting over from `checkpoint_000000`.

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu authored Dec 22, 2022
1 parent 7146df6 commit 77b94ab
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 20 deletions.
71 changes: 70 additions & 1 deletion python/ray/tune/tests/test_tuner_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def _train_fn_sometimes_failing(config):

checkpoint = session.get_checkpoint()
if checkpoint:
state = checkpoint.to_dict()
checkpoint_dict = checkpoint.to_dict()
state = {"it": checkpoint_dict["it"]}
else:
state = {"it": 0}

Expand Down Expand Up @@ -800,6 +801,74 @@ def on_trial_result(self, runner, trial, result):
)


def test_checkpoints_saved_after_resume(tmp_path):
"""Checkpoints saved after experiment restore should pick up at the correct
iteration and should not overwrite the checkpoints from the original run.
Old checkpoints should still be deleted if the total number of checkpoints
(old + new) exceeds `num_to_keep`.
In this test, `num_to_keep=4`:
- Initial run saves checkpoint_000000 and checkpoint_000001
- Restored run saves checkpoint_000002, checkpoint_000003, and checkpoint_000004
- Checkpoint 000000 should be deleted.
"""

def get_checkpoints(experiment_dir):
checkpoint_dirs = [
path
for path in os.listdir(experiment_dir)
if path.startswith("checkpoint_")
]
sorted_checkpoint_dirs = sorted(checkpoint_dirs)
checkpoints = [
Checkpoint.from_directory(os.path.join(experiment_dir, d))
for d in sorted_checkpoint_dirs
]
return sorted_checkpoint_dirs, checkpoints

fail_marker = tmp_path / "fail_marker"
fail_marker.write_text("", encoding="utf-8")

num_to_keep = 4
tuner = Tuner(
_train_fn_sometimes_failing,
tune_config=TuneConfig(num_samples=1),
run_config=RunConfig(
name="exp_name",
local_dir=str(tmp_path),
checkpoint_config=CheckpointConfig(num_to_keep=num_to_keep),
),
param_space={
"failing_hanging": (fail_marker, None),
"num_epochs": 2,
},
)
results = tuner.fit()
training_iteration = results[0].metrics["training_iteration"]
assert (
training_iteration == 2
), f"Should be at 2 iters before erroring, got {training_iteration}"

# Initial run saves the first 2 checkpoints
checkpoint_dirs, checkpoints = get_checkpoints(results[0].log_dir)
assert checkpoint_dirs == ["checkpoint_000000", "checkpoint_000001"]
assert [ckpt.to_dict()["it"] for ckpt in checkpoints] == [1, 2]

fail_marker.unlink()
tuner = Tuner.restore(str(tmp_path / "exp_name"), resume_errored=True)
results = tuner.fit()

assert len(results.errors) == 0
training_iteration = results[0].metrics["training_iteration"]
# Restored at it=2, reported 3 more times -> should have it=5
assert training_iteration == 5

# Restored run saves the 3 more checkpoints, and first checkpoint should be deleted
checkpoint_dirs, checkpoints = get_checkpoints(results[0].log_dir)
assert checkpoint_dirs == [f"checkpoint_00000{i}" for i in range(1, 5)]
assert [ckpt.to_dict()["it"] for ckpt in checkpoints] == [2, 3, 4, 5]


if __name__ == "__main__":
import sys

Expand Down
36 changes: 19 additions & 17 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,33 +123,35 @@ def create_perm_checkpoint(checkpoint_dir, logdir, step):
return perm_checkpoint_dir


@DeveloperAPI
class _StatusReporter:
def __init__(
self,
result_queue,
continue_semaphore,
end_event,
experiment_name=None,
trial_name=None,
trial_id=None,
logdir=None,
trial_resources=None,
result_queue: queue.Queue,
continue_semaphore: threading.Semaphore,
end_event: threading.Event,
training_iteration_func: Callable[[], int],
experiment_name: Optional[str] = None,
trial_name: Optional[str] = None,
trial_id: Optional[str] = None,
logdir: Optional[str] = None,
trial_resources: Optional[Union[Resources, PlacementGroupFactory]] = None,
):
self._queue = result_queue
self._last_report_time = None
self._continue_semaphore = continue_semaphore
self._end_event = end_event
self._get_training_iteration = training_iteration_func
self._experiment_name = experiment_name
self._trial_name = trial_name
self._trial_id = trial_id
self._logdir = logdir
self._last_checkpoint = None
self._fresh_checkpoint = False
self._trial_resources = trial_resources
# Also used as a marker of whether new `report()` API is being used,
# in which case, `_iter` will be incremented from 0 every time `report`
# is called.
self._iter = None
# Mark whether the `ray.air.session.report()` API is being used,
# to throw an error if `tune.report()` is called as well
self._air_session_has_reported = False

def reset(self, trial_name=None, trial_id=None, logdir=None, trial_resources=None):
self._trial_name = trial_name
Expand All @@ -158,7 +160,7 @@ def reset(self, trial_name=None, trial_id=None, logdir=None, trial_resources=Non
self._last_checkpoint = None
self._fresh_checkpoint = False
self._trial_resources = trial_resources
self._iter = None
self._air_session_has_reported = False

def __call__(self, _metric=None, **kwargs):
"""Report updated training status.
Expand Down Expand Up @@ -237,16 +239,15 @@ def _start(self):

def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
# TODO(xwjiang): Tons of optimizations.
if not self._iter:
self._iter = 0
self._air_session_has_reported = True
if checkpoint:
checkpoint_dir = self.make_checkpoint_dir(step=self._iter)
training_iteration = self._get_training_iteration()
checkpoint_dir = self.make_checkpoint_dir(step=training_iteration)
self.set_checkpoint(checkpoint_dir)
checkpoint.to_directory(checkpoint_dir)
# TODO(krfricke): Remove this once support is added in Checkpoint.
open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close()
self.__call__(**metrics)
self._iter += 1

@property
def loaded_checkpoint(self) -> Optional[Checkpoint]:
Expand Down Expand Up @@ -309,6 +310,7 @@ def setup(self, config):
self._results_queue,
self._continue_semaphore,
self._end_event,
training_iteration_func=lambda: self.training_iteration,
experiment_name=(
self._trial_info.experiment_name if self._trial_info else None
),
Expand Down
4 changes: 2 additions & 2 deletions python/ray/tune/trainable/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def run_me(config):
)
_session = get_session()
if _session:
if _session._iter:
if _session._air_session_has_reported:
raise ValueError(
"It is not allowed to mix `tune.report` with `session.report`."
)
Expand Down Expand Up @@ -310,7 +310,7 @@ def func(config, checkpoint_dir=None):
raise ValueError("checkpoint_dir(step) must be provided - got None.")

if _session:
if _session._iter:
if _session._air_session_has_reported:
raise ValueError(
"It is not allowed to mix `with tune.checkpoint_dir` "
"with `session.report`."
Expand Down

0 comments on commit 77b94ab

Please sign in to comment.