Skip to content

Commit

Permalink
snapshot for extra repos (#1454)
Browse files Browse the repository at this point in the history
* snapshot for extra repos

* revert a change

* revision and diff for extra repos
  • Loading branch information
hnyu authored Mar 29, 2023
1 parent 028f9bb commit fc5f8fd
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 68 deletions.
12 changes: 6 additions & 6 deletions alf/bin/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,11 @@ def _worker(self, root_dir, parameters, device_queue):
time.sleep(random.uniform(0, 3))

conf_file = common.get_conf_file()
# This is the snapshot stored in grid-search root dir
alf_repo = common.abs_path(
os.path.join(FLAGS.root_dir, "alf.tar.gz"))

# We still need to keep a snapshot of ALF repo at ``<root_dir>``
# for playing individual searching job later
os.system(f"mkdir -p {root_dir}; cp {alf_repo} {root_dir}/")
os.system(f"mkdir -p {root_dir}; "
f"cp {FLAGS.root_dir}/*.tar.gz {root_dir}/")

device = device_queue.get()
if self._conf.use_gpu:
Expand Down Expand Up @@ -402,9 +401,10 @@ def launch_snapshot_gridsearch():
common.write_config(root_dir, common.read_conf_file(root_dir))

# generate a snapshot of ALF repo as ``<root_dir>/alf``
common.generate_alf_root_snapshot(common.alf_root(), root_dir)
common.generate_alf_snapshot(common.alf_root(), conf_file, root_dir)

# point the grid search to the snapshot paths
# point the grid search to the snapshot paths, in case the code has been
# changed when launching a job in the queue
env_vars = common.get_alf_snapshot_env_vars(root_dir)

# remove the conf file option since we will retrieve it from ``root_dir``
Expand Down
3 changes: 0 additions & 3 deletions alf/bin/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,6 @@ def launch_snapshot_play():
"""
# assert the current path is not ALF_ROOT because sys.path will always prepend
# the current path to the path list, which makes our snapshot ALF path shadowed
assert not common.is_alf_root(
os.getcwd()), ("Play with a snapshot is not allowed under ALF root!")

root_dir = common.abs_path(FLAGS.root_dir)

env_vars = common.get_alf_snapshot_env_vars(root_dir)
Expand Down
6 changes: 3 additions & 3 deletions alf/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,11 @@ def main(_):
root_dir = common.abs_path(FLAGS.root_dir)
os.makedirs(root_dir, exist_ok=True)

if FLAGS.store_snapshot:
common.generate_alf_root_snapshot(common.alf_root(), root_dir)

conf_file = common.get_conf_file()

if FLAGS.store_snapshot:
common.generate_alf_snapshot(common.alf_root(), conf_file, root_dir)

# FLAGS.distributed is guaranteed to be one of the possible values.
if FLAGS.distributed == 'none':
training_worker(
Expand Down
16 changes: 14 additions & 2 deletions alf/trainers/policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,20 @@ def _markdownify(paragraph):
alf.summary.text(
'unoptimized_parameters',
_markdownify(self._algorithm.get_unoptimized_parameter_info()))
alf.summary.text('revision', git_utils.get_revision())
alf.summary.text('diff', _markdownify(git_utils.get_diff()))

repo_roots = {
**common._extra_repo_roots_,
**{
'alf': common.alf_root()
}
}
for name, root in repo_roots.items():
alf.summary.text(f'{name}/revision',
git_utils.get_revision(f'{root}/{name}'))
alf.summary.text(
f'{name}/diff',
_markdownify(git_utils.get_diff(f'{root}/{name}')))

alf.summary.text('seed', str(self._random_seed))

# Save a rendered directed graph of the algorithm to the root
Expand Down
115 changes: 73 additions & 42 deletions alf/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import traceback
import types
from typing import Callable, List
import runpy

import alf
from alf.algorithms.config import TrainerConfig
Expand Down Expand Up @@ -1412,7 +1413,22 @@ def get_all_parameters(obj):
return all_parameters


def generate_alf_root_snapshot(alf_root, dest_path):
_extra_repo_roots_ = {}


def add_snapshot_repo_root(root, module_name):
"""Call this function to add any other repo root dir which needs to be included
in the alf snapshot.
Args:
root: the parent path of the repo module.
module_name: the module name for importing in python, e.g., 'alf', 'hobot'.
"""
global _extra_repo_roots_
_extra_repo_roots_[module_name] = root


def generate_alf_snapshot(alf_root: str, conf_file: str, dest_path: str):
"""Given a destination path, copy the local ALF root dir to the path. To
save disk space, only ``*.py`` files will be copied.
Expand All @@ -1421,8 +1437,9 @@ def generate_alf_root_snapshot(alf_root, dest_path):
model or launching a grid-search job in the waiting queue.
Args:
alf_root (str): the parent path of the 'alf' module
dest_path (str): the path to generate a snapshot of ALF repo
alf_root: the parent path of the 'alf' module
conf_file: the alf config file
dest_path: the path to generate a snapshot of ALF repo
"""

def _is_subdir(path, directory):
Expand All @@ -1438,57 +1455,71 @@ def rsync(src, target, includes):
subprocess.check_call(
" ".join(args), stdout=sys.stdout, stderr=sys.stdout, shell=True)

assert not _is_subdir(dest_path, alf_root), (
"Snapshot path '%s' is not allowed under ALF root! Use a different one!"
% dest_path)

# these files are important for code status
includes = ["*.py", "*.gin", "*.so", "*.json"]
# Only copy the 'alf' module dir because the root dir might contain many
# other modules in the case where alf is pip installed in 'site-packages'.
rsync(alf_root + '/alf', dest_path, includes)

# compress the snapshot repo into a ".tar.gz" file
os.system("cd %s; tar -czf alf.tar.gz alf" % dest_path)
os.system("rm -rf %s/alf" % dest_path)
if conf_file.endswith('.py'):
# Currently, snapshot generation is always before calling ``parse_conf_file``
# which will generate an environment in each subprocess.
# So we need to first (informally) run the conf file so that any extra repo
# can have a chance of calling ``add_snapshot_repo_root()`` in their
# ``__init__.py``.
# After this, we need to reset all configs.
# .gin config is not supported here.
runpy.run_path(conf_file)
alf.reset_configs()

includes = ["*.py", "*.gin", "*.so", "*.json", "*.xml"]
repo_roots = {**_extra_repo_roots_, **{'alf': alf_root}}
for name, root in repo_roots.items():
assert not _is_subdir(dest_path, root), (
"Snapshot path '%s' is not allowed under any repo root '%s'! " %
(dest_path, root) + "Use a different one!")
# Only copy the module dir because the root dir might contain many
# other modules in the case where repo is pip installed in 'site-packages'.
rsync(root + f'/{name}', dest_path, includes)
# compress the snapshot repo into a ".tar.gz" file
os.system(
f"cd {dest_path}; tar -czf {name}.tar.gz {name}; rm -rf {name}")


def unzip_alf_snapshot(root_dir: str):
"""Restore an ALF snapshot from a job directory by unzipping the snapshot
'tar' file.
'tar.gz' files.
Args:
root_dir: the tensorboard job directory
"""
alf_zipped_repo = os.path.join(root_dir, "alf.tar.gz")
alf_repo = os.path.join(root_dir, "alf")
if os.path.isfile(alf_zipped_repo):
info("=== Using an ALF snapshot at '%s' ===", alf_zipped_repo)
os.system("rm -rf %s/alf" % root_dir)
os.system("cd %s; tar -xzf alf.tar.gz" % root_dir)
elif os.path.isdir(alf_repo):
# To be backward compatible of snapshots as an unzipped dirs
info("=== Using an ALF snapshot at '%s' ===", alf_repo)
else:
info("=== Didn't find a snapshot; using update-to-date ALF ===")
module_names = []
for zipped_repo in glob.glob(f"{root_dir}/*.tar.gz"):
# assuming all '*.tar.gz' under root_dir are repo snapshots
name = os.path.basename(zipped_repo).split('.')[0]
info("=== Using an ALF snapshot at '%s' ===", zipped_repo)
os.system(f"rm -rf {root_dir}/{name}")
os.system(f"cd {root_dir}; tar -xzf {name}.tar.gz")
module_names.append(name)
return module_names


def get_alf_snapshot_env_vars(root_dir):
"""Given a ``root_dir``, return modified env variable dict so that ``PYTHONPATH``
points to the ALF snapshot under this directory.
"""
unzip_alf_snapshot(root_dir)
legacy_alf_repo = os.path.join(root_dir, "alf")
if os.path.isfile(os.path.join(legacy_alf_repo, "alf")):
# legacy alf repo path for backward compatibility
# legacy tb dirs: root_dir/alf/alf/__init__.py
alf_repo = legacy_alf_repo
else:
# new tb dirs: root_dir/alf/__init__.py
alf_repo = root_dir
alf_examples = os.path.join(alf_repo, "alf/examples")
module_names = unzip_alf_snapshot(root_dir)
python_path = os.environ.get("PYTHONPATH", "")
python_path = ":".join([alf_repo, alf_examples, python_path])
for name in module_names:
assert not is_repo_root(os.getcwd(), name), (
"Using a snapshot is not allowed under a valid repo root: " +
"'%s' (contains '%s')!" % (os.getcwd(), name) +
" Try running the command in a different directory.")
root = root_dir
if name == "alf":
legacy_alf_root = os.path.join(root, "alf")
if os.path.isfile(os.path.join(legacy_alf_root, "alf")):
# legacy alf repo path for backward compatibility
# legacy tb dirs: root_dir/alf/alf/__init__.py
root = legacy_alf_root
alf_examples = os.path.join(root, "alf/examples")
python_path = ":".join([root, alf_examples, python_path])
else:
python_path = ":".join([root, python_path])
env_vars = copy.copy(os.environ)
env_vars.update({"PYTHONPATH": python_path})
return env_vars
Expand All @@ -1512,11 +1543,11 @@ def alf_root():
return _alf_root


def is_alf_root(dir):
"""Given a directory, check if it is a valid ALF root. Currently the way
def is_repo_root(dir, module_name):
"""Given a directory, check if it is a valid repo root. Currently the way
of checking is to see if there is valid ``__init__.py`` under it.
"""
return os.path.isfile(os.path.join(dir, 'alf/__init__.py'))
return os.path.isfile(os.path.join(dir, f'{module_name}/__init__.py'))


def compute_summary_or_eval_interval(config, summary_or_eval_calls=100):
Expand Down
28 changes: 16 additions & 12 deletions alf/utils/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,35 @@
import os


def _get_repo_root():
"""Get ALF repo root path."""
return os.path.join(os.path.dirname(__file__), "..", "..")


def _exec(command):
def _exec(command, module_root):
cwd = os.getcwd()
os.chdir(_get_repo_root())
os.chdir(module_root)
stream = os.popen(command)
ret = stream.read()
stream.close()
os.chdir(cwd)
return ret


def get_revision():
"""Get the current revision of ALF at HEAD."""
return _exec("git rev-parse HEAD").strip()
def get_revision(module_root: str):
"""Get the current revision of a python module at HEAD.
Args:
module_root: the path to the module root
"""
return _exec("git rev-parse HEAD", module_root).strip()


def get_diff():
def get_diff(module_root: str):
"""Get the diff of ALF at HEAD.
If the repo is clean, the returned value is an empty string.
Args:
module_root: the path to the module root
Returns:
current diff.
"""
return _exec("git -c core.fileMode=false diff --diff-filter=M")
return _exec("git -c core.fileMode=false diff --diff-filter=M",
module_root)

0 comments on commit fc5f8fd

Please sign in to comment.