-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tune] Initial track integration #4362
Changes from 2 commits
5d936d4
f991f21
7431e4b
982600b
ee34398
eb67943
e7c0e22
c56f7e9
792f1de
581a57d
dd69ed6
c7e1579
a7d023a
6f7ba56
6cec7cc
2cade6d
6d897b8
bdd01ff
8b678e5
1c1ee5f
e737a33
28d0283
923284f
8037f58
f52c3f3
b81f6f8
350188e
41075fd
9ce0403
732fd12
38444ae
9e13309
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import pickle | ||
|
||
from .trial import Trial | ||
from .project import Project | ||
from .log import debug | ||
from .convenience import absl_flags | ||
|
||
|
||
_trial = None | ||
|
||
|
||
def init(log_dir=None, | ||
upload_dir=None, | ||
sync_period=None, | ||
trial_prefix="", | ||
param_map=None, | ||
init_logging=True): | ||
""" | ||
Initializes the global trial context for this process. | ||
This creates a Trial object and the corresponding hooks for logging. | ||
""" | ||
global _trial # pylint: disable=global-statement | ||
if _trial: | ||
# TODO: would be nice to stack crawl at creation time to report | ||
# where that initial trial was created, and that creation line | ||
# info is helpful to keep around anyway. | ||
raise ValueError("A trial already exists in the current context") | ||
local_trial = Trial( | ||
log_dir=log_dir, | ||
upload_dir=upload_dir, | ||
sync_period=sync_period, | ||
trial_prefix=trial_prefix, | ||
param_map=param_map, | ||
init_logging=True) | ||
# try: | ||
_trial = local_trial | ||
_trial.start() | ||
|
||
|
||
def shutdown(): | ||
""" | ||
Cleans up the trial and removes it from the global context. | ||
""" | ||
global _trial # pylint: disable=global-statement | ||
if not _trial: | ||
raise ValueError("Tried to stop trial, but no trial exists") | ||
_trial.close() | ||
_trial = None | ||
|
||
|
||
def save(obj, obj_name, iteration=None, save_fn=pickle.dump, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. write save/load/artifact test cases |
||
""" Applies Trial.save to the trial in the current context """ | ||
return _trial.save(obj=obj, obj_name=obj_name, iteration=iteration, | ||
save_fn=save_fn, **kwargs) | ||
|
||
|
||
def metric(*, iteration=None, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make python2-compatible |
||
"""Applies Trial.metric to the trial in the current context.""" | ||
return _trial.metric(iteration=iteration, **kwargs) | ||
|
||
|
||
def load(obj_name, iteration=None, load_fn=pickle.load, **kwargs): | ||
"""Applies Trial.load to the trial in the current context.""" | ||
return _trial.load(obj_name=obj_name, iteration=iteration, | ||
load_fn=load_fn, **kwargs) | ||
|
||
|
||
def trial_dir(): | ||
"""Retrieves the trial directory for the trial in the current context.""" | ||
return _trial.trial_dir() | ||
|
||
|
||
__all__ = ["Trial", "Project", "trial", "absl_flags", "debug", "metric", | ||
"save", "load", "trial_dir"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cut for now, should be modularized from "track.trial" |
||
Hacky, library and usage specific tricks to infer decent defaults. | ||
""" | ||
import os | ||
import subprocess | ||
import sys | ||
import shlex | ||
|
||
|
||
def git_repo(): | ||
""" | ||
Returns the git repository root if the cwd is in a repo, else None | ||
""" | ||
try: | ||
with open(os.devnull, 'wb') as quiet: | ||
reldir = subprocess.check_output( | ||
["git", "rev-parse", "--git-dir"], | ||
stdout=quiet) | ||
reldir = reldir.decode("utf-8") | ||
return os.path.basename(os.path.dirname(os.path.abspath(reldir))) | ||
except subprocess.CalledProcessError: | ||
return None | ||
|
||
|
||
def git_hash(): | ||
"""returns the current git hash or unknown if not in git repo""" | ||
if git_repo() is None: | ||
return "unknown" | ||
git_hash = subprocess.check_output( | ||
["git", "rev-parse", "HEAD"]) | ||
# git_hash is a byte string; we want a string. | ||
git_hash = git_hash.decode("utf-8") | ||
# git_hash also comes with an extra \n at the end, which we remove. | ||
git_hash = git_hash.strip() | ||
return git_hash | ||
|
||
def git_pretty(): | ||
"""returns a pretty summary of the commit or unkown if not in git repo""" | ||
if git_repo() is None: | ||
return "unknown" | ||
pretty = subprocess.check_output( | ||
["git", "log", "--pretty=format:%h %s", "-n", "1"]) | ||
pretty = pretty.decode("utf-8") | ||
pretty = pretty.strip() | ||
return pretty | ||
|
||
def invocation(): | ||
"""reconstructs the invocation for this python program""" | ||
cmdargs = [sys.executable] + sys.argv[:] | ||
invocation = " ".join(shlex.quote(s) for s in cmdargs) | ||
return invocation |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
""" | ||
Both locally and remotely the directory structure is as follows: | ||
|
||
project-directory/ | ||
METADATA_FOLDER/ | ||
trialprefix_uuid_param_map.json | ||
trialprefix_uuid_result.json | ||
... other trials | ||
trailprefix_uuid/ | ||
... trial artifacts | ||
... other trial artifact folders | ||
|
||
Where the param map is a single json containing the trial uuid | ||
and configuration parameters and, the result.json is a json | ||
list file (i.e., not valid json, but valid json on each line), | ||
and the artifacts folder contains the artifacts as supplied by | ||
the user. | ||
""" | ||
import os | ||
|
||
METADATA_FOLDER = "trials" | ||
CONFIG_SUFFIX = "param_map.json" | ||
RESULT_SUFFIX = "result.json" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cut |
||
Miscellaneous helpers for getting some of the arguments to tracking-related | ||
functions automatically, usually involving parameter extraction in a | ||
sensible default way from commonly used libraries. | ||
""" | ||
|
||
import sys | ||
from absl import flags | ||
|
||
def absl_flags(): | ||
""" | ||
Extracts absl-py flags that the user has specified and outputs their | ||
key-value mapping. | ||
|
||
By default, extracts only those flags in the current __package__ | ||
and mainfile. | ||
|
||
Useful to put into a trial's param_map. | ||
""" | ||
# TODO: need same thing for argparse | ||
flags_dict = flags.FLAGS.flags_by_module_dict() | ||
# only include parameters from modules the user probably cares about | ||
def _relevant_module(module_name): | ||
if __package__ and __package__ in module_name: | ||
return True | ||
if module_name == sys.argv[0]: | ||
return True | ||
return False | ||
return { | ||
flag.name: flag.value for module, flags in flags_dict.items() | ||
for flag in flags if _relevant_module(module)} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from __future__ import absolute_import | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cut |
||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
|
||
class TrackError(Exception): | ||
"""General error class raised by Track""" | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cut (for now) |
||
File adopted from Michael Whittaker. | ||
|
||
This module firstly offers a convenience function called "debug" which | ||
alleviates a couple of inconveniences in python logging: | ||
|
||
* No need to find a logger before logging (uses the one from this package) | ||
* Slightly friendly string interpolation interface. | ||
""" | ||
|
||
from datetime import datetime | ||
import inspect | ||
import hashlib | ||
import subprocess | ||
import os | ||
import shlex | ||
import json | ||
import sys | ||
import logging | ||
|
||
class TrackLogHandler(logging.FileHandler): | ||
"""File-based logging handler for the track package""" | ||
pass | ||
|
||
class StdoutHandler(logging.StreamHandler): | ||
"""As described by the name""" | ||
def __init__(self): | ||
super().__init__(sys.stdout) | ||
|
||
def init(track_log_handler): | ||
""" | ||
(Re)initialize track's file handler for track package logger. | ||
|
||
Adds a stdout-printing handler automatically. | ||
""" | ||
|
||
logger = logging.getLogger(__package__) | ||
|
||
# TODO (just document prominently) | ||
# assume only one trial can run at once right now | ||
# multi-concurrent-trial support will require complex filter logic | ||
# based on the currently-running trial (maybe we shouldn't allow multiple | ||
# trials on different python threads, that's dumb) | ||
to_rm = [h for h in logger.handlers if isinstance(h, TrackLogHandler)] | ||
for h in to_rm: | ||
logger.removeHandler(h) | ||
|
||
if not any(isinstance(h, StdoutHandler) for h in logger.handlers): | ||
handler = StdoutHandler() | ||
handler.setFormatter(_FORMATTER) | ||
logger.addHandler(handler) | ||
|
||
track_log_handler.setFormatter(_FORMATTER) | ||
logger.addHandler(track_log_handler) | ||
|
||
logger.propagate = False | ||
logger.setLevel(logging.DEBUG) | ||
|
||
def debug(s, *args): | ||
"""debug(s, x1, ..., xn) logs s.format(x1, ..., xn).""" | ||
# Get the path name and line number of the function which called us. | ||
previous_frame = inspect.currentframe().f_back | ||
try: | ||
pathname, lineno, _, _, _ = inspect.getframeinfo(previous_frame) | ||
# if path is in cwd, simplify it | ||
cwd = os.path.abspath(os.getcwd()) | ||
pathname = os.path.abspath(pathname) | ||
if os.path.commonprefix([cwd, pathname]) == cwd: | ||
pathname = os.path.relpath(pathname, cwd) | ||
except Exception: # pylint: disable=broad-except | ||
pathname = '<UNKNOWN-FILE>.py' | ||
lineno = 0 | ||
if _FORMATTER: # log could have not been initialized. | ||
_FORMATTER.pathname = pathname | ||
_FORMATTER.lineno = lineno | ||
logger = logging.getLogger(__package__) | ||
logger.debug(s.format(*args)) | ||
|
||
class _StackCrawlingFormatter(logging.Formatter): | ||
""" | ||
If we configure a python logger with the format string | ||
"%(pathname):%(lineno): %(message)", messages logged via `log.debug` will | ||
be prefixed with the path name and line number of the code that called | ||
`log.debug`. Unfortunately, when a `log.debug` call is wrapped in a helper | ||
function (e.g. debug below), the path name and line number is always that | ||
of the helper function, not the function which called the helper function. | ||
|
||
A _StackCrawlingFormatter is a hack to log a different pathname and line | ||
number. Simply set the `pathname` and `lineno` attributes of the formatter | ||
before you call `log.debug`. See `debug` below for an example. | ||
""" | ||
|
||
def __init__(self, format_str): | ||
super().__init__(format_str) | ||
self.pathname = None | ||
self.lineno = None | ||
|
||
def format(self, record): | ||
s = super().format(record) | ||
if self.pathname is not None: | ||
s = s.replace('{pathname}', self.pathname) | ||
if self.lineno is not None: | ||
s = s.replace('{lineno}', str(self.lineno)) | ||
return s | ||
|
||
_FORMAT_STRING = "[%(asctime)-15s {pathname}:{lineno}] %(message)s" | ||
_FORMATTER = _StackCrawlingFormatter(_FORMAT_STRING) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cut absl integration