-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tune] Initial track integration #4362
Merged
richardliaw
merged 32 commits into
ray-project:master
from
noahgolmant:tune_track_integration
May 17, 2019
Merged
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
5d936d4
initial track integration
richardliaw f991f21
initial pull from track repo
noahgolmant 7431e4b
cut extraneous sync/log/project code
noahgolmant 982600b
small_cleanup
richardliaw ee34398
Session
richardliaw eb67943
nit
richardliaw e7c0e22
Merge branch 'master' into tune_track_integration
richardliaw c56f7e9
nit
richardliaw 792f1de
remove git
noahgolmant 581a57d
Integration for functionrunner
richardliaw dd69ed6
Merge branch 'master' into tune_track_integration
richardliaw c7e1579
use unifiedlogger for json data; save/load gone
noahgolmant a7d023a
Merge branch 'tune_track_integration' of https://github.com/noahgolma…
noahgolmant 6f7ba56
fix to use tune unified logger; add initial test cases
noahgolmant 6cec7cc
formatting
richardliaw 2cade6d
Enums
richardliaw 6d897b8
Reformat tracking
richardliaw bdd01ff
full cleanup
richardliaw 8b678e5
lint
richardliaw 1c1ee5f
Fix up tests
richardliaw e737a33
some formatting
richardliaw 28d0283
Param, fix up metric test
richardliaw 923284f
Merge branch 'master' into tune_track_integration
richardliaw 8037f58
fix up for example
richardliaw f52c3f3
Fix up example and test
richardliaw b81f6f8
Cleanup
richardliaw 350188e
lint
richardliaw 41075fd
localdir
richardliaw 9ce0403
fix
richardliaw 732fd12
comments
richardliaw 38444ae
safer track inspection
richardliaw 9e13309
lint
richardliaw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import keras | ||
from keras.datasets import mnist | ||
from keras.models import Sequential | ||
from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D) | ||
|
||
from ray.tune import track | ||
from ray.tune.examples.utils import TuneKerasCallback, get_mnist_data | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--smoke-test", action="store_true", help="Finish quickly for testing") | ||
parser.add_argument( | ||
"--lr", | ||
type=float, | ||
default=0.01, | ||
metavar="LR", | ||
help="learning rate (default: 0.01)") | ||
parser.add_argument( | ||
"--momentum", | ||
type=float, | ||
default=0.5, | ||
metavar="M", | ||
help="SGD momentum (default: 0.5)") | ||
parser.add_argument( | ||
"--hidden", type=int, default=64, help="Size of hidden layer.") | ||
args, _ = parser.parse_known_args() | ||
|
||
|
||
def train_mnist(args): | ||
track.init(trial_name="track-example", trial_config=vars(args)) | ||
batch_size = 128 | ||
num_classes = 10 | ||
epochs = 1 if args.smoke_test else 12 | ||
mnist.load() | ||
x_train, y_train, x_test, y_test, input_shape = get_mnist_data() | ||
|
||
model = Sequential() | ||
model.add( | ||
Conv2D( | ||
32, kernel_size=(3, 3), activation="relu", | ||
input_shape=input_shape)) | ||
model.add(Conv2D(64, (3, 3), activation="relu")) | ||
model.add(MaxPooling2D(pool_size=(2, 2))) | ||
model.add(Dropout(0.5)) | ||
model.add(Flatten()) | ||
model.add(Dense(args.hidden, activation="relu")) | ||
model.add(Dropout(0.5)) | ||
model.add(Dense(num_classes, activation="softmax")) | ||
|
||
model.compile( | ||
loss="categorical_crossentropy", | ||
optimizer=keras.optimizers.SGD(lr=args.lr, momentum=args.momentum), | ||
metrics=["accuracy"]) | ||
|
||
model.fit( | ||
x_train, | ||
y_train, | ||
batch_size=batch_size, | ||
epochs=epochs, | ||
validation_data=(x_test, y_test), | ||
callbacks=[TuneKerasCallback(track.metric)]) | ||
track.shutdown() | ||
|
||
|
||
if __name__ == "__main__": | ||
train_mnist(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
import pandas as pd | ||
import unittest | ||
|
||
import ray | ||
from ray import tune | ||
from ray.tune import track | ||
from ray.tune.result import EXPR_PARAM_FILE, EXPR_RESULT_FILE | ||
|
||
|
||
def _check_json_val(fname, key, val): | ||
with open(fname, "r") as f: | ||
df = pd.read_json(f, typ="frame", lines=True) | ||
return key in df.columns and (df[key].tail(n=1) == val).all() | ||
|
||
|
||
class TrackApiTest(unittest.TestCase): | ||
def tearDown(self): | ||
track.shutdown() | ||
ray.shutdown() | ||
|
||
def testSessionInitShutdown(self): | ||
self.assertTrue(track._session is None) | ||
|
||
# Checks that the singleton _session is created/destroyed | ||
# by track.init() and track.shutdown() | ||
for _ in range(2): | ||
# do it twice to see that we can reopen the session | ||
track.init(trial_name="test_init") | ||
self.assertTrue(track._session is not None) | ||
track.shutdown() | ||
self.assertTrue(track._session is None) | ||
|
||
def testLogCreation(self): | ||
"""Checks that track.init() starts logger and creates log files.""" | ||
track.init(trial_name="test_init") | ||
session = track.get_session() | ||
self.assertTrue(session is not None) | ||
|
||
self.assertTrue(os.path.isdir(session.logdir)) | ||
|
||
params_path = os.path.join(session.logdir, EXPR_PARAM_FILE) | ||
result_path = os.path.join(session.logdir, EXPR_RESULT_FILE) | ||
|
||
self.assertTrue(os.path.exists(params_path)) | ||
self.assertTrue(os.path.exists(result_path)) | ||
self.assertTrue(session.logdir == track.trial_dir()) | ||
|
||
def testMetric(self): | ||
track.init(trial_name="test_log") | ||
session = track.get_session() | ||
for i in range(5): | ||
track.log(test=i) | ||
result_path = os.path.join(session.logdir, EXPR_RESULT_FILE) | ||
self.assertTrue(_check_json_val(result_path, "test", i)) | ||
|
||
def testRayOutput(self): | ||
"""Checks that local and remote log format are the same.""" | ||
ray.init() | ||
|
||
def testme(config): | ||
for i in range(config["iters"]): | ||
track.log(iteration=i, hi="test") | ||
|
||
trials = tune.run(testme, config={"iters": 5}) | ||
trial_res = trials[0].last_result | ||
self.assertTrue(trial_res["hi"], "test") | ||
self.assertTrue(trial_res["training_iteration"], 5) | ||
|
||
def testLocalMetrics(self): | ||
"""Checks that metric state is updated correctly.""" | ||
track.init(trial_name="test_logs") | ||
session = track.get_session() | ||
self.assertEqual(set(session.trial_config.keys()), {"trial_id"}) | ||
|
||
result_path = os.path.join(session.logdir, EXPR_RESULT_FILE) | ||
track.log(test=1) | ||
self.assertTrue(_check_json_val(result_path, "test", 1)) | ||
track.log(iteration=1, test=2) | ||
self.assertTrue(_check_json_val(result_path, "test", 2)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import logging | ||
|
||
from ray.tune.track.session import TrackSession | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
_session = None | ||
|
||
|
||
def get_session(): | ||
global _session | ||
if not _session: | ||
raise ValueError("Session not detected. Try `track.init()`?") | ||
return _session | ||
|
||
|
||
def init(ignore_reinit_error=True, **session_kwargs): | ||
"""Initializes the global trial context for this process. | ||
|
||
This creates a TrackSession object and the corresponding hooks for logging. | ||
|
||
Examples: | ||
>>> from ray.tune import track | ||
>>> track.init() | ||
""" | ||
global _session | ||
|
||
if _session: | ||
# TODO(ng): would be nice to stack crawl at creation time to report | ||
# where that initial trial was created, and that creation line | ||
# info is helpful to keep around anyway. | ||
reinit_msg = "A session already exists in the current context." | ||
if ignore_reinit_error: | ||
if not _session.is_tune_session: | ||
logger.warning(reinit_msg) | ||
return | ||
else: | ||
raise ValueError(reinit_msg) | ||
|
||
_session = TrackSession(**session_kwargs) | ||
|
||
|
||
def shutdown(): | ||
"""Cleans up the trial and removes it from the global context.""" | ||
|
||
global _session | ||
if _session: | ||
_session.close() | ||
_session = None | ||
|
||
|
||
def log(**kwargs): | ||
"""Applies TrackSession.log to the trial in the current context.""" | ||
_session = get_session() | ||
return _session.log(**kwargs) | ||
|
||
|
||
def trial_dir(): | ||
"""Returns the directory where trial results are saved. | ||
|
||
This includes json data containing the session's parameters and metrics. | ||
""" | ||
_session = get_session() | ||
return _session.logdir | ||
|
||
|
||
__all__ = ["TrackSession", "session", "log", "trial_dir", "init", "shutdown"] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the main usage showcase of track -> ray integration