Skip to content

Commit

Permalink
metrics/params: support reading from cached files and directories, us…
Browse files Browse the repository at this point in the history
…e repo-relative path consistently (#9909)

* metrics

* fix path handling

* fix dvc metrics diff

* params show/diff

* share common utils

* remove onerror_collect from exp_show

* extract read_params

* move around code

* fix path handling on Windows

* rename _show to _gather_metrics

* remove _collect_vars

* add tests

* Revert "remove _collect_vars"

This reverts commit 6d2cb50.

* bring back _collect_vars

* fix tests

* fix paths
  • Loading branch information
skshetry authored Sep 11, 2023
1 parent 7e52c23 commit f2673ac
Show file tree
Hide file tree
Showing 20 changed files with 736 additions and 504 deletions.
18 changes: 10 additions & 8 deletions dvc/api/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
from dvc.repo import Repo


def _onerror_raise(exception: Exception, *args, **kwargs):
raise exception


def _postprocess(results):
processed: Dict[str, Dict] = {}
for rev, rev_data in results.items():
Expand All @@ -23,7 +19,6 @@ def _postprocess(results):
for file_data in rev_data["data"].values():
for k in file_data["data"]:
counts[k] += 1

for file_name, file_data in rev_data["data"].items():
to_merge = {
(k if counts[k] == 1 else f"{file_name}:{k}"): v
Expand Down Expand Up @@ -138,13 +133,17 @@ def metrics_show(
.. _Git revision:
https://git-scm.com/docs/revisions
"""
from dvc.repo.metrics.show import to_relpath

with Repo.open(repo, config=config) as _repo:
metrics = _repo.metrics.show(
targets=targets,
revs=rev if rev is None else [rev],
onerror=_onerror_raise,
on_error="raise",
)
metrics = {
k: to_relpath(_repo.fs, _repo.root_dir, v) for k, v in metrics.items()
}

metrics = _postprocess(metrics)

Expand Down Expand Up @@ -382,17 +381,20 @@ def params_show(
https://git-scm.com/docs/revisions
"""
from dvc.repo.metrics.show import to_relpath

if isinstance(stages, str):
stages = [stages]

with Repo.open(repo, config=config) as _repo:
params = _repo.params.show(
revs=rev if rev is None else [rev],
targets=targets,
deps=deps,
onerror=_onerror_raise,
deps_only=deps,
on_error="raise",
stages=stages,
)
params = {k: to_relpath(_repo.fs, _repo.root_dir, v) for k, v in params.items()}

params = _postprocess(params)

Expand Down
64 changes: 42 additions & 22 deletions dvc/commands/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dvc.cli import completion
from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link, fix_subparsers
from dvc.exceptions import DvcException
from dvc.ui import ui
from dvc.utils.serialize import encode_exception

Expand All @@ -20,17 +19,25 @@ class CmdMetricsBase(CmdBase):

class CmdMetricsShow(CmdMetricsBase):
def run(self):
try:
metrics = self.repo.metrics.show(
self.args.targets,
all_branches=self.args.all_branches,
all_tags=self.args.all_tags,
all_commits=self.args.all_commits,
recursive=self.args.recursive,
from dvc.repo.metrics.show import to_relpath
from dvc.utils import errored_revisions

metrics = self.repo.metrics.show(
self.args.targets,
all_branches=self.args.all_branches,
all_tags=self.args.all_tags,
all_commits=self.args.all_commits,
)
metrics = {
k: to_relpath(self.repo.fs, self.repo.root_dir, v)
for k, v in metrics.items()
}

if errored := errored_revisions(metrics):
ui.error_write(
"DVC failed to load some metrics for following revisions:"
f" '{', '.join(errored)}'."
)
except DvcException:
logger.exception("")
return 1

if self.args.json:
ui.write_json(metrics, default=encode_exception)
Expand All @@ -52,17 +59,26 @@ def run(self):

class CmdMetricsDiff(CmdMetricsBase):
def run(self):
try:
diff = self.repo.metrics.diff(
a_rev=self.args.a_rev,
b_rev=self.args.b_rev,
targets=self.args.targets,
recursive=self.args.recursive,
all=self.args.all,
import os
from os.path import relpath

diff_result = self.repo.metrics.diff(
a_rev=self.args.a_rev,
b_rev=self.args.b_rev,
targets=self.args.targets,
all=self.args.all,
)

errored = [rev for rev, err in diff_result.get("errors", {}).items() if err]
if errored:
ui.error_write(
"DVC failed to load some metrics for following revisions:"
f" '{', '.join(errored)}'."
)
except DvcException:
logger.exception("failed to show metrics diff")
return 1

start = relpath(os.getcwd(), self.repo.root_dir)
diff = diff_result.get("diff", {})
diff = {relpath(path, start): result for path, result in diff.items()}

if self.args.json:
ui.write_json(diff)
Expand Down Expand Up @@ -184,10 +200,14 @@ def add_parser(subparsers, parent_parser):
formatter_class=argparse.RawDescriptionHelpFormatter,
)
metrics_diff_parser.add_argument(
"a_rev", nargs="?", help="Old Git commit to compare (defaults to HEAD)"
"a_rev",
nargs="?",
help="Old Git commit to compare (defaults to HEAD)",
default="HEAD",
)
metrics_diff_parser.add_argument(
"b_rev",
default="workspace",
nargs="?",
help="New Git commit to compare (defaults to the current workspace)",
)
Expand Down
37 changes: 25 additions & 12 deletions dvc/commands/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dvc.cli import completion
from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link, fix_subparsers
from dvc.exceptions import DvcException
from dvc.ui import ui

logger = logging.getLogger(__name__)
Expand All @@ -14,17 +13,27 @@ class CmdParamsDiff(CmdBase):
UNINITIALIZED = True

def run(self):
try:
diff = self.repo.params.diff(
a_rev=self.args.a_rev,
b_rev=self.args.b_rev,
targets=self.args.targets,
all=self.args.all,
deps=self.args.deps,
import os
from os.path import relpath

diff_result = self.repo.params.diff(
a_rev=self.args.a_rev,
b_rev=self.args.b_rev,
targets=self.args.targets,
all=self.args.all,
deps_only=self.args.deps,
)

errored = [rev for rev, err in diff_result.get("errors", {}).items() if err]
if errored:
ui.error_write(
"DVC failed to load some metrics for following revisions:"
f" '{', '.join(errored)}'."
)
except DvcException:
logger.exception("failed to show params diff")
return 1

start = relpath(os.getcwd(), self.repo.root_dir)
diff = diff_result.get("diff", {})
diff = {relpath(path, start): result for path, result in diff.items()}

if self.args.json:
ui.write_json(diff)
Expand Down Expand Up @@ -74,10 +83,14 @@ def add_parser(subparsers, parent_parser):
formatter_class=argparse.RawDescriptionHelpFormatter,
)
params_diff_parser.add_argument(
"a_rev", nargs="?", help="Old Git commit to compare (defaults to HEAD)"
"a_rev",
nargs="?",
default="HEAD",
help="Old Git commit to compare (defaults to HEAD)",
)
params_diff_parser.add_argument(
"b_rev",
default="workspace",
nargs="?",
help="New Git commit to compare (defaults to the current workspace)",
)
Expand Down
4 changes: 2 additions & 2 deletions dvc/dependency/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, Mapping, Set
from typing import Any, Dict, List, Mapping, Set

from dvc.output import ARTIFACT_SCHEMA, DIR_FILES_SCHEMA, Output

Expand Down Expand Up @@ -58,7 +58,7 @@ def loads_from(stage, s_list, erepo=None, fs_config=None):
]


def _merge_params(s_list):
def _merge_params(s_list) -> Dict[str, List[str]]:
d = defaultdict(list)
default_file = ParamsDependency.DEFAULT_PARAMS_FILE

Expand Down
74 changes: 44 additions & 30 deletions dvc/dependency/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import typing
from collections import defaultdict
from typing import Any, Dict
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import dpath

Expand All @@ -12,6 +12,9 @@

from .base import Dependency

if TYPE_CHECKING:
from dvc.fs import FileSystem

logger = logging.getLogger(__name__)


Expand All @@ -31,6 +34,36 @@ class BadParamFileError(DvcException):
pass


def read_param_file(
fs: "FileSystem",
path: str,
key_paths: Optional[List[str]] = None,
flatten: bool = False,
) -> Any:
config = load_path(path, fs)
if not key_paths:
return config

ret = {}
if flatten:
for key_path in key_paths:
try:
ret[key_path] = dpath.get(config, key_path, separator=".")
except KeyError:
continue
return ret

from dpath import merge

for key_path in key_paths:
merge(
ret,
dpath.search(config, key_path, separator="."),
separator=".",
)
return ret


class ParamsDependency(Dependency):
PARAM_PARAMS = "params"
DEFAULT_PARAMS_FILE = "params.yaml"
Expand Down Expand Up @@ -75,31 +108,19 @@ def read_params(
self, flatten: bool = True, **kwargs: typing.Any
) -> Dict[str, typing.Any]:
try:
config = self.read_file()
self.validate_filepath()
except MissingParamsFile:
config = {}

if not self.params:
return config

ret = {}
if flatten:
for param in self.params:
try:
ret[param] = dpath.get(config, param, separator=".")
except KeyError:
continue
return ret
return {}

from dpath import merge

for param in self.params:
merge(
ret,
dpath.search(config, param, separator="."),
separator=".",
try:
return read_param_file(
self.repo.fs,
self.fs_path,
list(self.params) if self.params else None,
flatten=flatten,
)
return ret
except ParseError as exc:
raise BadParamFileError(f"Unable to read parameters from '{self}'") from exc

def workspace_status(self):
if not self.exists:
Expand Down Expand Up @@ -149,13 +170,6 @@ def validate_filepath(self):
f"'{self}' is a directory, expected a parameters file"
)

def read_file(self):
self.validate_filepath()
try:
return load_path(self.fs_path, self.repo.fs)
except ParseError as exc:
raise BadParamFileError(f"Unable to read parameters from '{self}'") from exc

def get_hash(self):
info = self.read_params()

Expand Down
5 changes: 3 additions & 2 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
Union,
)

from funcy import get_in
from scmrepo.exceptions import SCMError

from dvc.env import DVC_EXP_AUTO_PUSH, DVC_EXP_GIT_REMOTE
Expand Down Expand Up @@ -644,13 +643,15 @@ def _repro_dvc( # noqa: C901
info.status = TaskStatus.FAILED
raise
finally:
from dvc.repo.metrics.show import _gather_metrics

post_live_metrics(
"done",
info.baseline_rev,
info.name, # type: ignore[arg-type]
"dvc",
experiment_rev=dvc.experiments.scm.get_ref(EXEC_BRANCH),
metrics=get_in(dvc.metrics.show(), ["", "data"]),
metrics=_gather_metrics(dvc, on_error="return"),
dvc_studio_config=dvc_studio_config,
)

Expand Down
Loading

0 comments on commit f2673ac

Please sign in to comment.