Skip to content

Commit

Permalink
[tune] Use Checkpoint.to_bytes() for store_to_object (#25805)
Browse files Browse the repository at this point in the history
We currently use our own serialization to ship checkpoints as objects. Instead we should use the Checkpoint class. This PR also adds support to create results from checkpoints pointing to object references.

Depends on #26351

Signed-off-by: Kai Fricke <[email protected]>
  • Loading branch information
krfricke authored Jul 8, 2022
1 parent 0e259ff commit e1a7efe
Show file tree
Hide file tree
Showing 12 changed files with 324 additions and 186 deletions.
8 changes: 8 additions & 0 deletions python/ray/tune/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,14 @@ py_test(
tags = ["team:ml", "exclusive", "tests_dir_S"],
)

py_test(
name = "test_trainable",
size = "medium",
srcs = ["tests/test_trainable.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "tests_dir_T"],
)

py_test(
name = "test_trainable_util",
size = "small",
Expand Down
18 changes: 12 additions & 6 deletions python/ray/tune/examples/xgboost_dynamic_resources_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@

def get_best_model_checkpoint(analysis):
best_bst = xgb.Booster()
try:
with open(analysis.best_checkpoint, "rb") as inputFile:
_, _, raw_model = pickle.load(inputFile)
best_bst.load_model(bytearray(raw_model))
except IsADirectoryError:
best_bst.load_model(os.path.join(analysis.best_checkpoint, CHECKPOINT_FILENAME))

with analysis.best_checkpoint.as_directory() as checkpoint_dir:
to_load = os.path.join(checkpoint_dir, CHECKPOINT_FILENAME)

if not os.path.exists(to_load):
# Class trainable
with open(os.path.join(checkpoint_dir, "checkpoint"), "rb") as f:
_, _, raw_model = pickle.load(f)
to_load = bytearray(raw_model)

best_bst.load_model(to_load)

accuracy = 1.0 - analysis.best_result["eval-logloss"]
print(f"Best model parameters: {analysis.best_config}")
print(f"Best model total accuracy: {accuracy:.4f}")
Expand Down
4 changes: 3 additions & 1 deletion python/ray/tune/execution/ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Callable, Dict, Iterable, List, Optional, Set, Union

import ray
from ray.air import Checkpoint
from ray.exceptions import GetTimeoutError, RayTaskError
from ray.tune.error import (
TuneError,
Expand Down Expand Up @@ -794,7 +795,8 @@ def restore(self, trial: Trial) -> None:
# This provides FT backwards compatibility in the
# case where no cloud checkpoints are provided.
logger.debug("Trial %s: Reading checkpoint into memory", trial)
obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
checkpoint_path = TrainableUtil.find_checkpoint_dir(checkpoint_dir)
obj = Checkpoint.from_directory(checkpoint_path).to_bytes()
with self._change_working_directory(trial):
remote = trial.runner.restore_from_object.remote(obj)
else:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def test_migration_checkpoint_removal(
cluster.wait_for_nodes()

# Remove checkpoint on "remote" node
shutil.rmtree(os.path.dirname(t1.checkpoint.dir_or_data))
shutil.rmtree(t1.checkpoint.dir_or_data)

if not durable:
# Recover from driver file
Expand Down
24 changes: 12 additions & 12 deletions python/ray/tune/tests/test_experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,31 +131,31 @@ def testGetTrialCheckpointsPathsByTrial(self):
best_trial = self.ea.get_best_trial(self.metric, mode="max")
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(best_trial)
logdir = self.ea.get_best_logdir(self.metric, mode="max")
expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
assert checkpoints_metrics[0][0] == expected_path
expected_path = os.path.join(logdir, "checkpoint_000001")
assert os.path.normpath(checkpoints_metrics[0][0]) == expected_path
assert checkpoints_metrics[0][1] == 1

def testGetTrialCheckpointsPathsByPath(self):
logdir = self.ea.get_best_logdir(self.metric, mode="max")
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(logdir)
expected_path = os.path.join(logdir, "checkpoint_000001/", "checkpoint")
assert checkpoints_metrics[0][0] == expected_path
expected_path = os.path.join(logdir, "checkpoint_000001")
assert os.path.normpath(checkpoints_metrics[0][0]) == expected_path
assert checkpoints_metrics[0][1] == 1

def testGetTrialCheckpointsPathsWithMetricByTrial(self):
best_trial = self.ea.get_best_trial(self.metric, mode="max")
paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric)
logdir = self.ea.get_best_logdir(self.metric, mode="max")
expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
assert paths[0][0] == expected_path
expected_path = os.path.join(logdir, "checkpoint_000001")
assert os.path.normpath(paths[0][0]) == expected_path
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]

def testGetTrialCheckpointsPathsWithMetricByPath(self):
best_trial = self.ea.get_best_trial(self.metric, mode="max")
logdir = self.ea.get_best_logdir(self.metric, mode="max")
paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric)
expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
assert paths[0][0] == expected_path
expected_path = os.path.join(logdir, "checkpoint_000001")
assert os.path.normpath(paths[0][0]) == expected_path
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]

def testGetBestCheckpoint(self):
Expand Down Expand Up @@ -266,8 +266,8 @@ def testGetTrialCheckpointsPathsByPathWithSpecialCharacters(self):
)
logdir = analysis.get_best_logdir(self.metric, mode="max")
checkpoints_metrics = analysis.get_trial_checkpoints_paths(logdir)
expected_path = os.path.join(logdir, "checkpoint_000001/", "checkpoint")
assert checkpoints_metrics[0][0] == expected_path
expected_path = os.path.join(logdir, "checkpoint_000001")
assert os.path.normpath(checkpoints_metrics[0][0]) == expected_path
assert checkpoints_metrics[0][1] == 1

