diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index cb8223c..30b025b 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -23,6 +23,8 @@ jobs: uses: abatilo/actions-poetry@v2 - name: Install poetry project run: poetry install + - name: Check unused imports + run: poetry run ruff --select F401 . - name: Sort imports run: poetry run isort --check --diff . - name: Run pyright diff --git a/README.md b/README.md index c878778..ab93335 100644 --- a/README.md +++ b/README.md @@ -12,12 +12,12 @@ from arcparse import arcparser, positional @arcparser class Args: name: str = positional() - age: int = positional() + age: int hobbies: list[str] = positional() happy: bool -args = Args.parse("Thomas 25 news coffee running --happy".split()) +args = Args.parse("--age 25 Thomas news coffee running --happy".split()) print(f"Hi, my name is {args.name}!") ``` @@ -27,9 +27,6 @@ For a complete overview of features see [Features](#features). ```shell # Using pip $ pip install arcparse - -# locally using poetry -$ poetry install ``` ## Features @@ -55,12 +52,17 @@ class Args: ``` ### Flags -All arguments type-hinted as `bool` are flags, they use `action="store_true"` in the background. Use `no_flag()` to easily create a `--no-...` flag with `action="store_false"`. Flags as well as options can also define short forms for each argument. They can also disable the long form with `short_only=True`. +All arguments type-hinted as `bool` are flags, they use `action="store_true"` in the background. Flags (as well as options) can also define short forms for each argument. They can also disable the long form with `short_only=True`. + +Use `no_flag()` to easily create a `--no-...` flag with `action="store_false"`. + +Use `tri_flag()` (or type-hint argument as `bool | None`) to create a "true" flag and a "false" flag (e.g. `--clone` and `--no-clone`). Passing `--clone` will store `True`, passing `--no-clone` will store `False` and not passing anything will store `None`. Passing both is an error ensured by an implicit mutually exclusive group. ```py @arcparser class Args: sync: bool recurse: bool = no_flag(help="Do not recurse") + clone: bool | None debug: bool = flag("-d") # both -d and --debug verbose: bool = flag("-v", short_only=True) # only -v @@ -91,13 +93,14 @@ class Args: ``` ### Type conversions -Automatic type conversions are supported. The type-hint is used in `type=...` in the background (unless it's `str`, which does no conversion). Using a `StrEnum` subclass as a type-hint automatically populates `choices`. Using a `re.Pattern` typehint automatically uses `re.compile` as a converter. A custom type-converter can be used by passing `converter=...` to either `option()` or `positional()`. Come common utility converters are defined in [converters.py](arcparse/converters.py). +Automatic type conversions are supported. The type-hint is used in `type=...` in the background (unless it's `str`, which does no conversion). Using a `StrEnum` subclass as a type-hint automatically populates `choices`, using `Literal` also populates choices but does not set converter unlike `StrEnum`. Using a `re.Pattern` typehint automatically uses `re.compile` as a converter. A custom type-converter can be used by passing `converter=...` to either `option()` or `positional()`. Come common utility converters are defined in [converters.py](arcparse/converters.py). Custom converters may be used in combination with multiple values per argument. These converters are called `itemwise` and need to be wrapped in `itemwise()`. This wrapper is used automatically if an argument is typed as `list[...]` and no converter is set. ```py from arcparse.converters import sv, csv, sv_dict, itemwise from enum import StrEnum from re import Pattern +from typing import Literal @arcparser class Args: @@ -112,6 +115,7 @@ class Args: number: int result: Result + literal: Literal["yes", "no"] pattern: Pattern custom: Result = option(converter=Result.from_int) ints: list[int] = option(converter=csv(int)) @@ -120,12 +124,27 @@ class Args: results: list[Result] = option(converter=itemwise(Result.from_int)) ``` +### dict helpers +Sometimes creating an argument able to choose a value from a dict by its key is desired. `dict_option` and `dict_positional` do exactly that. In the following example passing `--foo yes` will result in `.foo` being `True`. +```py +from arcparse import dict_option + +values = { + "yes": True, + "no": False, +} + +@arcparser +class Args: + foo: bool = dict_option(values) +``` + ### Mutually exclusive groups Use `mx_group` to group multiple arguments together in a mutually exclusive group. Each argument has to have a default defined either implicitly through the type (being `bool` or a union with `None`) or explicitly with `default`. ```py @arcparser class Args: - group = MxGroup() # alternatively use `(group := MxGroup())` on the next line + group = mx_group() # alternatively use `(group := mx_group())` on the next line flag: bool = flag(mx_group=group) option: str | None = option(mx_group=group) ``` diff --git a/arcparse/__init__.py b/arcparse/__init__.py index f5b9719..b34af7b 100644 --- a/arcparse/__init__.py +++ b/arcparse/__init__.py @@ -1,14 +1,31 @@ -from ._arguments import MxGroup, flag, no_flag, option, positional -from ._parser import arcparser, subparsers +from ._argument_helpers import ( + dict_option, + dict_positional, + flag, + mx_group, + no_flag, + option, + positional, + subparsers, + tri_flag, +) +from ._parser import InvalidArgument, InvalidParser, InvalidTypehint, arcparser from .converters import itemwise + __all__ = [ "arcparser", "positional", "option", "flag", "no_flag", - "MxGroup", + "tri_flag", + "dict_positional", + "dict_option", + "mx_group", "subparsers", "itemwise", + "InvalidParser", + "InvalidArgument", + "InvalidTypehint", ] diff --git a/arcparse/_argument_helpers.py b/arcparse/_argument_helpers.py new file mode 100644 index 0000000..ee41d40 --- /dev/null +++ b/arcparse/_argument_helpers.py @@ -0,0 +1,159 @@ +from collections.abc import Callable, Collection +from typing import Any + +from arcparse.errors import InvalidArgument + +from ._arguments import Void, void +from ._partial_arguments import ( + PartialFlag, + PartialMxGroup, + PartialNoFlag, + PartialOption, + PartialPositional, + PartialSubparsers, + PartialTriFlag, +) + + +def positional[T]( + *, + default: T | str | Void = void, + choices: Collection[str] | None = None, + converter: Callable[[str], T] | None = None, + name_override: str | None = None, + at_least_one: bool = False, + mx_group: PartialMxGroup | None = None, + help: str | None = None, +) -> T: + return PartialPositional( + default=default, + choices=choices, + converter=converter, + name_override=name_override, + at_least_one=at_least_one, + mx_group=mx_group, + help=help, + ) # type: ignore + + +def option[T]( + short: str | None = None, + *, + short_only: bool = False, + default: T | str | Void = void, + choices: Collection[str] | None = None, + converter: Callable[[str], T] | None = None, + name_override: str | None = None, + append: bool = False, + at_least_one: bool = False, + mx_group: PartialMxGroup | None = None, + help: str | None = None, +) -> T: + if short_only and short is None: + raise ValueError("`short_only` cannot be True if `short` is not provided") + + if append and at_least_one: + raise ValueError("`append` is incompatible with `at_least_one`") + + return PartialOption( + short=short, + short_only=short_only, + default=default, + choices=choices, + converter=converter, + name_override=name_override, + append=append, + at_least_one=at_least_one, + mx_group=mx_group, + help=help, + ) # type: ignore + + +def flag( + short: str | None = None, + *, + short_only: bool = False, + mx_group: PartialMxGroup | None = None, + help: str | None = None, +) -> bool: + if short_only and short is None: + raise ValueError("`short_only` cannot be True if `short` is not provided") + return PartialFlag( + short=short, + short_only=short_only, + help=help, + mx_group=mx_group, + ) # type: ignore + + +def no_flag(*, mx_group: PartialMxGroup | None = None, help: str | None = None) -> bool: + return PartialNoFlag(mx_group=mx_group, help=help) # type: ignore + + +def tri_flag(mx_group: PartialMxGroup | None = None) -> bool | None: + return PartialTriFlag(mx_group=mx_group) # type: ignore + + +def mx_group(*, required: bool = False) -> PartialMxGroup: + return PartialMxGroup(required=required) + + +def subparsers(*args: str) -> Any: + return PartialSubparsers(names=list(args)) + + +def dict_positional[T]( + dict_: dict[str, T], + *, + default: T | Void = void, + name_override: str | None = None, + at_least_one: bool = False, + mx_group: PartialMxGroup | None = None, + help: str | None = None, +) -> T: + """Creates positional() from dict by pre-filling choices and converter""" + + if default is not void and default not in dict_.values(): + raise InvalidArgument("dict_positional default must be a value in dict") + + return positional( + default=default, + choices=list(dict_.keys()), + converter=dict_.__getitem__, + name_override=name_override, + at_least_one=at_least_one, + mx_group=mx_group, + help=help, + ) + + + +def dict_option[T]( + dict_: dict[str, T], + *, + short: str | None = None, + short_only: bool = False, + default: T | Void = void, + name_override: str | None = None, + append: bool = False, + at_least_one: bool = False, + mx_group: PartialMxGroup | None = None, + help: str | None = None, +) -> T: + """Creates option() from dict by pre-filling choices and converter""" + + if default is not void and default not in dict_.values(): + raise InvalidArgument("dict_positional default must be a value in dict") + + return option( + short=short, + short_only=short_only, + default=default, + choices=list(dict_.keys()), + converter=dict_.__getitem__, + name_override=name_override, + append=append, + at_least_one=at_least_one, + mx_group=mx_group, + help=help, + ) diff --git a/arcparse/_arguments.py b/arcparse/_arguments.py index 346c208..1041d88 100644 --- a/arcparse/_arguments.py +++ b/arcparse/_arguments.py @@ -1,15 +1,12 @@ from abc import ABC, abstractmethod -from argparse import _ActionsContainer -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any, Literal, overload +from argparse import Action, _ActionsContainer, _MutuallyExclusiveGroup +from collections.abc import Callable, Collection +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal, Protocol -from ._typehints import ( - extract_collection_type, - extract_optional_type, - extract_type_from_typehint, -) -from .converters import itemwise + +if TYPE_CHECKING: + from ._parser import Parser class Void: @@ -18,140 +15,126 @@ class Void: void = Void() -@dataclass(kw_only=True, eq=False) -class MxGroup: - required: bool = False +class ContainerApplicable(Protocol): + def apply(self, actions_container: _ActionsContainer, name: str) -> Action: + ... -@dataclass(kw_only=True) -class _BaseArgument(ABC): - mx_group: MxGroup | None = None - help: str | None = None +class BaseSingleArgument(ContainerApplicable, ABC): + def apply(self, actions_container: _ActionsContainer, name: str) -> Action: + args = self.get_argparse_args(name) + kwargs = self.get_argparse_kwargs(name) + return actions_container.add_argument(*args, **kwargs) - def apply(self, actions_container: _ActionsContainer, name: str, typehint: type) -> None: - args = self.get_argparse_args(name, typehint) - kwargs = self.get_argparse_kwargs(name, typehint) - actions_container.add_argument(*args, **kwargs) + @abstractmethod + def get_argparse_args(self, name: str) -> list[str]: + ... @abstractmethod - def get_argparse_args(self, name: str, typehint: type) -> list[str]: + def get_argparse_kwargs(self, name: str) -> dict[str, Any]: ... - def get_argparse_kwargs(self, name: str, typehint: type) -> dict[str, Any]: + +@dataclass(kw_only=True) +class BaseArgument(BaseSingleArgument): + help: str | None = None + + def get_argparse_kwargs(self, name: str) -> dict[str, Any]: kwargs = {} if self.help is not None: kwargs["help"] = self.help return kwargs -@dataclass(kw_only=True) -class _BaseValueArgument[T](_BaseArgument): - default: T | Void = void - choices: list[T] | None = None - converter: Callable[[str], T] | None = None - name_override: str | None = None - at_least_one: bool = False - - def need_multiple(self, typehint: type) -> bool: - return ( - (self.converter is None and extract_collection_type(typehint) is not None) - or isinstance(self.converter, itemwise) - ) - - def get_argparse_kwargs(self, name: str, typehint: type) -> dict[str, Any]: - kwargs = super().get_argparse_kwargs(name, typehint) - - if self.converter is None: - type_ = extract_type_from_typehint(typehint) - if type_ is not str: - if extract_collection_type(typehint): - self.converter = itemwise(type_) # type: ignore (list[T@itemwise] somehow incompatible with T@_BaseValueArgument) - else: - self.converter = type_ - - if self.converter is not None: - kwargs["type"] = self.converter - if self.default is not void: - kwargs["default"] = self.default - if self.choices is not None: - kwargs["choices"] = self.choices +@dataclass +class Flag(BaseArgument): + short: str | None = None + short_only: bool = False - if self.need_multiple(typehint) and not self.at_least_one and self.default is void: - kwargs["default"] = [] + def get_argparse_args(self, name: str) -> list[str]: + args = [f"--{name.replace("_", "-")}"] + if self.short_only: + assert self.short is not None + return [self.short] + elif self.short is not None: + args.insert(0, self.short) - return kwargs + return args + def get_argparse_kwargs(self, name: str) -> dict[str, Any]: + kwargs = super().get_argparse_kwargs(name) + kwargs["action"] = "store_true" -@dataclass -class _Positional[T](_BaseValueArgument[T]): - def get_argparse_args(self, name: str, typehint: type) -> list[str]: - return [name] + if self.short_only: + kwargs["dest"] = name + return kwargs - def get_argparse_kwargs(self, name: str, typehint: type) -> dict[str, Any]: - kwargs = super().get_argparse_kwargs(name, typehint) - type_is_optional = bool(extract_optional_type(typehint)) - type_is_collection = bool(extract_collection_type(typehint)) - optional = type_is_optional or type_is_collection or self.default is not void +@dataclass +class NoFlag(BaseArgument): + def get_argparse_args(self, name: str) -> list[str]: + return [f"--no-{name.replace("_", "-")}"] - if self.name_override is not None: - kwargs["metavar"] = self.name_override + def get_argparse_kwargs(self, name: str) -> dict[str, Any]: + kwargs = super().get_argparse_kwargs(name) + kwargs["action"] = "store_false" - if self.need_multiple(typehint): - kwargs["nargs"] = "+" if self.at_least_one else "*" - kwargs["metavar"] = self.name_override if self.name_override is not None else name.upper() - elif optional: - kwargs["nargs"] = "?" + kwargs["dest"] = name return kwargs -@dataclass -class _Option[T](_BaseValueArgument[T]): - short: str | None = None - short_only: bool = False - append: bool = False +class TriFlag(ContainerApplicable): + def apply(self, actions_container: _ActionsContainer, name: str) -> None: + # if actions_container is not an mx group, make it one, argparse + # doesn't support mx group nesting + if not isinstance(actions_container, _MutuallyExclusiveGroup): + actions_container = actions_container.add_mutually_exclusive_group() - def get_argparse_args(self, name: str, typehint: type) -> list[str]: - name = self.name_override if self.name_override is not None else name.replace("_", "-") - args = [f"--{name}"] - if self.short_only: - assert self.short is not None - return [self.short] - elif self.short is not None: - args.insert(0, self.short) + name = name.replace("_", "-") + actions_container.add_argument(f"--{name}", action="store_true") + actions_container.add_argument(f"--no-{name}", action="store_true") - return args - def get_argparse_kwargs(self, name: str, typehint: type) -> dict[str, Any]: - kwargs = super().get_argparse_kwargs(name, typehint) - if self.need_multiple(typehint): - if self.append: - kwargs["action"] = "append" - else: - kwargs["nargs"] = "+" if self.at_least_one else "*" +@dataclass(kw_only=True) +class BaseValueArgument[T](BaseArgument): + default: T | str | Void = void + converter: Callable[[str], T] | None = None + choices: Collection[T] | None = None + nargs: Literal["?", "*", "+"] | None = None + metavar: str | None = None - if self.name_override is not None: - kwargs["dest"] = name - kwargs["metavar"] = self.name_override.replace("-", "_").upper() - elif self.short_only: - kwargs["dest"] = name + def get_argparse_kwargs(self, name: str) -> dict[str, Any]: + kwargs = super().get_argparse_kwargs(name) - type_is_optional = bool(extract_optional_type(typehint)) - type_is_collection = bool(extract_collection_type(typehint)) - required = (not (type_is_optional or type_is_collection) and self.default is void) or self.at_least_one - if required: - kwargs["required"] = True + if self.default is not void: + kwargs["default"] = self.default + if self.choices is not None: + kwargs["choices"] = self.choices + if self.nargs is not None: + kwargs["nargs"] = self.nargs + if self.metavar is not None: + kwargs["metavar"] = self.metavar return kwargs @dataclass -class _Flag(_BaseArgument): +class Positional[T](BaseValueArgument[T]): + def get_argparse_args(self, name: str) -> list[str]: + return [name] + + +@dataclass +class Option[T](BaseValueArgument[T]): + name_override: str | None = None short: str | None = None short_only: bool = False + required: bool = False + append: bool = False - def get_argparse_args(self, name: str, typehint: type) -> list[str]: - args = [f"--{name.replace("_", "-")}"] + def get_argparse_args(self, name: str) -> list[str]: + name = self.name_override if self.name_override is not None else name.replace("_", "-") + args = [f"--{name}"] if self.short_only: assert self.short is not None return [self.short] @@ -160,164 +143,26 @@ def get_argparse_args(self, name: str, typehint: type) -> list[str]: return args - def get_argparse_kwargs(self, name: str, typehint: type) -> dict[str, Any]: - kwargs = super().get_argparse_kwargs(name, typehint) - kwargs["action"] = "store_true" + def get_argparse_kwargs(self, name: str) -> dict[str, Any]: + kwargs = super().get_argparse_kwargs(name) - if self.short_only: + if self.name_override is not None or self.short_only: kwargs["dest"] = name - return kwargs - + if self.required: + kwargs["required"] = True + if self.append: + kwargs["action"] = "append" -@dataclass -class _NoFlag(_BaseArgument): - def get_argparse_args(self, name: str, typehint: type) -> list[str]: - return [f"--no-{name.replace("_", "-")}"] + return kwargs - def get_argparse_kwargs(self, name: str, typehint: type) -> dict[str, Any]: - kwargs = super().get_argparse_kwargs(name, typehint) - kwargs["action"] = "store_false" - kwargs["dest"] = name - return kwargs +@dataclass(eq=False) +class MxGroup: + arguments: dict[str, BaseArgument] = field(default_factory=dict) + required: bool = False -@overload -def positional[T]( - *, - default: T | Void = void, - choices: list[T] | None = None, - converter: Callable[[str], T] | None = None, - name_override: str | None = None, - at_least_one: Literal[False] = False, - mx_group: MxGroup | None = None, - help: str | None = None, -) -> T: ... - -@overload -def positional[T]( - *, - default: list[T] | Void = void, - choices: list[T] | None = None, - converter: Callable[[str], list[T]] | None = None, - name_override: str | None = None, - at_least_one: Literal[True] = True, - mx_group: MxGroup | None = None, - help: str | None = None, -) -> list[T]: ... - -def positional( # type: ignore - *, - default=void, - choices=None, - converter=None, - name_override=None, - at_least_one=False, - mx_group=None, - help=None, -): - return _Positional( - default=default, - choices=choices, - converter=converter, - name_override=name_override, - at_least_one=at_least_one, - mx_group=mx_group, - help=help, - ) - - -@overload -def option[T]( - short: str | None = None, - *, - short_only: bool = False, - default: T | Void = void, - choices: list[T] | None = None, - converter: Callable[[str], T] | None = None, - name_override: str | None = None, - append: Literal[False] = False, - at_least_one: Literal[False] = False, - mx_group: MxGroup | None = None, - help: str | None = None, -) -> T: ... - - -@overload -def option[T]( - short: str | None = None, - *, - short_only: bool = False, - default: list[T] | Void = void, - choices: list[T] | None = None, - converter: Callable[[str], list[T]] | None = None, - name_override: str | None = None, - append: Literal[True] = True, - at_least_one: Literal[False] = False, - mx_group: MxGroup | None = None, - help: str | None = None, -) -> list[T]: ... - -@overload -def option[T]( - short: str | None = None, - *, - short_only: bool = False, - default: list[T] | Void = void, - choices: list[T] | None = None, - converter: Callable[[str], list[T]] | None = None, - name_override: str | None = None, - append: Literal[False] = False, - at_least_one: Literal[True] = True, - mx_group: MxGroup | None = None, - help: str | None = None, -) -> list[T]: ... - -def option( # type: ignore - short=None, - *, - short_only=False, - default=void, - choices=None, - converter=None, - name_override=None, - append=False, - at_least_one=False, - mx_group=None, - help=None, -): - if short_only and short is None: - raise Exception("`short_only` cannot be True if `short` is not provided") - return _Option( - short=short, - short_only=short_only, - default=default, - choices=choices, - converter=converter, - name_override=name_override, - append=append, - at_least_one=at_least_one, - mx_group=mx_group, - help=help, - ) - - -def flag( - short: str | None = None, - *, - short_only: bool = False, - mx_group: MxGroup | None = None, - help: str | None = None, -) -> bool: - if short_only and short is None: - raise Exception("`short_only` cannot be True if `short` is not provided") - return _Flag( - short=short, - short_only=short_only, - help=help, - mx_group=mx_group, - ) # type: ignore - - -def no_flag(*, mx_group: MxGroup | None = None, help: str | None = None) -> bool: - return _NoFlag(mx_group=mx_group, help=help) # type: ignore +@dataclass +class Subparsers: + sub_parsers: dict[str, "Parser"] = field(default_factory=dict) + required: bool = False diff --git a/arcparse/_parser.py b/arcparse/_parser.py index f71977c..825d2cb 100644 --- a/arcparse/_parser.py +++ b/arcparse/_parser.py @@ -1,24 +1,104 @@ -from __future__ import annotations - -from argparse import ArgumentParser, _ActionsContainer -from collections.abc import Sequence -from dataclasses import dataclass -from enum import StrEnum +from collections.abc import Iterator, Sequence +from dataclasses import dataclass, field from types import NoneType, UnionType from typing import Any, Union, get_args, get_origin +import argparse import inspect -import re - -from ._arguments import MxGroup, _BaseArgument, _BaseValueArgument, _Flag, _Option, void -from ._typehints import ( - extract_collection_type, - extract_optional_type, - extract_subparsers_from_typehint, - extract_type_from_typehint, + +from arcparse.converters import itemwise + +from ._arguments import ( + BaseArgument, + BaseValueArgument, + MxGroup, + Subparsers, + TriFlag, + void, +) +from ._partial_arguments import ( + BasePartialArgument, + PartialFlag, + PartialMxGroup, + PartialOption, + PartialSubparsers, + PartialTriFlag, ) -from .converters import itemwise +from ._typehints import extract_optional_type, extract_subparsers_from_typehint +from .errors import InvalidArgument, InvalidParser, InvalidTypehint + + +@dataclass +class Parser[T]: + shape: type[T] + arguments: dict[str, BaseArgument] = field(default_factory=dict) + mx_groups: list[MxGroup] = field(default_factory=list) + + @property + def all_arguments(self) -> Iterator[tuple[str, BaseArgument]]: + yield from self.arguments.items() + for mx_group in self.mx_groups: + yield from mx_group.arguments.items() + + def apply(self, actions_container: argparse._ActionsContainer) -> None: + for name, argument in self.arguments.items(): + argument.apply(actions_container, name) + + for mx_group in self.mx_groups: + group = actions_container.add_mutually_exclusive_group(required=mx_group.required) + for name, argument in mx_group.arguments.items(): + argument.apply(group, name) -type NameTypeArg[TArg: "_BaseArgument | _Subparsers"] = tuple[str, type, TArg] + +@dataclass +class RootParser[T]: + parser: Parser + subparsers: tuple[str, Subparsers] | None = None + + def parse(self, args: Sequence[str] | None = None) -> T: + ap_parser = argparse.ArgumentParser() + self.parser.apply(ap_parser) + if self.subparsers is not None: + name, subparsers = self.subparsers + ap_subparsers = ap_parser.add_subparsers(dest=name, required=subparsers.required) + for name, subparser in subparsers.sub_parsers.items(): + ap_subparser = ap_subparsers.add_parser(name) + subparser.apply(ap_subparser) + + parsed = ap_parser.parse_args(args) + + ret = parsed.__dict__ + if self.subparsers is not None: + name, subparsers = self.subparsers + + # optional subparsers will result in `dict[name]` being `None` + if chosen_subparser := getattr(parsed, name, None): + sub_parser = subparsers.sub_parsers[chosen_subparser] + ret[name] = _construct_object_with_parsed(sub_parser, ret) + + return _construct_object_with_parsed(self.parser, ret) + + +def _construct_object_with_parsed[T](parser: Parser[T], parsed: dict[str, Any]) -> T: + # apply argument converters + for name, argument in parser.all_arguments: + if not isinstance(argument, BaseValueArgument) or argument.converter is None: + continue + + value = parsed.get(name, argument.default) + if isinstance(argument.converter, itemwise): + assert isinstance(value, list) + parsed[name] = [ + argument.converter(item) if isinstance(item, str) else item + for item in value + ] + else: + parsed[name] = argument.converter(value) if isinstance(value, str) else value + + # reduce tri_flags + tri_flag_names = [name for name, arg in parser.all_arguments if isinstance(arg, TriFlag)] + _reduce_tri_flags(parsed, tri_flag_names) + + return _instantiate_from_dict(parser.shape, parsed) def _instantiate_from_dict[T](cls: type[T], dict_: dict[str, Any]) -> T: @@ -32,7 +112,23 @@ def _instantiate_from_dict[T](cls: type[T], dict_: dict[str, Any]) -> T: return obj -def _collect_arguments(cls: type) -> list[NameTypeArg[_BaseArgument]]: +def _reduce_tri_flags(dict_: dict[str, Any], tri_flag_names: list[str]) -> None: + for name in tri_flag_names: + no_flag_name = f"no_{name}" + yes_case = dict_[name] + no_case = dict_[no_flag_name] + assert not yes_case or not no_case + + del dict_[no_flag_name] + if yes_case: + dict_[name] = True + elif no_case: + dict_[name] = False + else: + dict_[name] = None + + +def _collect_partial_arguments(cls: type) -> dict[str, tuple[type, BasePartialArgument]]: # collect declared typehints all_params: dict[str, tuple[type, Any]] = { name: (typehint, void) @@ -47,177 +143,92 @@ def _collect_arguments(cls: type) -> list[NameTypeArg[_BaseArgument]]: # ignore untyped class variables un if key not in all_params: - if isinstance(value, _BaseArgument): - raise Exception(f"Argument {key} is missing a type-hint and would be ignored") + if isinstance(value, BasePartialArgument): + raise InvalidTypehint(f"Argument {key} is missing a type-hint and would be ignored") continue typehint, _ = all_params[key] all_params[key] = (typehint, value) # construct arguments - arguments: list[NameTypeArg] = [] + arguments: dict[str, tuple[type, BasePartialArgument]] = {} for name, (typehint, value) in all_params.items(): - if isinstance(value, _Subparsers): + if isinstance(value, PartialSubparsers): continue if get_origin(typehint) in {Union, UnionType}: union_args = get_args(typehint) if len(union_args) > 2 or NoneType not in union_args: - raise Exception("Union can be used only for optional arguments (length of 2, 1 of them being None)") + raise InvalidTypehint("Union can be used only for optional arguments (length of 2, 1 of them being None)") - if isinstance(value, _BaseArgument): + if isinstance(value, BasePartialArgument): argument = value + elif typehint is bool: + if value is not void: + raise InvalidArgument("defaults don't make sense for flags") + argument = PartialFlag() + elif extract_optional_type(typehint) == bool: + argument = PartialTriFlag() else: - argument = _construct_argument(typehint, value) - - arguments.append((name, typehint, argument)) + argument = PartialOption(default=value) + arguments[name] = (typehint, argument) return arguments -def _construct_argument(typehint: type, default: Any) -> _BaseArgument: - if typehint is bool: - if default is not void: - raise Exception("defaults don't make sense for flags") - return _Flag() - - actual_type = extract_type_from_typehint(typehint) - if actual_type is bool: - raise Exception("Can't construct argument with inner type bool, conversion would be always True") - elif getattr(actual_type, "_is_protocol", False): - raise Exception("Argument with no converter can't be typed as a Protocol subclass") - - if type_ := extract_collection_type(typehint): - converter = itemwise(type_) - else: - converter = actual_type - - if issubclass(actual_type, StrEnum): - return _Option(default=default, choices=list(actual_type), converter=converter) - elif actual_type == re.Pattern: - return _Option(default=default, converter=re.compile) - - return _Option(default=default, converter=converter if actual_type is not str else None) - - -def _check_argument_sanity(name: str, typehint: type, arg: _BaseArgument) -> None: - if isinstance(arg, _BaseValueArgument) and extract_type_from_typehint(typehint) is bool and arg.converter is None: - raise Exception(f"Argument \"{name}\" yielding a value can't be typed as `bool`") - - if arg.mx_group is not None and isinstance(arg, _BaseValueArgument) and extract_optional_type(typehint) is None and arg.default is void: - raise Exception(f"Argument \"{name}\" in mutually exclusive group has to have a default") - - -def _collect_subparsers(cls: type) -> NameTypeArg[_Subparsers] | None: - all_subparsers = [(key, value) for key, value in vars(cls).items() if isinstance(value, _Subparsers)] +def _collect_subparsers(shape: type) -> tuple[str, type, PartialSubparsers] | None: + all_subparsers = [(key, value) for key, value in vars(shape).items() if isinstance(value, PartialSubparsers)] if not all_subparsers: return None elif len(all_subparsers) > 1: - raise Exception(f"Multiple subparsers definitions found on {cls}") - - name, subparsers = all_subparsers[0] - if not (typehint := inspect.get_annotations(cls, eval_str=True)[name]): - raise Exception("subparsers have to be type-hinted") + raise InvalidParser(f"Multiple subparsers definitions found on {shape}") - if not extract_subparsers_from_typehint(typehint): - raise Exception(f"Unable to extract subparser types from {typehint}, expected a non-empty union of ArcParser types") + name, partial_subparsers = all_subparsers[0] + if not (typehint := inspect.get_annotations(shape, eval_str=True).get(name)): + raise InvalidTypehint("subparsers have to be type-hinted") - return (name, typehint, subparsers) + return name, typehint, partial_subparsers -@dataclass -class _Subparsers: - names: list[str] - - def apply( - self, - parser: ArgumentParser, - name: str, - typehint: type, - ) -> None: - if not (subparser_types := extract_subparsers_from_typehint(typehint)): - raise Exception(f"Unable to extract subparser types from {typehint}, expected a non-empty union of ArcParser types") - - subparsers_kwargs: dict = {"dest": name} - if NoneType not in subparser_types: - subparsers_kwargs["required"] = True - subparsers = parser.add_subparsers(**subparsers_kwargs) - - nonnull_subparser_types: list[type[_Parser]] = [ - typ for typ in subparser_types if typ is not NoneType - ] # type: ignore (NoneType is getting confused with None) - - for name, subparser_type in zip(self.names, nonnull_subparser_types): - subparser = _make_parser(subparser_type) - subparser.apply(subparsers.add_parser(name)) - - -class _Parser[T]: - def __init__( - self, - cls: type[T], - arguments: list[NameTypeArg[_BaseArgument]] | None = None, - mx_groups: dict[MxGroup, list[NameTypeArg[_BaseArgument]]] | None = None, - subparsers: NameTypeArg[_Subparsers] | None = None, - ): - self._cls = cls - self._arguments = arguments if arguments is not None else {} - self._mx_groups = mx_groups if mx_groups is not None else {} - self._subparsers = subparsers - - assert all(arg.mx_group is None for _, _, arg in self._arguments) +def _make_parser[T](shape: type[T]) -> Parser[T]: + arguments = {} + mx_groups: dict[PartialMxGroup, MxGroup] = {} + for name, (typehint, partial_argument) in _collect_partial_arguments(shape).items(): + mx_group = partial_argument.mx_group + argument = partial_argument.resolve_with_typehint(typehint) - def parse(self, args: Sequence[str] | None = None) -> T: - parser = ArgumentParser() - self.apply(parser) - if self._subparsers is not None: - name, typehint, subparsers = self._subparsers - subparsers.apply(parser, name, typehint) - parsed = parser.parse_args(args) - - ret = parsed.__dict__.copy() - if self._subparsers is not None: - name, typehint, subparsers = self._subparsers - - # optional subparsers will result in `dict[name]` being `None` - subshape_classes = extract_subparsers_from_typehint(typehint) - assert subshape_classes is not None - if chosen_subparser := getattr(parsed, name): - subshape_class = subshape_classes[subparsers.names.index(chosen_subparser)] - ret[name] = _instantiate_from_dict(subshape_class, parsed.__dict__) - - return _instantiate_from_dict(self._cls, ret) - - def apply(self, actions_container: _ActionsContainer) -> None: - for (name, typehint, argument) in self._arguments: - argument.apply(actions_container, name, typehint) - - for mx_group, arguments in self._mx_groups.items(): - group = actions_container.add_mutually_exclusive_group(required=mx_group.required) - for name, typehint, arg in arguments: - arg.apply(group, name, typehint) - - -def _make_parser[T](cls: type[T]) -> _Parser[T]: - arguments = _collect_arguments(cls) - subparsers = _collect_subparsers(cls) - - for name, type, arg in arguments: - _check_argument_sanity(name, type, arg) - - mx_groups: dict[MxGroup, list[NameTypeArg[_BaseArgument]]] = {} - for argument in arguments: - if argument[2].mx_group is not None: - mx_groups.setdefault(argument[2].mx_group, []).append(argument) - arguments = [argument for argument in arguments if argument[2].mx_group is None] - - return _Parser(cls, arguments, mx_groups, subparsers) - - -def arcparser[T](cls: type[T]) -> _Parser[T]: - return _make_parser(cls) - - -def subparsers(*args: str) -> Any: - return _Subparsers(names=list(args)) + if mx_group is None: + arguments[name] = argument + else: + if mx_group not in mx_groups: + mx_groups[mx_group] = MxGroup(required=mx_group.required) + mx_groups[mx_group].arguments[name] = argument + + return Parser( + shape, + arguments, + list(mx_groups.values()), + ) + + +def _make_root_parser[T](shape: type[T]) -> RootParser[T]: + match _collect_subparsers(shape): + case (name, typehint, partial_subparsers): + subshapes = extract_subparsers_from_typehint(typehint) + subparsers_by_name = { + name: _make_parser(subshape) + for name, subshape in zip(partial_subparsers.names, subshapes) + } + subparsers = (name, Subparsers(subparsers_by_name, required=NoneType not in subshapes)) + case _: + subparsers = None + + return RootParser( + _make_parser(shape), + subparsers, + ) + + +def arcparser[T](shape: type[T]) -> RootParser[T]: + return _make_root_parser(shape) diff --git a/arcparse/_partial_arguments.py b/arcparse/_partial_arguments.py new file mode 100644 index 0000000..58bf2c0 --- /dev/null +++ b/arcparse/_partial_arguments.py @@ -0,0 +1,196 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable, Collection +from dataclasses import dataclass +from typing import Any, Literal, get_origin +import re + +from arcparse.errors import InvalidArgument, InvalidTypehint, MissingConverter + +from ._arguments import ( + BaseValueArgument, + ContainerApplicable, + Flag, + NoFlag, + Option, + Positional, + TriFlag, + Void, + void, +) +from ._typehints import ( + extract_collection_type, + extract_literal_strings, + extract_optional_type, + extract_type_from_typehint, +) +from .converters import itemwise + + +@dataclass(kw_only=True, eq=False) +class PartialMxGroup: + required: bool = False + + +@dataclass(kw_only=True) +class BasePartialArgument[R: ContainerApplicable](ABC): + mx_group: PartialMxGroup | None = None + + @abstractmethod + def resolve_with_typehint(self, typehint: type) -> R: + ... + + def resolve_to_kwargs(self, typehint: type) -> dict[str, Any]: + return {} + + +@dataclass(kw_only=True) +class BaseSinglePartialArgument[R: ContainerApplicable](BasePartialArgument[R]): + help: str | None = None + + def resolve_to_kwargs(self, typehint: type) -> dict[str, Any]: + return super().resolve_to_kwargs(typehint) | { + "help": self.help, + } + + +@dataclass(kw_only=True) +class BasePartialValueArgument[T, R: BaseValueArgument](BaseSinglePartialArgument[R]): + default: T | str | Void = void + choices: Collection[str] | None = None + converter: Callable[[str], T] | None = None + name_override: str | None = None + at_least_one: bool = False + + def resolve_to_kwargs(self, typehint: type) -> dict[str, Any]: + kwargs = super().resolve_to_kwargs(typehint) + + type_ = extract_type_from_typehint(typehint) + if self.converter is None: + if type_ is bool: + raise InvalidTypehint("Arguments yielding a value cannot be typed as `bool`") + elif getattr(type_, "_is_protocol", False): + raise MissingConverter("Argument with no converter can't be typed as a Protocol subclass") + + if type_ is not str and get_origin(type_) != Literal: + if extract_collection_type(typehint): + self.converter = itemwise(type_) # type: ignore (list[T@itemwise] somehow incompatible with T@_BaseValueArgument) + elif type_ == re.Pattern: + self.converter = re.compile # type: ignore (somehow incompatible) + else: + self.converter = type_ + + choices = self.choices + if literal_choices := extract_literal_strings(type_): + if self.choices is None: + choices = literal_choices + elif not (set(self.choices) <= set(literal_choices)): + raise InvalidArgument("explicit choices have to be a subset of target literal typehint") + + if self.converter is not None: + kwargs["converter"] = self.converter + if self.default is not void: + kwargs["default"] = self.default + if choices is not None: + kwargs["choices"] = choices + + if self.need_multiple(typehint) and not self.at_least_one and self.default is void: + kwargs["default"] = [] + + return kwargs + + def need_multiple(self, typehint: type) -> bool: + return ( + (self.converter is None and extract_collection_type(typehint) is not None) + or isinstance(self.converter, itemwise) + ) + + +@dataclass +class PartialPositional[T](BasePartialValueArgument[T, Positional]): + def resolve_with_typehint(self, typehint: type) -> Positional: + kwargs = self.resolve_to_kwargs(typehint) + return Positional(**kwargs) + + def resolve_to_kwargs(self, typehint: type) -> dict[str, Any]: + kwargs = super().resolve_to_kwargs(typehint) + + type_is_optional = bool(extract_optional_type(typehint)) + type_is_collection = bool(extract_collection_type(typehint)) + optional = type_is_optional or type_is_collection or self.default is not void + if not optional and self.mx_group is not None: + raise InvalidArgument("Arguments in mutually exclusive group have to have a default") + + if self.name_override is not None: + kwargs["metavar"] = self.name_override + + if self.need_multiple(typehint): + kwargs["nargs"] = "+" if self.at_least_one else "*" + kwargs["metavar"] = self.name_override + elif optional: + kwargs["nargs"] = "?" + + return kwargs + + +@dataclass +class PartialOption[T](BasePartialValueArgument[T, Option]): + short: str | None = None + short_only: bool = False + append: bool = False + + def resolve_with_typehint(self, typehint: type) -> Option: + kwargs = self.resolve_to_kwargs(typehint) + return Option(**kwargs) + + def resolve_to_kwargs(self, typehint: type) -> dict[str, Any]: + kwargs = super().resolve_to_kwargs(typehint) + kwargs["short"] = self.short + kwargs["short_only"] = self.short_only + + if self.need_multiple(typehint): + if self.append: + kwargs["append"] = True + else: + kwargs["nargs"] = "+" if self.at_least_one else "*" + + if self.name_override is not None: + kwargs["metavar"] = self.name_override.replace("-", "_").upper() + + type_is_optional = bool(extract_optional_type(typehint)) + type_is_collection = bool(extract_collection_type(typehint)) + required = (not (type_is_optional or type_is_collection) and self.default is void) or self.at_least_one + if required: + if self.mx_group is not None: + raise InvalidArgument("Arguments in mutually exclusive group have to have a default") + kwargs["required"] = True + + return kwargs + + +@dataclass +class PartialFlag(BaseSinglePartialArgument[Flag]): + short: str | None = None + short_only: bool = False + + def resolve_with_typehint(self, typehint: type) -> Flag: + kwargs = self.resolve_to_kwargs(typehint) + kwargs["short"] = self.short + kwargs["short_only"] = self.short_only + return Flag(**kwargs) + + +@dataclass +class PartialNoFlag(BaseSinglePartialArgument[NoFlag]): + def resolve_with_typehint(self, typehint: type) -> NoFlag: + kwargs = self.resolve_to_kwargs(typehint) + return NoFlag(**kwargs) + + +class PartialTriFlag(BasePartialArgument[TriFlag]): + def resolve_with_typehint(self, typehint: type) -> TriFlag: + return TriFlag() + + +@dataclass +class PartialSubparsers: + names: list[str] diff --git a/arcparse/_typehints.py b/arcparse/_typehints.py index b9d2ef3..e0c2315 100644 --- a/arcparse/_typehints.py +++ b/arcparse/_typehints.py @@ -1,5 +1,7 @@ from types import NoneType, UnionType -from typing import Optional, Union, get_args, get_origin +from typing import Literal, Optional, Union, get_args, get_origin + +from arcparse.errors import InvalidTypehint def extract_optional_type(typehint: type) -> type | None: @@ -23,11 +25,11 @@ def extract_collection_type(typehint: type) -> type | None: return None -def extract_subparsers_from_typehint(typehint: type) -> list[type] | None: +def extract_subparsers_from_typehint(typehint: type) -> list[type]: origin = get_origin(typehint) if origin in {Union, UnionType}: return list(get_args(typehint)) - return None + raise InvalidTypehint(f"Unable to extract subparser types from {typehint}, expected a non-empty union of ArcParser types") def extract_type_from_typehint(typehint: type) -> type: @@ -36,3 +38,15 @@ def extract_type_from_typehint(typehint: type) -> type: elif collection_type := extract_collection_type(typehint): return collection_type return typehint + + +def extract_literal_strings(typehint: type) -> list[str] | None: + origin = get_origin(typehint) + if origin != Literal: + return None + + args = get_args(typehint) + if not all(isinstance(arg, str) for arg in args): + return None + + return list(args) diff --git a/arcparse/converters.py b/arcparse/converters.py index bb41753..fe17e15 100644 --- a/arcparse/converters.py +++ b/arcparse/converters.py @@ -16,6 +16,9 @@ def __init__(self, converter: Callable[[str], T]) -> None: def __call__(self, string: str) -> list[T]: return self._converter(string) # type: ignore + def __repr__(self) -> str: + return f"itemwise({self._converter})" + def sv[T](separator: str, type_: type[T] = str, /) -> Callable[[str], list[T]]: def conv(arg: str) -> list[T]: diff --git a/arcparse/errors.py b/arcparse/errors.py new file mode 100644 index 0000000..cd1a831 --- /dev/null +++ b/arcparse/errors.py @@ -0,0 +1,16 @@ + + +class InvalidParser(Exception): + pass + + +class InvalidArgument(InvalidParser): + pass + + +class InvalidTypehint(InvalidArgument): + pass + + +class MissingConverter(InvalidArgument): + pass diff --git a/poetry.lock b/poetry.lock index 57f451c..6c5f27b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "black" @@ -208,6 +208,32 @@ pluggy = ">=0.12,<2.0" [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "ruff" +version = "0.1.9" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.1.9-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:e6a212f436122ac73df851f0cf006e0c6612fe6f9c864ed17ebefce0eff6a5fd"}, + {file = "ruff-0.1.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:28d920e319783d5303333630dae46ecc80b7ba294aeffedf946a02ac0b7cc3db"}, + {file = "ruff-0.1.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:104aa9b5e12cb755d9dce698ab1b97726b83012487af415a4512fedd38b1459e"}, + {file = "ruff-0.1.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1e63bf5a4a91971082a4768a0aba9383c12392d0d6f1e2be2248c1f9054a20da"}, + {file = "ruff-0.1.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4d0738917c203246f3e275b37006faa3aa96c828b284ebfe3e99a8cb413c8c4b"}, + {file = "ruff-0.1.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:69dac82d63a50df2ab0906d97a01549f814b16bc806deeac4f064ff95c47ddf5"}, + {file = "ruff-0.1.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2aec598fb65084e41a9c5d4b95726173768a62055aafb07b4eff976bac72a592"}, + {file = "ruff-0.1.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:744dfe4b35470fa3820d5fe45758aace6269c578f7ddc43d447868cfe5078bcb"}, + {file = "ruff-0.1.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:479ca4250cab30f9218b2e563adc362bd6ae6343df7c7b5a7865300a5156d5a6"}, + {file = "ruff-0.1.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:aa8344310f1ae79af9ccd6e4b32749e93cddc078f9b5ccd0e45bd76a6d2e8bb6"}, + {file = "ruff-0.1.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:837c739729394df98f342319f5136f33c65286b28b6b70a87c28f59354ec939b"}, + {file = "ruff-0.1.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:e6837202c2859b9f22e43cb01992373c2dbfeae5c0c91ad691a4a2e725392464"}, + {file = "ruff-0.1.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:331aae2cd4a0554667ac683243b151c74bd60e78fb08c3c2a4ac05ee1e606a39"}, + {file = "ruff-0.1.9-py3-none-win32.whl", hash = "sha256:8151425a60878e66f23ad47da39265fc2fad42aed06fb0a01130e967a7a064f4"}, + {file = "ruff-0.1.9-py3-none-win_amd64.whl", hash = "sha256:c497d769164df522fdaf54c6eba93f397342fe4ca2123a2e014a5b8fc7df81c7"}, + {file = "ruff-0.1.9-py3-none-win_arm64.whl", hash = "sha256:0e17f53bcbb4fff8292dfd84cf72d767b5e146f009cccd40c2fad27641f8a7a9"}, + {file = "ruff-0.1.9.tar.gz", hash = "sha256:b041dee2734719ddbb4518f762c982f2e912e7f28b8ee4fe1dee0b15d1b6e800"}, +] + [[package]] name = "setuptools" version = "68.2.2" @@ -227,4 +253,4 @@ testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jar [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "1eb854ba2a09da1231f9b02cd14523dd22038c30ca25e63bc910d54b40ca38fc" +content-hash = "28ba4a494627646b95f276bf778ca79d266c7a34046487e571f82ecbc3adbb0f" diff --git a/pyproject.toml b/pyproject.toml index 9a11267..56bea9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,10 +25,12 @@ black = "^23.11.0" pyright = "^1.1.335" pytest = "^7.4.3" isort = "^5.12.0" +ruff = "^0.1.9" [tool.isort] profile = "black" from_first = "true" +lines_after_imports = 2 [build-system] requires = ["poetry-core"] diff --git a/tests/test_auto_converter.py b/tests/test_auto_converter.py index 83c2df9..1c62e6b 100644 --- a/tests/test_auto_converter.py +++ b/tests/test_auto_converter.py @@ -1,10 +1,11 @@ from enum import StrEnum, auto -from typing import Any +from typing import Any, Literal import re import pytest -from arcparse import arcparser +from arcparse import arcparser, option, positional +from arcparse.errors import InvalidArgument class Result(StrEnum): @@ -15,22 +16,30 @@ class Result(StrEnum): @arcparser class Args: num: int | None + num_default: int = option(default="123") result: Result | None regex: re.Pattern | None + literal: Literal["yes", "no"] | None defaults = { "num": None, + "num_default": 123, "result": None, + "regex": None, + "literal": None, } @pytest.mark.parametrize( "arg_string,provided", [ ("--num 123", {"num": 123}), + ("--num-default 456", {"num_default": 456}), ("--result pass", {"result": Result.PASS}), ("--result fail", {"result": Result.FAIL}), ("--regex ^\\d+$", {"regex": re.compile(r"^\d+$")}), + ("--literal yes", {"literal": "yes"}), + ("--literal no", {"literal": "no"}), ] ) def test_auto_converter_valid(arg_string: str, provided: dict[str, Any]) -> None: @@ -51,3 +60,48 @@ def test_auto_converter_valid(arg_string: str, provided: dict[str, Any]) -> None def test_option_invalid(arg_string: str) -> None: with pytest.raises(BaseException): Args.parse(args = arg_string.split()) + + +def test_enum_positional() -> None: + @arcparser + class Args: + result: Result = positional() + + parsed = Args.parse("pass".split()) + assert parsed.result == Result.PASS + + parsed = Args.parse("fail".split()) + assert parsed.result == Result.FAIL + + +def test_literal_positional() -> None: + @arcparser + class Args: + literal: Literal["yes", "no"] = positional() + + parsed = Args.parse("yes".split()) + assert parsed.literal == "yes" + + parsed = Args.parse("no".split()) + assert parsed.literal == "no" + + with pytest.raises(SystemExit): + Args.parse("maybe".split()) + + +def test_literal_choices_subset() -> None: + class InvalidArgs: + literal: Literal["yes", "no"] = positional(choices={"yes", "no", "maybe"}) + + with pytest.raises(InvalidArgument): + arcparser(InvalidArgs) + + @arcparser + class ValidArgs: + literal: Literal["yes", "no"] = positional(choices={"yes"}) + + parsed = ValidArgs.parse("yes".split()) + assert parsed.literal == "yes" + + with pytest.raises(SystemExit): + Args.parse("no".split()) diff --git a/tests/test_dict_helpers.py b/tests/test_dict_helpers.py new file mode 100644 index 0000000..6f9efcc --- /dev/null +++ b/tests/test_dict_helpers.py @@ -0,0 +1,52 @@ +import pytest + +from arcparse import arcparser, dict_option, dict_positional +from arcparse.errors import InvalidArgument + + +dict_ = { + "foo": 1, + "bar": 0, +} + + +def test_dict_positional_default_in_dict() -> None: + with pytest.raises(InvalidArgument): + dict_positional(dict_, default=2) + + +@pytest.mark.parametrize( + "arg_string,value", + [ + ("foo", 1), + ("bar", 0), + ] +) +def test_dict_positional(arg_string: str, value: int) -> None: + @arcparser + class Args: + foo_bar: int = dict_positional(dict_) + + parsed = Args.parse(arg_string.split()) + assert parsed.foo_bar == value + + +def test_dict_option_default_in_dict() -> None: + with pytest.raises(InvalidArgument): + dict_option(dict_, default=2) + + +@pytest.mark.parametrize( + "arg_string,value", + [ + ("--foo-bar foo", 1), + ("--foo-bar bar", 0), + ] +) +def test_dict_option(arg_string: str, value: int) -> None: + @arcparser + class Args: + foo_bar: int = dict_option(dict_) + + parsed = Args.parse(arg_string.split()) + assert parsed.foo_bar == value diff --git a/tests/test_flag.py b/tests/test_flag.py index d6801e0..5743767 100644 --- a/tests/test_flag.py +++ b/tests/test_flag.py @@ -2,7 +2,7 @@ import pytest -from arcparse import arcparser, flag, no_flag +from arcparse import arcparser, flag, no_flag, tri_flag @arcparser @@ -31,7 +31,7 @@ class Args: ("-o", {"boo": True}), ] ) -def test_option_valid(arg_string: str, provided: dict[str, Any]) -> None: +def test_flag_valid(arg_string: str, provided: dict[str, Any]) -> None: parsed = Args.parse(arg_string.split()) for k, v in (defaults | provided).items(): @@ -41,11 +41,40 @@ def test_option_valid(arg_string: str, provided: dict[str, Any]) -> None: @pytest.mark.parametrize( "arg_string", [ - "flag --bar", - "flag --no-foo", - "flag --boo", + "--bar", + "--no-foo", + "--boo", ] ) -def test_option_invalid(arg_string: str) -> None: +def test_flag_invalid(arg_string: str) -> None: with pytest.raises(SystemExit): Args.parse(args = arg_string.split()) + + + +@pytest.mark.parametrize( + "string,result", + [ + ("", {"foo": None, "bar": None}), + ("--foo", {"foo": True, "bar": None}), + ("--no-foo", {"foo": False, "bar": None}), + ("--foo --no-foo", None), + ("", {"foo": None, "bar": None}), + ("--bar", {"foo": None, "bar": True}), + ("--no-bar", {"foo": None, "bar": False}), + ("--bar --no-bar", None), + ], +) +def test_tri_flag(string: str, result: dict[str, Any]): + @arcparser + class Args: + foo: bool | None + bar: bool | None = tri_flag() + + if result is None: + with pytest.raises(SystemExit): + Args.parse(string.split()) + else: + args = Args.parse(string.split()) + for k, v in result.items(): + assert getattr(args, k) == v diff --git a/tests/test_invalid.py b/tests/test_invalid.py index 9cf0051..2302093 100644 --- a/tests/test_invalid.py +++ b/tests/test_invalid.py @@ -1,36 +1,56 @@ import pytest from arcparse import arcparser, positional +from arcparse.errors import InvalidArgument, InvalidTypehint -class Invalid1: - x: bool | None +def test_no_bool_valued_type_without_converter() -> None: + class Args: + x: bool = positional() + + with pytest.raises(InvalidTypehint): + arcparser(Args) + + +def test_no_nonnone_union() -> None: + class Args: + x: int | str + + with pytest.raises(InvalidTypehint): + arcparser(Args) -class Invalid2: - x: bool = positional() -class Invalid3: - x: int | str +def test_no_large_union_typehint() -> None: + class Args: + x: int | str | None + + with pytest.raises(InvalidTypehint): + arcparser(Args) + + +def test_no_typehint_invalid() -> None: + class Args: + x = positional() + + with pytest.raises(InvalidTypehint): + arcparser(Args) -class Invalid4: - x: int | str | None -class Invalid5: - x = positional() +def test_no_default_for_flag() -> None: + class ArgsTrue: + x: bool = True -class Invalid6: - x: bool = True + class ArgsFalse: + x: bool = False -class Invalid7: - x: bool = False + with pytest.raises(InvalidArgument): + arcparser(ArgsTrue) -@pytest.mark.parametrize("args_shape", [Invalid1, Invalid2, Invalid3, Invalid4, Invalid5, Invalid6, Invalid7]) -def test_invalid(args_shape: type) -> None: - with pytest.raises(Exception): - arcparser(args_shape) + with pytest.raises(InvalidArgument): + arcparser(ArgsFalse) -def test_untyped_variable() -> None: +def test_untyped_nonargument_variable_valid() -> None: @arcparser class Args: foo = 1 diff --git a/tests/test_mutual_exclusion.py b/tests/test_mutual_exclusion.py index d6b5c2c..c557f1d 100644 --- a/tests/test_mutual_exclusion.py +++ b/tests/test_mutual_exclusion.py @@ -2,13 +2,14 @@ import pytest -from arcparse import MxGroup, arcparser, flag, option +from arcparse import arcparser, flag, mx_group, option, subparsers, tri_flag +from arcparse.errors import InvalidArgument def test_group_as_untyped_attribute() -> None: @arcparser class Args: - group = MxGroup() + group = mx_group() foo: str | None = option(mx_group=group) bar: str | None = option(mx_group=group) @@ -17,29 +18,28 @@ class Args: def test_group_elements_both_nonoptional() -> None: class Args: - foo: str = option(mx_group=(group := MxGroup())) + foo: str = option(mx_group=(group := mx_group())) bar: str = option(mx_group=group) - # TODO: raise mx-arg-no-default instead of required-arguments error - with pytest.raises(Exception): + with pytest.raises(InvalidArgument): arcparser(Args) def test_group_elements_some_nonoptional() -> None: class Args: - foo: str = option(mx_group=(group := MxGroup())) + foo: str = option(mx_group=(group := mx_group())) bar: str | None = option(mx_group=group) - with pytest.raises(Exception): + with pytest.raises(InvalidArgument): arcparser(Args) @arcparser class Args: - foo: str | None = option(mx_group=(option_group := MxGroup())) + foo: str | None = option(mx_group=(option_group := mx_group())) bar: str | None = option(mx_group=option_group) - flag1: bool = flag(mx_group=(flag_group := MxGroup())) + flag1: bool = flag(mx_group=(flag_group := mx_group())) flag2: bool = flag(mx_group=flag_group) @@ -66,3 +66,80 @@ def test_mutual_exclusion_valid(string: str, result: dict[str, Any]) -> None: args = Args.parse(string.split()) for k, v in result.items(): assert getattr(args, k) == v + + +def test_mutual_exclusion_required() -> None: + @arcparser + class Args: + foo: str | None = option(mx_group=(option_group := mx_group(required=True))) + bar: str | None = option(mx_group=option_group) + + with pytest.raises(SystemExit): + Args.parse("".split()) + + parsed = Args.parse("--foo foo".split()) + assert parsed.foo == "foo" + assert parsed.bar is None + + parsed = Args.parse("--bar bar".split()) + assert parsed.foo is None + assert parsed.bar == "bar" + + with pytest.raises(SystemExit): + Args.parse("--foo foo --bar bar".split()) + + +def test_tri_flag_inside_mx_group() -> None: + @arcparser + class Args: + foo: str | None = option(mx_group=(group := mx_group())) + bar: bool | None = tri_flag(mx_group=group) + + parsed = Args.parse("".split()) + assert parsed.foo is None + assert parsed.bar is None + + parsed = Args.parse("--foo foo".split()) + assert parsed.foo == "foo" + assert parsed.bar is None + + parsed = Args.parse("--bar".split()) + assert parsed.foo is None + assert parsed.bar is True + + parsed = Args.parse("--no-bar".split()) + assert parsed.foo is None + assert parsed.bar is False + + with pytest.raises(SystemExit): + Args.parse("--foo foo --bar".split()) + + with pytest.raises(SystemExit): + Args.parse("--foo foo --no-bar".split()) + + +def test_tri_flag_inside_subparser() -> None: + class FooArgs: + foo: str + + class BarArgs: + bar: bool | None + + @arcparser + class Args: + foo_bar: FooArgs | BarArgs = subparsers("foo", "bar") + + parsed = Args.parse("bar".split()) + assert isinstance(foo_bar := parsed.foo_bar, BarArgs) + assert foo_bar.bar is None + + parsed = Args.parse("bar --bar".split()) + assert isinstance(foo_bar := parsed.foo_bar, BarArgs) + assert foo_bar.bar is True + + parsed = Args.parse("bar --no-bar".split()) + assert isinstance(foo_bar := parsed.foo_bar, BarArgs) + assert foo_bar.bar is False + + with pytest.raises(SystemExit): + Args.parse("bar --bar --no-bar".split()) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index c2cac1d..648b8ce 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -4,6 +4,7 @@ import pytest from arcparse import arcparser, option +from arcparse.errors import MissingConverter class FooLike(Protocol): @@ -32,7 +33,7 @@ def test_protocol_no_converter_invalid(): class Args: foo: FooLike - with pytest.raises(Exception): + with pytest.raises(MissingConverter): arcparser(Args) @@ -40,7 +41,7 @@ def test_inh_protocol_no_converter_invalid(): class Args: foo_bar: FooBarLike - with pytest.raises(Exception): + with pytest.raises(MissingConverter): arcparser(Args) diff --git a/tests/test_readme.py b/tests/test_readme.py index d576aeb..9c1fe2b 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -31,8 +31,4 @@ def test_python_codeblock_functional(code: str, tmp_path: Path) -> None: path = tmp_path / "code.py" path.write_text(code) - # try: subprocess.check_call(["python3", str(path)]) - # runpy.run_path(str(tmp_path)) - # except BaseException as e: - # raise Exception(code) from e diff --git a/tests/test_subparsers.py b/tests/test_subparsers.py index b61d550..2a8ecc6 100644 --- a/tests/test_subparsers.py +++ b/tests/test_subparsers.py @@ -4,6 +4,7 @@ import pytest from arcparse import arcparser, positional, subparsers +from arcparse.errors import InvalidParser class FooArgs: @@ -24,15 +25,15 @@ class OptArgs: @pytest.mark.parametrize( "string,result", [ - ("", None), + ("", SystemExit), ("foo --arg1 foo", (FooArgs, {"arg1": "foo"})), ("bar 123", (BarArgs, {"arg2": 123})), - ("bar bar", None), + ("bar bar", ValueError), ], ) -def test_subparsers_required(string: str, result: tuple[type, dict[str, Any]] | None) -> None: - if result is None: - with pytest.raises(SystemExit): +def test_subparsers_required(string: str, result: tuple[type, dict[str, Any]] | type[BaseException]) -> None: + if isinstance(result, type): + with pytest.raises(result): ReqArgs.parse(string.split()) else: args = ReqArgs.parse(string.split()) @@ -48,12 +49,12 @@ def test_subparsers_required(string: str, result: tuple[type, dict[str, Any]] | ("", (NoneType, None)), ("foo --arg1 foo", (FooArgs, {"arg1": "foo"})), ("bar 123", (BarArgs, {"arg2": 123})), - ("bar bar", None), + ("bar bar", ValueError), ], ) -def test_subparsers_optional(string: str, result: tuple[type, dict[str, Any] | None] | None) -> None: - if result is None: - with pytest.raises(SystemExit): +def test_subparsers_optional(string: str, result: tuple[type, dict[str, Any] | None] | type[BaseException]) -> None: + if isinstance(result, type): + with pytest.raises(result): OptArgs.parse(string.split()) else: args = OptArgs.parse(string.split()) @@ -83,5 +84,5 @@ class Args: foo_or_bar: Foo | Bar = subparsers("foo", "bar") baz_or_boo: Baz | Boo = subparsers("baz", "boo") - with pytest.raises(Exception): + with pytest.raises(InvalidParser): arcparser(Args)