Skip to content

Commit

Permalink
Unify implementation of track and team repositories
Browse files Browse the repository at this point in the history
Closes #308
  • Loading branch information
danielmitterdorfer committed Aug 8, 2017
1 parent 69c06e9 commit 4105331
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 164 deletions.
82 changes: 13 additions & 69 deletions esrally/mechanic/team.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import sys
import logging
import configparser

import tabulate

from esrally import exceptions, PROGRAM_NAME
from esrally.utils import console, git, versions, io
from esrally.utils import console, repo, io

logger = logging.getLogger("rally.team")

Expand Down Expand Up @@ -57,78 +56,23 @@ def name_and_config(p):

def team_repo(cfg, update=True):
distribution_version = cfg.opts("mechanic", "distribution.version", mandatory=False)
repo = TeamRepository(cfg)
repo_name = cfg.opts("mechanic", "repository.name")
offline = cfg.opts("system", "offline.mode")
remote_url = cfg.opts("teams", "%s.url" % repo_name, mandatory=False)
root = cfg.opts("node", "root.dir")
team_repositories = cfg.opts("mechanic", "team.repository.dir")
teams_dir = os.path.join(root, team_repositories)

current_team_repo = repo.RallyRepository(remote_url, teams_dir, repo_name, "teams", offline)
if update:
repo.update(distribution_version)
return repo


# TODO #308: This is now generic enough to be merged with the track repo.
class TeamRepository:
"""
Manages teams (consisting of cars and their plugins).
"""

def __init__(self, cfg, fetch=True):
self.cfg = cfg
self.name = cfg.opts("mechanic", "repository.name")
self.offline = cfg.opts("system", "offline.mode")
# If no URL is found, we consider this a local only repo (but still require that it is a git repo)
self.url = cfg.opts("teams", "%s.url" % self.name, mandatory=False)
self.remote = self.url is not None and self.url.strip() != ""
root = cfg.opts("node", "root.dir")
team_repositories = cfg.opts("mechanic", "team.repository.dir")
self.teams_dir = os.path.join(root, team_repositories, self.name)
if self.remote and not self.offline and fetch:
# a normal git repo with a remote
if not git.is_working_copy(self.teams_dir):
git.clone(src=self.teams_dir, remote=self.url)
else:
try:
git.fetch(src=self.teams_dir)
except exceptions.SupplyError:
console.warn("Could not update teams. Continuing with your locally available state.", logger=logger)
else:
if not git.is_working_copy(self.teams_dir):
raise exceptions.SystemSetupError("[{src}] must be a git repository.\n\nPlease run:\ngit -C {src} init"
.format(src=self.teams_dir))

def update(self, distribution_version):
try:
if self.remote and not self.offline:
branch = versions.best_match(git.branches(self.teams_dir, remote=self.remote), distribution_version)
if branch:
# Allow uncommitted changes iff we do not have to change the branch
logger.info(
"Checking out [%s] in [%s] for distribution version [%s]." % (branch, self.teams_dir, distribution_version))
git.checkout(self.teams_dir, branch=branch)
logger.info("Rebasing on [%s] in [%s] for distribution version [%s]." % (branch, self.teams_dir, distribution_version))
try:
git.rebase(self.teams_dir, branch=branch)
except exceptions.SupplyError:
logger.exception("Cannot rebase due to local changes in [%s]" % self.teams_dir)
console.warn(
"Local changes in [%s] prevent team update from remote. Please commit your changes." % self.teams_dir)
return
else:
msg = "Could not find team data remotely for distribution version [%s]. " \
"Trying to find team data locally." % distribution_version
logger.warning(msg)
branch = versions.best_match(git.branches(self.teams_dir, remote=False), distribution_version)
if branch:
logger.info("Checking out [%s] in [%s] for distribution version [%s]." % (branch, self.teams_dir, distribution_version))
git.checkout(self.teams_dir, branch=branch)
else:
raise exceptions.SystemSetupError("Cannot find team data for distribution version %s" % distribution_version)
except exceptions.SupplyError:
tb = sys.exc_info()[2]
raise exceptions.DataError("Cannot update team data in [%s]." % self.teams_dir).with_traceback(tb)
current_team_repo.update(distribution_version)
return current_team_repo


