Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exp init: add basic template support #6630

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 28 additions & 47 deletions dvc/command/experiments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import logging
import os
from collections import Counter, OrderedDict, defaultdict
from datetime import date, datetime
from fnmatch import fnmatch
Expand Down Expand Up @@ -791,58 +790,36 @@ def run(self):


class CmdExperimentsInit(CmdBase):
CODE = "src"
DATA = "data"
MODELS = "models"
DEFAULT_METRICS = "metrics.json"
DEFAULT_PARAMS = "params.yaml"
PLOTS = "plots"
DVCLIVE = "dvclive"
DEFAULT_NAME = "default"

def run(self):
from dvc.command.stage import parse_cmd
from dvc.repo.experiments.init import init

cmd = parse_cmd(self.args.cmd)
if not cmd:
raise InvalidArgumentError("command is not specified")
if self.args.interactive:
raise NotImplementedError(
"'-i/--interactive' is not implemented yet."
)
if self.args.explicit:
raise NotImplementedError("'--explicit' is not implemented yet.")
if self.args.template:
raise NotImplementedError("template is not supported yet.")

from dvc.utils.serialize import LOADERS

code = self.args.code or self.CODE
data = self.args.data or self.DATA
models = self.args.models or self.MODELS
metrics = self.args.metrics or self.DEFAULT_METRICS
params_path = self.args.params or self.DEFAULT_PARAMS
plots = self.args.plots or self.PLOTS
dvclive = self.args.live or self.DVCLIVE

_, ext = os.path.splitext(params_path)
params = list(LOADERS[ext](params_path))

name = self.args.name or self.DEFAULT_NAME
stage = self.repo.stage.add(
name=name,
cmd=cmd,
deps=[code, data],
outs=[models],
params=[{params_path: params}],
metrics_no_cache=[metrics],
plots_no_cache=[plots],
live=dvclive,
force=True,
)

data = {
"cmd": cmd,
"code": self.args.code,
"data": self.args.data,
"models": self.args.models,
"metrics": self.args.metrics,
"params": self.args.params,
"plots": self.args.plots,
"live": self.args.live,
}

initialized_stage = init(
self.repo,
data,
template_name=self.args.template_name,
interactive=self.args.interactive,
explicit=self.args.explicit,
)
if self.args.run:
return self.repo.experiments.run(targets=[stage.addressing])
return self.repo.experiments.run(
targets=[initialized_stage.addressing]
)
return 0


