Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] Fix checkpoint directory assignment for new checkpoints created after restoring a function trainable #31231

Merged
merged 6 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 70 additions & 1 deletion python/ray/tune/tests/test_tuner_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def _train_fn_sometimes_failing(config):

checkpoint = session.get_checkpoint()
if checkpoint:
state = checkpoint.to_dict()
checkpoint_dict = checkpoint.to_dict()
state = {"it": checkpoint_dict["it"]}
else:
state = {"it": 0}

Expand Down Expand Up @@ -800,6 +801,74 @@ def on_trial_result(self, runner, trial, result):
)


def test_checkpoints_saved_after_resume(tmp_path):
"""Checkpoints saved after experiment restore should pick up at the correct
iteration and should not overwrite the checkpoints from the original run.
Old checkpoints should still be deleted if the total number of checkpoints
(old + new) exceeds `num_to_keep`.

In this test, `num_to_keep=4`:
- Initial run saves checkpoint_000000 and checkpoint_000001
- Restored run saves checkpoint_000002, checkpoint_000003, and checkpoint_000004
- Checkpoint 000000 should be deleted.
"""

def get_checkpoints(experiment_dir):
checkpoint_dirs = [
path
for path in os.listdir(experiment_dir)
if path.startswith("checkpoint_")
]
sorted_checkpoint_dirs = sorted(checkpoint_dirs)
checkpoints = [
Checkpoint.from_directory(os.path.join(experiment_dir, d))
for d in sorted_checkpoint_dirs
]
return sorted_checkpoint_dirs, checkpoints

fail_marker = tmp_path / "fail_marker"
fail_marker.write_text("", encoding="utf-8")

num_to_keep = 4
tuner = Tuner(
_train_fn_sometimes_failing,
tune_config=TuneConfig(num_samples=1),
run_config=RunConfig(
name="exp_name",
local_dir=str(tmp_path),
checkpoint_config=CheckpointConfig(num_to_keep=num_to_keep),
),
param_space={
"failing_hanging": (fail_marker, None),
"num_epochs": 2,
},
)
results = tuner.fit()
training_iteration = results[0].metrics["training_iteration"]
assert (
training_iteration == 2
), f"Should be at 2 iters before erroring, got {training_iteration}"

# Initial run saves the first 2 checkpoints
checkpoint_dirs, checkpoints = get_checkpoints(results[0].log_dir)
assert checkpoint_dirs == ["checkpoint_000000", "checkpoint_000001"]
assert [ckpt.to_dict()["it"] for ckpt in checkpoints] == [1, 2]

fail_marker.unlink()
tuner = Tuner.restore(str(tmp_path / "exp_name"), resume_errored=True)
results = tuner.fit()

assert len(results.errors) == 0
training_iteration = results[0].metrics["training_iteration"]
# Restored at it=2, reported 3 more times -> should have it=5
assert training_iteration == 5

# Restored run saves the 3 more checkpoints, and first checkpoint should be deleted
checkpoint_dirs, checkpoints = get_checkpoints(results[0].log_dir)
assert checkpoint_dirs == [f"checkpoint_00000{i}" for i in range(1, 5)]
assert [ckpt.to_dict()["it"] for ckpt in checkpoints] == [2, 3, 4, 5]


if __name__ == "__main__":
import sys

Expand Down
36 changes: 19 additions & 17 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,33 +123,35 @@ def create_perm_checkpoint(checkpoint_dir, logdir, step):
return perm_checkpoint_dir


@DeveloperAPI
class _StatusReporter:
def __init__(
self,
result_queue,
continue_semaphore,
end_event,
experiment_name=None,
trial_name=None,
trial_id=None,
logdir=None,
trial_resources=None,
result_queue: queue.Queue,
continue_semaphore: threading.Semaphore,
end_event: threading.Event,
training_iteration_func: Callable[[], int],
experiment_name: Optional[str] = None,
trial_name: Optional[str] = None,
trial_id: Optional[str] = None,
logdir: Optional[str] = None,
trial_resources: Optional[Union[Resources, PlacementGroupFactory]] = None,
Comment on lines +130 to +138
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

):
self._queue = result_queue
self._last_report_time = None
self._continue_semaphore = continue_semaphore
self._end_event = end_event
self._get_training_iteration = training_iteration_func
self._experiment_name = experiment_name
self._trial_name = trial_name
self._trial_id = trial_id
self._logdir = logdir
self._last_checkpoint = None
self._fresh_checkpoint = False
self._trial_resources = trial_resources
# Also used as a marker of whether new `report()` API is being used,
# in which case, `_iter` will be incremented from 0 every time `report`
# is called.
self._iter = None
# Mark whether the `ray.air.session.report()` API is being used,
# to throw an error if `tune.report()` is called as well
self._air_session_has_reported = False

def reset(self, trial_name=None, trial_id=None, logdir=None, trial_resources=None):
self._trial_name = trial_name
Expand All @@ -158,7 +160,7 @@ def reset(self, trial_name=None, trial_id=None, logdir=None, trial_resources=Non
self._last_checkpoint = None
self._fresh_checkpoint = False
self._trial_resources = trial_resources
self._iter = None
self._air_session_has_reported = False

def __call__(self, _metric=None, **kwargs):
"""Report updated training status.
Expand Down Expand Up @@ -237,16 +239,15 @@ def _start(self):

def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
# TODO(xwjiang): Tons of optimizations.
if not self._iter:
self._iter = 0
self._air_session_has_reported = True
if checkpoint:
checkpoint_dir = self.make_checkpoint_dir(step=self._iter)
training_iteration = self._get_training_iteration()
checkpoint_dir = self.make_checkpoint_dir(step=training_iteration)
self.set_checkpoint(checkpoint_dir)
checkpoint.to_directory(checkpoint_dir)
# TODO(krfricke): Remove this once support is added in Checkpoint.
open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close()
self.__call__(**metrics)
self._iter += 1

@property
def loaded_checkpoint(self) -> Optional[Checkpoint]:
Expand Down Expand Up @@ -309,6 +310,7 @@ def setup(self, config):
self._results_queue,
self._continue_semaphore,
self._end_event,
training_iteration_func=lambda: self.training_iteration,
experiment_name=(
self._trial_info.experiment_name if self._trial_info else None
),
Expand Down
4 changes: 2 additions & 2 deletions python/ray/tune/trainable/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def run_me(config):
)
_session = get_session()
if _session:
if _session._iter:
if _session._air_session_has_reported:
raise ValueError(
"It is not allowed to mix `tune.report` with `session.report`."
)
Expand Down Expand Up @@ -310,7 +310,7 @@ def func(config, checkpoint_dir=None):
raise ValueError("checkpoint_dir(step) must be provided - got None.")

if _session:
if _session._iter:
if _session._air_session_has_reported:
raise ValueError(
"It is not allowed to mix `with tune.checkpoint_dir` "
"with `session.report`."
Expand Down