class CarLoader:
def __init__(self, repo):
self.repo = repo
self.cars_dir = os.path.join(self.repo.teams_dir, "cars")
self.cars_dir = os.path.join(self.repo.repo_dir, "cars")

def car_names(self):
def __car_name(path):
Expand Down Expand Up @@ -202,7 +146,7 @@ def __str__(self):
class PluginLoader:
def __init__(self, repo):
self.repo = repo
self.plugins_root_path = os.path.join(self.repo.teams_dir, "plugins")
self.plugins_root_path = os.path.join(self.repo.repo_dir, "plugins")

def plugins(self):
known_plugins = self._official_plugins() + self._configured_plugins()
Expand Down
131 changes: 38 additions & 93 deletions esrally/track/loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import logging
import os
import sys
import glob
import urllib.error

Expand All @@ -11,7 +10,7 @@
import tabulate
from esrally import exceptions, time, PROGRAM_NAME
from esrally.track import params, track
from esrally.utils import io, convert, net, git, versions, console, modules
from esrally.utils import io, convert, net, console, modules, repo

logger = logging.getLogger("rally.track")

Expand All @@ -32,17 +31,19 @@ def tracks(cfg):
:param cfg: The config object.
:return: A list of tracks that are available for the provided distribution version or else for the master version.
"""
repo = TrackRepository(cfg)

def track_names(repo):
return filter(lambda p: os.path.exists(track_file(repo, p)), next(os.walk(repo.repo_dir))[1])

repo = track_repo(cfg)
reader = TrackFileReader(cfg)
distribution_version = cfg.opts("mechanic", "distribution.version", mandatory=False)
data_root = cfg.opts("benchmarks", "local.dataset.cache")
return [reader.read(track_name,
# avoid excessive fetch from remote repo
repo.track_file(distribution_version, track_name, needs_update=False),
repo.track_dir(track_name),
track_file(repo, track_name),
track_dir(repo, track_name),
"%s/%s" % (data_root, track_name.lower())
)
for track_name in repo.track_names(distribution_version)]
for track_name in track_names(repo)]


def list_tracks(cfg):
Expand All @@ -66,12 +67,12 @@ def load_track(cfg):
"""
track_name = cfg.opts("track", "track.name")
try:
repo = TrackRepository(cfg)
repo = track_repo(cfg)
reader = TrackFileReader(cfg)
distribution_version = cfg.opts("mechanic", "distribution.version", mandatory=False)
data_root = cfg.opts("benchmarks", "local.dataset.cache")
full_track = reader.read(track_name, repo.track_file(distribution_version, track_name), repo.track_dir(track_name),
"%s/%s" % (data_root, track_name.lower()))

full_track = reader.read(track_name, track_file(repo, track_name), track_dir(repo, track_name),
os.path.join(data_root, track_name.lower()))
if cfg.opts("track", "test.mode.enabled"):
return post_process_for_test_mode(full_track)
else:
Expand All @@ -85,8 +86,8 @@ def load_track(cfg):
def load_track_plugins(cfg, register_runner, register_scheduler):
track_name = cfg.opts("track", "track.name")
# TODO #257: If we distribute drivers we need to ensure that the correct branch in the track repo is checked out
repo = TrackRepository(cfg, fetch=False)
track_plugin_path = repo.track_dir(track_name)
repo = track_repo(cfg, fetch=False, update=False)
track_plugin_path = track_dir(repo, track_name)

plugin_reader = TrackPluginReader(track_plugin_path, register_runner, register_scheduler)

Expand All @@ -96,6 +97,29 @@ def load_track_plugins(cfg, register_runner, register_scheduler):
logger.debug("Track [%s] in path [%s] does not define any track plugins." % (track_name, track_plugin_path))


