Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Initial track integration #4362

Merged
merged 32 commits into from
May 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5d936d4
initial track integration
richardliaw Mar 14, 2019
f991f21
initial pull from track repo
noahgolmant Mar 14, 2019
7431e4b
cut extraneous sync/log/project code
noahgolmant Mar 14, 2019
982600b
small_cleanup
richardliaw Mar 22, 2019
ee34398
Session
richardliaw Mar 31, 2019
eb67943
nit
richardliaw Mar 31, 2019
e7c0e22
Merge branch 'master' into tune_track_integration
richardliaw Mar 31, 2019
c56f7e9
nit
richardliaw Apr 2, 2019
792f1de
remove git
noahgolmant May 4, 2019
581a57d
Integration for functionrunner
richardliaw May 4, 2019
dd69ed6
Merge branch 'master' into tune_track_integration
richardliaw May 4, 2019
c7e1579
use unifiedlogger for json data; save/load gone
noahgolmant May 4, 2019
a7d023a
Merge branch 'tune_track_integration' of https://github.com/noahgolma…
noahgolmant May 4, 2019
6f7ba56
fix to use tune unified logger; add initial test cases
noahgolmant May 4, 2019
6cec7cc
formatting
richardliaw May 10, 2019
2cade6d
Enums
richardliaw May 10, 2019
6d897b8
Reformat tracking
richardliaw May 10, 2019
bdd01ff
full cleanup
richardliaw May 10, 2019
8b678e5
lint
richardliaw May 10, 2019
1c1ee5f
Fix up tests
richardliaw May 10, 2019
e737a33
some formatting
richardliaw May 10, 2019
28d0283
Param, fix up metric test
richardliaw May 11, 2019
923284f
Merge branch 'master' into tune_track_integration
richardliaw May 11, 2019
8037f58
fix up for example
richardliaw May 11, 2019
f52c3f3
Fix up example and test
richardliaw May 11, 2019
b81f6f8
Cleanup
richardliaw May 13, 2019
350188e
lint
richardliaw May 16, 2019
41075fd
localdir
richardliaw May 16, 2019
9ce0403
fix
richardliaw May 16, 2019
732fd12
comments
richardliaw May 16, 2019
38444ae
safer track inspection
richardliaw May 17, 2019
9e13309
lint
richardliaw May 17, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/tune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
__all__ = [
"Trainable", "TuneError", "grid_search", "register_env",
"register_trainable", "run", "run_experiments", "Experiment", "function",
"sample_from", "uniform", "choice", "randint", "randn"
"sample_from", "track", "uniform", "choice", "randint", "randn"
]
6 changes: 3 additions & 3 deletions python/ray/tune/automlboard/backend/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ray.tune.automlboard.models.models import JobRecord, \
TrialRecord, ResultRecord
from ray.tune.result import DEFAULT_RESULTS_DIR, JOB_META_FILE, \
EXPR_PARARM_FILE, EXPR_RESULT_FILE, EXPR_META_FILE
EXPR_PARAM_FILE, EXPR_RESULT_FILE, EXPR_META_FILE


class CollectorService(object):
Expand Down Expand Up @@ -327,7 +327,7 @@ def _build_trial_meta(cls, expr_dir):
if not meta:
job_id = expr_dir.split("/")[-2]
trial_id = expr_dir[-8:]
params = parse_json(os.path.join(expr_dir, EXPR_PARARM_FILE))
params = parse_json(os.path.join(expr_dir, EXPR_PARAM_FILE))
meta = {
"trial_id": trial_id,
"job_id": job_id,
Expand All @@ -349,7 +349,7 @@ def _build_trial_meta(cls, expr_dir):
if meta.get("end_time", None):
meta["end_time"] = timestamp2date(meta["end_time"])

meta["params"] = parse_json(os.path.join(expr_dir, EXPR_PARARM_FILE))
meta["params"] = parse_json(os.path.join(expr_dir, EXPR_PARAM_FILE))

return meta

Expand Down
71 changes: 71 additions & 0 deletions python/ray/tune/examples/track_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D)

from ray.tune import track
from ray.tune.examples.utils import TuneKerasCallback, get_mnist_data

parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
parser.add_argument(
"--lr",
type=float,
default=0.01,
metavar="LR",
help="learning rate (default: 0.01)")
parser.add_argument(
"--momentum",
type=float,
default=0.5,
metavar="M",
help="SGD momentum (default: 0.5)")
parser.add_argument(
"--hidden", type=int, default=64, help="Size of hidden layer.")
args, _ = parser.parse_known_args()


