Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Better handling of tune.function in global checkpoint #4519

Merged
merged 4 commits into from
Apr 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions python/ray/tune/suggest/variant_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class sample_from(object):
"""Specify that tune should sample configuration values from this function.

The use of function arguments in tune configs must be disambiguated by
either wrapped the function in tune.eval() or tune.function().
either wrapped the function in tune.sample_from() or tune.function().

Arguments:
func: An callable function to draw a sample from.
Expand All @@ -67,12 +67,18 @@ class sample_from(object):
def __init__(self, func):
self.func = func

def __str__(self):
return "tune.sample_from({})".format(str(self.func))

def __repr__(self):
return "tune.sample_from({})".format(repr(self.func))


class function(object):
"""Wraps `func` to make sure it is not expanded during resolution.

The use of function arguments in tune configs must be disambiguated by
either wrapped the function in tune.eval() or tune.function().
either wrapped the function in tune.sample_from() or tune.function().

Arguments:
func: A function literal.
Expand All @@ -84,6 +90,12 @@ def __init__(self, func):
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

def __str__(self):
return "tune.function({})".format(str(self.func))

def __repr__(self):
return "tune.function({})".format(repr(self.func))


_STANDARD_IMPORTS = {
"random": random,
Expand Down
2 changes: 0 additions & 2 deletions python/ray/tune/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,8 @@ def __init__(self,

self._nonjson_fields = [
"_checkpoint",
"config",
"loggers",
"sync_function",
"last_result",
"results",
"best_result",
"param_config",
Expand Down
28 changes: 26 additions & 2 deletions python/ray/tune/trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
import time
import traceback

import ray.cloudpickle as cloudpickle
from ray.tune import TuneError
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
from ray.tune.trial import Trial, Checkpoint
from ray.tune.suggest import function
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.util import warn_if_slow
from ray.utils import binary_to_hex, hex_to_binary
from ray.tune.web_server import TuneServer

MAX_DEBUG_TRIALS = 20
Expand All @@ -39,6 +42,27 @@ def _find_newest_ckpt(ckpt_dir):
return max(full_paths)


class _TuneFunctionEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, function):
return {
"_type": "function",
"value": binary_to_hex(cloudpickle.dumps(obj))
}
return super(_TuneFunctionEncoder, self).default(obj)


class _TuneFunctionDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(
self, object_hook=self.object_hook, *args, **kwargs)

def object_hook(self, obj):
if obj.get("_type") == "function":
return cloudpickle.loads(hex_to_binary(obj["value"]))
return obj


class TrialRunner(object):
"""A TrialRunner implements the event loop for scheduling trials on Ray.

Expand Down Expand Up @@ -150,7 +174,7 @@ def checkpoint(self):
tmp_file_name = os.path.join(metadata_checkpoint_dir,
".tmp_checkpoint")
with open(tmp_file_name, "w") as f:
json.dump(runner_state, f, indent=2)
json.dump(runner_state, f, indent=2, cls=_TuneFunctionEncoder)

os.rename(
tmp_file_name,
Expand Down Expand Up @@ -183,7 +207,7 @@ def restore(cls,

newest_ckpt_path = _find_newest_ckpt(metadata_checkpoint_dir)
with open(newest_ckpt_path, "r") as f:
runner_state = json.load(f)
runner_state = json.load(f, cls=_TuneFunctionDecoder)

logger.warning("".join([
"Attempting to resume experiment from {}. ".format(
Expand Down