Skip to content

Commit

Permalink
Support cli_cmd default, verify main subcommands.
Browse files Browse the repository at this point in the history
  • Loading branch information
kschwab committed Jul 25, 2024
1 parent 6b1c76a commit 78aee2b
Showing 1 changed file with 32 additions and 10 deletions.
42 changes: 32 additions & 10 deletions pydantic_settings/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import inspect
from pathlib import Path
from typing import Any, Callable, ClassVar, cast
from typing import Any, Callable, ClassVar, TypeVar, cast, get_args

from pydantic import ConfigDict, create_model
from pydantic._internal._config import config_keys
Expand All @@ -25,6 +25,8 @@
_CliSubCommand,
)

Model = TypeVar('Model', bound='BaseModel')


class SettingsConfigDict(ConfigDict, total=False):
case_sensitive: bool
Expand Down Expand Up @@ -370,14 +372,16 @@ def _settings_build_values(

class CliApp:
@staticmethod
def _get_command_entrypoint(model_cls: type[BaseModel]) -> Callable[[Any], None]:
def _get_command_entrypoint(model_cls: type[Model]) -> Callable[[Any], None]:
for _, function in inspect.getmembers(model_cls, predicate=inspect.isfunction):
if hasattr(unwrap_wrapped_function(function), '_cli_app_command_entrypoint'):
return function
if hasattr(model_cls, 'cli_cmd'):
return getattr(model_cls, 'cli_cmd')
raise SettingsError(f'Error: {model_cls.__name__} class is missing AppCli.command entrypoint')

@staticmethod
def _cli_settings_source(model_cls: type[BaseModel]) -> CliSettingsSource[Any]:
def _cli_settings_source(model_cls: type[Model]) -> CliSettingsSource[Any]:
fields = model_cls.__pydantic_fields__ if is_pydantic_dataclass(model_cls) else model_cls.model_fields
field_definitions: dict[str, tuple[type, Any]] = {
name: (info.annotation, info) for name, info in fields.items() if info.annotation is not None
Expand All @@ -386,9 +390,23 @@ def _cli_settings_source(model_cls: type[BaseModel]) -> CliSettingsSource[Any]:
return CliSettingsSource(base_settings)

@staticmethod
def main(cli_app_main_cls: type[BaseModel] | type[BaseSettings]) -> type[BaseModel] | type[BaseSettings]:
def _validate_main_subcommands(model_cls: type[Model]) -> None:
fields = model_cls.__pydantic_fields__ if is_pydantic_dataclass(model_cls) else model_cls.model_fields
for _, field_info in fields.items():
if _CliSubCommand in field_info.metadata:
field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)]
for subcommand_cls in field_types:
if hasattr(subcommand_cls.__init__, '_cli_app_main_entrypoint'):
raise SettingsError(
f'Error: CliApp.main "{model_cls.__name__}" cannot have a '
f'subcommand of CliApp.main "{subcommand_cls.__name__}"'
)

@staticmethod
def main(cli_app_main_cls: type[Model]) -> type[Model]:
if cli_app_main_cls.__init__.__name__ != '_cli_app_init':
original_init = cli_app_main_cls.__init__
CliApp._validate_main_subcommands(cli_app_main_cls)

def _cli_app_init(*args: Any, **kwargs: Any) -> None:
if issubclass(cli_app_main_cls, BaseSettings):
Expand All @@ -406,6 +424,7 @@ def _cli_app_init(*args: Any, **kwargs: Any) -> None:
command_entry_point = CliApp._get_command_entrypoint(cli_app_main_cls)
command_entry_point(args[0])

setattr(_cli_app_init, '_cli_app_main_entrypoint', True)
setattr(cli_app_main_cls, '__init__', _cli_app_init)
return cli_app_main_cls

Expand All @@ -416,7 +435,7 @@ def command(function: Callable[[Any], Any]) -> Callable[[Any], Any]:

@staticmethod
def run(
model_cls: type[BaseModel] | type[BaseSettings],
model_cls: type[Model],
cli_args: list[str] | None = None,
cli_settings_source: CliSettingsSource[Any] | None = None,
**model_args: Any,
Expand All @@ -439,8 +458,11 @@ def run(
cli_settings_source = CliApp._cli_settings_source(model_cls)
cli_settings_source = cli_settings_source(args=cli_parse_args)
init_kwargs = deep_update(cli_settings_source(), init_kwargs)
command_entry_point = CliApp._get_command_entrypoint(model_cls)
command_entry_point(model_cls(**init_kwargs))
if hasattr(model_cls.__init__, '_cli_app_main_entrypoint'):
model_cls(**init_kwargs)
else:
command_entry_point = CliApp._get_command_entrypoint(model_cls)
command_entry_point(model_cls(**init_kwargs))

@staticmethod
def get_subcommand(model: BaseModel, is_required: bool = True) -> Any:
Expand All @@ -461,10 +483,10 @@ def get_subcommand(model: BaseModel, is_required: bool = True) -> Any:
return None

@staticmethod
def run_subcommand(model: BaseModel, cli_exit_on_error: bool | None = None) -> Any:
def run_subcommand(model: BaseModel, cli_exit_on_error: bool | None = None) -> None:
try:
subcommand = CliApp.get_subcommand(model)
return CliApp._get_command_entrypoint(subcommand.__class__)(subcommand)
subcommand = CliApp.get_subcommand(model, is_required=True)
CliApp._get_command_entrypoint(subcommand.__class__)(subcommand)
except SettingsError as err:
if (cli_exit_on_error is not None and cli_exit_on_error) or model.model_config.get('cli_exit_on_error'):
raise SystemExit(err)
Expand Down

0 comments on commit 78aee2b

Please sign in to comment.