def train_mnist(args):
track.init(trial_name="track-example", trial_config=vars(args))
batch_size = 128
num_classes = 10
epochs = 1 if args.smoke_test else 12
mnist.load()
x_train, y_train, x_test, y_test, input_shape = get_mnist_data()

model = Sequential()
model.add(
Conv2D(
32, kernel_size=(3, 3), activation="relu",
input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation="relu"))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.5))
model.add(Flatten())
model.add(Dense(args.hidden, activation="relu"))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation="softmax"))

model.compile(
loss="categorical_crossentropy",
optimizer=keras.optimizers.SGD(lr=args.lr, momentum=args.momentum),
metrics=["accuracy"])

model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
callbacks=[TuneKerasCallback(track.metric)])
track.shutdown()


if __name__ == "__main__":
train_mnist(args)
4 changes: 3 additions & 1 deletion python/ray/tune/examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def __init__(self, reporter, logs={}):

def on_train_end(self, epoch, logs={}):
self.reporter(
timesteps_total=self.iteration, done=1, mean_accuracy=logs["acc"])
timesteps_total=self.iteration,
done=1,
mean_accuracy=logs.get("acc"))

def on_batch_end(self, batch, logs={}):
self.iteration += 1
Expand Down
23 changes: 22 additions & 1 deletion python/ray/tune/function_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import logging
import sys
import time
import inspect
import threading
from six.moves import queue

from ray.tune import track
from ray.tune import TuneError
from ray.tune.trainable import Trainable
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
Expand Down Expand Up @@ -244,6 +246,17 @@ def _report_thread_runner_error(self, block=False):


def wrap_function(train_func):

use_track = False
try:
func_args = inspect.getargspec(train_func).args
use_track = ("reporter" not in func_args and len(func_args) == 1)
if use_track:
logger.info("tune.track signature detected.")
except Exception:
logger.info(
"Function inspection failed - assuming reporter signature.")

class WrappedFunc(FunctionRunner):
def _trainable_func(self, config, reporter):
output = train_func(config, reporter)
Expand All @@ -253,4 +266,12 @@ def _trainable_func(self, config, reporter):
reporter(**{RESULT_DUPLICATE: True})
return output

return WrappedFunc
class WrappedTrackFunc(FunctionRunner):
def _trainable_func(self, config, reporter):
track.init(_tune_reporter=reporter)
output = train_func(config)
reporter(**{RESULT_DUPLICATE: True})
track.shutdown()
return output

return WrappedTrackFunc if use_track else WrappedFunc
30 changes: 19 additions & 11 deletions python/ray/tune/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def on_result(self, result):

raise NotImplementedError

def update_config(self, config):
"""Updates the config for all loggers."""

pass

def close(self):
"""Releases all resources used by this logger."""

Expand All @@ -68,17 +73,7 @@ def on_result(self, result):

class JsonLogger(Logger):
def _init(self):
config_out = os.path.join(self.logdir, "params.json")
with open(config_out, "w") as f:
json.dump(
self.config,
f,
indent=2,
sort_keys=True,
cls=_SafeFallbackEncoder)
config_pkl = os.path.join(self.logdir, "params.pkl")
with open(config_pkl, "wb") as f:
cloudpickle.dump(self.config, f)
self.update_config(self.config)
local_file = os.path.join(self.logdir, "result.json")
self.local_out = open(local_file, "a")

Expand All @@ -96,6 +91,15 @@ def flush(self):
def close(self):
self.local_out.close()

def update_config(self, config):
self.config = config
config_out = os.path.join(self.logdir, "params.json")
with open(config_out, "w") as f:
json.dump(self.config, f, cls=_SafeFallbackEncoder)
config_pkl = os.path.join(self.logdir, "params.pkl")
with open(config_pkl, "wb") as f:
cloudpickle.dump(self.config, f)


def to_tf_values(result, path):
values = []
Expand Down Expand Up @@ -231,6 +235,10 @@ def on_result(self, result):
self._log_syncer.set_worker_ip(result.get(NODE_IP))
self._log_syncer.sync_if_needed()

def update_config(self, config):
for _logger in self._loggers:
_logger.update_config(config)

