Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Initial track integration #4362

Merged
merged 32 commits into from
May 17, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5d936d4
initial track integration
richardliaw Mar 14, 2019
f991f21
initial pull from track repo
noahgolmant Mar 14, 2019
7431e4b
cut extraneous sync/log/project code
noahgolmant Mar 14, 2019
982600b
small_cleanup
richardliaw Mar 22, 2019
ee34398
Session
richardliaw Mar 31, 2019
eb67943
nit
richardliaw Mar 31, 2019
e7c0e22
Merge branch 'master' into tune_track_integration
richardliaw Mar 31, 2019
c56f7e9
nit
richardliaw Apr 2, 2019
792f1de
remove git
noahgolmant May 4, 2019
581a57d
Integration for functionrunner
richardliaw May 4, 2019
dd69ed6
Merge branch 'master' into tune_track_integration
richardliaw May 4, 2019
c7e1579
use unifiedlogger for json data; save/load gone
noahgolmant May 4, 2019
a7d023a
Merge branch 'tune_track_integration' of https://github.com/noahgolma…
noahgolmant May 4, 2019
6f7ba56
fix to use tune unified logger; add initial test cases
noahgolmant May 4, 2019
6cec7cc
formatting
richardliaw May 10, 2019
2cade6d
Enums
richardliaw May 10, 2019
6d897b8
Reformat tracking
richardliaw May 10, 2019
bdd01ff
full cleanup
richardliaw May 10, 2019
8b678e5
lint
richardliaw May 10, 2019
1c1ee5f
Fix up tests
richardliaw May 10, 2019
e737a33
some formatting
richardliaw May 10, 2019
28d0283
Param, fix up metric test
richardliaw May 11, 2019
923284f
Merge branch 'master' into tune_track_integration
richardliaw May 11, 2019
8037f58
fix up for example
richardliaw May 11, 2019
f52c3f3
Fix up example and test
richardliaw May 11, 2019
b81f6f8
Cleanup
richardliaw May 13, 2019
350188e
lint
richardliaw May 16, 2019
41075fd
localdir
richardliaw May 16, 2019
9ce0403
fix
richardliaw May 16, 2019
732fd12
comments
richardliaw May 16, 2019
38444ae
safer track inspection
richardliaw May 17, 2019
9e13309
lint
richardliaw May 17, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions python/ray/tune/track/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pickle

from .trial import Trial
from .project import Project
from .log import debug
from .convenience import absl_flags
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cut absl integration



_trial = None


def init(log_dir=None,
upload_dir=None,
sync_period=None,
trial_prefix="",
param_map=None,
init_logging=True):
"""
Initializes the global trial context for this process.
This creates a Trial object and the corresponding hooks for logging.
"""
global _trial # pylint: disable=global-statement
if _trial:
# TODO: would be nice to stack crawl at creation time to report
# where that initial trial was created, and that creation line
# info is helpful to keep around anyway.
raise ValueError("A trial already exists in the current context")
local_trial = Trial(
log_dir=log_dir,
upload_dir=upload_dir,
sync_period=sync_period,
trial_prefix=trial_prefix,
param_map=param_map,
init_logging=True)
# try:
_trial = local_trial
_trial.start()


def shutdown():
"""
Cleans up the trial and removes it from the global context.
"""
global _trial # pylint: disable=global-statement
if not _trial:
raise ValueError("Tried to stop trial, but no trial exists")
_trial.close()
_trial = None