def testGetTrialCheckpointsPathsWithTemporaryCheckpoints(self):
Expand All @@ -288,11 +288,11 @@ def testGetTrialCheckpointsPathsWithTemporaryCheckpoints(self):
)

checkpoints_metrics = analysis.get_trial_checkpoints_paths(logdir)
expected_path = os.path.join(logdir, "checkpoint_000002/", "checkpoint")
expected_path = os.path.join(logdir, "checkpoint_000002")

assert len(checkpoints_metrics) == 1

assert checkpoints_metrics[0][0] == expected_path
assert os.path.normpath(checkpoints_metrics[0][0]) == expected_path
assert checkpoints_metrics[0][1] == 2


Expand Down
149 changes: 149 additions & 0 deletions python/ray/tune/tests/test_trainable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import json
import os
import tempfile
from typing import Dict, Union

import pytest

import ray
from ray import tune
from ray.air import session, Checkpoint
from ray.tune.trainable import wrap_function


@pytest.fixture
def ray_start_2_cpus():
address_info = ray.init(num_cpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()


class SavingTrainable(tune.Trainable):
def __init__(self, return_type: str, *args, **kwargs):
self.return_type = return_type
super(SavingTrainable, self).__init__(*args, **kwargs)

def save_checkpoint(self, tmp_checkpoint_dir: str):
checkpoint_data = {"data": 1}

if self.return_type == "object":
return checkpoint_data

subdir = os.path.join(tmp_checkpoint_dir, "subdir")
os.makedirs(subdir, exist_ok=True)
checkpoint_file = os.path.join(subdir, "checkpoint.pkl")
with open(checkpoint_file, "w") as f:
f.write(json.dumps(checkpoint_data))

if self.return_type == "root":
return tmp_checkpoint_dir
elif self.return_type == "subdir":
return subdir
elif self.return_type == "checkpoint":
return checkpoint_file

def load_checkpoint(self, checkpoint: Union[Dict, str]):
if self.return_type == "object":
assert isinstance(checkpoint, dict)
checkpoint_data = checkpoint
checkpoint_file = None
elif self.return_type == "root":
assert "subdir" not in checkpoint
checkpoint_file = os.path.join(checkpoint, "subdir", "checkpoint.pkl")
elif self.return_type == "subdir":
assert "subdir" in checkpoint
assert "checkpoint.pkl" not in checkpoint
checkpoint_file = os.path.join(checkpoint, "checkpoint.pkl")
else: # self.return_type == "checkpoint"
assert checkpoint.endswith("subdir/checkpoint.pkl")
checkpoint_file = checkpoint

if checkpoint_file:
with open(checkpoint_file, "rb") as f:
checkpoint_data = json.load(f)

assert checkpoint_data == {"data": 1}, checkpoint_data


def function_trainable_dict(config):
session.report(
{"metric": 2}, checkpoint=Checkpoint.from_dict({"checkpoint_data": 3})
)


def function_trainable_directory(config):
tmpdir = tempfile.mkdtemp("checkpoint_test")
with open(os.path.join(tmpdir, "data.json"), "w") as f:
json.dump({"checkpoint_data": 5}, f)
session.report({"metric": 4}, checkpoint=Checkpoint.from_directory(tmpdir))


@pytest.mark.parametrize("return_type", ["object", "root", "subdir", "checkpoint"])
def test_save_load_checkpoint_path_class(ray_start_2_cpus, return_type):
trainable = ray.remote(SavingTrainable).remote(return_type=return_type)

saving_future = trainable.save.remote()

# Check for errors
ray.get(saving_future)

restoring_future = trainable.restore.remote(saving_future)

ray.get(restoring_future)


@pytest.mark.parametrize("return_type", ["object", "root", "subdir", "checkpoint"])
def test_save_load_checkpoint_object_class(ray_start_2_cpus, return_type):
trainable = ray.remote(SavingTrainable).remote(return_type=return_type)

saving_future = trainable.save_to_object.remote()

# Check for errors
ray.get(saving_future)

restoring_future = trainable.restore_from_object.remote(saving_future)

ray.get(restoring_future)


@pytest.mark.parametrize(
"fn_trainable", [function_trainable_dict, function_trainable_directory]
)
def test_save_load_checkpoint_path_fn(ray_start_2_cpus, fn_trainable):
trainable_cls = wrap_function(fn_trainable)
trainable = ray.remote(trainable_cls).remote()
ray.get(trainable.train.remote())

saving_future = trainable.save.remote()

# Check for errors
ray.get(saving_future)

restoring_future = trainable.restore.remote(saving_future)

ray.get(restoring_future)


@pytest.mark.parametrize(
"fn_trainable", [function_trainable_dict, function_trainable_directory]
)
def test_save_load_checkpoint_object_fn(ray_start_2_cpus, fn_trainable):
trainable_cls = wrap_function(fn_trainable)
trainable = ray.remote(trainable_cls).remote()
ray.get(trainable.train.remote())

saving_future = trainable.save_to_object.remote()

# Check for errors
ray.get(saving_future)

restoring_future = trainable.restore_from_object.remote(saving_future)

ray.get(restoring_future)


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))
4 changes: 2 additions & 2 deletions python/ray/tune/tests/test_trial_runner_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def testCheckpointing(self):
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1)
self.addCleanup(os.remove, trials[0].checkpoint.dir_or_data)
self.addCleanup(shutil.rmtree, trials[0].checkpoint.dir_or_data)

def testRestoreMetricsAfterCheckpointing(self):
ray.init(num_cpus=1, num_gpus=1)
Expand Down Expand Up @@ -263,7 +263,7 @@ def testRestoreMetricsAfterCheckpointing(self):
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20)
self.assertEqual(trials[1].last_result["iterations_since_restore"], 2)
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
self.addCleanup(os.remove, trials[0].checkpoint.dir_or_data)
self.addCleanup(shutil.rmtree, trials[0].checkpoint.dir_or_data)

def testCheckpointingAtEnd(self):
ray.init(num_cpus=1, num_gpus=1)
Expand Down
Loading

0 comments on commit e1a7efe

Please sign in to comment.