Expand Down Expand Up @@ -1385,10 +1362,14 @@ def add_parser(subparsers, parent_parser):
help="Prompt for values that are not provided",
)
experiments_init_parser.add_argument(
"--template", help="Stage template to use to fill with provided values"
"--template",
dest="template_name",
help="Stage template to use to fill with provided values",
)
experiments_init_parser.add_argument(
"--explicit", help="Only use the path values explicitly provided"
"--explicit",
action="store_true",
help="Only use the path values explicitly provided",
)
experiments_init_parser.add_argument(
"--name", "-n", help="Name of the stage to create"
Expand Down
121 changes: 121 additions & 0 deletions dvc/repo/experiments/init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import dataclasses
import os
from collections import ChainMap
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, Optional

from funcy import compact
from voluptuous import MultipleInvalid, Schema

from dvc.exceptions import DvcException
from dvc.schema import STAGE_DEFINITION

if TYPE_CHECKING:
from jinja2 import BaseLoader

from dvc.repo import Repo


DEFAULT_TEMPLATE = "default"


@dataclasses.dataclass
class TemplateDefaults:
code: str = "src"
data: str = "data"
models: str = "models"
metrics: str = "metrics.json"
params: str = "params.yaml"
plots: str = "plots"
live: str = "dvclive"


DEFAULT_VALUES = dataclasses.asdict(TemplateDefaults())
STAGE_SCHEMA = Schema(STAGE_DEFINITION)


def get_loader(repo: "Repo") -> "BaseLoader":
from jinja2 import ChoiceLoader, FileSystemLoader

default_path = Path(__file__).parents[3] / "resources" / "stages"
return ChoiceLoader(
[
# not initialized yet
FileSystemLoader(Path(repo.dvc_dir) / "stages"),
# won't work for other packages
FileSystemLoader(default_path),
]
)


def init(
repo: "Repo",
data: Dict[str, Optional[object]],
template_name: str = None,
interactive: bool = False,
explicit: bool = False,
template_loader: Callable[["Repo"], "BaseLoader"] = get_loader,
force: bool = False,
):
from jinja2 import Environment

from dvc.dvcfile import make_dvcfile
skshetry marked this conversation as resolved.
Show resolved Hide resolved
from dvc.stage import check_circular_dependency, check_duplicated_arguments
skshetry marked this conversation as resolved.
Show resolved Hide resolved
from dvc.stage.loader import StageLoader
from dvc.utils.serialize import LOADERS, parse_yaml_for_update

data = compact(data) # remove None values
loader = template_loader(repo)
environment = Environment(loader=loader)
name = template_name or DEFAULT_TEMPLATE

dvcfile = make_dvcfile(repo, "dvc.yaml")
if not force and dvcfile.exists() and name in dvcfile.stages:
raise DvcException(f"stage '{name}' already exists.")

template = environment.get_template(f"{name}.yaml")
context = ChainMap(data)
if interactive:
# TODO: interactive requires us to check for variables present
# in the template and, adapt our prompts accordingly.
raise NotImplementedError("'-i/--interactive' is not supported yet.")
if not explicit:
context.maps.append(DEFAULT_VALUES)
else:
# TODO: explicit requires us to check for undefined variables.
raise NotImplementedError("'--explicit' is not implemented yet.")

assert "params" in context
# See https://github.com/iterative/dvc/issues/6605 for the support
# for depending on all params of a file.
param_path = str(context["params"])
_, ext = os.path.splitext(param_path)
param_names = list(LOADERS[ext](param_path))

# render, parse yaml and then validate schema
rendered = template.render(**context, param_names=param_names)
template_path = os.path.relpath(template.filename)
skshetry marked this conversation as resolved.
Show resolved Hide resolved
data = parse_yaml_for_update(rendered, template_path)
try:
validated = STAGE_SCHEMA(data)
except MultipleInvalid as exc:
raise DvcException(
f"template '{template_path}' "
"failed schema validation while rendering"
) from exc

stage = StageLoader.load_stage(dvcfile, name, validated)
# ensure correctness, similar to what we have in `repo.stage.add`
check_circular_dependency(stage)
check_duplicated_arguments(stage)
new_index = repo.index.add(stage)
new_index.check_graph()

with repo.scm.track_file_changes(config=repo.config):
# note that we are not dumping the "template" as-is
# we are dumping a stage data, which is processed
# so formatting-wise, it may look different.
stage.dump(update_lock=False)
stage.ignore_outs()

return stage
1 change: 1 addition & 0 deletions requirements/default.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ typing_extensions>=3.10.0.2
fsspec[http]>=2021.8.1
aiohttp-retry==2.4.5
diskcache>=5.2.1
jinja2>=2.11.3
17 changes: 17 additions & 0 deletions resources/stages/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
cmd: {{ cmd }}
deps:
- {{ code }}
- {{ data }}
params:
- {{ params }}:
{% for p in param_names %}
- {{ p }}
{% endfor %}
outs:
- {{ models }}
metrics:
- {{ metrics }}:
cache: false
plots:
- {{ plots }}:
cache: false
15 changes: 15 additions & 0 deletions resources/stages/live.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
cmd: {{ cmd }}
deps:
- {{ code }}
- {{ data }}
params:
- {{ params }}:
{% for p in param_names %}
- {{ p }}
{% endfor %}
outs:
- {{ models }}
live:
{{ live }}:
summary: true
html: true
36 changes: 30 additions & 6 deletions tests/func/experiments/test_init.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,58 @@
import os

from dvc.command.experiments import CmdExperimentsInit
from dvc.main import main
from dvc.utils.serialize import load_yaml


def test_init(tmp_dir, dvc):
tmp_dir.gen(
{
CmdExperimentsInit.CODE: {"copy.py": ""},
"src": {"copy.py": ""},
"data": "data",
"params.yaml": '{"foo": 1}',
"dvclive": {},
"plots": {},
}
)
code_path = os.path.join(CmdExperimentsInit.CODE, "copy.py")
code_path = os.path.join("src", "copy.py")
script = f"python {code_path}"

assert main(["exp", "init", script]) == 0
assert load_yaml(tmp_dir / "dvc.yaml") == {
assert (tmp_dir / "dvc.yaml").parse() == {
"stages": {
"default": {
"cmd": script,
"deps": ["data", "src"],
"live": {"dvclive": {"html": True, "summary": True}},
"metrics": [{"metrics.json": {"cache": False}}],
"outs": ["models"],
"params": ["foo"],
"plots": [{"plots": {"cache": False}}],
}
}
}


def test_init_live(tmp_dir, dvc):
tmp_dir.gen(
{
"src": {"copy.py": ""},
"data": "data",
"params.yaml": '{"foo": 1}',
"dvclive": {},
"plots": {},
}
)
code_path = os.path.join("src", "copy.py")
script = f"python {code_path}"

assert main(["exp", "init", "--template", "live", script]) == 0
assert (tmp_dir / "dvc.yaml").parse() == {
"stages": {
"live": {
"cmd": script,
"deps": ["data", "src"],
"outs": ["models"],
"params": ["foo"],
"live": {"dvclive": {"html": True, "summary": True}},
}
}
}