def save(obj, obj_name, iteration=None, save_fn=pickle.dump, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write save/load/artifact test cases

""" Applies Trial.save to the trial in the current context """
return _trial.save(obj=obj, obj_name=obj_name, iteration=iteration,
save_fn=save_fn, **kwargs)


def metric(*, iteration=None, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make python2-compatible

"""Applies Trial.metric to the trial in the current context."""
return _trial.metric(iteration=iteration, **kwargs)


def load(obj_name, iteration=None, load_fn=pickle.load, **kwargs):
"""Applies Trial.load to the trial in the current context."""
return _trial.load(obj_name=obj_name, iteration=iteration,
load_fn=load_fn, **kwargs)


def trial_dir():
"""Retrieves the trial directory for the trial in the current context."""
return _trial.trial_dir()


__all__ = ["Trial", "Project", "trial", "absl_flags", "debug", "metric",
"save", "load", "trial_dir"]
51 changes: 51 additions & 0 deletions python/ray/tune/track/autodetect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cut for now, should be modularized from "track.trial"

Hacky, library and usage specific tricks to infer decent defaults.
"""
import os
import subprocess
import sys
import shlex


def git_repo():
"""
Returns the git repository root if the cwd is in a repo, else None
"""
try:
with open(os.devnull, 'wb') as quiet:
reldir = subprocess.check_output(
["git", "rev-parse", "--git-dir"],
stdout=quiet)
reldir = reldir.decode("utf-8")
return os.path.basename(os.path.dirname(os.path.abspath(reldir)))
except subprocess.CalledProcessError:
return None


def git_hash():
"""returns the current git hash or unknown if not in git repo"""
if git_repo() is None:
return "unknown"
git_hash = subprocess.check_output(
["git", "rev-parse", "HEAD"])
# git_hash is a byte string; we want a string.
git_hash = git_hash.decode("utf-8")
# git_hash also comes with an extra \n at the end, which we remove.
git_hash = git_hash.strip()
return git_hash

def git_pretty():
"""returns a pretty summary of the commit or unkown if not in git repo"""
if git_repo() is None:
return "unknown"
pretty = subprocess.check_output(
["git", "log", "--pretty=format:%h %s", "-n", "1"])
pretty = pretty.decode("utf-8")
pretty = pretty.strip()
return pretty

def invocation():
"""reconstructs the invocation for this python program"""
cmdargs = [sys.executable] + sys.argv[:]
invocation = " ".join(shlex.quote(s) for s in cmdargs)
return invocation
23 changes: 23 additions & 0 deletions python/ray/tune/track/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Both locally and remotely the directory structure is as follows:

project-directory/
METADATA_FOLDER/
trialprefix_uuid_param_map.json
trialprefix_uuid_result.json
... other trials
trailprefix_uuid/
... trial artifacts
... other trial artifact folders

Where the param map is a single json containing the trial uuid
and configuration parameters and, the result.json is a json
list file (i.e., not valid json, but valid json on each line),
and the artifacts folder contains the artifacts as supplied by
the user.
"""
import os

METADATA_FOLDER = "trials"
CONFIG_SUFFIX = "param_map.json"
RESULT_SUFFIX = "result.json"
31 changes: 31 additions & 0 deletions python/ray/tune/track/convenience.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cut

Miscellaneous helpers for getting some of the arguments to tracking-related
functions automatically, usually involving parameter extraction in a
sensible default way from commonly used libraries.
"""

import sys
from absl import flags

def absl_flags():
"""
Extracts absl-py flags that the user has specified and outputs their
key-value mapping.

By default, extracts only those flags in the current __package__
and mainfile.

Useful to put into a trial's param_map.
"""
# TODO: need same thing for argparse
flags_dict = flags.FLAGS.flags_by_module_dict()
# only include parameters from modules the user probably cares about
def _relevant_module(module_name):
if __package__ and __package__ in module_name:
return True
if module_name == sys.argv[0]:
return True
return False
return {
flag.name: flag.value for module, flags in flags_dict.items()
for flag in flags if _relevant_module(module)}
8 changes: 8 additions & 0 deletions python/ray/tune/track/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import absolute_import
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cut

from __future__ import division
from __future__ import print_function


class TrackError(Exception):
"""General error class raised by Track"""
pass
107 changes: 107 additions & 0 deletions python/ray/tune/track/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cut (for now)

File adopted from Michael Whittaker.

This module firstly offers a convenience function called "debug" which
alleviates a couple of inconveniences in python logging:

* No need to find a logger before logging (uses the one from this package)
* Slightly friendly string interpolation interface.
"""

from datetime import datetime
import inspect
import hashlib
import subprocess
import os
import shlex
import json
import sys
import logging

class TrackLogHandler(logging.FileHandler):
"""File-based logging handler for the track package"""
pass

class StdoutHandler(logging.StreamHandler):
"""As described by the name"""
def __init__(self):
super().__init__(sys.stdout)

def init(track_log_handler):
"""
(Re)initialize track's file handler for track package logger.

Adds a stdout-printing handler automatically.
"""

logger = logging.getLogger(__package__)

# TODO (just document prominently)
# assume only one trial can run at once right now
# multi-concurrent-trial support will require complex filter logic
# based on the currently-running trial (maybe we shouldn't allow multiple
# trials on different python threads, that's dumb)
to_rm = [h for h in logger.handlers if isinstance(h, TrackLogHandler)]
for h in to_rm:
logger.removeHandler(h)

if not any(isinstance(h, StdoutHandler) for h in logger.handlers):
handler = StdoutHandler()
handler.setFormatter(_FORMATTER)
logger.addHandler(handler)

track_log_handler.setFormatter(_FORMATTER)
logger.addHandler(track_log_handler)

logger.propagate = False
logger.setLevel(logging.DEBUG)

def debug(s, *args):
"""debug(s, x1, ..., xn) logs s.format(x1, ..., xn)."""
# Get the path name and line number of the function which called us.
previous_frame = inspect.currentframe().f_back
try:
pathname, lineno, _, _, _ = inspect.getframeinfo(previous_frame)
# if path is in cwd, simplify it
cwd = os.path.abspath(os.getcwd())
pathname = os.path.abspath(pathname)
if os.path.commonprefix([cwd, pathname]) == cwd:
pathname = os.path.relpath(pathname, cwd)
except Exception: # pylint: disable=broad-except
pathname = '<UNKNOWN-FILE>.py'
lineno = 0
if _FORMATTER: # log could have not been initialized.
_FORMATTER.pathname = pathname
_FORMATTER.lineno = lineno
logger = logging.getLogger(__package__)
logger.debug(s.format(*args))

class _StackCrawlingFormatter(logging.Formatter):
"""
If we configure a python logger with the format string
"%(pathname):%(lineno): %(message)", messages logged via `log.debug` will
be prefixed with the path name and line number of the code that called
`log.debug`. Unfortunately, when a `log.debug` call is wrapped in a helper
function (e.g. debug below), the path name and line number is always that
of the helper function, not the function which called the helper function.

A _StackCrawlingFormatter is a hack to log a different pathname and line
number. Simply set the `pathname` and `lineno` attributes of the formatter
before you call `log.debug`. See `debug` below for an example.
"""

def __init__(self, format_str):
super().__init__(format_str)
self.pathname = None
self.lineno = None

def format(self, record):
s = super().format(record)
if self.pathname is not None:
s = s.replace('{pathname}', self.pathname)
if self.lineno is not None:
s = s.replace('{lineno}', str(self.lineno))
return s

_FORMAT_STRING = "[%(asctime)-15s {pathname}:{lineno}] %(message)s"
_FORMATTER = _StackCrawlingFormatter(_FORMAT_STRING)
Loading