Skip to content

Commit

Permalink
exp init: validate params as soon as it's entered in prompt
Browse files Browse the repository at this point in the history
It'll give a error message to the user and reprompt them.
This mechanism can be extended to other prompts, and we
can show a warning message and/or reprompt them based on
validations.

Closes #6865. Part of #6446.
  • Loading branch information
skshetry committed Nov 11, 2021
1 parent 7e3e257 commit a466241
Showing 1 changed file with 72 additions and 20 deletions.
92 changes: 72 additions & 20 deletions dvc/repo/experiments/init.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import logging
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Iterable, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
cast,
)

from funcy import compact, lremove, post_processing
from funcy import compact, lremove
from rich.prompt import Confirm, Prompt
from rich.rule import Rule
from rich.syntax import Syntax
Expand Down Expand Up @@ -33,6 +42,10 @@
}


class RetryPrompt(Exception):
"""Used for signalling whether prompts should be retried again."""


class RichInputMixin:
"""Prevents exc message from printing in the same line on Ctrl + D/C."""

Expand Down Expand Up @@ -107,19 +120,35 @@ def make_prompt(self, default):
return prompt


@post_processing(dict)
def _prompts(keys: Iterable[str], defaults: Dict[str, OptStr]):
for key in keys:
if key == "cmd":
prompt_cls = RequiredPrompt
else:
prompt_cls = SkippablePrompt
kwargs = {"default": defaults[key]} if key in defaults else {}
prompt = PROMPTS[key]
def _prompt(
key: str,
validator: Optional[Callable[[str, Any], None]] = None,
default: OptStr = None,
) -> str:
prompt_cls = RequiredPrompt if key == "cmd" else SkippablePrompt
kwargs = {"default": default} if default is not None else {}
while True:
value = prompt_cls.ask( # type: ignore[call-overload]
prompt, console=ui.error_console, **kwargs
PROMPTS[key], console=ui.error_console, **kwargs
)
yield key, value
if validator:
try:
validator(key, value)
except RetryPrompt as exc:
ui.error_write(f"[red]{exc}[/]", styled=True)
continue
return value


def _prompts(
keys: Iterable[str],
defaults: Dict[str, str],
validator: Callable[[str, Any], None] = None,
) -> Dict[str, str]:
return {
key: _prompt(key, default=defaults.get(key), validator=validator)
for key in keys
}


@contextmanager
Expand All @@ -143,6 +172,7 @@ def init_interactive(
name: str,
defaults: Dict[str, str],
provided: Dict[str, str],
validator: Callable[[str, Any], None] = None,
show_tree: bool = False,
live: bool = False,
) -> Dict[str, str]:
Expand Down Expand Up @@ -190,8 +220,8 @@ def init_interactive(
ui.error_write(tree, styled=True)

ui.error_write()
ret.update(_prompts(primary, defaults))
ret.update(_prompts(secondary, defaults))
ret.update(_prompts(primary, defaults, validator=validator))
ret.update(_prompts(secondary, defaults, validator=validator))
return compact(ret)


Expand All @@ -207,6 +237,13 @@ def _check_stage_exists(
)


def loadd_params(path: str) -> Dict[str, List[str]]:
from dvc.utils.serialize import LOADERS

_, ext = os.path.splitext(path)
return {path: list(LOADERS[ext](path))}


def init(
repo: "Repo",
name: str = None,
Expand All @@ -227,9 +264,28 @@ def init(
overrides = overrides or {}

with_live = type == "live"

def validate_prompts_input(key: str, value: Any) -> None:
if value is None:
return

if key == "params":
assert isinstance(value, str)
try:
loadd_params(value)
except (FileNotFoundError, IsADirectoryError) as exc:
reason = "does not exist"
if isinstance(exc, IsADirectoryError):
reason = "is a directory"
raise RetryPrompt(
f"'{value}' {reason}. "
"Please retry with an existing parameters file."
)

if interactive:
defaults = init_interactive(
name,
validator=validate_prompts_input,
defaults=defaults,
live=with_live,
provided=overrides,
Expand All @@ -251,11 +307,7 @@ def init(
params_kv = []
params = context.get("params")
if params:
from dvc.utils.serialize import LOADERS

assert isinstance(params, str)
_, ext = os.path.splitext(params)
params_kv = [{params: list(LOADERS[ext](params))}]
params_kv.append(loadd_params(params))

checkpoint_out = bool(context.get("live"))
models = context.get("models")
Expand Down

0 comments on commit a466241

Please sign in to comment.