-
Notifications
You must be signed in to change notification settings - Fork 1
/
example.py
44 lines (34 loc) · 1.77 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
from src.dargparser import Choice, dArg, dargparse
# flake8 doesn't like Choice with strings, you can use Literal instead.
# flake8: noqa
@dataclass
class Args:
epochs: int
learning_rate: float = dArg(aliases="--lr", help="Required argument (no default).")
data_path: Path = dArg(default="./data/", aliases=["--data", "-d"], parsing_function=lambda x: Path(x).resolve())
# str | None syntax is only available in Python >=3.10. Use Optional[str] for older versions.
extra_data: str | None = dArg(default=None)
cuda: bool = dArg(default=True, help="We automatically create a `--no_<arg>` flag for bools.")
precision: Choice[32, 16, 8, "bf16", "tf32"] = dArg(default=32, help="Choices with mixed types are supported.")
some_list_arg: list[int] = dArg(default=[1, 2, 3])
evaluation_datasets: list[Choice["xnli", "tydiqa", "wikiann", "squad"]] = dArg(
default=["xnli", "wikiann"], help="Select arbitrary number of datasets to evaluate on."
)
complex_arg: tuple[int, list[str]] = dArg(
default=(1, ["a", "b"]), parsing_function=lambda x: (int(x.split(",")[0]), x.split(",")[1:])
)
@dataclass
class LoggingArgs:
log_dir: str | None = dArg(default=None)
log_backends: list[Choice["wandb", "tensorboard", "neptune"]] = dArg(default=["wandb"])
# Choice is just an alias for Literal to be more descriptive.
log_level: Literal["debug", "info", "warning", "error", "critical"] = dArg(default="info")
def main(args: Args, logging_args: LoggingArgs):
print("Arguments:", args)
print("Logging arguments:", logging_args)
if __name__ == "__main__":
args, logging_args = dargparse(dataclasses=(Args, LoggingArgs))
main(args, logging_args)