def close(self):
for _logger in self._loggers:
_logger.close()
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
EXPR_META_FILE = "trial_status.json"

# File that stores parameters of the trial.
EXPR_PARARM_FILE = "params.json"
EXPR_PARAM_FILE = "params.json"

# File that stores the progress of the trial.
EXPR_PROGRESS_FILE = "progress.csv"
Expand Down
84 changes: 84 additions & 0 deletions python/ray/tune/tests/test_track.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import pandas as pd
import unittest

import ray
from ray import tune
from ray.tune import track
from ray.tune.result import EXPR_PARAM_FILE, EXPR_RESULT_FILE


def _check_json_val(fname, key, val):
with open(fname, "r") as f:
df = pd.read_json(f, typ="frame", lines=True)
return key in df.columns and (df[key].tail(n=1) == val).all()


class TrackApiTest(unittest.TestCase):
def tearDown(self):
track.shutdown()
ray.shutdown()

def testSessionInitShutdown(self):
self.assertTrue(track._session is None)

# Checks that the singleton _session is created/destroyed
# by track.init() and track.shutdown()
for _ in range(2):
# do it twice to see that we can reopen the session
track.init(trial_name="test_init")
self.assertTrue(track._session is not None)
track.shutdown()
self.assertTrue(track._session is None)

def testLogCreation(self):
"""Checks that track.init() starts logger and creates log files."""
track.init(trial_name="test_init")
session = track.get_session()
self.assertTrue(session is not None)

self.assertTrue(os.path.isdir(session.logdir))

params_path = os.path.join(session.logdir, EXPR_PARAM_FILE)
result_path = os.path.join(session.logdir, EXPR_RESULT_FILE)

self.assertTrue(os.path.exists(params_path))
self.assertTrue(os.path.exists(result_path))
self.assertTrue(session.logdir == track.trial_dir())

def testMetric(self):
track.init(trial_name="test_log")
session = track.get_session()
for i in range(5):
track.log(test=i)
result_path = os.path.join(session.logdir, EXPR_RESULT_FILE)
self.assertTrue(_check_json_val(result_path, "test", i))

def testRayOutput(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main usage showcase of track -> ray integration

"""Checks that local and remote log format are the same."""
ray.init()

def testme(config):
for i in range(config["iters"]):
track.log(iteration=i, hi="test")

trials = tune.run(testme, config={"iters": 5})
trial_res = trials[0].last_result
self.assertTrue(trial_res["hi"], "test")
self.assertTrue(trial_res["training_iteration"], 5)

def testLocalMetrics(self):
"""Checks that metric state is updated correctly."""
track.init(trial_name="test_logs")
session = track.get_session()
self.assertEqual(set(session.trial_config.keys()), {"trial_id"})

result_path = os.path.join(session.logdir, EXPR_RESULT_FILE)
track.log(test=1)
self.assertTrue(_check_json_val(result_path, "test", 1))
track.log(iteration=1, test=2)
self.assertTrue(_check_json_val(result_path, "test", 2))
71 changes: 71 additions & 0 deletions python/ray/tune/track/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging

from ray.tune.track.session import TrackSession

logger = logging.getLogger(__name__)

_session = None


def get_session():
global _session
if not _session:
raise ValueError("Session not detected. Try `track.init()`?")
return _session


def init(ignore_reinit_error=True, **session_kwargs):
"""Initializes the global trial context for this process.

This creates a TrackSession object and the corresponding hooks for logging.

Examples:
>>> from ray.tune import track
>>> track.init()
"""
global _session

if _session:
# TODO(ng): would be nice to stack crawl at creation time to report
# where that initial trial was created, and that creation line
# info is helpful to keep around anyway.
reinit_msg = "A session already exists in the current context."
if ignore_reinit_error:
if not _session.is_tune_session:
logger.warning(reinit_msg)
return
else:
raise ValueError(reinit_msg)

_session = TrackSession(**session_kwargs)


def shutdown():
"""Cleans up the trial and removes it from the global context."""

global _session
if _session:
_session.close()
_session = None


def log(**kwargs):
"""Applies TrackSession.log to the trial in the current context."""
_session = get_session()
return _session.log(**kwargs)


def trial_dir():
"""Returns the directory where trial results are saved.

This includes json data containing the session's parameters and metrics.
"""
_session = get_session()
return _session.logdir


__all__ = ["TrackSession", "session", "log", "trial_dir", "init", "shutdown"]
Loading