Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: Settings source - allow non-JSON complex fields #126

Closed
dbendall opened this issue Jul 13, 2023 · 4 comments
Closed

Feature Request: Settings source - allow non-JSON complex fields #126

dbendall opened this issue Jul 13, 2023 · 4 comments
Assignees

Comments

@dbendall
Copy link
Contributor

dbendall commented Jul 13, 2023

Use case

When defining a settings model with complex fields, the default environment variable settings source class only decodes JSON encoded strings for that complex field. It would be great if there were an easy way to decode non-JSON encoded strings too (for example a comma separated list)

import os

from pydantic import AnyHttpUrl
from pydantic_settings import BaseSettings


class Settings(BaseSettings):
    list_of_urls: list[AnyHttpUrl]


os.environ["LIST_OF_URLS"] = '["https://pydantic.dev","https://github.com"]'
settings_from_json = Settings()  # OK

os.environ["LIST_OF_URLS"] = "https://pydantic.dev,https://github.com"
settings_from_csl = Settings()  # SettingsError from json.decode.JSONDecodeError

Current solution

One can subclass the EnvSettingsSource class to implement this functionality and apply it as a customisable settings source but this requires the developer to repeat the prepare_field_value and explode_env_vars methods to replace only the json.loads() call resulting in a large amount of repeated code.

Example code...
import json
import os
from typing import Any, Mapping, Optional

from pydantic import AnyHttpUrl, fields
from pydantic_settings import (
    BaseSettings,
    EnvSettingsSource,
    PydanticBaseSettingsSource,
    sources,
)


class ModifiedEnvSettingsSource(EnvSettingsSource):
    """Environment variables settings source that uses custom complex object parser"""

    def prepare_field_value(
        self,
        field_name: str,
        field: fields.FieldInfo,
        value: Any,
        value_is_complex: bool,
    ) -> Any:
        """
        Prepare value for the field.
        * Extract value for nested field.
        * Deserialize value to python object for complex field.
        """
        is_complex, allow_parse_failure = self._field_is_complex(field)
        if is_complex or value_is_complex:
            if value is None:
                # field is complex but no value found so far, try explode_env_vars
                env_val_built = self.explode_env_vars(field_name, field, self.env_vars)
                if env_val_built:
                    return env_val_built
            else:
                # field is complex and there's a value, decode that as JSON, then add explode_env_vars
                try:
                    value = self.decode_complex_value(value, field)
                except ValueError as ex:
                    if not allow_parse_failure:
                        raise ex

                if isinstance(value, dict):
                    return sources.deep_update(
                        value, self.explode_env_vars(field_name, field, self.env_vars)
                    )
                return value
        elif value is not None:
            # simplest case, field is not complex, we only need to add the value if it was found
            return value
        return None

    def explode_env_vars(
        self,
        field_name: str,
        field: fields.FieldInfo,
        env_vars: Mapping[str, Optional[str]],
    ) -> dict[str, Any]:
        """
        Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
        This is applied to a single field, hence filtering by env_var prefix.
        """
        prefixes = [
            f"{env_name}{self.env_nested_delimiter}"
            for _, env_name, _ in self._extract_field_info(field, field_name)
        ]
        result: dict[str, Any] = {}
        for env_name, env_val in env_vars.items():
            if not any(env_name.startswith(prefix) for prefix in prefixes):
                continue
            # we remove the prefix before splitting in case the prefix has characters in common with the delimiter
            env_name_without_prefix = env_name[self.env_prefix_len :]
            _, *keys, last_key = env_name_without_prefix.split(
                self.env_nested_delimiter
            )
            env_var = result
            target_field: Optional[fields.FieldInfo] = field
            for key in keys:
                target_field = self.next_field(target_field, key)
                env_var = env_var.setdefault(key, {})

            # get proper field with last_key
            target_field = self.next_field(target_field, last_key)

            # check if env_val maps to a complex field and if so, parse the env_val
            if target_field and env_val:
                is_complex, allow_json_failure = self._field_is_complex(target_field)
                if is_complex:
                    try:
                        env_val = self.decode_complex_value(env_val, target_field)
                    except ValueError as ex:
                        if not allow_json_failure:
                            raise ex
            env_var[last_key] = env_val

        return result

    @staticmethod
    def decode_complex_value(value: str, field: fields.FieldInfo) -> Any:
        """Decode the complex value for the given field"""
        try:
            return json.loads(value)
        except json.decoder.JSONDecodeError as ex:
            if field.annotation and field.annotation.__origin__ is list:
                return [str(x).strip() for x in value.split(",")]
            raise ex


class Settings(BaseSettings):
    list_of_urls: list[AnyHttpUrl]

    @classmethod
    def settings_customise_sources(
        cls,
        settings_cls: type[BaseSettings],
        init_settings: PydanticBaseSettingsSource,
        env_settings: PydanticBaseSettingsSource,
        dotenv_settings: PydanticBaseSettingsSource,
        file_secret_settings: PydanticBaseSettingsSource,
    ) -> tuple[PydanticBaseSettingsSource, ...]:
        return (init_settings, ModifiedEnvSettingsSource(settings_cls))