def track_repo(cfg, fetch=True, update=True):
distribution_version = cfg.opts("mechanic", "distribution.version", mandatory=False)
repo_name = cfg.opts("track", "repository.name")
offline = cfg.opts("system", "offline.mode")
remote_url = cfg.opts("tracks", "%s.url" % repo_name, mandatory=False)
root = cfg.opts("node", "root.dir")
track_repositories = cfg.opts("benchmarks", "track.repository.dir")
tracks_dir = os.path.join(root, track_repositories)

current_track_repo = repo.RallyRepository(remote_url, tracks_dir, repo_name, "tracks", offline, fetch)
if update:
current_track_repo.update(distribution_version)
return current_track_repo


def track_dir(repo, track_name):
return os.path.join(repo.repo_dir, track_name)


def track_file(repo, track_name):
return os.path.join(track_dir(repo, track_name), "track.json")


def operation_parameters(t, op):
if op.param_source:
logger.debug("Creating parameter source with name [%s]" % op.param_source)
Expand Down Expand Up @@ -208,85 +232,6 @@ def decompress(data_set_path, expected_size_in_bytes):
(type.name, index.name))


class TrackRepository:
"""
Manages track specifications.
"""

def __init__(self, cfg, fetch=True):
self.cfg = cfg
self.name = cfg.opts("track", "repository.name")
self.offline = cfg.opts("system", "offline.mode")
# If no URL is found, we consider this a local only repo (but still require that it is a git repo)
self.url = cfg.opts("tracks", "%s.url" % self.name, mandatory=False)
self.remote = self.url is not None and self.url.strip() != ""
root = cfg.opts("node", "root.dir")
track_repositories = cfg.opts("benchmarks", "track.repository.dir")
self.tracks_dir = "%s/%s/%s" % (root, track_repositories, self.name)
if self.remote and not self.offline and fetch:
# a normal git repo with a remote
if not git.is_working_copy(self.tracks_dir):
git.clone(src=self.tracks_dir, remote=self.url)
else:
try:
git.fetch(src=self.tracks_dir)
except exceptions.SupplyError:
console.warn("Could not update tracks. Continuing with your locally available state.", logger=logger)
else:
if not git.is_working_copy(self.tracks_dir):
raise exceptions.SystemSetupError("[{src}] must be a git repository.\n\nPlease run:\ngit -C {src} init"
.format(src=self.tracks_dir))

def track_names(self, distribution_version):
self._update(distribution_version)
return filter(self._is_track, next(os.walk(self.tracks_dir))[1])

def _is_track(self, path):
return os.path.exists(self._track_file(path))

def track_dir(self, track_name):
return "%s/%s" % (self.tracks_dir, track_name)

def _track_file(self, track_name):
return "%s/track.json" % self.track_dir(track_name)

def track_file(self, distribution_version, track_name, needs_update=True):
if needs_update:
self._update(distribution_version)
return self._track_file(track_name)

def _update(self, distribution_version):
try:
if self.remote and not self.offline:
branch = versions.best_match(git.branches(self.tracks_dir, remote=self.remote), distribution_version)
if branch:
# Allow uncommitted changes iff we do not have to change the branch
logger.info(
"Checking out [%s] in [%s] for distribution version [%s]." % (branch, self.tracks_dir, distribution_version))
git.checkout(self.tracks_dir, branch=branch)
logger.info("Rebasing on [%s] in [%s] for distribution version [%s]." % (branch, self.tracks_dir, distribution_version))
try:
git.rebase(self.tracks_dir, branch=branch)
except exceptions.SupplyError:
logger.exception("Cannot rebase due to local changes in [%s]" % self.tracks_dir)
console.warn(
"Local changes in [%s] prevent track update from remote. Please commit your changes." % self.tracks_dir)
return
else:
msg = "Could not find track data remotely for distribution version [%s]. " \
"Trying to find track data locally." % distribution_version
logger.warning(msg)
branch = versions.best_match(git.branches(self.tracks_dir, remote=False), distribution_version)
if branch:
logger.info("Checking out [%s] in [%s] for distribution version [%s]." % (branch, self.tracks_dir, distribution_version))
git.checkout(self.tracks_dir, branch=branch)
else:
raise exceptions.SystemSetupError("Cannot find track data for distribution version %s" % distribution_version)
except exceptions.SupplyError:
tb = sys.exc_info()[2]
raise exceptions.DataError("Cannot update track data in [%s]." % self.tracks_dir).with_traceback(tb)


