Skip to content

Commit

Permalink
params/metrics: cache remote files (#9932)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Sep 11, 2023
1 parent f2673ac commit 33b6c1d
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 26 deletions.
3 changes: 2 additions & 1 deletion dvc/dependency/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def read_param_file(
path: str,
key_paths: Optional[List[str]] = None,
flatten: bool = False,
**load_kwargs,
) -> Any:
config = load_path(path, fs)
config = load_path(path, fs, **load_kwargs)
if not key_paths:
return config

Expand Down
17 changes: 7 additions & 10 deletions dvc/repo/metrics/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,27 @@ def _extract_metrics(metrics, path: str):
ret[key] = m
else:
logger.debug(
(
"Could not parse '%s' metric from '%s'"
"due to its unsupported type: '%s'"
),
"Could not parse %r metric from %r due to its unsupported type: %r",
key,
path,
type(val),
type(val).__name__,
)

return ret


def _read_metric(fs: "FileSystem", path: str) -> Any:
val = load_path(path, fs)
def _read_metric(fs: "FileSystem", path: str, **load_kwargs) -> Any:
val = load_path(path, fs, **load_kwargs)
val = _extract_metrics(val, path)
return val or {}


def _read_metrics(
fs: "FileSystem", metrics: Iterable[str]
fs: "FileSystem", metrics: Iterable[str], **load_kwargs
) -> Iterator[Tuple[str, Union[Exception, Any]]]:
for metric in metrics:
try:
yield metric, _read_metric(fs, metric)
yield metric, _read_metric(fs, metric, **load_kwargs)
except Exception as exc: # noqa: BLE001 # pylint:disable=broad-exception-caught
logger.debug(exc)
yield metric, exc
Expand Down Expand Up @@ -157,7 +154,7 @@ def _gather_metrics(
data = {}

fs = repo.dvcfs
for fs_path, result in _read_metrics(fs, files):
for fs_path, result in _read_metrics(fs, files, cache=True):
repo_path = fs_path.lstrip(fs.root_marker)
repo_os_path = os.sep.join(fs.path.parts(repo_path))
if not isinstance(result, Exception):
Expand Down
6 changes: 3 additions & 3 deletions dvc/repo/params/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ def _collect_vars(repo, params, stages=None) -> Dict:


def _read_params(
fs: "FileSystem", params: Dict[str, List[str]]
fs: "FileSystem", params: Dict[str, List[str]], **load_kwargs
) -> Iterator[Tuple[str, Union[Exception, Any]]]:
for file_path, key_paths in params.items():
try:
yield file_path, read_param_file(fs, file_path, key_paths)
yield file_path, read_param_file(fs, file_path, key_paths, **load_kwargs)
except Exception as exc: # noqa: BLE001 # pylint:disable=broad-exception-caught
logger.debug(exc)
yield file_path, exc
Expand Down Expand Up @@ -136,7 +136,7 @@ def _gather_params(
data: Dict[str, FileResult] = {}

fs = repo.dvcfs
for fs_path, result in _read_params(fs, files_keypaths):
for fs_path, result in _read_params(fs, files_keypaths, cache=True):
repo_path = fs_path.lstrip(fs.root_marker)
repo_os_path = os.sep.join(fs.path.parts(repo_path))
if not isinstance(result, Exception):
Expand Down
4 changes: 2 additions & 2 deletions dvc/utils/serialize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
)


def load_path(fs_path, fs):
def load_path(fs_path, fs, **kwargs):
suffix = fs.path.suffix(fs_path).lower()
loader = LOADERS[suffix]
return loader(fs_path, fs=fs)
return loader(fs_path, fs=fs, **kwargs)


DUMPERS: DefaultDict[str, DumperFn] = defaultdict( # noqa: F405
Expand Down
6 changes: 4 additions & 2 deletions dvc/utils/serialize/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ def __init__(self, path: "StrPath", encoding: str):
super().__init__(path, f"is not valid {encoding}")


def _load_data(path: "StrPath", parser: ParserFn, fs: Optional["FileSystem"] = None):
def _load_data(
path: "StrPath", parser: ParserFn, fs: Optional["FileSystem"] = None, **kwargs
):
open_fn = fs.open if fs else open
encoding = "utf-8"
with open_fn(path, encoding=encoding) as fd: # type: ignore[operator]
with open_fn(path, encoding=encoding, **kwargs) as fd: # type: ignore[operator]
with reraise(UnicodeDecodeError, EncodingError(path, encoding)):
return parser(fd.read(), path)

Expand Down
4 changes: 2 additions & 2 deletions dvc/utils/serialize/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def __init__(self, path):
super().__init__(path, "JSON file structure is corrupted")


def load_json(path, fs=None):
return _load_data(path, parser=parse_json, fs=fs)
def load_json(path, fs=None, **kwargs):
return _load_data(path, parser=parse_json, fs=fs, **kwargs)


def parse_json(text, path, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions dvc/utils/serialize/_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def __init__(self, path, message="Python file structure is corrupted"):
super().__init__(path, message)


def load_py(path, fs=None):
return _load_data(path, parser=parse_py, fs=fs)
def load_py(path, fs=None, **kwargs):
return _load_data(path, parser=parse_py, fs=fs, **kwargs)


def parse_py(text, path):
Expand Down
4 changes: 2 additions & 2 deletions dvc/utils/serialize/_toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def __init__(self, path):
super().__init__(path, "TOML file structure is corrupted")


def load_toml(path, fs=None):
return _load_data(path, parser=parse_toml, fs=fs)
def load_toml(path, fs=None, **kwargs):
return _load_data(path, parser=parse_toml, fs=fs, **kwargs)


def _parse_toml(text, path):
Expand Down
4 changes: 2 additions & 2 deletions dvc/utils/serialize/_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def __init__(self, path):
super().__init__(path, "YAML file structure is corrupted")


def load_yaml(path, fs=None):
return _load_data(path, parser=parse_yaml, fs=fs)
def load_yaml(path, fs=None, **kwargs):
return _load_data(path, parser=parse_yaml, fs=fs, **kwargs)


def parse_yaml(text, path, typ="safe"):
Expand Down

0 comments on commit 33b6c1d

Please sign in to comment.