Skip to content

Commit

Permalink
[air] pyarrow.fs persistence (8/n): ray.train.Checkpoint restore:…
Browse files Browse the repository at this point in the history
… `resume_from_checkpoint` (ray-project#38143)

This PR supports `Trainer(resume_from_checkpoint)` with the new Checkpoint and adds it as a section of the e2e test.

This PR also fixes a bug where no checkpoints being reported causes the Result object to error on construction.

Signed-off-by: e428265 <[email protected]>
  • Loading branch information
justinvyu authored and arvind-chandra committed Aug 31, 2023
1 parent 730daf5 commit 39925a9
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ py_test(

py_test(
name = "test_new_persistence",
size = "medium",
size = "large",
srcs = ["tests/test_new_persistence.py"],
tags = ["team:ml", "exclusive"],
deps = [":train_lib", ":conftest"]
Expand Down
6 changes: 5 additions & 1 deletion python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ray.air.checkpoint import Checkpoint
from ray.air.config import RunConfig, ScalingConfig
from ray.air.result import Result
from ray.train._checkpoint import Checkpoint as NewCheckpoint
from ray.train._internal import session
from ray.train._internal.storage import _use_storage_context
from ray.train.constants import TRAIN_DATASET_KEY
Expand Down Expand Up @@ -453,8 +454,11 @@ def _validate_attributes(self):
f"found {type(self.preprocessor)} with value `{self.preprocessor}`."
)

expected_checkpoint_type = (
NewCheckpoint if _use_storage_context() else ray.air.Checkpoint
)
if self.starting_checkpoint is not None and not isinstance(
self.starting_checkpoint, ray.air.Checkpoint
self.starting_checkpoint, expected_checkpoint_type
):
raise ValueError(
f"`resume_from_checkpoint` should be an instance of "
Expand Down
43 changes: 42 additions & 1 deletion python/ray/train/tests/test_new_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ray import train, tune
from ray.air.constants import EXPR_RESULT_FILE
from ray.train._internal.storage import _download_from_fs_path
from ray.train._internal.storage import _download_from_fs_path, StorageContext
from ray.train._checkpoint import Checkpoint as NewCheckpoint
from ray.train.data_parallel_trainer import DataParallelTrainer

Expand Down Expand Up @@ -167,6 +167,38 @@ def train_fn(config):
raise RuntimeError(f"Failing on iter={i}!!")


def _resume_from_checkpoint(checkpoint: NewCheckpoint, expected_state: dict):
print(f"\nStarting run with `resume_from_checkpoint`: {checkpoint}\n")

def assert_fn(config):
checkpoint_to_check = train.get_checkpoint()
with checkpoint_to_check.as_directory() as checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint.pkl"), "rb") as f:
state = pickle.load(f)

print("Loaded state from `resume_from_checkpoint`:", state)
print("Expected state:", expected_state)
assert state == expected_state, (state, expected_state)

dummy_ckpt = tempfile.mkdtemp()
with open(os.path.join(dummy_ckpt, "dummy.txt"), "w") as f:
f.write("data")
train.report({"dummy": 1}, checkpoint=NewCheckpoint.from_directory(dummy_ckpt))

trainer = DataParallelTrainer(
assert_fn,
scaling_config=train.ScalingConfig(num_workers=2),
run_config=train.RunConfig(name="test_resume_from_checkpoint"),
resume_from_checkpoint=checkpoint,
)
result = trainer.fit()

# Make sure that the checkpoint indexing starts from scratch.
assert Path(
result.checkpoint.path
).name == StorageContext._make_checkpoint_dir_name(0)


@pytest.mark.parametrize("storage_path_type", [None, "nfs", "cloud", "custom_fs"])
def test_tuner(monkeypatch, storage_path_type, tmp_path):
"""End-to-end test that the new persistence mode works with the Tuner API.
Expand Down Expand Up @@ -329,8 +361,17 @@ def test_trainer(
failure_config=train.FailureConfig(max_failures=2),
),
)
print("\nStarting initial run.\n")
result = trainer.fit()

with monkeypatch.context() as m:
# This is so that the `resume_from_checkpoint` run doesn't mess up the
# assertions later for the `storage_path=None` case.
m.setenv("RAY_AIR_LOCAL_CACHE_DIR", tmp_path / "resume_from_checkpoint")
_resume_from_checkpoint(
result.checkpoint, expected_state={"iter": NUM_ITERATIONS - 1}
)

local_inspect_dir, storage_fs_path = _get_local_inspect_dir(
root_local_path=tmp_path,
storage_path=storage_path,
Expand Down
6 changes: 5 additions & 1 deletion python/ray/tune/result_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,11 @@ def _trial_to_result(self, trial: Trial) -> Result:
)

assert isinstance(trial.checkpoint_manager, _NewCheckpointManager)
checkpoint = trial.checkpoint_manager.latest_checkpoint_result.checkpoint
checkpoint = None
if trial.checkpoint_manager.latest_checkpoint_result:
checkpoint = (
trial.checkpoint_manager.latest_checkpoint_result.checkpoint
)
best_checkpoint_results = trial.checkpoint_manager.best_checkpoint_results
best_checkpoints = [
(checkpoint_result.checkpoint, checkpoint_result.metrics)
Expand Down

0 comments on commit 39925a9

Please sign in to comment.