def render_template(loader, template_name, glob_helper=lambda f: [], clock=time.Clock):
macros = """
{% macro collect(parts) -%}
Expand Down
67 changes: 67 additions & 0 deletions esrally/utils/repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import sys
import logging

from esrally import exceptions
from esrally.utils import git, console, versions

logger = logging.getLogger("rally.repo")


class RallyRepository:
"""
Manages Rally resources (e.g. teams or tracks).
"""

def __init__(self, remote_url, root_dir, repo_name, resource_name, offline, fetch=True):
# If no URL is found, we consider this a local only repo (but still require that it is a git repo)
self.url = remote_url
self.remote = self.url is not None and self.url.strip() != ""
self.repo_dir = os.path.join(root_dir, repo_name)
self.resource_name = resource_name
self.offline = offline
if self.remote and not self.offline and fetch:
# a normal git repo with a remote
if not git.is_working_copy(self.repo_dir):
git.clone(src=self.repo_dir, remote=self.url)
else:
try:
git.fetch(src=self.repo_dir)
except exceptions.SupplyError:
console.warn("Could not update %s. Continuing with your locally available state." % self.resource_name, logger=logger)
else:
if not git.is_working_copy(self.repo_dir):
raise exceptions.SystemSetupError("[{src}] must be a git repository.\n\nPlease run:\ngit -C {src} init"
.format(src=self.repo_dir))

def update(self, distribution_version):
try:
if self.remote and not self.offline:
branch = versions.best_match(git.branches(self.repo_dir, remote=self.remote), distribution_version)
if branch:
# Allow uncommitted changes iff we do not have to change the branch
logger.info(
"Checking out [%s] in [%s] for distribution version [%s]." % (branch, self.repo_dir, distribution_version))
git.checkout(self.repo_dir, branch=branch)
logger.info("Rebasing on [%s] in [%s] for distribution version [%s]." % (branch, self.repo_dir, distribution_version))
try:
git.rebase(self.repo_dir, branch=branch)
except exceptions.SupplyError:
logger.exception("Cannot rebase due to local changes in [%s]" % self.repo_dir)
console.warn(
"Local changes in [%s] prevent %s update from remote. Please commit your changes." %
(self.resource_name, self.repo_dir))
return
else:
msg = "Could not find %s remotely for distribution version [%s]. Trying to find %s locally." % \
(self.resource_name, distribution_version, self.resource_name)
logger.warning(msg)
branch = versions.best_match(git.branches(self.repo_dir, remote=False), distribution_version)
if branch:
logger.info("Checking out [%s] in [%s] for distribution version [%s]." % (branch, self.repo_dir, distribution_version))
git.checkout(self.repo_dir, branch=branch)
else:
raise exceptions.SystemSetupError("Cannot find %s for distribution version %s" % (self.resource_name, distribution_version))
except exceptions.SupplyError:
tb = sys.exc_info()[2]
raise exceptions.DataError("Cannot update [%s] in [%s]." % (self.resource_name, self.repo_dir)).with_traceback(tb)
4 changes: 2 additions & 2 deletions tests/mechanic/team_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@


class UnitTestRepo:
def __init__(self, teams_dir):
self.teams_dir = teams_dir
def __init__(self, repo_dir):
self.repo_dir = repo_dir


class CarLoaderTests(TestCase):
Expand Down
Loading

0 comments on commit 4105331

Please sign in to comment.