Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLI ignore external parser list fix #379

Merged
merged 1 commit into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 38 additions & 24 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from enum import Enum
from pathlib import Path
from textwrap import dedent
from types import SimpleNamespace
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -1155,13 +1156,15 @@ def __call__(self, *, args: list[str] | tuple[str, ...] | bool) -> CliSettingsSo
...

@overload
def __call__(self, *, parsed_args: Namespace | dict[str, list[str] | str]) -> CliSettingsSource[T]:
def __call__(
self, *, parsed_args: Namespace | SimpleNamespace | dict[str, list[str] | str]
) -> CliSettingsSource[T]:
"""
Loads parsed command line arguments into the CLI settings source.

Note:
The parsed args must be in `argparse.Namespace` or vars dictionary (e.g., vars(argparse.Namespace))
format.
The parsed args must be in `argparse.Namespace`, `SimpleNamespace`, or vars dictionary
(e.g., vars(argparse.Namespace)) format.

Args:
parsed_args: The parsed args to load.
Expand All @@ -1175,7 +1178,7 @@ def __call__(
self,
*,
args: list[str] | tuple[str, ...] | bool | None = None,
parsed_args: Namespace | dict[str, list[str] | str] | None = None,
parsed_args: Namespace | SimpleNamespace | dict[str, list[str] | str] | None = None,
) -> dict[str, Any] | CliSettingsSource[T]:
if args is not None and parsed_args is not None:
raise SettingsError('`args` and `parsed_args` are mutually exclusive')
Expand All @@ -1194,13 +1197,15 @@ def __call__(
def _load_env_vars(self) -> Mapping[str, str | None]: ...

@overload
def _load_env_vars(self, *, parsed_args: Namespace | dict[str, list[str] | str]) -> CliSettingsSource[T]:
def _load_env_vars(
self, *, parsed_args: Namespace | SimpleNamespace | dict[str, list[str] | str]
) -> CliSettingsSource[T]:
"""
Loads the parsed command line arguments into the CLI environment settings variables.

Note:
The parsed args must be in `argparse.Namespace` or vars dictionary (e.g., vars(argparse.Namespace))
format.
The parsed args must be in `argparse.Namespace`, `SimpleNamespace`, or vars dictionary
(e.g., vars(argparse.Namespace)) format.

Args:
parsed_args: The parsed args to load.
Expand All @@ -1211,12 +1216,12 @@ def _load_env_vars(self, *, parsed_args: Namespace | dict[str, list[str] | str])
...

def _load_env_vars(
self, *, parsed_args: Namespace | dict[str, list[str] | str] | None = None
self, *, parsed_args: Namespace | SimpleNamespace | dict[str, list[str] | str] | None = None
) -> Mapping[str, str | None] | CliSettingsSource[T]:
if parsed_args is None:
return {}

if isinstance(parsed_args, Namespace):
if isinstance(parsed_args, (Namespace, SimpleNamespace)):
parsed_args = vars(parsed_args)

selected_subcommands: list[str] = []
Expand Down Expand Up @@ -1246,26 +1251,35 @@ def _load_env_vars(

return self

def _get_merge_parsed_list_types(
self, parsed_list: list[str], field_name: str
) -> tuple[Optional[type], Optional[type]]:
merge_type = self._cli_dict_args.get(field_name, list)
if (
merge_type is list
or not origin_is_union(get_origin(merge_type))
or not any(
type_
for type_ in get_args(merge_type)
if type_ is not type(None) and get_origin(type_) not in (dict, Mapping)
)
):
inferred_type = merge_type
else:
inferred_type = list if parsed_list and (len(parsed_list) > 1 or parsed_list[0].startswith('[')) else str

return merge_type, inferred_type

def _merge_parsed_list(self, parsed_list: list[str], field_name: str) -> str:
try:
merged_list: list[str] = []
is_last_consumed_a_value = False
merge_type = self._cli_dict_args.get(field_name, list)
if (
merge_type is list
or not origin_is_union(get_origin(merge_type))
or not any(
type_
for type_ in get_args(merge_type)
if type_ is not type(None) and get_origin(type_) not in (dict, Mapping)
)
):
inferred_type = merge_type
else:
inferred_type = (
list if parsed_list and (len(parsed_list) > 1 or parsed_list[0].startswith('[')) else str
)
merge_type, inferred_type = self._get_merge_parsed_list_types(parsed_list, field_name)
for val in parsed_list:
if not isinstance(val, str):
# If val is not a string, it's from an external parser and we can ignore parsing the rest of the
# list.
break
val = val.strip()
if val.startswith('[') and val.endswith(']'):
val = val[1:-1].strip()
Expand Down
36 changes: 33 additions & 3 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3665,21 +3665,51 @@ class Cfg(BaseSettings):
cli_cfg_settings = CliSettingsSource(Cfg, cli_prefix=prefix, root_parser=parser)

add_arg('--fruit', choices=['pear', 'kiwi', 'lime'])
add_arg('--num-list', action='append', type=int)
add_arg('--num', type=int)

args = ['--fruit', 'pear']
args = ['--fruit', 'pear', '--num', '0', '--num-list', '1', '--num-list', '2', '--num-list', '3']
parsed_args = parse_args(args)
assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=parsed_args)).model_dump() == {'pet': 'bird'}
assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == {'pet': 'bird'}
assert Cfg(_cli_settings_source=cli_cfg_settings(args=False)).model_dump() == {'pet': 'bird'}

arg_prefix = f'{prefix}.' if prefix else ''
args = ['--fruit', 'kiwi', f'--{arg_prefix}pet', 'dog']
args = [
'--fruit',
'kiwi',
'--num',
'0',
'--num-list',
'1',
'--num-list',
'2',
'--num-list',
'3',
f'--{arg_prefix}pet',
'dog',
]
parsed_args = parse_args(args)
assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=parsed_args)).model_dump() == {'pet': 'dog'}
assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == {'pet': 'dog'}
assert Cfg(_cli_settings_source=cli_cfg_settings(args=False)).model_dump() == {'pet': 'bird'}

parsed_args = parse_args(['--fruit', 'kiwi', f'--{arg_prefix}pet', 'cat'])
parsed_args = parse_args(
[
'--fruit',
'kiwi',
'--num',
'0',
'--num-list',
'1',
'--num-list',
'2',
'--num-list',
'3',
f'--{arg_prefix}pet',
'cat',
]
)
assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=vars(parsed_args))).model_dump() == {'pet': 'cat'}
assert Cfg(_cli_settings_source=cli_cfg_settings(args=False)).model_dump() == {'pet': 'bird'}

Expand Down
Loading