os.environ["LIST_OF_URLS"] = '["https://pydantic.dev","https://github.com"]'
settings_from_json = Settings()  # OK

os.environ["LIST_OF_URLS"] = "https://pydantic.dev,https://github.com"
settings_from_csl = Settings()  # SettingsError from json.decode.JSONDecodeError

Proposed solution

Break out the json.loads() from these methods into a dedicated decode_complex_value method for easier sub-classing when implementing custom settings source.

Example code...
import json
import os
from typing import Any

from pydantic import AnyHttpUrl, fields
from pydantic_settings import (
    BaseSettings,
    EnvSettingsSource,
    PydanticBaseSettingsSource,
)


class ModifiedEnvSettingsSource(EnvSettingsSource):
    """Environment variables settings source that uses custom complex object parser"""

    @staticmethod
    def decode_complex_value(value: str, field: fields.FieldInfo) -> Any:
        """Decode the complex value for the given field"""
        try:
            return json.loads(value)
        except json.decoder.JSONDecodeError as ex:
            if field.annotation and field.annotation.__origin__ is list:
                return [str(x).strip() for x in value.split(",")]
            raise ex


class Settings(BaseSettings):
    list_of_urls: list[AnyHttpUrl]

    @classmethod
    def settings_customise_sources(
        cls,
        settings_cls: type[BaseSettings],
        init_settings: PydanticBaseSettingsSource,
        env_settings: PydanticBaseSettingsSource,
        dotenv_settings: PydanticBaseSettingsSource,
        file_secret_settings: PydanticBaseSettingsSource,
    ) -> tuple[PydanticBaseSettingsSource, ...]:
        return (init_settings, ModifiedEnvSettingsSource(settings_cls))


os.environ["LIST_OF_URLS"] = '["https://pydantic.dev","https://github.com"]'
settings_from_json = Settings()  # OK

os.environ["LIST_OF_URLS"] = "https://pydantic.dev,https://github.com"
settings_from_csl = Settings()  # OK

Thanks for reading, I'd be happy to put together a PR for this if acceptable.

Selected Assignee: @hramezani

@hramezani
Copy link
Member

Thanks @dbendall for this issue 🙏

You only need to override the prepare_field_value.

import json
import os
from typing import Any

from pydantic import AnyHttpUrl
from pydantic.fields import FieldInfo
from pydantic_settings import (
    BaseSettings,
    EnvSettingsSource,
    PydanticBaseSettingsSource,
)

class ModifiedEnvSettingsSource(EnvSettingsSource):
    def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
        try:
            return json.loads(value)
        except ValueError:
            return value.split(',')


class Settings(BaseSettings):
    list_of_urls: list[AnyHttpUrl]

    @classmethod
    def settings_customise_sources(
        cls,
        settings_cls: type[BaseSettings],
        init_settings: PydanticBaseSettingsSource,
        env_settings: PydanticBaseSettingsSource,
        dotenv_settings: PydanticBaseSettingsSource,
        file_secret_settings: PydanticBaseSettingsSource,
    ) -> tuple[PydanticBaseSettingsSource, ...]:
        return (init_settings, ModifiedEnvSettingsSource(settings_cls))


os.environ["LIST_OF_URLS"] = '["https://pydantic.dev","https://github.com"]'
settings_from_json = Settings()  # OK

os.environ["LIST_OF_URLS"] = "https://pydantic.dev,https://github.com"
settings_from_csl = Settings()  # SettingsError from json.decode.JSONDecodeError

@dbendall
Copy link
Contributor Author

Thanks for checking this out, although this works for the simple example I provided, it starts to get more complicated when you add non-complex fields or if you use nested settings models. For example:

class NestedSettings(BaseSettings):
    another_list_of_urls: list[AnyHttpUrl]


class Settings(BaseSettings):
    single_str: str
    single_int: int
    list_of_urls: list[AnyHttpUrl]
    nested_model: NestedSettings

    model_config = SettingsConfigDict(env_nested_delimiter="__")


os.environ["SINGLE_STR"] = "example"
os.environ["SINGLE_INT"] = "99"

os.environ["LIST_OF_URLS"] = '["https://pydantic.dev","https://github.com"]'
os.environ["NESTED_MODEL__ANOTHER_LIST_OF_URLS"] = '["https://pydantic.dev","https://github.com"]'
settings_from_json = Settings()  # OK

os.environ["LIST_OF_URLS"] = "https://pydantic.dev,https://github.com"
os.environ["NESTED_MODEL__ANOTHER_LIST_OF_URLS"] = "https://pydantic.dev,https://github.com"
settings_from_csl = Settings()  # SettingsError from json.decode.JSONDecodeError

For our use-case it's only the complex model decoding that we would like to customise, we would like to retain all the existing environment variable parsing, extraction, and preparation where possible so that we are only extending the specific bit of functionality we need to change.

Please let me know if I can provide any further examples and as I said, I'm happy to put together a PR to show how I think it can be done without making any changes to the existing interface.

@hramezani
Copy link
Member

Thanks @dbendall for the explanation.
Happy to review a PR

@hramezani
Copy link
Member

Closed in 7ebb3bf

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants