From 704a279ba212be3f12ed0cf82f85ade075c72c5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Tue, 2 Nov 2021 15:53:19 +0545 Subject: [PATCH] exp init: validate params as soon as it's entered in prompt It'll give a error message to the user and reprompt them. This mechanism can be extended to other prompts, and we can show a warning message and/or reprompt them based on validations. Closes #6865. Part of #6446. --- dvc/repo/experiments/init.py | 92 ++++++++++++++++++++++++++++-------- 1 file changed, 72 insertions(+), 20 deletions(-) diff --git a/dvc/repo/experiments/init.py b/dvc/repo/experiments/init.py index 5c660d49db..a7cf574816 100644 --- a/dvc/repo/experiments/init.py +++ b/dvc/repo/experiments/init.py @@ -1,9 +1,18 @@ import logging import os from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Iterable, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + cast, +) -from funcy import compact, lremove, post_processing +from funcy import compact, lremove from rich.prompt import Confirm, Prompt from rich.rule import Rule from rich.syntax import Syntax @@ -33,6 +42,10 @@ } +class RetryPrompt(Exception): + """Used for signalling whether prompts should be retried again.""" + + class RichInputMixin: """Prevents exc message from printing in the same line on Ctrl + D/C.""" @@ -107,19 +120,35 @@ def make_prompt(self, default): return prompt -@post_processing(dict) -def _prompts(keys: Iterable[str], defaults: Dict[str, OptStr]): - for key in keys: - if key == "cmd": - prompt_cls = RequiredPrompt - else: - prompt_cls = SkippablePrompt - kwargs = {"default": defaults[key]} if key in defaults else {} - prompt = PROMPTS[key] +def _prompt( + key: str, + validator: Optional[Callable[[str, Any], None]] = None, + default: OptStr = None, +) -> str: + prompt_cls = RequiredPrompt if key == "cmd" else SkippablePrompt + kwargs = {"default": default} if default is not None else {} + while True: value = prompt_cls.ask( # type: ignore[call-overload] - prompt, console=ui.error_console, **kwargs + PROMPTS[key], console=ui.error_console, **kwargs ) - yield key, value + if validator: + try: + validator(key, value) + except RetryPrompt as exc: + ui.error_write(f"[red]{exc}[/]", styled=True) + continue + return value + + +def _prompts( + keys: Iterable[str], + defaults: Dict[str, str], + validator: Callable[[str, Any], None] = None, +) -> Dict[str, str]: + return { + key: _prompt(key, default=defaults.get(key), validator=validator) + for key in keys + } @contextmanager @@ -143,6 +172,7 @@ def init_interactive( name: str, defaults: Dict[str, str], provided: Dict[str, str], + validator: Callable[[str, Any], None] = None, show_tree: bool = False, live: bool = False, ) -> Dict[str, str]: @@ -190,8 +220,8 @@ def init_interactive( ui.error_write(tree, styled=True) ui.error_write() - ret.update(_prompts(primary, defaults)) - ret.update(_prompts(secondary, defaults)) + ret.update(_prompts(primary, defaults, validator=validator)) + ret.update(_prompts(secondary, defaults, validator=validator)) return compact(ret) @@ -207,6 +237,13 @@ def _check_stage_exists( ) +def loadd_params(path: str) -> Dict[str, List[str]]: + from dvc.utils.serialize import LOADERS + + _, ext = os.path.splitext(path) + return {path: list(LOADERS[ext](path))} + + def init( repo: "Repo", name: str = None, @@ -227,9 +264,28 @@ def init( overrides = overrides or {} with_live = type == "live" + + def validate_prompts_input(key: str, value: Any) -> None: + if value is None: + return + + if key == "params": + assert isinstance(value, str) + try: + loadd_params(value) + except (FileNotFoundError, IsADirectoryError) as exc: + reason = "does not exist" + if isinstance(exc, IsADirectoryError): + reason = "is a directory" + raise RetryPrompt( + f"'{value}' {reason}. " + "Please retry with an existing parameters file." + ) + if interactive: defaults = init_interactive( name, + validator=validate_prompts_input, defaults=defaults, live=with_live, provided=overrides, @@ -251,11 +307,7 @@ def init( params_kv = [] params = context.get("params") if params: - from dvc.utils.serialize import LOADERS - - assert isinstance(params, str) - _, ext = os.path.splitext(params) - params_kv = [{params: list(LOADERS[ext](params))}] + params_kv.append(loadd_params(params)) checkpoint_out = bool(context.get("live")) models = context.get("models")