Skip to content

Commit

Permalink
test: fix slow delete_checkpoint test (#8377)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicholasBlaskey authored Nov 9, 2023
1 parent b0505db commit a590999
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
39 changes: 21 additions & 18 deletions e2e_tests/tests/cluster/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,26 @@
EXPECT_TIMEOUT = 5


def wait_for_gc_to_finish(experiment_id: int) -> None:
def wait_for_gc_to_finish(experiment_ids: List[int]) -> None:
certs.cli_cert = certs.default_load(conf.make_master_url())
authentication.cli_auth = authentication.Authentication(conf.make_master_url())
saw_gc = False

seen_gc_experiment_ids = set()
done_gc_experiment_ids = set()
# Don't wait longer than 5 minutes (as 600 half-seconds to improve our sampling resolution).
for _ in range(600):
r = api.get(conf.make_master_url(), "tasks").json()
names = [task["name"] for task in r.values()]
gc_name = f"Checkpoint GC (Experiment {experiment_id})"
if gc_name in names:
saw_gc = True
elif saw_gc:
# We previously saw checkpoint gc but now we don't, so it must have finished.

for experiment_id in experiment_ids:
gc_name = f"Checkpoint GC (Experiment {experiment_id})"
if gc_name in names:
seen_gc_experiment_ids.add(experiment_id)
elif experiment_id in seen_gc_experiment_ids:
# We saw the gc before but now don't so we assume it is done.
done_gc_experiment_ids.add(experiment_id)

if len(done_gc_experiment_ids) == len(experiment_ids):
return
time.sleep(0.5)

Expand Down Expand Up @@ -145,15 +152,12 @@ def test_delete_checkpoints() -> None:
config, model_def_path=conf.fixtures_path("no_op"), expected_trials=1
)

wait_for_gc_to_finish(exp_id_1)
wait_for_gc_to_finish(exp_id_2)

test_session = api_utils.determined_test_session()
exp_1_checkpoints = bindings.get_GetExperimentCheckpoints(
session=test_session, id=exp_id_1
).checkpoints
exp_2_checkpoints = bindings.get_GetExperimentCheckpoints(
session=test_session, id=exp_id_1
session=test_session, id=exp_id_2
).checkpoints
assert len(exp_1_checkpoints) > 0, f"no checkpoints found in experiment with ID:{exp_id_1}"
assert len(exp_2_checkpoints) > 0, f"no checkpoints found in experiment with ID:{exp_id_2}"
Expand Down Expand Up @@ -182,8 +186,7 @@ def test_delete_checkpoints() -> None:
delete_body = bindings.v1DeleteCheckpointsRequest(checkpointUuids=d_checkpoint_uuids)
bindings.delete_DeleteCheckpoints(session=test_session, body=delete_body)

wait_for_gc_to_finish(exp_id_1)
wait_for_gc_to_finish(exp_id_2)
wait_for_gc_to_finish([exp_id_1, exp_id_2])

for d_c in d_checkpoint_uuids:
ensure_checkpoint_deleted(test_session, d_c, storage_manager)
Expand Down Expand Up @@ -265,7 +268,7 @@ def run_gc_checkpoints_test(checkpoint_storage: Dict[str, str]) -> None:

# In some configurations, checkpoint GC will run on an auxillary machine, which may have to
# be spun up still. So we'll wait for it to run.
wait_for_gc_to_finish(experiment_id)
wait_for_gc_to_finish([experiment_id])

# Checkpoints are not marked as deleted until gc_checkpoint task starts.
retries = 5
Expand Down Expand Up @@ -465,7 +468,7 @@ def assert_checkpoint_state(
checkpointUuids=[completed_checkpoints[0].uuid],
)
bindings.post_CheckpointsRemoveFiles(test_session, body=remove_body)
wait_for_gc_to_finish(exp_id)
wait_for_gc_to_finish([exp_id])

assert_checkpoint_state(
completed_checkpoints[0].uuid,
Expand All @@ -491,7 +494,7 @@ def assert_checkpoint_state(
checkpointUuids=[completed_checkpoints[0].uuid],
)
bindings.post_CheckpointsRemoveFiles(test_session, body=remove_body)
wait_for_gc_to_finish(exp_id)
wait_for_gc_to_finish([exp_id])

assert_checkpoint_state(
completed_checkpoints[0].uuid,
Expand All @@ -509,7 +512,7 @@ def assert_checkpoint_state(
checkpointUuids=[completed_checkpoints[0].uuid],
)
bindings.post_CheckpointsRemoveFiles(test_session, body=remove_body)
wait_for_gc_to_finish(exp_id)
wait_for_gc_to_finish([exp_id])

assert_checkpoint_state(
completed_checkpoints[0].uuid,
Expand All @@ -532,7 +535,7 @@ def assert_checkpoint_state(
checkpointUuids=[completed_checkpoints[1].uuid],
)
bindings.post_CheckpointsRemoveFiles(test_session, body=remove_body)
wait_for_gc_to_finish(exp_id)
wait_for_gc_to_finish([exp_id])

assert_checkpoint_state(
completed_checkpoints[1].uuid,
Expand Down
2 changes: 1 addition & 1 deletion e2e_tests/tests/experiment/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_end_to_end_adaptive() -> None:
None,
)

wait_for_gc_to_finish(experiment_id=exp_id)
wait_for_gc_to_finish(experiment_ids=[exp_id])

# Check that validation accuracy look sane (more than 93% on MNIST).
trials = exp.experiment_trials(exp_id)
Expand Down

0 comments on commit a590999

Please sign in to comment.