From c8780e2308dc418031a89e100867da48d4420f90 Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Thu, 18 Jul 2024 08:54:00 +0000 Subject: [PATCH 1/2] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Speed=20up=20function?= =?UTF-8?q?=20`resolve=5Foriginal=5Fschema`=20by=2035%=20Here's=20an=20opt?= =?UTF-8?q?imized=20version=20of=20the=20given=20Python=20program.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Changes Made 1. **Avoid Multiple Dictionary Lookups**: By storing `schema['type']` in the variable `schema_type` at the beginning, we avoid accessing the dictionary multiple times, which can be a relatively slow operation. 2. **Simplified Return Statement**: Removed the `, None` part in `definitions.get(schema['schema_ref'])` because `dict.get()` already returns `None` if the key is not found, making this explicit mention redundant and adding a small unnecessary overhead. These changes ensure that the code runs faster while maintaining the same functionality and return values. --- pydantic/_internal/_generate_schema.py | 1141 ++++++++++++++++-------- 1 file changed, 770 insertions(+), 371 deletions(-) diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index 4055a4e3e7..b6b7ab4de7 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -36,16 +36,34 @@ from warnings import warn from pydantic_core import CoreSchema, PydanticUndefined, core_schema, to_jsonable_python -from typing_extensions import Annotated, Literal, TypeAliasType, TypedDict, get_args, get_origin, is_typeddict +from typing_extensions import ( + Annotated, + Literal, + TypeAliasType, + TypedDict, + get_args, + get_origin, + is_typeddict, +) from ..aliases import AliasGenerator from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler from ..config import ConfigDict, JsonDict, JsonEncoder -from ..errors import PydanticSchemaGenerationError, PydanticUndefinedAnnotation, PydanticUserError +from ..errors import ( + PydanticSchemaGenerationError, + PydanticUndefinedAnnotation, + PydanticUserError, +) from ..json_schema import JsonSchemaValue from ..version import version_short from ..warnings import PydanticDeprecatedSince20 -from . import _core_utils, _decorators, _discriminated_union, _known_annotated_metadata, _typing_extra +from . import ( + _core_utils, + _decorators, + _discriminated_union, + _known_annotated_metadata, + _typing_extra, +) from ._config import ConfigWrapper, ConfigWrapperStack from ._core_metadata import CoreMetadataHandler, build_metadata_dict from ._core_utils import ( @@ -76,7 +94,12 @@ from ._docs_extraction import extract_docstrings_from_cls from ._fields import collect_dataclass_fields, get_type_hints_infer_globalns from ._forward_ref import PydanticRecursiveRef -from ._generics import get_standard_typevars_map, has_instance_in_type, recursively_defined_type_refs, replace_types +from ._generics import ( + get_standard_typevars_map, + has_instance_in_type, + recursively_defined_type_refs, + replace_types, +) from ._mock_val_ser import MockCoreSchema from ._schema_generation_shared import CallbackGetCoreSchemaHandler from ._typing_extra import is_finalvar, is_self_type @@ -93,8 +116,10 @@ _SUPPORTS_TYPEDDICT = sys.version_info >= (3, 12) _AnnotatedType = type(Annotated[int, 123]) -FieldDecoratorInfo = Union[ValidatorDecoratorInfo, FieldValidatorDecoratorInfo, FieldSerializerDecoratorInfo] -FieldDecoratorInfoType = TypeVar('FieldDecoratorInfoType', bound=FieldDecoratorInfo) +FieldDecoratorInfo = Union[ + ValidatorDecoratorInfo, FieldValidatorDecoratorInfo, FieldSerializerDecoratorInfo +] +FieldDecoratorInfoType = TypeVar("FieldDecoratorInfoType", bound=FieldDecoratorInfo) AnyFieldDecorator = Union[ Decorator[ValidatorDecoratorInfo], Decorator[FieldValidatorDecoratorInfo], @@ -102,13 +127,20 @@ ] ModifyCoreSchemaWrapHandler = GetCoreSchemaHandler -GetCoreSchemaFunction = Callable[[Any, ModifyCoreSchemaWrapHandler], core_schema.CoreSchema] +GetCoreSchemaFunction = Callable[ + [Any, ModifyCoreSchemaWrapHandler], core_schema.CoreSchema +] TUPLE_TYPES: list[type] = [tuple, typing.Tuple] LIST_TYPES: list[type] = [list, typing.List, collections.abc.MutableSequence] SET_TYPES: list[type] = [set, typing.Set, collections.abc.MutableSet] FROZEN_SET_TYPES: list[type] = [frozenset, typing.FrozenSet, collections.abc.Set] -DICT_TYPES: list[type] = [dict, typing.Dict, collections.abc.MutableMapping, collections.abc.Mapping] +DICT_TYPES: list[type] = [ + dict, + typing.Dict, + collections.abc.MutableMapping, + collections.abc.Mapping, +] def check_validator_fields_against_field_name( @@ -124,7 +156,7 @@ def check_validator_fields_against_field_name( Returns: `True` if field name is in validator fields, `False` otherwise. """ - if '*' in info.fields: + if "*" in info.fields: return True for v_field_name in info.fields: if v_field_name == field: @@ -132,7 +164,9 @@ def check_validator_fields_against_field_name( return False -def check_decorator_fields_exist(decorators: Iterable[AnyFieldDecorator], fields: Iterable[str]) -> None: +def check_decorator_fields_exist( + decorators: Iterable[AnyFieldDecorator], fields: Iterable[str] +) -> None: """Check if the defined fields in decorators exist in `fields` param. It ignores the check for a decorator if the decorator has `*` as field or `check_fields=False`. @@ -146,23 +180,27 @@ def check_decorator_fields_exist(decorators: Iterable[AnyFieldDecorator], fields """ fields = set(fields) for dec in decorators: - if '*' in dec.info.fields: + if "*" in dec.info.fields: continue if dec.info.check_fields is False: continue for field in dec.info.fields: if field not in fields: raise PydanticUserError( - f'Decorators defined with incorrect fields: {dec.cls_ref}.{dec.cls_var_name}' + f"Decorators defined with incorrect fields: {dec.cls_ref}.{dec.cls_var_name}" " (use check_fields=False if you're inheriting from the model and intended this)", - code='decorator-missing-field', + code="decorator-missing-field", ) def filter_field_decorator_info_by_field( validator_functions: Iterable[Decorator[FieldDecoratorInfoType]], field: str ) -> list[Decorator[FieldDecoratorInfoType]]: - return [dec for dec in validator_functions if check_validator_fields_against_field_name(dec.info, field)] + return [ + dec + for dec in validator_functions + if check_validator_fields_against_field_name(dec.info, field) + ] def apply_each_item_validators( @@ -175,26 +213,34 @@ def apply_each_item_validators( # push down any `each_item=True` validators # note that this won't work for any Annotated types that get wrapped by a function validator # but that's okay because that didn't exist in V1 - if schema['type'] == 'nullable': - schema['schema'] = apply_each_item_validators(schema['schema'], each_item_validators, field_name) + if schema["type"] == "nullable": + schema["schema"] = apply_each_item_validators( + schema["schema"], each_item_validators, field_name + ) return schema - elif schema['type'] == 'tuple': - if (variadic_item_index := schema.get('variadic_item_index')) is not None: - schema['items_schema'][variadic_item_index] = apply_validators( - schema['items_schema'][variadic_item_index], each_item_validators, field_name + elif schema["type"] == "tuple": + if (variadic_item_index := schema.get("variadic_item_index")) is not None: + schema["items_schema"][variadic_item_index] = apply_validators( + schema["items_schema"][variadic_item_index], + each_item_validators, + field_name, ) elif is_list_like_schema_with_items_schema(schema): - inner_schema = schema.get('items_schema', None) + inner_schema = schema.get("items_schema", None) if inner_schema is None: inner_schema = core_schema.any_schema() - schema['items_schema'] = apply_validators(inner_schema, each_item_validators, field_name) - elif schema['type'] == 'dict': + schema["items_schema"] = apply_validators( + inner_schema, each_item_validators, field_name + ) + elif schema["type"] == "dict": # push down any `each_item=True` validators onto dict _values_ # this is super arbitrary but it's the V1 behavior - inner_schema = schema.get('values_schema', None) + inner_schema = schema.get("values_schema", None) if inner_schema is None: inner_schema = core_schema.any_schema() - schema['values_schema'] = apply_validators(inner_schema, each_item_validators, field_name) + schema["values_schema"] = apply_validators( + inner_schema, each_item_validators, field_name + ) elif each_item_validators: raise TypeError( f"`@validator(..., each_item=True)` cannot be applied to fields with a schema of {schema['type']}" @@ -228,20 +274,24 @@ def modify_model_json_schema( json_schema = handler(schema_or_field) original_schema = handler.resolve_ref_schema(json_schema) # Preserve the fact that definitions schemas should never have sibling keys: - if '$ref' in original_schema: - ref = original_schema['$ref'] + if "$ref" in original_schema: + ref = original_schema["$ref"] original_schema.clear() - original_schema['allOf'] = [{'$ref': ref}] + original_schema["allOf"] = [{"$ref": ref}] if title is not None: - original_schema['title'] = title - elif 'title' not in original_schema: - original_schema['title'] = cls.__name__ + original_schema["title"] = title + elif "title" not in original_schema: + original_schema["title"] = cls.__name__ # BaseModel + Dataclass; don't use cls.__doc__ as it will contain the verbose class signature by default - docstring = None if cls is BaseModel or is_builtin_dataclass(cls) or is_pydantic_dataclass(cls) else cls.__doc__ - if docstring and 'description' not in original_schema: - original_schema['description'] = inspect.cleandoc(docstring) - elif issubclass(cls, RootModel) and cls.model_fields['root'].description: - original_schema['description'] = cls.model_fields['root'].description + docstring = ( + None + if cls is BaseModel or is_builtin_dataclass(cls) or is_pydantic_dataclass(cls) + else cls.__doc__ + ) + if docstring and "description" not in original_schema: + original_schema["description"] = inspect.cleandoc(docstring) + elif issubclass(cls, RootModel) and cls.model_fields["root"].description: + original_schema["description"] = cls.model_fields["root"].description return json_schema @@ -260,23 +310,25 @@ def _add_custom_serialization_from_json_encoders( """ if not json_encoders: return schema - if 'serialization' in schema: + if "serialization" in schema: return schema # Check the class type and its superclasses for a matching encoder # Decimal.__class__.__mro__ (and probably other cases) doesn't include Decimal itself # if the type is a GenericAlias (e.g. from list[int]) we need to use __class__ instead of .__mro__ - for base in (tp, *getattr(tp, '__mro__', tp.__class__.__mro__)[:-1]): + for base in (tp, *getattr(tp, "__mro__", tp.__class__.__mro__)[:-1]): encoder = json_encoders.get(base) if encoder is None: continue warnings.warn( - f'`json_encoders` is deprecated. See https://docs.pydantic.dev/{version_short()}/concepts/serialization/#custom-serializers for alternatives', + f"`json_encoders` is deprecated. See https://docs.pydantic.dev/{version_short()}/concepts/serialization/#custom-serializers for alternatives", PydanticDeprecatedSince20, ) # TODO: in theory we should check that the schema accepts a serialization key - schema['serialization'] = core_schema.plain_serializer_function_ser_schema(encoder, when_used='json') + schema["serialization"] = core_schema.plain_serializer_function_ser_schema( + encoder, when_used="json" + ) return schema return schema @@ -297,7 +349,10 @@ def tail(self) -> TypesNamespace: @contextmanager def push(self, for_type: type[Any]): - types_namespace = {**_typing_extra.get_cls_types_namespace(for_type), **(self.tail or {})} + types_namespace = { + **_typing_extra.get_cls_types_namespace(for_type), + **(self.tail or {}), + } self._types_namespace_stack.append(types_namespace) try: yield @@ -318,12 +373,12 @@ class GenerateSchema: """Generate core schema for a Pydantic model, dataclass and types like `str`, `datetime`, ... .""" __slots__ = ( - '_config_wrapper_stack', - '_types_namespace_stack', - '_typevars_map', - 'field_name_stack', - 'model_type_stack', - 'defs', + "_config_wrapper_stack", + "_types_namespace_stack", + "_typevars_map", + "field_name_stack", + "model_type_stack", + "defs", ) def __init__( @@ -391,7 +446,9 @@ def _list_schema(self, tp: Any, items_type: Any) -> CoreSchema: return core_schema.list_schema(self.generate_schema(items_type)) def _dict_schema(self, tp: Any, keys_type: Any, values_type: Any) -> CoreSchema: - return core_schema.dict_schema(self.generate_schema(keys_type), self.generate_schema(values_type)) + return core_schema.dict_schema( + self.generate_schema(keys_type), self.generate_schema(values_type) + ) def _set_schema(self, tp: Any, items_type: Any) -> CoreSchema: return core_schema.set_schema(self.generate_schema(items_type)) @@ -402,10 +459,10 @@ def _frozenset_schema(self, tp: Any, items_type: Any) -> CoreSchema: def _arbitrary_type_schema(self, tp: Any) -> CoreSchema: if not isinstance(tp, type): warn( - f'{tp!r} is not a Python type (it may be an instance of an object),' - ' Pydantic will allow any object with no validation since we cannot even' - ' enforce that the input is an instance of the given type.' - ' To get rid of this error wrap the type with `pydantic.SkipValidation`.', + f"{tp!r} is not a Python type (it may be an instance of an object)," + " Pydantic will allow any object with no validation since we cannot even" + " enforce that the input is an instance of the given type." + " To get rid of this error wrap the type with `pydantic.SkipValidation`.", UserWarning, ) return core_schema.any_schema() @@ -413,13 +470,13 @@ def _arbitrary_type_schema(self, tp: Any) -> CoreSchema: def _unknown_type_schema(self, obj: Any) -> CoreSchema: raise PydanticSchemaGenerationError( - f'Unable to generate pydantic-core schema for {obj!r}. ' - 'Set `arbitrary_types_allowed=True` in the model_config to ignore this error' - ' or implement `__get_pydantic_core_schema__` on your type to fully support it.' - '\n\nIf you got this error by calling handler() within' - ' `__get_pydantic_core_schema__` then you likely need to call' - ' `handler.generate_schema()` since we do not call' - ' `__get_pydantic_core_schema__` on `` otherwise to avoid infinite recursion.' + f"Unable to generate pydantic-core schema for {obj!r}. " + "Set `arbitrary_types_allowed=True` in the model_config to ignore this error" + " or implement `__get_pydantic_core_schema__` on your type to fully support it." + "\n\nIf you got this error by calling handler() within" + " `__get_pydantic_core_schema__` then you likely need to call" + " `handler.generate_schema()` since we do not call" + " `__get_pydantic_core_schema__` on `` otherwise to avoid infinite recursion." ) def _apply_discriminator_to_union( @@ -453,19 +510,21 @@ def clean_schema(self, schema: CoreSchema) -> CoreSchema: return schema def collect_definitions(self, schema: CoreSchema) -> CoreSchema: - ref = cast('str | None', schema.get('ref', None)) + ref = cast("str | None", schema.get("ref", None)) if ref: self.defs.definitions[ref] = schema - if 'ref' in schema: - schema = core_schema.definition_reference_schema(schema['ref']) + if "ref" in schema: + schema = core_schema.definition_reference_schema(schema["ref"]) return core_schema.definitions_schema( schema, list(self.defs.definitions.values()), ) - def _add_js_function(self, metadata_schema: CoreSchema, js_function: Callable[..., Any]) -> None: + def _add_js_function( + self, metadata_schema: CoreSchema, js_function: Callable[..., Any] + ) -> None: metadata = CoreMetadataHandler(metadata_schema).metadata - pydantic_js_functions = metadata.setdefault('pydantic_js_functions', []) + pydantic_js_functions = metadata.setdefault("pydantic_js_functions", []) # because of how we generate core schemas for nested generic models # we can end up adding `BaseModel.__get_pydantic_json_schema__` multiple times # this check may fail to catch duplicates if the function is a `functools.partial` @@ -517,7 +576,9 @@ def generate_schema( if metadata_schema: self._add_js_function(metadata_schema, metadata_js_function) - schema = _add_custom_serialization_from_json_encoders(self._config_wrapper.json_encoders, obj, schema) + schema = _add_custom_serialization_from_json_encoders( + self._config_wrapper.json_encoders, obj, schema + ) return schema @@ -541,26 +602,32 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema: config_wrapper = ConfigWrapper(cls.model_config, check=False) core_config = config_wrapper.core_config(cls) title = self._get_model_title_from_config(cls, config_wrapper) - metadata = build_metadata_dict(js_functions=[partial(modify_model_json_schema, cls=cls, title=title)]) + metadata = build_metadata_dict( + js_functions=[partial(modify_model_json_schema, cls=cls, title=title)] + ) model_validators = decorators.model_validators.values() extras_schema = None - if core_config.get('extra_fields_behavior') == 'allow': + if core_config.get("extra_fields_behavior") == "allow": assert cls.__mro__[0] is cls assert cls.__mro__[-1] is object for candidate_cls in cls.__mro__[:-1]: - extras_annotation = getattr(candidate_cls, '__annotations__', {}).get('__pydantic_extra__', None) + extras_annotation = getattr( + candidate_cls, "__annotations__", {} + ).get("__pydantic_extra__", None) if extras_annotation is not None: if isinstance(extras_annotation, str): extras_annotation = _typing_extra.eval_type_backport( - _typing_extra._make_forward_ref(extras_annotation, is_argument=False, is_class=True), + _typing_extra._make_forward_ref( + extras_annotation, is_argument=False, is_class=True + ), self._types_namespace, ) tp = get_origin(extras_annotation) if tp not in (Dict, dict): raise PydanticSchemaGenerationError( - 'The type annotation for `__pydantic_extra__` must be `Dict[str, ...]`' + "The type annotation for `__pydantic_extra__` must be `Dict[str, ...]`" ) extra_items_type = self._get_args_resolving_forward_refs( extras_annotation, @@ -570,57 +637,79 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema: extras_schema = self.generate_schema(extra_items_type) break - with self._config_wrapper_stack.push(config_wrapper), self._types_namespace_stack.push(cls): + with self._config_wrapper_stack.push( + config_wrapper + ), self._types_namespace_stack.push(cls): self = self._current_generate_schema if cls.__pydantic_root_model__: - root_field = self._common_field_schema('root', fields['root'], decorators) - inner_schema = root_field['schema'] - inner_schema = apply_model_validators(inner_schema, model_validators, 'inner') + root_field = self._common_field_schema( + "root", fields["root"], decorators + ) + inner_schema = root_field["schema"] + inner_schema = apply_model_validators( + inner_schema, model_validators, "inner" + ) model_schema = core_schema.model_schema( cls, inner_schema, - custom_init=getattr(cls, '__pydantic_custom_init__', None), + custom_init=getattr(cls, "__pydantic_custom_init__", None), root_model=True, - post_init=getattr(cls, '__pydantic_post_init__', None), + post_init=getattr(cls, "__pydantic_post_init__", None), config=core_config, ref=model_ref, metadata=metadata, ) else: - fields_schema: core_schema.CoreSchema = core_schema.model_fields_schema( - {k: self._generate_md_field_schema(k, v, decorators) for k, v in fields.items()}, - computed_fields=[ - self._computed_field_schema(d, decorators.field_serializers) - for d in computed_fields.values() - ], - extras_schema=extras_schema, - model_name=cls.__name__, + fields_schema: core_schema.CoreSchema = ( + core_schema.model_fields_schema( + { + k: self._generate_md_field_schema(k, v, decorators) + for k, v in fields.items() + }, + computed_fields=[ + self._computed_field_schema( + d, decorators.field_serializers + ) + for d in computed_fields.values() + ], + extras_schema=extras_schema, + model_name=cls.__name__, + ) + ) + inner_schema = apply_validators( + fields_schema, decorators.root_validators.values(), None + ) + new_inner_schema = define_expected_missing_refs( + inner_schema, recursively_defined_type_refs() ) - inner_schema = apply_validators(fields_schema, decorators.root_validators.values(), None) - new_inner_schema = define_expected_missing_refs(inner_schema, recursively_defined_type_refs()) if new_inner_schema is not None: inner_schema = new_inner_schema - inner_schema = apply_model_validators(inner_schema, model_validators, 'inner') + inner_schema = apply_model_validators( + inner_schema, model_validators, "inner" + ) model_schema = core_schema.model_schema( cls, inner_schema, - custom_init=getattr(cls, '__pydantic_custom_init__', None), + custom_init=getattr(cls, "__pydantic_custom_init__", None), root_model=False, - post_init=getattr(cls, '__pydantic_post_init__', None), + post_init=getattr(cls, "__pydantic_post_init__", None), config=core_config, ref=model_ref, metadata=metadata, ) - schema = self._apply_model_serializers(model_schema, decorators.model_serializers.values()) - schema = apply_model_validators(schema, model_validators, 'outer') + schema = self._apply_model_serializers( + model_schema, decorators.model_serializers.values() + ) + schema = apply_model_validators(schema, model_validators, "outer") self.defs.definitions[model_ref] = schema return core_schema.definition_reference_schema(model_ref) @staticmethod def _get_model_title_from_config( - model: type[BaseModel | StandardDataclass], config_wrapper: ConfigWrapper | None = None + model: type[BaseModel | StandardDataclass], + config_wrapper: ConfigWrapper | None = None, ) -> str | None: """Get the title of a model if `model_title_generator` or `title` are set in the config, else return None""" if config_wrapper is None: @@ -633,7 +722,9 @@ def _get_model_title_from_config( if model_title_generator: title = model_title_generator(model) if not isinstance(title, str): - raise TypeError(f'model_title_generator {model_title_generator} must return str, not {title.__class__}') + raise TypeError( + f"model_title_generator {model_title_generator} must return str, not {title.__class__}" + ) return title return None @@ -644,14 +735,16 @@ def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema: """ def get_ref(s: CoreSchema) -> str: - return s['ref'] # type: ignore + return s["ref"] # type: ignore - if schema['type'] == 'definitions': - self.defs.definitions.update({get_ref(s): s for s in schema['definitions']}) - schema = schema['schema'] + if schema["type"] == "definitions": + self.defs.definitions.update({get_ref(s): s for s in schema["definitions"]}) + schema = schema["schema"] return schema - def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.CoreSchema | None: + def _generate_schema_from_property( + self, obj: Any, source: Any + ) -> core_schema.CoreSchema | None: """Try to generate schema from either the `__get_pydantic_core_schema__` function or `__pydantic_core_schema__` property. @@ -665,19 +758,24 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C if maybe_schema is not None: return maybe_schema if obj is source: - ref_mode = 'unpack' + ref_mode = "unpack" else: - ref_mode = 'to-def' + ref_mode = "to-def" schema: CoreSchema - if (get_schema := getattr(obj, '__get_pydantic_core_schema__', None)) is not None: + if ( + get_schema := getattr(obj, "__get_pydantic_core_schema__", None) + ) is not None: if len(inspect.signature(get_schema).parameters) == 1: # (source) -> CoreSchema schema = get_schema(source) else: schema = get_schema( - source, CallbackGetCoreSchemaHandler(self._generate_schema_inner, self, ref_mode=ref_mode) + source, + CallbackGetCoreSchemaHandler( + self._generate_schema_inner, self, ref_mode=ref_mode + ), ) # fmt: off elif ( @@ -687,12 +785,17 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C ): schema = existing_schema # fmt: on - elif (validators := getattr(obj, '__get_validators__', None)) is not None: + elif (validators := getattr(obj, "__get_validators__", None)) is not None: warn( - '`__get_validators__` is deprecated and will be removed, use `__get_pydantic_core_schema__` instead.', + "`__get_validators__` is deprecated and will be removed, use `__get_pydantic_core_schema__` instead.", PydanticDeprecatedSince20, ) - schema = core_schema.chain_schema([core_schema.with_info_plain_validator_function(v) for v in validators()]) + schema = core_schema.chain_schema( + [ + core_schema.with_info_plain_validator_function(v) + for v in validators() + ] + ) else: # we have no existing schema information on the property, exit early so that we can go generate a schema return None @@ -700,9 +803,11 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C schema = self._unpack_refs_defs(schema) if is_function_with_inner_schema(schema): - ref = schema['schema'].pop('ref', None) # pyright: ignore[reportCallIssue, reportArgumentType] + ref = schema["schema"].pop( + "ref", None + ) # pyright: ignore[reportCallIssue, reportArgumentType] if ref: - schema['ref'] = ref + schema["ref"] = ref else: ref = get_ref(schema) @@ -727,7 +832,9 @@ def _resolve_forward_ref(self, obj: Any) -> Any: # if obj is still a ForwardRef, it means we can't evaluate it, raise PydanticUndefinedAnnotation if isinstance(obj, ForwardRef): - raise PydanticUndefinedAnnotation(obj.__forward_arg__, f'Unable to evaluate forward reference {obj}') + raise PydanticUndefinedAnnotation( + obj.__forward_arg__, f"Unable to evaluate forward reference {obj}" + ) if self._typevars_map: obj = replace_types(obj, self._typevars_map) @@ -735,17 +842,28 @@ def _resolve_forward_ref(self, obj: Any) -> Any: return obj @overload - def _get_args_resolving_forward_refs(self, obj: Any, required: Literal[True]) -> tuple[Any, ...]: ... + def _get_args_resolving_forward_refs( + self, obj: Any, required: Literal[True] + ) -> tuple[Any, ...]: ... @overload def _get_args_resolving_forward_refs(self, obj: Any) -> tuple[Any, ...] | None: ... - def _get_args_resolving_forward_refs(self, obj: Any, required: bool = False) -> tuple[Any, ...] | None: + def _get_args_resolving_forward_refs( + self, obj: Any, required: bool = False + ) -> tuple[Any, ...] | None: args = get_args(obj) if args: - args = tuple([self._resolve_forward_ref(a) if isinstance(a, ForwardRef) else a for a in args]) + args = tuple( + [ + self._resolve_forward_ref(a) if isinstance(a, ForwardRef) else a + for a in args + ] + ) elif required: # pragma: no cover - raise TypeError(f'Expected {obj} to have generic parameters but it had none') + raise TypeError( + f"Expected {obj} to have generic parameters but it had none" + ) return args def _get_first_arg_or_any(self, obj: Any) -> Any: @@ -760,7 +878,7 @@ def _get_first_two_args_or_any(self, obj: Any) -> tuple[Any, Any]: return (Any, Any) if len(args) < 2: origin = get_origin(obj) - raise TypeError(f'Expected two type arguments for {origin}, got 1') + raise TypeError(f"Expected two type arguments for {origin}, got 1") return args[0], args[1] def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema: @@ -909,7 +1027,12 @@ def _match_generic_type(self, obj: Any, origin: Any) -> CoreSchema: # noqa: C90 return self._subclass_schema(obj) elif origin in {typing.Sequence, collections.abc.Sequence}: return self._sequence_schema(obj) - elif origin in {typing.Iterable, collections.abc.Iterable, typing.Generator, collections.abc.Generator}: + elif origin in { + typing.Iterable, + collections.abc.Iterable, + typing.Generator, + collections.abc.Generator, + }: return self._iterable_schema(obj) elif origin in (re.Pattern, typing.Pattern): return self._pattern_schema(obj) @@ -929,12 +1052,12 @@ def _generate_td_field_schema( """Prepare a TypedDictField to represent a model or typeddict field.""" common_field = self._common_field_schema(name, field_info, decorators) return core_schema.typed_dict_field( - common_field['schema'], + common_field["schema"], required=False if not field_info.is_required() else required, - serialization_exclude=common_field['serialization_exclude'], - validation_alias=common_field['validation_alias'], - serialization_alias=common_field['serialization_alias'], - metadata=common_field['metadata'], + serialization_exclude=common_field["serialization_exclude"], + validation_alias=common_field["validation_alias"], + serialization_alias=common_field["serialization_alias"], + metadata=common_field["metadata"], ) def _generate_md_field_schema( @@ -946,12 +1069,12 @@ def _generate_md_field_schema( """Prepare a ModelField to represent a model field.""" common_field = self._common_field_schema(name, field_info, decorators) return core_schema.model_field( - common_field['schema'], - serialization_exclude=common_field['serialization_exclude'], - validation_alias=common_field['validation_alias'], - serialization_alias=common_field['serialization_alias'], - frozen=common_field['frozen'], - metadata=common_field['metadata'], + common_field["schema"], + serialization_exclude=common_field["serialization_exclude"], + validation_alias=common_field["validation_alias"], + serialization_alias=common_field["serialization_alias"], + frozen=common_field["frozen"], + metadata=common_field["metadata"], ) def _generate_dc_field_schema( @@ -964,20 +1087,22 @@ def _generate_dc_field_schema( common_field = self._common_field_schema(name, field_info, decorators) return core_schema.dataclass_field( name, - common_field['schema'], + common_field["schema"], init=field_info.init, init_only=field_info.init_var or None, kw_only=None if field_info.kw_only else False, - serialization_exclude=common_field['serialization_exclude'], - validation_alias=common_field['validation_alias'], - serialization_alias=common_field['serialization_alias'], - frozen=common_field['frozen'], - metadata=common_field['metadata'], + serialization_exclude=common_field["serialization_exclude"], + validation_alias=common_field["validation_alias"], + serialization_alias=common_field["serialization_alias"], + frozen=common_field["frozen"], + metadata=common_field["metadata"], ) @staticmethod def _apply_alias_generator_to_field_info( - alias_generator: Callable[[str], str] | AliasGenerator, field_info: FieldInfo, field_name: str + alias_generator: Callable[[str], str] | AliasGenerator, + field_info: FieldInfo, + field_name: str, ) -> None: """Apply an alias_generator to aliases on a FieldInfo instance if appropriate. @@ -999,11 +1124,15 @@ def _apply_alias_generator_to_field_info( alias, validation_alias, serialization_alias = None, None, None if isinstance(alias_generator, AliasGenerator): - alias, validation_alias, serialization_alias = alias_generator.generate_aliases(field_name) + alias, validation_alias, serialization_alias = ( + alias_generator.generate_aliases(field_name) + ) elif isinstance(alias_generator, Callable): alias = alias_generator(field_name) if not isinstance(alias, str): - raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}') + raise TypeError( + f"alias_generator {alias_generator} must return str, not {alias.__class__}" + ) # if priority is not set, we set to 1 # which supports the case where the alias_generator from a child class is used @@ -1013,17 +1142,25 @@ def _apply_alias_generator_to_field_info( # if the priority is 1, then we set the aliases to the generated alias if field_info.alias_priority == 1: - field_info.serialization_alias = _get_first_non_null(serialization_alias, alias) - field_info.validation_alias = _get_first_non_null(validation_alias, alias) + field_info.serialization_alias = _get_first_non_null( + serialization_alias, alias + ) + field_info.validation_alias = _get_first_non_null( + validation_alias, alias + ) field_info.alias = alias # if any of the aliases are not set, then we set them to the corresponding generated alias if field_info.alias is None: field_info.alias = alias if field_info.serialization_alias is None: - field_info.serialization_alias = _get_first_non_null(serialization_alias, alias) + field_info.serialization_alias = _get_first_non_null( + serialization_alias, alias + ) if field_info.validation_alias is None: - field_info.validation_alias = _get_first_non_null(validation_alias, alias) + field_info.validation_alias = _get_first_non_null( + validation_alias, alias + ) @staticmethod def _apply_alias_generator_to_computed_field_info( @@ -1050,27 +1187,38 @@ def _apply_alias_generator_to_computed_field_info( alias, validation_alias, serialization_alias = None, None, None if isinstance(alias_generator, AliasGenerator): - alias, validation_alias, serialization_alias = alias_generator.generate_aliases(computed_field_name) + alias, validation_alias, serialization_alias = ( + alias_generator.generate_aliases(computed_field_name) + ) elif isinstance(alias_generator, Callable): alias = alias_generator(computed_field_name) if not isinstance(alias, str): - raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}') + raise TypeError( + f"alias_generator {alias_generator} must return str, not {alias.__class__}" + ) # if priority is not set, we set to 1 # which supports the case where the alias_generator from a child class is used # to generate an alias for a field in a parent class - if computed_field_info.alias_priority is None or computed_field_info.alias_priority <= 1: + if ( + computed_field_info.alias_priority is None + or computed_field_info.alias_priority <= 1 + ): computed_field_info.alias_priority = 1 # if the priority is 1, then we set the aliases to the generated alias # note that we use the serialization_alias with priority over alias, as computed_field # aliases are used for serialization only (not validation) if computed_field_info.alias_priority == 1: - computed_field_info.alias = _get_first_non_null(serialization_alias, alias) + computed_field_info.alias = _get_first_non_null( + serialization_alias, alias + ) @staticmethod def _apply_field_title_generator_to_field_info( - config_wrapper: ConfigWrapper, field_info: FieldInfo | ComputedFieldInfo, field_name: str + config_wrapper: ConfigWrapper, + field_info: FieldInfo | ComputedFieldInfo, + field_name: str, ) -> None: """Apply a field_title_generator on a FieldInfo or ComputedFieldInfo instance if appropriate Args: @@ -1078,7 +1226,9 @@ def _apply_field_title_generator_to_field_info( field_info: The FieldInfo or ComputedField instance to which the title_generator is (maybe) applied. field_name: The name of the field from which to generate the title. """ - field_title_generator = field_info.field_title_generator or config_wrapper.field_title_generator + field_title_generator = ( + field_info.field_title_generator or config_wrapper.field_title_generator + ) if field_title_generator is None: return @@ -1086,7 +1236,9 @@ def _apply_field_title_generator_to_field_info( if field_info.title is None: title = field_title_generator(field_name, field_info) # type: ignore if not isinstance(title, str): - raise TypeError(f'field_title_generator {field_title_generator} must return str, not {title.__class__}') + raise TypeError( + f"field_title_generator {field_title_generator} must return str, not {title.__class__}" + ) field_info.title = title @@ -1095,17 +1247,24 @@ def _common_field_schema( # C901 ) -> _CommonField: # Update FieldInfo annotation if appropriate: from .. import AliasChoices, AliasPath - from ..fields import FieldInfo if has_instance_in_type(field_info.annotation, (ForwardRef, str)): + from ..fields import FieldInfo + types_namespace = self._types_namespace if self._typevars_map: types_namespace = (types_namespace or {}).copy() # Ensure that typevars get mapped to their concrete types: - types_namespace.update({k.__name__: v for k, v in self._typevars_map.items()}) + types_namespace.update( + {k.__name__: v for k, v in self._typevars_map.items()} + ) - evaluated = _typing_extra.eval_type_lenient(field_info.annotation, types_namespace) - if evaluated is not field_info.annotation and not has_instance_in_type(evaluated, PydanticRecursiveRef): + evaluated = _typing_extra.eval_type_lenient( + field_info.annotation, types_namespace + ) + if evaluated is not field_info.annotation and not has_instance_in_type( + evaluated, PydanticRecursiveRef + ): new_field_info = FieldInfo.from_annotation(evaluated) field_info.annotation = new_field_info.annotation @@ -1115,7 +1274,10 @@ def _common_field_schema( # C901 # default value), and that should take the highest priority. So don't overwrite existing attributes. # We skip over "attributes" that are present in the metadata_lookup dict because these won't # actually end up as attributes of the `FieldInfo` instance. - if k not in field_info._attributes_set and k not in field_info.metadata_lookup: + if ( + k not in field_info._attributes_set + and k not in field_info.metadata_lookup + ): setattr(field_info, k, v) # Finally, ensure the field info also reflects all the `_attributes_set` that are actually metadata. @@ -1124,12 +1286,16 @@ def _common_field_schema( # C901 source_type, annotations = field_info.annotation, field_info.metadata def set_discriminator(schema: CoreSchema) -> CoreSchema: - schema = self._apply_discriminator_to_union(schema, field_info.discriminator) + schema = self._apply_discriminator_to_union( + schema, field_info.discriminator + ) return schema with self.field_name_stack.push(name): if field_info.discriminator is not None: - schema = self._apply_annotations(source_type, annotations, transform_inner_schema=set_discriminator) + schema = self._apply_annotations( + source_type, annotations, transform_inner_schema=set_discriminator + ) else: schema = self._apply_annotations( source_type, @@ -1140,16 +1306,30 @@ def set_discriminator(schema: CoreSchema) -> CoreSchema: # push down any `each_item=True` validators # note that this won't work for any Annotated types that get wrapped by a function validator # but that's okay because that didn't exist in V1 - this_field_validators = filter_field_decorator_info_by_field(decorators.validators.values(), name) + this_field_validators = filter_field_decorator_info_by_field( + decorators.validators.values(), name + ) if _validators_require_validate_default(this_field_validators): field_info.validate_default = True - each_item_validators = [v for v in this_field_validators if v.info.each_item is True] - this_field_validators = [v for v in this_field_validators if v not in each_item_validators] + each_item_validators = [ + v for v in this_field_validators if v.info.each_item is True + ] + this_field_validators = [ + v for v in this_field_validators if v not in each_item_validators + ] schema = apply_each_item_validators(schema, each_item_validators, name) - schema = apply_validators(schema, filter_field_decorator_info_by_field(this_field_validators, name), name) schema = apply_validators( - schema, filter_field_decorator_info_by_field(decorators.field_validators.values(), name), name + schema, + filter_field_decorator_info_by_field(this_field_validators, name), + name, + ) + schema = apply_validators( + schema, + filter_field_decorator_info_by_field( + decorators.field_validators.values(), name + ), + name, ) # the default validator needs to go outside of any other validators @@ -1159,22 +1339,33 @@ def set_discriminator(schema: CoreSchema) -> CoreSchema: schema = wrap_default(field_info, schema) schema = self._apply_field_serializers( - schema, filter_field_decorator_info_by_field(decorators.field_serializers.values(), name) + schema, + filter_field_decorator_info_by_field( + decorators.field_serializers.values(), name + ), + ) + self._apply_field_title_generator_to_field_info( + self._config_wrapper, field_info, name ) - self._apply_field_title_generator_to_field_info(self._config_wrapper, field_info, name) json_schema_updates = { - 'title': field_info.title, - 'description': field_info.description, - 'deprecated': bool(field_info.deprecated) or field_info.deprecated == '' or None, - 'examples': to_jsonable_python(field_info.examples), + "title": field_info.title, + "description": field_info.description, + "deprecated": bool(field_info.deprecated) + or field_info.deprecated == "" + or None, + "examples": to_jsonable_python(field_info.examples), + } + json_schema_updates = { + k: v for k, v in json_schema_updates.items() if v is not None } - json_schema_updates = {k: v for k, v in json_schema_updates.items() if v is not None} json_schema_extra = field_info.json_schema_extra metadata = build_metadata_dict( - js_annotation_functions=[get_json_schema_update_func(json_schema_updates, json_schema_extra)] + js_annotation_functions=[ + get_json_schema_update_func(json_schema_updates, json_schema_extra) + ] ) alias_generator = self._config_wrapper.alias_generator @@ -1211,7 +1402,7 @@ def _union_schema(self, union_type: Any) -> core_schema.CoreSchema: else: choices_with_tags: list[CoreSchema | tuple[CoreSchema, str]] = [] for choice in choices: - tag = choice.get('metadata', {}).get(_core_utils.TAGGED_UNION_TAG_KEY) + tag = choice.get("metadata", {}).get(_core_utils.TAGGED_UNION_TAG_KEY) if tag is not None: choices_with_tags.append((choice, tag)) else: @@ -1236,11 +1427,13 @@ def _type_alias_type_schema( typevars_map = get_standard_typevars_map(obj) with self._types_namespace_stack.push(origin): - annotation = _typing_extra.eval_type_lenient(annotation, self._types_namespace) + annotation = _typing_extra.eval_type_lenient( + annotation, self._types_namespace + ) annotation = replace_types(annotation, typevars_map) schema = self.generate_schema(annotation) - assert schema['type'] != 'definitions' - schema['ref'] = ref # type: ignore + assert schema["type"] != "definitions" + schema["ref"] = ref # type: ignore self.defs.definitions[ref] = schema return core_schema.definition_reference_schema(ref) @@ -1250,14 +1443,18 @@ def _literal_schema(self, literal_type: Any) -> CoreSchema: assert expected, f'literal "expected" cannot be empty, obj={literal_type}' schema = core_schema.literal_schema(expected) - if self._config_wrapper.use_enum_values and any(isinstance(v, Enum) for v in expected): + if self._config_wrapper.use_enum_values and any( + isinstance(v, Enum) for v in expected + ): schema = core_schema.no_info_after_validator_function( lambda v: v.value if isinstance(v, Enum) else v, schema ) return schema - def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.CoreSchema: + def _typed_dict_schema( + self, typed_dict_cls: Any, origin: Any + ) -> core_schema.CoreSchema: """Generate schema for a TypedDict. It is not possible to track required/optional keys in TypedDict without __required_keys__ @@ -1275,7 +1472,9 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co """ from ..fields import FieldInfo - with self.model_type_stack.push(typed_dict_cls), self.defs.get_schema_or_ref(typed_dict_cls) as ( + with self.model_type_stack.push(typed_dict_cls), self.defs.get_schema_or_ref( + typed_dict_cls + ) as ( typed_dict_ref, maybe_schema, ): @@ -1286,18 +1485,22 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co if origin is not None: typed_dict_cls = origin - if not _SUPPORTS_TYPEDDICT and type(typed_dict_cls).__module__ == 'typing': + if not _SUPPORTS_TYPEDDICT and type(typed_dict_cls).__module__ == "typing": raise PydanticUserError( - 'Please use `typing_extensions.TypedDict` instead of `typing.TypedDict` on Python < 3.12.', - code='typed-dict-version', + "Please use `typing_extensions.TypedDict` instead of `typing.TypedDict` on Python < 3.12.", + code="typed-dict-version", ) try: - config: ConfigDict | None = get_attribute_from_bases(typed_dict_cls, '__pydantic_config__') + config: ConfigDict | None = get_attribute_from_bases( + typed_dict_cls, "__pydantic_config__" + ) except AttributeError: config = None - with self._config_wrapper_stack.push(config), self._types_namespace_stack.push(typed_dict_cls): + with self._config_wrapper_stack.push( + config + ), self._types_namespace_stack.push(typed_dict_cls): core_config = self._config_wrapper.core_config(typed_dict_cls) self = self._current_generate_schema @@ -1309,7 +1512,9 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co decorators = DecoratorInfos.build(typed_dict_cls) if self._config_wrapper.use_attribute_docstrings: - field_docstrings = extract_docstrings_from_cls(typed_dict_cls, use_inspect=True) + field_docstrings = extract_docstrings_from_cls( + typed_dict_cls, use_inspect=True + ) else: field_docstrings = None @@ -1339,14 +1544,22 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co and field_name in field_docstrings ): field_info.description = field_docstrings[field_name] - self._apply_field_title_generator_to_field_info(self._config_wrapper, field_info, field_name) + self._apply_field_title_generator_to_field_info( + self._config_wrapper, field_info, field_name + ) fields[field_name] = self._generate_td_field_schema( field_name, field_info, decorators, required=required ) - title = self._get_model_title_from_config(typed_dict_cls, ConfigWrapper(config)) + title = self._get_model_title_from_config( + typed_dict_cls, ConfigWrapper(config) + ) metadata = build_metadata_dict( - js_functions=[partial(modify_model_json_schema, cls=typed_dict_cls, title=title)], + js_functions=[ + partial( + modify_model_json_schema, cls=typed_dict_cls, title=title + ) + ], typed_dict_cls=typed_dict_cls, ) td_schema = core_schema.typed_dict_schema( @@ -1360,14 +1573,22 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co config=core_config, ) - schema = self._apply_model_serializers(td_schema, decorators.model_serializers.values()) - schema = apply_model_validators(schema, decorators.model_validators.values(), 'all') + schema = self._apply_model_serializers( + td_schema, decorators.model_serializers.values() + ) + schema = apply_model_validators( + schema, decorators.model_validators.values(), "all" + ) self.defs.definitions[typed_dict_ref] = schema return core_schema.definition_reference_schema(typed_dict_ref) - def _namedtuple_schema(self, namedtuple_cls: Any, origin: Any) -> core_schema.CoreSchema: + def _namedtuple_schema( + self, namedtuple_cls: Any, origin: Any + ) -> core_schema.CoreSchema: """Generate schema for a NamedTuple.""" - with self.model_type_stack.push(namedtuple_cls), self.defs.get_schema_or_ref(namedtuple_cls) as ( + with self.model_type_stack.push(namedtuple_cls), self.defs.get_schema_or_ref( + namedtuple_cls + ) as ( namedtuple_ref, maybe_schema, ): @@ -1393,20 +1614,28 @@ def _namedtuple_schema(self, namedtuple_cls: Any, origin: Any) -> core_schema.Co arguments_schema = core_schema.arguments_schema( [ self._generate_parameter_schema( - field_name, annotation, default=namedtuple_cls._field_defaults.get(field_name, Parameter.empty) + field_name, + annotation, + default=namedtuple_cls._field_defaults.get( + field_name, Parameter.empty + ), ) for field_name, annotation in annotations.items() ], metadata=build_metadata_dict(js_prefer_positional_arguments=True), ) - return core_schema.call_schema(arguments_schema, namedtuple_cls, ref=namedtuple_ref) + return core_schema.call_schema( + arguments_schema, namedtuple_cls, ref=namedtuple_ref + ) def _generate_parameter_schema( self, name: str, annotation: type[Any], default: Any = Parameter.empty, - mode: Literal['positional_only', 'positional_or_keyword', 'keyword_only'] | None = None, + mode: ( + Literal["positional_only", "positional_or_keyword", "keyword_only"] | None + ) = None, ) -> core_schema.ArgumentsParameter: """Prepare a ArgumentsParameter to represent a field in a namedtuple or function signature.""" from ..fields import FieldInfo @@ -1415,7 +1644,9 @@ def _generate_parameter_schema( field = FieldInfo.from_annotation(annotation) else: field = FieldInfo.from_annotated_attribute(annotation, default) - assert field.annotation is not None, 'field.annotation should not be None when generating a schema' + assert ( + field.annotation is not None + ), "field.annotation should not be None when generating a schema" source_type, annotations = field.annotation, field.metadata with self.field_name_stack.push(name): schema = self._apply_annotations(source_type, annotations) @@ -1425,15 +1656,18 @@ def _generate_parameter_schema( parameter_schema = core_schema.arguments_parameter(name, schema) if mode is not None: - parameter_schema['mode'] = mode + parameter_schema["mode"] = mode if field.alias is not None: - parameter_schema['alias'] = field.alias + parameter_schema["alias"] = field.alias else: alias_generator = self._config_wrapper.alias_generator - if isinstance(alias_generator, AliasGenerator) and alias_generator.alias is not None: - parameter_schema['alias'] = alias_generator.alias(name) + if ( + isinstance(alias_generator, AliasGenerator) + and alias_generator.alias is not None + ): + parameter_schema["alias"] = alias_generator.alias(name) elif isinstance(alias_generator, Callable): - parameter_schema['alias'] = alias_generator(name) + parameter_schema["alias"] = alias_generator(name) return parameter_schema def _tuple_schema(self, tuple_type: Any) -> core_schema.CoreSchema: @@ -1449,34 +1683,42 @@ def _tuple_schema(self, tuple_type: Any) -> core_schema.CoreSchema: # This is only true for <3.11, on Python 3.11+ `typing.Tuple[()]` gives `params=()` if not params: if tuple_type in TUPLE_TYPES: - return core_schema.tuple_schema([core_schema.any_schema()], variadic_item_index=0) + return core_schema.tuple_schema( + [core_schema.any_schema()], variadic_item_index=0 + ) else: # special case for `tuple[()]` which means `tuple[]` - an empty tuple return core_schema.tuple_schema([]) elif params[-1] is Ellipsis: if len(params) == 2: - return core_schema.tuple_schema([self.generate_schema(params[0])], variadic_item_index=0) + return core_schema.tuple_schema( + [self.generate_schema(params[0])], variadic_item_index=0 + ) else: # TODO: something like https://github.com/pydantic/pydantic/issues/5952 - raise ValueError('Variable tuples can only have one type') + raise ValueError("Variable tuples can only have one type") elif len(params) == 1 and params[0] == (): # special case for `Tuple[()]` which means `Tuple[]` - an empty tuple # NOTE: This conditional can be removed when we drop support for Python 3.10. return core_schema.tuple_schema([]) else: - return core_schema.tuple_schema([self.generate_schema(param) for param in params]) + return core_schema.tuple_schema( + [self.generate_schema(param) for param in params] + ) def _type_schema(self) -> core_schema.CoreSchema: return core_schema.custom_error_schema( core_schema.is_instance_schema(type), - custom_error_type='is_type', - custom_error_message='Input should be a type', + custom_error_type="is_type", + custom_error_message="Input should be a type", ) def _union_is_subclass_schema(self, union_type: Any) -> core_schema.CoreSchema: """Generate schema for `Type[Union[X, ...]]`.""" args = self._get_args_resolving_forward_refs(union_type, required=True) - return core_schema.union_schema([self.generate_schema(typing.Type[args]) for args in args]) + return core_schema.union_schema( + [self.generate_schema(typing.Type[args]) for args in args] + ) def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema: """Generate schema for a Type, e.g. `Type[int]`.""" @@ -1490,7 +1732,10 @@ def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema: return core_schema.is_subclass_schema(type_param.__bound__) elif type_param.__constraints__: return core_schema.union_schema( - [self.generate_schema(typing.Type[c]) for c in type_param.__constraints__] + [ + self.generate_schema(typing.Type[c]) + for c in type_param.__constraints__ + ] ) else: return self._type_schema() @@ -1507,19 +1752,28 @@ def _sequence_schema(self, sequence_type: Any) -> core_schema.CoreSchema: item_type_schema = self.generate_schema(item_type) list_schema = core_schema.list_schema(item_type_schema) - python_schema = core_schema.is_instance_schema(typing.Sequence, cls_repr='Sequence') + python_schema = core_schema.is_instance_schema( + typing.Sequence, cls_repr="Sequence" + ) if item_type != Any: from ._validators import sequence_validator python_schema = core_schema.chain_schema( - [python_schema, core_schema.no_info_wrap_validator_function(sequence_validator, list_schema)], + [ + python_schema, + core_schema.no_info_wrap_validator_function( + sequence_validator, list_schema + ), + ], ) serialization = core_schema.wrap_serializer_function_ser_schema( serialize_sequence_via_list, schema=item_type_schema, info_arg=True ) return core_schema.json_or_python_schema( - json_schema=list_schema, python_schema=python_schema, serialization=serialization + json_schema=list_schema, + python_schema=python_schema, + serialization=serialization, ) def _iterable_schema(self, type_: Any) -> core_schema.GeneratorSchema: @@ -1531,14 +1785,20 @@ def _iterable_schema(self, type_: Any) -> core_schema.GeneratorSchema: def _pattern_schema(self, pattern_type: Any) -> core_schema.CoreSchema: from . import _validators - metadata = build_metadata_dict(js_functions=[lambda _1, _2: {'type': 'string', 'format': 'regex'}]) + metadata = build_metadata_dict( + js_functions=[lambda _1, _2: {"type": "string", "format": "regex"}] + ) ser = core_schema.plain_serializer_function_ser_schema( - attrgetter('pattern'), when_used='json', return_schema=core_schema.str_schema() + attrgetter("pattern"), + when_used="json", + return_schema=core_schema.str_schema(), ) if pattern_type == typing.Pattern or pattern_type == re.Pattern: # bare type return core_schema.no_info_plain_validator_function( - _validators.pattern_either_validator, serialization=ser, metadata=metadata + _validators.pattern_either_validator, + serialization=ser, + metadata=metadata, ) param = self._get_args_resolving_forward_refs( @@ -1551,23 +1811,29 @@ def _pattern_schema(self, pattern_type: Any) -> core_schema.CoreSchema: ) elif param is bytes: return core_schema.no_info_plain_validator_function( - _validators.pattern_bytes_validator, serialization=ser, metadata=metadata + _validators.pattern_bytes_validator, + serialization=ser, + metadata=metadata, ) else: - raise PydanticSchemaGenerationError(f'Unable to generate pydantic-core schema for {pattern_type!r}.') + raise PydanticSchemaGenerationError( + f"Unable to generate pydantic-core schema for {pattern_type!r}." + ) def _hashable_schema(self) -> core_schema.CoreSchema: return core_schema.custom_error_schema( core_schema.is_instance_schema(collections.abc.Hashable), - custom_error_type='is_hashable', - custom_error_message='Input should be hashable', + custom_error_type="is_hashable", + custom_error_message="Input should be hashable", ) def _dataclass_schema( self, dataclass: type[StandardDataclass], origin: type[StandardDataclass] | None ) -> core_schema.CoreSchema: """Generate schema for a dataclass.""" - with self.model_type_stack.push(dataclass), self.defs.get_schema_or_ref(dataclass) as ( + with self.model_type_stack.push(dataclass), self.defs.get_schema_or_ref( + dataclass + ) as ( dataclass_ref, maybe_schema, ): @@ -1582,14 +1848,18 @@ def _dataclass_schema( # Pushing a namespace prioritises items already in the stack, so iterate though the MRO forwards for dataclass_base in dataclass.__mro__: if dataclasses.is_dataclass(dataclass_base): - dataclass_bases_stack.enter_context(self._types_namespace_stack.push(dataclass_base)) + dataclass_bases_stack.enter_context( + self._types_namespace_stack.push(dataclass_base) + ) # Pushing a config overwrites the previous config, so iterate though the MRO backwards config = None for dataclass_base in reversed(dataclass.__mro__): if dataclasses.is_dataclass(dataclass_base): - config = getattr(dataclass_base, '__pydantic_config__', None) - dataclass_bases_stack.enter_context(self._config_wrapper_stack.push(config)) + config = getattr(dataclass_base, "__pydantic_config__", None) + dataclass_bases_stack.enter_context( + self._config_wrapper_stack.push(config) + ) core_config = self._config_wrapper.core_config(dataclass) @@ -1601,7 +1871,9 @@ def _dataclass_schema( fields = deepcopy(dataclass.__pydantic_fields__) if typevars_map: for field in fields.values(): - field.apply_typevars_map(typevars_map, self._types_namespace) + field.apply_typevars_map( + typevars_map, self._types_namespace + ) else: fields = collect_dataclass_fields( dataclass, @@ -1610,25 +1882,30 @@ def _dataclass_schema( ) # disallow combination of init=False on a dataclass field and extra='allow' on a dataclass - if self._config_wrapper_stack.tail.extra == 'allow': + if self._config_wrapper_stack.tail.extra == "allow": # disallow combination of init=False on a dataclass field and extra='allow' on a dataclass for field_name, field in fields.items(): if field.init is False: raise PydanticUserError( f'Field {field_name} has `init=False` and dataclass has config setting `extra="allow"`. ' - f'This combination is not allowed.', - code='dataclass-init-false-extra-allow', + f"This combination is not allowed.", + code="dataclass-init-false-extra-allow", ) - decorators = dataclass.__dict__.get('__pydantic_decorators__') or DecoratorInfos.build(dataclass) + decorators = dataclass.__dict__.get( + "__pydantic_decorators__" + ) or DecoratorInfos.build(dataclass) # Move kw_only=False args to the start of the list, as this is how vanilla dataclasses work. # Note that when kw_only is missing or None, it is treated as equivalent to kw_only=True args = sorted( - (self._generate_dc_field_schema(k, v, decorators) for k, v in fields.items()), - key=lambda a: a.get('kw_only') is not False, + ( + self._generate_dc_field_schema(k, v, decorators) + for k, v in fields.items() + ), + key=lambda a: a.get("kw_only") is not False, ) - has_post_init = hasattr(dataclass, '__post_init__') - has_slots = hasattr(dataclass, '__slots__') + has_post_init = hasattr(dataclass, "__post_init__") + has_slots = hasattr(dataclass, "__slots__") args_schema = core_schema.dataclass_args_schema( dataclass.__name__, @@ -1640,14 +1917,22 @@ def _dataclass_schema( collect_init_only=has_post_init, ) - inner_schema = apply_validators(args_schema, decorators.root_validators.values(), None) + inner_schema = apply_validators( + args_schema, decorators.root_validators.values(), None + ) model_validators = decorators.model_validators.values() - inner_schema = apply_model_validators(inner_schema, model_validators, 'inner') + inner_schema = apply_model_validators( + inner_schema, model_validators, "inner" + ) - title = self._get_model_title_from_config(dataclass, ConfigWrapper(config)) + title = self._get_model_title_from_config( + dataclass, ConfigWrapper(config) + ) metadata = build_metadata_dict( - js_functions=[partial(modify_model_json_schema, cls=dataclass, title=title)] + js_functions=[ + partial(modify_model_json_schema, cls=dataclass, title=title) + ] ) dc_schema = core_schema.dataclass_schema( @@ -1660,14 +1945,16 @@ def _dataclass_schema( config=core_config, metadata=metadata, ) - schema = self._apply_model_serializers(dc_schema, decorators.model_serializers.values()) - schema = apply_model_validators(schema, model_validators, 'outer') + schema = self._apply_model_serializers( + dc_schema, decorators.model_serializers.values() + ) + schema = apply_model_validators(schema, model_validators, "outer") self.defs.definitions[dataclass_ref] = schema return core_schema.definition_reference_schema(dataclass_ref) # Type checkers seem to assume ExitStack may suppress exceptions and therefore # control flow can exit the `with` block without returning. - assert False, 'Unreachable' + assert False, "Unreachable" def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSchema: """Generate schema for a Callable. @@ -1676,12 +1963,17 @@ def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSche """ sig = signature(function) - type_hints = _typing_extra.get_function_type_hints(function, types_namespace=self._types_namespace) + type_hints = _typing_extra.get_function_type_hints( + function, types_namespace=self._types_namespace + ) - mode_lookup: dict[_ParameterKind, Literal['positional_only', 'positional_or_keyword', 'keyword_only']] = { - Parameter.POSITIONAL_ONLY: 'positional_only', - Parameter.POSITIONAL_OR_KEYWORD: 'positional_or_keyword', - Parameter.KEYWORD_ONLY: 'keyword_only', + mode_lookup: dict[ + _ParameterKind, + Literal["positional_only", "positional_or_keyword", "keyword_only"], + ] = { + Parameter.POSITIONAL_ONLY: "positional_only", + Parameter.POSITIONAL_OR_KEYWORD: "positional_or_keyword", + Parameter.KEYWORD_ONLY: "keyword_only", } arguments_list: list[core_schema.ArgumentsParameter] = [] @@ -1696,7 +1988,9 @@ def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSche parameter_mode = mode_lookup.get(p.kind) if parameter_mode is not None: - arg_schema = self._generate_parameter_schema(name, annotation, p.default, parameter_mode) + arg_schema = self._generate_parameter_schema( + name, annotation, p.default, parameter_mode + ) arguments_list.append(arg_schema) elif p.kind == Parameter.VAR_POSITIONAL: var_args_schema = self.generate_schema(annotation) @@ -1707,7 +2001,7 @@ def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSche return_schema: core_schema.CoreSchema | None = None config_wrapper = self._config_wrapper if config_wrapper.validate_return: - return_hint = type_hints.get('return') + return_hint = type_hints.get("return") if return_hint is not None: return_schema = self.generate_schema(return_hint) @@ -1722,7 +2016,9 @@ def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSche return_schema=return_schema, ) - def _unsubstituted_typevar_schema(self, typevar: typing.TypeVar) -> core_schema.CoreSchema: + def _unsubstituted_typevar_schema( + self, typevar: typing.TypeVar + ) -> core_schema.CoreSchema: assert isinstance(typevar, typing.TypeVar) bound = typevar.__bound__ @@ -1732,11 +2028,11 @@ def _unsubstituted_typevar_schema(self, typevar: typing.TypeVar) -> core_schema. typevar_has_default = typevar.has_default() # type: ignore except AttributeError: # could still have a default if it's an old version of typing_extensions.TypeVar - typevar_has_default = getattr(typevar, '__default__', None) is not None + typevar_has_default = getattr(typevar, "__default__", None) is not None if (bound is not None) + (len(constraints) != 0) + typevar_has_default > 1: raise NotImplementedError( - 'Pydantic does not support mixing more than one of TypeVar bounds, constraints and defaults' + "Pydantic does not support mixing more than one of TypeVar bounds, constraints and defaults" ) if typevar_has_default: @@ -1745,7 +2041,7 @@ def _unsubstituted_typevar_schema(self, typevar: typing.TypeVar) -> core_schema. return self._union_schema(typing.Union[constraints]) # type: ignore elif bound: schema = self.generate_schema(bound) - schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( + schema["serialization"] = core_schema.wrap_serializer_function_ser_schema( lambda x, h: h(x), schema=core_schema.any_schema() ) return schema @@ -1758,14 +2054,16 @@ def _computed_field_schema( field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]], ) -> core_schema.ComputedField: try: - return_type = _decorators.get_function_return_type(d.func, d.info.return_type, self._types_namespace) + return_type = _decorators.get_function_return_type( + d.func, d.info.return_type, self._types_namespace + ) except NameError as e: raise PydanticUndefinedAnnotation.from_name_error(e) from e if return_type is PydanticUndefined: raise PydanticUserError( - 'Computed field is missing return type annotation or specifying `return_type`' - ' to the `@computed_field` decorator (e.g. `@computed_field(return_type=int|str)`)', - code='model-field-missing-annotation', + "Computed field is missing return type annotation or specifying `return_type`" + " to the `@computed_field` decorator (e.g. `@computed_field(return_type=int|str)`)", + code="model-field-missing-annotation", ) return_type = replace_types(return_type, self._typevars_map) @@ -1776,36 +2074,44 @@ def _computed_field_schema( # Apply serializers to computed field if there exist return_type_schema = self._apply_field_serializers( return_type_schema, - filter_field_decorator_info_by_field(field_serializers.values(), d.cls_var_name), + filter_field_decorator_info_by_field( + field_serializers.values(), d.cls_var_name + ), computed_field=True, ) alias_generator = self._config_wrapper.alias_generator if alias_generator is not None: self._apply_alias_generator_to_computed_field_info( - alias_generator=alias_generator, computed_field_info=d.info, computed_field_name=d.cls_var_name + alias_generator=alias_generator, + computed_field_info=d.info, + computed_field_name=d.cls_var_name, ) - self._apply_field_title_generator_to_field_info(self._config_wrapper, d.info, d.cls_var_name) + self._apply_field_title_generator_to_field_info( + self._config_wrapper, d.info, d.cls_var_name + ) - def set_computed_field_metadata(schema: CoreSchemaOrField, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + def set_computed_field_metadata( + schema: CoreSchemaOrField, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: json_schema = handler(schema) - json_schema['readOnly'] = True + json_schema["readOnly"] = True title = d.info.title if title is not None: - json_schema['title'] = title + json_schema["title"] = title description = d.info.description if description is not None: - json_schema['description'] = description + json_schema["description"] = description - if d.info.deprecated or d.info.deprecated == '': - json_schema['deprecated'] = True + if d.info.deprecated or d.info.deprecated == "": + json_schema["deprecated"] = True examples = d.info.examples if examples is not None: - json_schema['examples'] = to_jsonable_python(examples) + json_schema["examples"] = to_jsonable_python(examples) json_schema_extra = d.info.json_schema_extra if json_schema_extra is not None: @@ -1813,9 +2119,14 @@ def set_computed_field_metadata(schema: CoreSchemaOrField, handler: GetJsonSchem return json_schema - metadata = build_metadata_dict(js_annotation_functions=[set_computed_field_metadata]) + metadata = build_metadata_dict( + js_annotation_functions=[set_computed_field_metadata] + ) return core_schema.computed_field( - d.cls_var_name, return_schema=return_type_schema, alias=d.info.alias, metadata=metadata + d.cls_var_name, + return_schema=return_type_schema, + alias=d.info.alias, + metadata=metadata, ) def _annotated_schema(self, annotated_type: Any) -> core_schema.CoreSchema: @@ -1865,8 +2176,12 @@ def _apply_annotations( not expect `source_type` to be an `Annotated` object, it expects it to be the first argument of that (in other words, `GenerateSchema._annotated_schema` just unpacks `Annotated`, this process it). """ - annotations = list(_known_annotated_metadata.expand_grouped_metadata(annotations)) - res = self._get_prepare_pydantic_annotations_for_known_type(source_type, tuple(annotations)) + annotations = list( + _known_annotated_metadata.expand_grouped_metadata(annotations) + ) + res = self._get_prepare_pydantic_annotations_for_known_type( + source_type, tuple(annotations) + ) if res is not None: source_type, annotations = res @@ -1897,10 +2212,16 @@ def inner_handler(obj: Any) -> CoreSchema: schema = get_inner_schema(source_type) if pydantic_js_annotation_functions: metadata = CoreMetadataHandler(schema).metadata - metadata.setdefault('pydantic_js_annotation_functions', []).extend(pydantic_js_annotation_functions) - return _add_custom_serialization_from_json_encoders(self._config_wrapper.json_encoders, source_type, schema) + metadata.setdefault("pydantic_js_annotation_functions", []).extend( + pydantic_js_annotation_functions + ) + return _add_custom_serialization_from_json_encoders( + self._config_wrapper.json_encoders, source_type, schema + ) - def _apply_single_annotation(self, schema: core_schema.CoreSchema, metadata: Any) -> core_schema.CoreSchema: + def _apply_single_annotation( + self, schema: core_schema.CoreSchema, metadata: Any + ) -> core_schema.CoreSchema: from ..fields import FieldInfo if isinstance(metadata, FieldInfo): @@ -1908,35 +2229,39 @@ def _apply_single_annotation(self, schema: core_schema.CoreSchema, metadata: Any schema = self._apply_single_annotation(schema, field_metadata) if metadata.discriminator is not None: - schema = self._apply_discriminator_to_union(schema, metadata.discriminator) + schema = self._apply_discriminator_to_union( + schema, metadata.discriminator + ) return schema - if schema['type'] == 'nullable': + if schema["type"] == "nullable": # for nullable schemas, metadata is automatically applied to the inner schema - inner = schema.get('schema', core_schema.any_schema()) + inner = schema.get("schema", core_schema.any_schema()) inner = self._apply_single_annotation(inner, metadata) if inner: - schema['schema'] = inner + schema["schema"] = inner return schema original_schema = schema - ref = schema.get('ref', None) + ref = schema.get("ref", None) if ref is not None: schema = schema.copy() - new_ref = ref + f'_{repr(metadata)}' + new_ref = ref + f"_{repr(metadata)}" if new_ref in self.defs.definitions: return self.defs.definitions[new_ref] - schema['ref'] = new_ref # type: ignore - elif schema['type'] == 'definition-ref': - ref = schema['schema_ref'] + schema["ref"] = new_ref # type: ignore + elif schema["type"] == "definition-ref": + ref = schema["schema_ref"] if ref in self.defs.definitions: schema = self.defs.definitions[ref].copy() - new_ref = ref + f'_{repr(metadata)}' + new_ref = ref + f"_{repr(metadata)}" if new_ref in self.defs.definitions: return self.defs.definitions[new_ref] - schema['ref'] = new_ref # type: ignore + schema["ref"] = new_ref # type: ignore - maybe_updated_schema = _known_annotated_metadata.apply_known_metadata(metadata, schema.copy()) + maybe_updated_schema = _known_annotated_metadata.apply_known_metadata( + metadata, schema.copy() + ) if maybe_updated_schema is not None: return maybe_updated_schema @@ -1949,18 +2274,22 @@ def _apply_single_annotation_json_schema( if isinstance(metadata, FieldInfo): for field_metadata in metadata.metadata: - schema = self._apply_single_annotation_json_schema(schema, field_metadata) + schema = self._apply_single_annotation_json_schema( + schema, field_metadata + ) json_schema_update: JsonSchemaValue = {} if metadata.title: - json_schema_update['title'] = metadata.title + json_schema_update["title"] = metadata.title if metadata.description: - json_schema_update['description'] = metadata.description + json_schema_update["description"] = metadata.description if metadata.examples: - json_schema_update['examples'] = to_jsonable_python(metadata.examples) + json_schema_update["examples"] = to_jsonable_python(metadata.examples) json_schema_extra = metadata.json_schema_extra if json_schema_update or json_schema_extra: - CoreMetadataHandler(schema).metadata.setdefault('pydantic_js_annotation_functions', []).append( + CoreMetadataHandler(schema).metadata.setdefault( + "pydantic_js_annotation_functions", [] + ).append( get_json_schema_update_func(json_schema_update, json_schema_extra) ) return schema @@ -1971,9 +2300,9 @@ def _get_wrapped_inner_schema( annotation: Any, pydantic_js_annotation_functions: list[GetJsonSchemaFunction], ) -> CallbackGetCoreSchemaHandler: - metadata_get_schema: GetCoreSchemaFunction = getattr(annotation, '__get_pydantic_core_schema__', None) or ( - lambda source, handler: handler(source) - ) + metadata_get_schema: GetCoreSchemaFunction = getattr( + annotation, "__get_pydantic_core_schema__", None + ) or (lambda source, handler: handler(source)) def new_handler(source: Any) -> core_schema.CoreSchema: schema = metadata_get_schema(source, get_inner_schema) @@ -1996,12 +2325,14 @@ def _apply_field_serializers( """Apply field serializers to a schema.""" if serializers: schema = copy(schema) - if schema['type'] == 'definitions': - inner_schema = schema['schema'] - schema['schema'] = self._apply_field_serializers(inner_schema, serializers) + if schema["type"] == "definitions": + inner_schema = schema["schema"] + schema["schema"] = self._apply_field_serializers( + inner_schema, serializers + ) return schema else: - ref = typing.cast('str|None', schema.get('ref', None)) + ref = typing.cast("str|None", schema.get("ref", None)) if ref is not None: schema = core_schema.definition_reference_schema(ref) @@ -2023,30 +2354,36 @@ def _apply_field_serializers( else: return_schema = self.generate_schema(return_type) - if serializer.info.mode == 'wrap': - schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( - serializer.func, - is_field_serializer=is_field_serializer, - info_arg=info_arg, - return_schema=return_schema, - when_used=serializer.info.when_used, + if serializer.info.mode == "wrap": + schema["serialization"] = ( + core_schema.wrap_serializer_function_ser_schema( + serializer.func, + is_field_serializer=is_field_serializer, + info_arg=info_arg, + return_schema=return_schema, + when_used=serializer.info.when_used, + ) ) else: - assert serializer.info.mode == 'plain' - schema['serialization'] = core_schema.plain_serializer_function_ser_schema( - serializer.func, - is_field_serializer=is_field_serializer, - info_arg=info_arg, - return_schema=return_schema, - when_used=serializer.info.when_used, + assert serializer.info.mode == "plain" + schema["serialization"] = ( + core_schema.plain_serializer_function_ser_schema( + serializer.func, + is_field_serializer=is_field_serializer, + info_arg=info_arg, + return_schema=return_schema, + when_used=serializer.info.when_used, + ) ) return schema def _apply_model_serializers( - self, schema: core_schema.CoreSchema, serializers: Iterable[Decorator[ModelSerializerDecoratorInfo]] + self, + schema: core_schema.CoreSchema, + serializers: Iterable[Decorator[ModelSerializerDecoratorInfo]], ) -> core_schema.CoreSchema: """Apply model serializers to a schema.""" - ref: str | None = schema.pop('ref', None) # type: ignore + ref: str | None = schema.pop("ref", None) # type: ignore if serializers: serializer = list(serializers)[-1] info_arg = inspect_model_serializer(serializer.func, serializer.info.mode) @@ -2062,12 +2399,14 @@ def _apply_model_serializers( else: return_schema = self.generate_schema(return_type) - if serializer.info.mode == 'wrap': - ser_schema: core_schema.SerSchema = core_schema.wrap_serializer_function_ser_schema( - serializer.func, - info_arg=info_arg, - return_schema=return_schema, - when_used=serializer.info.when_used, + if serializer.info.mode == "wrap": + ser_schema: core_schema.SerSchema = ( + core_schema.wrap_serializer_function_ser_schema( + serializer.func, + info_arg=info_arg, + return_schema=return_schema, + when_used=serializer.info.when_used, + ) ) else: # plain @@ -2077,30 +2416,56 @@ def _apply_model_serializers( return_schema=return_schema, when_used=serializer.info.when_used, ) - schema['serialization'] = ser_schema + schema["serialization"] = ser_schema if ref: - schema['ref'] = ref # type: ignore + schema["ref"] = ref # type: ignore return schema _VALIDATOR_F_MATCH: Mapping[ - tuple[FieldValidatorModes, Literal['no-info', 'with-info']], - Callable[[Callable[..., Any], core_schema.CoreSchema, str | None], core_schema.CoreSchema], + tuple[FieldValidatorModes, Literal["no-info", "with-info"]], + Callable[ + [Callable[..., Any], core_schema.CoreSchema, str | None], core_schema.CoreSchema + ], ] = { - ('before', 'no-info'): lambda f, schema, _: core_schema.no_info_before_validator_function(f, schema), - ('after', 'no-info'): lambda f, schema, _: core_schema.no_info_after_validator_function(f, schema), - ('plain', 'no-info'): lambda f, _1, _2: core_schema.no_info_plain_validator_function(f), - ('wrap', 'no-info'): lambda f, schema, _: core_schema.no_info_wrap_validator_function(f, schema), - ('before', 'with-info'): lambda f, schema, field_name: core_schema.with_info_before_validator_function( + ( + "before", + "no-info", + ): lambda f, schema, _: core_schema.no_info_before_validator_function(f, schema), + ( + "after", + "no-info", + ): lambda f, schema, _: core_schema.no_info_after_validator_function(f, schema), + ( + "plain", + "no-info", + ): lambda f, _1, _2: core_schema.no_info_plain_validator_function(f), + ( + "wrap", + "no-info", + ): lambda f, schema, _: core_schema.no_info_wrap_validator_function(f, schema), + ( + "before", + "with-info", + ): lambda f, schema, field_name: core_schema.with_info_before_validator_function( f, schema, field_name=field_name ), - ('after', 'with-info'): lambda f, schema, field_name: core_schema.with_info_after_validator_function( + ( + "after", + "with-info", + ): lambda f, schema, field_name: core_schema.with_info_after_validator_function( f, schema, field_name=field_name ), - ('plain', 'with-info'): lambda f, _, field_name: core_schema.with_info_plain_validator_function( + ( + "plain", + "with-info", + ): lambda f, _, field_name: core_schema.with_info_plain_validator_function( f, field_name=field_name ), - ('wrap', 'with-info'): lambda f, schema, field_name: core_schema.with_info_wrap_validator_function( + ( + "wrap", + "with-info", + ): lambda f, schema, field_name: core_schema.with_info_wrap_validator_function( f, schema, field_name=field_name ), } @@ -2108,9 +2473,11 @@ def _apply_model_serializers( def apply_validators( schema: core_schema.CoreSchema, - validators: Iterable[Decorator[RootValidatorDecoratorInfo]] - | Iterable[Decorator[ValidatorDecoratorInfo]] - | Iterable[Decorator[FieldValidatorDecoratorInfo]], + validators: ( + Iterable[Decorator[RootValidatorDecoratorInfo]] + | Iterable[Decorator[ValidatorDecoratorInfo]] + | Iterable[Decorator[FieldValidatorDecoratorInfo]] + ), field_name: str | None, ) -> core_schema.CoreSchema: """Apply validators to a schema. @@ -2125,13 +2492,17 @@ def apply_validators( """ for validator in validators: info_arg = inspect_validator(validator.func, validator.info.mode) - val_type = 'with-info' if info_arg else 'no-info' + val_type = "with-info" if info_arg else "no-info" - schema = _VALIDATOR_F_MATCH[(validator.info.mode, val_type)](validator.func, schema, field_name) + schema = _VALIDATOR_F_MATCH[(validator.info.mode, val_type)]( + validator.func, schema, field_name + ) return schema -def _validators_require_validate_default(validators: Iterable[Decorator[ValidatorDecoratorInfo]]) -> bool: +def _validators_require_validate_default( + validators: Iterable[Decorator[ValidatorDecoratorInfo]], +) -> bool: """In v1, if any of the validators for a field had `always=True`, the default value would be validated. This serves as an auxiliary function for re-implementing that logic, by looping over a provided @@ -2150,7 +2521,7 @@ def _validators_require_validate_default(validators: Iterable[Decorator[Validato def apply_model_validators( schema: core_schema.CoreSchema, validators: Iterable[Decorator[ModelValidatorDecoratorInfo]], - mode: Literal['inner', 'outer', 'all'], + mode: Literal["inner", "outer", "all"], ) -> core_schema.CoreSchema: """Apply model validators to a schema. @@ -2166,35 +2537,49 @@ def apply_model_validators( Returns: The updated schema. """ - ref: str | None = schema.pop('ref', None) # type: ignore + ref: str | None = schema.pop("ref", None) # type: ignore for validator in validators: - if mode == 'inner' and validator.info.mode != 'before': + if mode == "inner" and validator.info.mode != "before": continue - if mode == 'outer' and validator.info.mode == 'before': + if mode == "outer" and validator.info.mode == "before": continue info_arg = inspect_validator(validator.func, validator.info.mode) - if validator.info.mode == 'wrap': + if validator.info.mode == "wrap": if info_arg: - schema = core_schema.with_info_wrap_validator_function(function=validator.func, schema=schema) + schema = core_schema.with_info_wrap_validator_function( + function=validator.func, schema=schema + ) else: - schema = core_schema.no_info_wrap_validator_function(function=validator.func, schema=schema) - elif validator.info.mode == 'before': + schema = core_schema.no_info_wrap_validator_function( + function=validator.func, schema=schema + ) + elif validator.info.mode == "before": if info_arg: - schema = core_schema.with_info_before_validator_function(function=validator.func, schema=schema) + schema = core_schema.with_info_before_validator_function( + function=validator.func, schema=schema + ) else: - schema = core_schema.no_info_before_validator_function(function=validator.func, schema=schema) + schema = core_schema.no_info_before_validator_function( + function=validator.func, schema=schema + ) else: - assert validator.info.mode == 'after' + assert validator.info.mode == "after" if info_arg: - schema = core_schema.with_info_after_validator_function(function=validator.func, schema=schema) + schema = core_schema.with_info_after_validator_function( + function=validator.func, schema=schema + ) else: - schema = core_schema.no_info_after_validator_function(function=validator.func, schema=schema) + schema = core_schema.no_info_after_validator_function( + function=validator.func, schema=schema + ) if ref: - schema['ref'] = ref # type: ignore + schema["ref"] = ref # type: ignore return schema -def wrap_default(field_info: FieldInfo, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: +def wrap_default( + field_info: FieldInfo, schema: core_schema.CoreSchema +) -> core_schema.CoreSchema: """Wrap schema with default schema if default value or `default_factory` are available. Args: @@ -2206,39 +2591,47 @@ def wrap_default(field_info: FieldInfo, schema: core_schema.CoreSchema) -> core_ """ if field_info.default_factory: return core_schema.with_default_schema( - schema, default_factory=field_info.default_factory, validate_default=field_info.validate_default + schema, + default_factory=field_info.default_factory, + validate_default=field_info.validate_default, ) elif field_info.default is not PydanticUndefined: return core_schema.with_default_schema( - schema, default=field_info.default, validate_default=field_info.validate_default + schema, + default=field_info.default, + validate_default=field_info.validate_default, ) else: return schema -def _extract_get_pydantic_json_schema(tp: Any, schema: CoreSchema) -> GetJsonSchemaFunction | None: +def _extract_get_pydantic_json_schema( + tp: Any, schema: CoreSchema +) -> GetJsonSchemaFunction | None: """Extract `__get_pydantic_json_schema__` from a type, handling the deprecated `__modify_schema__`.""" - js_modify_function = getattr(tp, '__get_pydantic_json_schema__', None) + js_modify_function = getattr(tp, "__get_pydantic_json_schema__", None) - if hasattr(tp, '__modify_schema__'): + if hasattr(tp, "__modify_schema__"): from pydantic import BaseModel # circular reference has_custom_v2_modify_js_func = ( js_modify_function is not None and BaseModel.__get_pydantic_json_schema__.__func__ # type: ignore - not in (js_modify_function, getattr(js_modify_function, '__func__', None)) + not in (js_modify_function, getattr(js_modify_function, "__func__", None)) ) if not has_custom_v2_modify_js_func: - cls_name = getattr(tp, '__name__', None) + cls_name = getattr(tp, "__name__", None) raise PydanticUserError( - f'The `__modify_schema__` method is not supported in Pydantic v2. ' + f"The `__modify_schema__` method is not supported in Pydantic v2. " f'Use `__get_pydantic_json_schema__` instead{f" in class `{cls_name}`" if cls_name else ""}.', - code='custom-json-schema', + code="custom-json-schema", ) # handle GenericAlias' but ignore Annotated which "lies" about its origin (in this case it would be `int`) - if hasattr(tp, '__origin__') and not isinstance(tp, type(Annotated[int, 'placeholder'])): + if hasattr(tp, "__origin__") and not isinstance( + tp, type(Annotated[int, "placeholder"]) + ): return _extract_get_pydantic_json_schema(tp.__origin__, schema) if js_modify_function is None: @@ -2248,7 +2641,8 @@ def _extract_get_pydantic_json_schema(tp: Any, schema: CoreSchema) -> GetJsonSch def get_json_schema_update_func( - json_schema_update: JsonSchemaValue, json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None + json_schema_update: JsonSchemaValue, + json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None, ) -> GetJsonSchemaFunction: def json_schema_update_func( core_schema_or_field: CoreSchemaOrField, handler: GetJsonSchemaHandler @@ -2261,7 +2655,8 @@ def json_schema_update_func( def add_json_schema_extra( - json_schema: JsonSchemaValue, json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None + json_schema: JsonSchemaValue, + json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None, ): if isinstance(json_schema_extra, dict): json_schema.update(to_jsonable_python(json_schema_extra)) @@ -2288,12 +2683,12 @@ def _common_field( metadata: Any = None, ) -> _CommonField: return { - 'schema': schema, - 'validation_alias': validation_alias, - 'serialization_alias': serialization_alias, - 'serialization_exclude': serialization_exclude, - 'frozen': frozen, - 'metadata': metadata, + "schema": schema, + "validation_alias": validation_alias, + "serialization_alias": serialization_alias, + "serialization_exclude": serialization_exclude, + "frozen": frozen, + "metadata": metadata, } @@ -2305,7 +2700,9 @@ def __init__(self) -> None: self.definitions: dict[str, core_schema.CoreSchema] = {} @contextmanager - def get_schema_or_ref(self, tp: Any) -> Iterator[tuple[str, None] | tuple[str, CoreSchema]]: + def get_schema_or_ref( + self, tp: Any + ) -> Iterator[tuple[str, None] | tuple[str, CoreSchema]]: """Get a definition for `tp` if one exists. If a definition exists, a tuple of `(ref_string, CoreSchema)` is returned. @@ -2336,17 +2733,19 @@ def get_schema_or_ref(self, tp: Any) -> Iterator[tuple[str, None] | tuple[str, C self.seen.discard(ref) -def resolve_original_schema(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> CoreSchema | None: - if schema['type'] == 'definition-ref': - return definitions.get(schema['schema_ref'], None) - elif schema['type'] == 'definitions': - return schema['schema'] - else: - return schema +def resolve_original_schema( + schema: CoreSchema, definitions: dict[str, CoreSchema] +) -> CoreSchema | None: + schema_type = schema["type"] + if schema_type == "definition-ref": + return definitions.get(schema["schema_ref"]) + elif schema_type == "definitions": + return schema["schema"] + return schema class _FieldNameStack: - __slots__ = ('_stack',) + __slots__ = ("_stack",) def __init__(self) -> None: self._stack: list[str] = [] @@ -2365,7 +2764,7 @@ def get(self) -> str | None: class _ModelTypeStack: - __slots__ = ('_stack',) + __slots__ = ("_stack",) def __init__(self) -> None: self._stack: list[type] = [] From 200b2b0ca2fc2793ceea160feaac59678bca78e5 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Mon, 22 Jul 2024 20:39:05 -0700 Subject: [PATCH 2/2] Fix formatting --- pydantic/_internal/_generate_schema.py | 1139 ++++++++---------------- 1 file changed, 370 insertions(+), 769 deletions(-) diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index b6b7ab4de7..e95b7d6bd8 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -36,34 +36,16 @@ from warnings import warn from pydantic_core import CoreSchema, PydanticUndefined, core_schema, to_jsonable_python -from typing_extensions import ( - Annotated, - Literal, - TypeAliasType, - TypedDict, - get_args, - get_origin, - is_typeddict, -) +from typing_extensions import Annotated, Literal, TypeAliasType, TypedDict, get_args, get_origin, is_typeddict from ..aliases import AliasGenerator from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler from ..config import ConfigDict, JsonDict, JsonEncoder -from ..errors import ( - PydanticSchemaGenerationError, - PydanticUndefinedAnnotation, - PydanticUserError, -) +from ..errors import PydanticSchemaGenerationError, PydanticUndefinedAnnotation, PydanticUserError from ..json_schema import JsonSchemaValue from ..version import version_short from ..warnings import PydanticDeprecatedSince20 -from . import ( - _core_utils, - _decorators, - _discriminated_union, - _known_annotated_metadata, - _typing_extra, -) +from . import _core_utils, _decorators, _discriminated_union, _known_annotated_metadata, _typing_extra from ._config import ConfigWrapper, ConfigWrapperStack from ._core_metadata import CoreMetadataHandler, build_metadata_dict from ._core_utils import ( @@ -94,12 +76,7 @@ from ._docs_extraction import extract_docstrings_from_cls from ._fields import collect_dataclass_fields, get_type_hints_infer_globalns from ._forward_ref import PydanticRecursiveRef -from ._generics import ( - get_standard_typevars_map, - has_instance_in_type, - recursively_defined_type_refs, - replace_types, -) +from ._generics import get_standard_typevars_map, has_instance_in_type, recursively_defined_type_refs, replace_types from ._mock_val_ser import MockCoreSchema from ._schema_generation_shared import CallbackGetCoreSchemaHandler from ._typing_extra import is_finalvar, is_self_type @@ -116,10 +93,8 @@ _SUPPORTS_TYPEDDICT = sys.version_info >= (3, 12) _AnnotatedType = type(Annotated[int, 123]) -FieldDecoratorInfo = Union[ - ValidatorDecoratorInfo, FieldValidatorDecoratorInfo, FieldSerializerDecoratorInfo -] -FieldDecoratorInfoType = TypeVar("FieldDecoratorInfoType", bound=FieldDecoratorInfo) +FieldDecoratorInfo = Union[ValidatorDecoratorInfo, FieldValidatorDecoratorInfo, FieldSerializerDecoratorInfo] +FieldDecoratorInfoType = TypeVar('FieldDecoratorInfoType', bound=FieldDecoratorInfo) AnyFieldDecorator = Union[ Decorator[ValidatorDecoratorInfo], Decorator[FieldValidatorDecoratorInfo], @@ -127,20 +102,13 @@ ] ModifyCoreSchemaWrapHandler = GetCoreSchemaHandler -GetCoreSchemaFunction = Callable[ - [Any, ModifyCoreSchemaWrapHandler], core_schema.CoreSchema -] +GetCoreSchemaFunction = Callable[[Any, ModifyCoreSchemaWrapHandler], core_schema.CoreSchema] TUPLE_TYPES: list[type] = [tuple, typing.Tuple] LIST_TYPES: list[type] = [list, typing.List, collections.abc.MutableSequence] SET_TYPES: list[type] = [set, typing.Set, collections.abc.MutableSet] FROZEN_SET_TYPES: list[type] = [frozenset, typing.FrozenSet, collections.abc.Set] -DICT_TYPES: list[type] = [ - dict, - typing.Dict, - collections.abc.MutableMapping, - collections.abc.Mapping, -] +DICT_TYPES: list[type] = [dict, typing.Dict, collections.abc.MutableMapping, collections.abc.Mapping] def check_validator_fields_against_field_name( @@ -156,7 +124,7 @@ def check_validator_fields_against_field_name( Returns: `True` if field name is in validator fields, `False` otherwise. """ - if "*" in info.fields: + if '*' in info.fields: return True for v_field_name in info.fields: if v_field_name == field: @@ -164,9 +132,7 @@ def check_validator_fields_against_field_name( return False -def check_decorator_fields_exist( - decorators: Iterable[AnyFieldDecorator], fields: Iterable[str] -) -> None: +def check_decorator_fields_exist(decorators: Iterable[AnyFieldDecorator], fields: Iterable[str]) -> None: """Check if the defined fields in decorators exist in `fields` param. It ignores the check for a decorator if the decorator has `*` as field or `check_fields=False`. @@ -180,27 +146,23 @@ def check_decorator_fields_exist( """ fields = set(fields) for dec in decorators: - if "*" in dec.info.fields: + if '*' in dec.info.fields: continue if dec.info.check_fields is False: continue for field in dec.info.fields: if field not in fields: raise PydanticUserError( - f"Decorators defined with incorrect fields: {dec.cls_ref}.{dec.cls_var_name}" + f'Decorators defined with incorrect fields: {dec.cls_ref}.{dec.cls_var_name}' " (use check_fields=False if you're inheriting from the model and intended this)", - code="decorator-missing-field", + code='decorator-missing-field', ) def filter_field_decorator_info_by_field( validator_functions: Iterable[Decorator[FieldDecoratorInfoType]], field: str ) -> list[Decorator[FieldDecoratorInfoType]]: - return [ - dec - for dec in validator_functions - if check_validator_fields_against_field_name(dec.info, field) - ] + return [dec for dec in validator_functions if check_validator_fields_against_field_name(dec.info, field)] def apply_each_item_validators( @@ -213,34 +175,26 @@ def apply_each_item_validators( # push down any `each_item=True` validators # note that this won't work for any Annotated types that get wrapped by a function validator # but that's okay because that didn't exist in V1 - if schema["type"] == "nullable": - schema["schema"] = apply_each_item_validators( - schema["schema"], each_item_validators, field_name - ) + if schema['type'] == 'nullable': + schema['schema'] = apply_each_item_validators(schema['schema'], each_item_validators, field_name) return schema - elif schema["type"] == "tuple": - if (variadic_item_index := schema.get("variadic_item_index")) is not None: - schema["items_schema"][variadic_item_index] = apply_validators( - schema["items_schema"][variadic_item_index], - each_item_validators, - field_name, + elif schema['type'] == 'tuple': + if (variadic_item_index := schema.get('variadic_item_index')) is not None: + schema['items_schema'][variadic_item_index] = apply_validators( + schema['items_schema'][variadic_item_index], each_item_validators, field_name ) elif is_list_like_schema_with_items_schema(schema): - inner_schema = schema.get("items_schema", None) + inner_schema = schema.get('items_schema', None) if inner_schema is None: inner_schema = core_schema.any_schema() - schema["items_schema"] = apply_validators( - inner_schema, each_item_validators, field_name - ) - elif schema["type"] == "dict": + schema['items_schema'] = apply_validators(inner_schema, each_item_validators, field_name) + elif schema['type'] == 'dict': # push down any `each_item=True` validators onto dict _values_ # this is super arbitrary but it's the V1 behavior - inner_schema = schema.get("values_schema", None) + inner_schema = schema.get('values_schema', None) if inner_schema is None: inner_schema = core_schema.any_schema() - schema["values_schema"] = apply_validators( - inner_schema, each_item_validators, field_name - ) + schema['values_schema'] = apply_validators(inner_schema, each_item_validators, field_name) elif each_item_validators: raise TypeError( f"`@validator(..., each_item=True)` cannot be applied to fields with a schema of {schema['type']}" @@ -274,24 +228,20 @@ def modify_model_json_schema( json_schema = handler(schema_or_field) original_schema = handler.resolve_ref_schema(json_schema) # Preserve the fact that definitions schemas should never have sibling keys: - if "$ref" in original_schema: - ref = original_schema["$ref"] + if '$ref' in original_schema: + ref = original_schema['$ref'] original_schema.clear() - original_schema["allOf"] = [{"$ref": ref}] + original_schema['allOf'] = [{'$ref': ref}] if title is not None: - original_schema["title"] = title - elif "title" not in original_schema: - original_schema["title"] = cls.__name__ + original_schema['title'] = title + elif 'title' not in original_schema: + original_schema['title'] = cls.__name__ # BaseModel + Dataclass; don't use cls.__doc__ as it will contain the verbose class signature by default - docstring = ( - None - if cls is BaseModel or is_builtin_dataclass(cls) or is_pydantic_dataclass(cls) - else cls.__doc__ - ) - if docstring and "description" not in original_schema: - original_schema["description"] = inspect.cleandoc(docstring) - elif issubclass(cls, RootModel) and cls.model_fields["root"].description: - original_schema["description"] = cls.model_fields["root"].description + docstring = None if cls is BaseModel or is_builtin_dataclass(cls) or is_pydantic_dataclass(cls) else cls.__doc__ + if docstring and 'description' not in original_schema: + original_schema['description'] = inspect.cleandoc(docstring) + elif issubclass(cls, RootModel) and cls.model_fields['root'].description: + original_schema['description'] = cls.model_fields['root'].description return json_schema @@ -310,25 +260,23 @@ def _add_custom_serialization_from_json_encoders( """ if not json_encoders: return schema - if "serialization" in schema: + if 'serialization' in schema: return schema # Check the class type and its superclasses for a matching encoder # Decimal.__class__.__mro__ (and probably other cases) doesn't include Decimal itself # if the type is a GenericAlias (e.g. from list[int]) we need to use __class__ instead of .__mro__ - for base in (tp, *getattr(tp, "__mro__", tp.__class__.__mro__)[:-1]): + for base in (tp, *getattr(tp, '__mro__', tp.__class__.__mro__)[:-1]): encoder = json_encoders.get(base) if encoder is None: continue warnings.warn( - f"`json_encoders` is deprecated. See https://docs.pydantic.dev/{version_short()}/concepts/serialization/#custom-serializers for alternatives", + f'`json_encoders` is deprecated. See https://docs.pydantic.dev/{version_short()}/concepts/serialization/#custom-serializers for alternatives', PydanticDeprecatedSince20, ) # TODO: in theory we should check that the schema accepts a serialization key - schema["serialization"] = core_schema.plain_serializer_function_ser_schema( - encoder, when_used="json" - ) + schema['serialization'] = core_schema.plain_serializer_function_ser_schema(encoder, when_used='json') return schema return schema @@ -349,10 +297,7 @@ def tail(self) -> TypesNamespace: @contextmanager def push(self, for_type: type[Any]): - types_namespace = { - **_typing_extra.get_cls_types_namespace(for_type), - **(self.tail or {}), - } + types_namespace = {**_typing_extra.get_cls_types_namespace(for_type), **(self.tail or {})} self._types_namespace_stack.append(types_namespace) try: yield @@ -373,12 +318,12 @@ class GenerateSchema: """Generate core schema for a Pydantic model, dataclass and types like `str`, `datetime`, ... .""" __slots__ = ( - "_config_wrapper_stack", - "_types_namespace_stack", - "_typevars_map", - "field_name_stack", - "model_type_stack", - "defs", + '_config_wrapper_stack', + '_types_namespace_stack', + '_typevars_map', + 'field_name_stack', + 'model_type_stack', + 'defs', ) def __init__( @@ -446,9 +391,7 @@ def _list_schema(self, tp: Any, items_type: Any) -> CoreSchema: return core_schema.list_schema(self.generate_schema(items_type)) def _dict_schema(self, tp: Any, keys_type: Any, values_type: Any) -> CoreSchema: - return core_schema.dict_schema( - self.generate_schema(keys_type), self.generate_schema(values_type) - ) + return core_schema.dict_schema(self.generate_schema(keys_type), self.generate_schema(values_type)) def _set_schema(self, tp: Any, items_type: Any) -> CoreSchema: return core_schema.set_schema(self.generate_schema(items_type)) @@ -459,10 +402,10 @@ def _frozenset_schema(self, tp: Any, items_type: Any) -> CoreSchema: def _arbitrary_type_schema(self, tp: Any) -> CoreSchema: if not isinstance(tp, type): warn( - f"{tp!r} is not a Python type (it may be an instance of an object)," - " Pydantic will allow any object with no validation since we cannot even" - " enforce that the input is an instance of the given type." - " To get rid of this error wrap the type with `pydantic.SkipValidation`.", + f'{tp!r} is not a Python type (it may be an instance of an object),' + ' Pydantic will allow any object with no validation since we cannot even' + ' enforce that the input is an instance of the given type.' + ' To get rid of this error wrap the type with `pydantic.SkipValidation`.', UserWarning, ) return core_schema.any_schema() @@ -470,13 +413,13 @@ def _arbitrary_type_schema(self, tp: Any) -> CoreSchema: def _unknown_type_schema(self, obj: Any) -> CoreSchema: raise PydanticSchemaGenerationError( - f"Unable to generate pydantic-core schema for {obj!r}. " - "Set `arbitrary_types_allowed=True` in the model_config to ignore this error" - " or implement `__get_pydantic_core_schema__` on your type to fully support it." - "\n\nIf you got this error by calling handler() within" - " `__get_pydantic_core_schema__` then you likely need to call" - " `handler.generate_schema()` since we do not call" - " `__get_pydantic_core_schema__` on `` otherwise to avoid infinite recursion." + f'Unable to generate pydantic-core schema for {obj!r}. ' + 'Set `arbitrary_types_allowed=True` in the model_config to ignore this error' + ' or implement `__get_pydantic_core_schema__` on your type to fully support it.' + '\n\nIf you got this error by calling handler() within' + ' `__get_pydantic_core_schema__` then you likely need to call' + ' `handler.generate_schema()` since we do not call' + ' `__get_pydantic_core_schema__` on `` otherwise to avoid infinite recursion.' ) def _apply_discriminator_to_union( @@ -510,21 +453,19 @@ def clean_schema(self, schema: CoreSchema) -> CoreSchema: return schema def collect_definitions(self, schema: CoreSchema) -> CoreSchema: - ref = cast("str | None", schema.get("ref", None)) + ref = cast('str | None', schema.get('ref', None)) if ref: self.defs.definitions[ref] = schema - if "ref" in schema: - schema = core_schema.definition_reference_schema(schema["ref"]) + if 'ref' in schema: + schema = core_schema.definition_reference_schema(schema['ref']) return core_schema.definitions_schema( schema, list(self.defs.definitions.values()), ) - def _add_js_function( - self, metadata_schema: CoreSchema, js_function: Callable[..., Any] - ) -> None: + def _add_js_function(self, metadata_schema: CoreSchema, js_function: Callable[..., Any]) -> None: metadata = CoreMetadataHandler(metadata_schema).metadata - pydantic_js_functions = metadata.setdefault("pydantic_js_functions", []) + pydantic_js_functions = metadata.setdefault('pydantic_js_functions', []) # because of how we generate core schemas for nested generic models # we can end up adding `BaseModel.__get_pydantic_json_schema__` multiple times # this check may fail to catch duplicates if the function is a `functools.partial` @@ -576,9 +517,7 @@ def generate_schema( if metadata_schema: self._add_js_function(metadata_schema, metadata_js_function) - schema = _add_custom_serialization_from_json_encoders( - self._config_wrapper.json_encoders, obj, schema - ) + schema = _add_custom_serialization_from_json_encoders(self._config_wrapper.json_encoders, obj, schema) return schema @@ -602,32 +541,26 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema: config_wrapper = ConfigWrapper(cls.model_config, check=False) core_config = config_wrapper.core_config(cls) title = self._get_model_title_from_config(cls, config_wrapper) - metadata = build_metadata_dict( - js_functions=[partial(modify_model_json_schema, cls=cls, title=title)] - ) + metadata = build_metadata_dict(js_functions=[partial(modify_model_json_schema, cls=cls, title=title)]) model_validators = decorators.model_validators.values() extras_schema = None - if core_config.get("extra_fields_behavior") == "allow": + if core_config.get('extra_fields_behavior') == 'allow': assert cls.__mro__[0] is cls assert cls.__mro__[-1] is object for candidate_cls in cls.__mro__[:-1]: - extras_annotation = getattr( - candidate_cls, "__annotations__", {} - ).get("__pydantic_extra__", None) + extras_annotation = getattr(candidate_cls, '__annotations__', {}).get('__pydantic_extra__', None) if extras_annotation is not None: if isinstance(extras_annotation, str): extras_annotation = _typing_extra.eval_type_backport( - _typing_extra._make_forward_ref( - extras_annotation, is_argument=False, is_class=True - ), + _typing_extra._make_forward_ref(extras_annotation, is_argument=False, is_class=True), self._types_namespace, ) tp = get_origin(extras_annotation) if tp not in (Dict, dict): raise PydanticSchemaGenerationError( - "The type annotation for `__pydantic_extra__` must be `Dict[str, ...]`" + 'The type annotation for `__pydantic_extra__` must be `Dict[str, ...]`' ) extra_items_type = self._get_args_resolving_forward_refs( extras_annotation, @@ -637,79 +570,57 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema: extras_schema = self.generate_schema(extra_items_type) break - with self._config_wrapper_stack.push( - config_wrapper - ), self._types_namespace_stack.push(cls): + with self._config_wrapper_stack.push(config_wrapper), self._types_namespace_stack.push(cls): self = self._current_generate_schema if cls.__pydantic_root_model__: - root_field = self._common_field_schema( - "root", fields["root"], decorators - ) - inner_schema = root_field["schema"] - inner_schema = apply_model_validators( - inner_schema, model_validators, "inner" - ) + root_field = self._common_field_schema('root', fields['root'], decorators) + inner_schema = root_field['schema'] + inner_schema = apply_model_validators(inner_schema, model_validators, 'inner') model_schema = core_schema.model_schema( cls, inner_schema, - custom_init=getattr(cls, "__pydantic_custom_init__", None), + custom_init=getattr(cls, '__pydantic_custom_init__', None), root_model=True, - post_init=getattr(cls, "__pydantic_post_init__", None), + post_init=getattr(cls, '__pydantic_post_init__', None), config=core_config, ref=model_ref, metadata=metadata, ) else: - fields_schema: core_schema.CoreSchema = ( - core_schema.model_fields_schema( - { - k: self._generate_md_field_schema(k, v, decorators) - for k, v in fields.items() - }, - computed_fields=[ - self._computed_field_schema( - d, decorators.field_serializers - ) - for d in computed_fields.values() - ], - extras_schema=extras_schema, - model_name=cls.__name__, - ) - ) - inner_schema = apply_validators( - fields_schema, decorators.root_validators.values(), None - ) - new_inner_schema = define_expected_missing_refs( - inner_schema, recursively_defined_type_refs() + fields_schema: core_schema.CoreSchema = core_schema.model_fields_schema( + {k: self._generate_md_field_schema(k, v, decorators) for k, v in fields.items()}, + computed_fields=[ + self._computed_field_schema(d, decorators.field_serializers) + for d in computed_fields.values() + ], + extras_schema=extras_schema, + model_name=cls.__name__, ) + inner_schema = apply_validators(fields_schema, decorators.root_validators.values(), None) + new_inner_schema = define_expected_missing_refs(inner_schema, recursively_defined_type_refs()) if new_inner_schema is not None: inner_schema = new_inner_schema - inner_schema = apply_model_validators( - inner_schema, model_validators, "inner" - ) + inner_schema = apply_model_validators(inner_schema, model_validators, 'inner') model_schema = core_schema.model_schema( cls, inner_schema, - custom_init=getattr(cls, "__pydantic_custom_init__", None), + custom_init=getattr(cls, '__pydantic_custom_init__', None), root_model=False, - post_init=getattr(cls, "__pydantic_post_init__", None), + post_init=getattr(cls, '__pydantic_post_init__', None), config=core_config, ref=model_ref, metadata=metadata, ) - schema = self._apply_model_serializers( - model_schema, decorators.model_serializers.values() - ) - schema = apply_model_validators(schema, model_validators, "outer") + schema = self._apply_model_serializers(model_schema, decorators.model_serializers.values()) + schema = apply_model_validators(schema, model_validators, 'outer') self.defs.definitions[model_ref] = schema return core_schema.definition_reference_schema(model_ref) @staticmethod def _get_model_title_from_config( - model: type[BaseModel | StandardDataclass], - config_wrapper: ConfigWrapper | None = None, + model: type[BaseModel | StandardDataclass], config_wrapper: ConfigWrapper | None = None ) -> str | None: """Get the title of a model if `model_title_generator` or `title` are set in the config, else return None""" if config_wrapper is None: @@ -722,9 +633,7 @@ def _get_model_title_from_config( if model_title_generator: title = model_title_generator(model) if not isinstance(title, str): - raise TypeError( - f"model_title_generator {model_title_generator} must return str, not {title.__class__}" - ) + raise TypeError(f'model_title_generator {model_title_generator} must return str, not {title.__class__}') return title return None @@ -735,16 +644,14 @@ def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema: """ def get_ref(s: CoreSchema) -> str: - return s["ref"] # type: ignore + return s['ref'] # type: ignore - if schema["type"] == "definitions": - self.defs.definitions.update({get_ref(s): s for s in schema["definitions"]}) - schema = schema["schema"] + if schema['type'] == 'definitions': + self.defs.definitions.update({get_ref(s): s for s in schema['definitions']}) + schema = schema['schema'] return schema - def _generate_schema_from_property( - self, obj: Any, source: Any - ) -> core_schema.CoreSchema | None: + def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.CoreSchema | None: """Try to generate schema from either the `__get_pydantic_core_schema__` function or `__pydantic_core_schema__` property. @@ -758,24 +665,19 @@ def _generate_schema_from_property( if maybe_schema is not None: return maybe_schema if obj is source: - ref_mode = "unpack" + ref_mode = 'unpack' else: - ref_mode = "to-def" + ref_mode = 'to-def' schema: CoreSchema - if ( - get_schema := getattr(obj, "__get_pydantic_core_schema__", None) - ) is not None: + if (get_schema := getattr(obj, '__get_pydantic_core_schema__', None)) is not None: if len(inspect.signature(get_schema).parameters) == 1: # (source) -> CoreSchema schema = get_schema(source) else: schema = get_schema( - source, - CallbackGetCoreSchemaHandler( - self._generate_schema_inner, self, ref_mode=ref_mode - ), + source, CallbackGetCoreSchemaHandler(self._generate_schema_inner, self, ref_mode=ref_mode) ) # fmt: off elif ( @@ -785,17 +687,12 @@ def _generate_schema_from_property( ): schema = existing_schema # fmt: on - elif (validators := getattr(obj, "__get_validators__", None)) is not None: + elif (validators := getattr(obj, '__get_validators__', None)) is not None: warn( - "`__get_validators__` is deprecated and will be removed, use `__get_pydantic_core_schema__` instead.", + '`__get_validators__` is deprecated and will be removed, use `__get_pydantic_core_schema__` instead.', PydanticDeprecatedSince20, ) - schema = core_schema.chain_schema( - [ - core_schema.with_info_plain_validator_function(v) - for v in validators() - ] - ) + schema = core_schema.chain_schema([core_schema.with_info_plain_validator_function(v) for v in validators()]) else: # we have no existing schema information on the property, exit early so that we can go generate a schema return None @@ -803,11 +700,9 @@ def _generate_schema_from_property( schema = self._unpack_refs_defs(schema) if is_function_with_inner_schema(schema): - ref = schema["schema"].pop( - "ref", None - ) # pyright: ignore[reportCallIssue, reportArgumentType] + ref = schema['schema'].pop('ref', None) # pyright: ignore[reportCallIssue, reportArgumentType] if ref: - schema["ref"] = ref + schema['ref'] = ref else: ref = get_ref(schema) @@ -832,9 +727,7 @@ def _resolve_forward_ref(self, obj: Any) -> Any: # if obj is still a ForwardRef, it means we can't evaluate it, raise PydanticUndefinedAnnotation if isinstance(obj, ForwardRef): - raise PydanticUndefinedAnnotation( - obj.__forward_arg__, f"Unable to evaluate forward reference {obj}" - ) + raise PydanticUndefinedAnnotation(obj.__forward_arg__, f'Unable to evaluate forward reference {obj}') if self._typevars_map: obj = replace_types(obj, self._typevars_map) @@ -842,28 +735,17 @@ def _resolve_forward_ref(self, obj: Any) -> Any: return obj @overload - def _get_args_resolving_forward_refs( - self, obj: Any, required: Literal[True] - ) -> tuple[Any, ...]: ... + def _get_args_resolving_forward_refs(self, obj: Any, required: Literal[True]) -> tuple[Any, ...]: ... @overload def _get_args_resolving_forward_refs(self, obj: Any) -> tuple[Any, ...] | None: ... - def _get_args_resolving_forward_refs( - self, obj: Any, required: bool = False - ) -> tuple[Any, ...] | None: + def _get_args_resolving_forward_refs(self, obj: Any, required: bool = False) -> tuple[Any, ...] | None: args = get_args(obj) if args: - args = tuple( - [ - self._resolve_forward_ref(a) if isinstance(a, ForwardRef) else a - for a in args - ] - ) + args = tuple([self._resolve_forward_ref(a) if isinstance(a, ForwardRef) else a for a in args]) elif required: # pragma: no cover - raise TypeError( - f"Expected {obj} to have generic parameters but it had none" - ) + raise TypeError(f'Expected {obj} to have generic parameters but it had none') return args def _get_first_arg_or_any(self, obj: Any) -> Any: @@ -878,7 +760,7 @@ def _get_first_two_args_or_any(self, obj: Any) -> tuple[Any, Any]: return (Any, Any) if len(args) < 2: origin = get_origin(obj) - raise TypeError(f"Expected two type arguments for {origin}, got 1") + raise TypeError(f'Expected two type arguments for {origin}, got 1') return args[0], args[1] def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema: @@ -1027,12 +909,7 @@ def _match_generic_type(self, obj: Any, origin: Any) -> CoreSchema: # noqa: C90 return self._subclass_schema(obj) elif origin in {typing.Sequence, collections.abc.Sequence}: return self._sequence_schema(obj) - elif origin in { - typing.Iterable, - collections.abc.Iterable, - typing.Generator, - collections.abc.Generator, - }: + elif origin in {typing.Iterable, collections.abc.Iterable, typing.Generator, collections.abc.Generator}: return self._iterable_schema(obj) elif origin in (re.Pattern, typing.Pattern): return self._pattern_schema(obj) @@ -1052,12 +929,12 @@ def _generate_td_field_schema( """Prepare a TypedDictField to represent a model or typeddict field.""" common_field = self._common_field_schema(name, field_info, decorators) return core_schema.typed_dict_field( - common_field["schema"], + common_field['schema'], required=False if not field_info.is_required() else required, - serialization_exclude=common_field["serialization_exclude"], - validation_alias=common_field["validation_alias"], - serialization_alias=common_field["serialization_alias"], - metadata=common_field["metadata"], + serialization_exclude=common_field['serialization_exclude'], + validation_alias=common_field['validation_alias'], + serialization_alias=common_field['serialization_alias'], + metadata=common_field['metadata'], ) def _generate_md_field_schema( @@ -1069,12 +946,12 @@ def _generate_md_field_schema( """Prepare a ModelField to represent a model field.""" common_field = self._common_field_schema(name, field_info, decorators) return core_schema.model_field( - common_field["schema"], - serialization_exclude=common_field["serialization_exclude"], - validation_alias=common_field["validation_alias"], - serialization_alias=common_field["serialization_alias"], - frozen=common_field["frozen"], - metadata=common_field["metadata"], + common_field['schema'], + serialization_exclude=common_field['serialization_exclude'], + validation_alias=common_field['validation_alias'], + serialization_alias=common_field['serialization_alias'], + frozen=common_field['frozen'], + metadata=common_field['metadata'], ) def _generate_dc_field_schema( @@ -1087,22 +964,20 @@ def _generate_dc_field_schema( common_field = self._common_field_schema(name, field_info, decorators) return core_schema.dataclass_field( name, - common_field["schema"], + common_field['schema'], init=field_info.init, init_only=field_info.init_var or None, kw_only=None if field_info.kw_only else False, - serialization_exclude=common_field["serialization_exclude"], - validation_alias=common_field["validation_alias"], - serialization_alias=common_field["serialization_alias"], - frozen=common_field["frozen"], - metadata=common_field["metadata"], + serialization_exclude=common_field['serialization_exclude'], + validation_alias=common_field['validation_alias'], + serialization_alias=common_field['serialization_alias'], + frozen=common_field['frozen'], + metadata=common_field['metadata'], ) @staticmethod def _apply_alias_generator_to_field_info( - alias_generator: Callable[[str], str] | AliasGenerator, - field_info: FieldInfo, - field_name: str, + alias_generator: Callable[[str], str] | AliasGenerator, field_info: FieldInfo, field_name: str ) -> None: """Apply an alias_generator to aliases on a FieldInfo instance if appropriate. @@ -1124,15 +999,11 @@ def _apply_alias_generator_to_field_info( alias, validation_alias, serialization_alias = None, None, None if isinstance(alias_generator, AliasGenerator): - alias, validation_alias, serialization_alias = ( - alias_generator.generate_aliases(field_name) - ) + alias, validation_alias, serialization_alias = alias_generator.generate_aliases(field_name) elif isinstance(alias_generator, Callable): alias = alias_generator(field_name) if not isinstance(alias, str): - raise TypeError( - f"alias_generator {alias_generator} must return str, not {alias.__class__}" - ) + raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}') # if priority is not set, we set to 1 # which supports the case where the alias_generator from a child class is used @@ -1142,25 +1013,17 @@ def _apply_alias_generator_to_field_info( # if the priority is 1, then we set the aliases to the generated alias if field_info.alias_priority == 1: - field_info.serialization_alias = _get_first_non_null( - serialization_alias, alias - ) - field_info.validation_alias = _get_first_non_null( - validation_alias, alias - ) + field_info.serialization_alias = _get_first_non_null(serialization_alias, alias) + field_info.validation_alias = _get_first_non_null(validation_alias, alias) field_info.alias = alias # if any of the aliases are not set, then we set them to the corresponding generated alias if field_info.alias is None: field_info.alias = alias if field_info.serialization_alias is None: - field_info.serialization_alias = _get_first_non_null( - serialization_alias, alias - ) + field_info.serialization_alias = _get_first_non_null(serialization_alias, alias) if field_info.validation_alias is None: - field_info.validation_alias = _get_first_non_null( - validation_alias, alias - ) + field_info.validation_alias = _get_first_non_null(validation_alias, alias) @staticmethod def _apply_alias_generator_to_computed_field_info( @@ -1187,38 +1050,27 @@ def _apply_alias_generator_to_computed_field_info( alias, validation_alias, serialization_alias = None, None, None if isinstance(alias_generator, AliasGenerator): - alias, validation_alias, serialization_alias = ( - alias_generator.generate_aliases(computed_field_name) - ) + alias, validation_alias, serialization_alias = alias_generator.generate_aliases(computed_field_name) elif isinstance(alias_generator, Callable): alias = alias_generator(computed_field_name) if not isinstance(alias, str): - raise TypeError( - f"alias_generator {alias_generator} must return str, not {alias.__class__}" - ) + raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}') # if priority is not set, we set to 1 # which supports the case where the alias_generator from a child class is used # to generate an alias for a field in a parent class - if ( - computed_field_info.alias_priority is None - or computed_field_info.alias_priority <= 1 - ): + if computed_field_info.alias_priority is None or computed_field_info.alias_priority <= 1: computed_field_info.alias_priority = 1 # if the priority is 1, then we set the aliases to the generated alias # note that we use the serialization_alias with priority over alias, as computed_field # aliases are used for serialization only (not validation) if computed_field_info.alias_priority == 1: - computed_field_info.alias = _get_first_non_null( - serialization_alias, alias - ) + computed_field_info.alias = _get_first_non_null(serialization_alias, alias) @staticmethod def _apply_field_title_generator_to_field_info( - config_wrapper: ConfigWrapper, - field_info: FieldInfo | ComputedFieldInfo, - field_name: str, + config_wrapper: ConfigWrapper, field_info: FieldInfo | ComputedFieldInfo, field_name: str ) -> None: """Apply a field_title_generator on a FieldInfo or ComputedFieldInfo instance if appropriate Args: @@ -1226,9 +1078,7 @@ def _apply_field_title_generator_to_field_info( field_info: The FieldInfo or ComputedField instance to which the title_generator is (maybe) applied. field_name: The name of the field from which to generate the title. """ - field_title_generator = ( - field_info.field_title_generator or config_wrapper.field_title_generator - ) + field_title_generator = field_info.field_title_generator or config_wrapper.field_title_generator if field_title_generator is None: return @@ -1236,9 +1086,7 @@ def _apply_field_title_generator_to_field_info( if field_info.title is None: title = field_title_generator(field_name, field_info) # type: ignore if not isinstance(title, str): - raise TypeError( - f"field_title_generator {field_title_generator} must return str, not {title.__class__}" - ) + raise TypeError(f'field_title_generator {field_title_generator} must return str, not {title.__class__}') field_info.title = title @@ -1247,24 +1095,17 @@ def _common_field_schema( # C901 ) -> _CommonField: # Update FieldInfo annotation if appropriate: from .. import AliasChoices, AliasPath + from ..fields import FieldInfo if has_instance_in_type(field_info.annotation, (ForwardRef, str)): - from ..fields import FieldInfo - types_namespace = self._types_namespace if self._typevars_map: types_namespace = (types_namespace or {}).copy() # Ensure that typevars get mapped to their concrete types: - types_namespace.update( - {k.__name__: v for k, v in self._typevars_map.items()} - ) + types_namespace.update({k.__name__: v for k, v in self._typevars_map.items()}) - evaluated = _typing_extra.eval_type_lenient( - field_info.annotation, types_namespace - ) - if evaluated is not field_info.annotation and not has_instance_in_type( - evaluated, PydanticRecursiveRef - ): + evaluated = _typing_extra.eval_type_lenient(field_info.annotation, types_namespace) + if evaluated is not field_info.annotation and not has_instance_in_type(evaluated, PydanticRecursiveRef): new_field_info = FieldInfo.from_annotation(evaluated) field_info.annotation = new_field_info.annotation @@ -1274,10 +1115,7 @@ def _common_field_schema( # C901 # default value), and that should take the highest priority. So don't overwrite existing attributes. # We skip over "attributes" that are present in the metadata_lookup dict because these won't # actually end up as attributes of the `FieldInfo` instance. - if ( - k not in field_info._attributes_set - and k not in field_info.metadata_lookup - ): + if k not in field_info._attributes_set and k not in field_info.metadata_lookup: setattr(field_info, k, v) # Finally, ensure the field info also reflects all the `_attributes_set` that are actually metadata. @@ -1286,16 +1124,12 @@ def _common_field_schema( # C901 source_type, annotations = field_info.annotation, field_info.metadata def set_discriminator(schema: CoreSchema) -> CoreSchema: - schema = self._apply_discriminator_to_union( - schema, field_info.discriminator - ) + schema = self._apply_discriminator_to_union(schema, field_info.discriminator) return schema with self.field_name_stack.push(name): if field_info.discriminator is not None: - schema = self._apply_annotations( - source_type, annotations, transform_inner_schema=set_discriminator - ) + schema = self._apply_annotations(source_type, annotations, transform_inner_schema=set_discriminator) else: schema = self._apply_annotations( source_type, @@ -1306,30 +1140,16 @@ def set_discriminator(schema: CoreSchema) -> CoreSchema: # push down any `each_item=True` validators # note that this won't work for any Annotated types that get wrapped by a function validator # but that's okay because that didn't exist in V1 - this_field_validators = filter_field_decorator_info_by_field( - decorators.validators.values(), name - ) + this_field_validators = filter_field_decorator_info_by_field(decorators.validators.values(), name) if _validators_require_validate_default(this_field_validators): field_info.validate_default = True - each_item_validators = [ - v for v in this_field_validators if v.info.each_item is True - ] - this_field_validators = [ - v for v in this_field_validators if v not in each_item_validators - ] + each_item_validators = [v for v in this_field_validators if v.info.each_item is True] + this_field_validators = [v for v in this_field_validators if v not in each_item_validators] schema = apply_each_item_validators(schema, each_item_validators, name) + schema = apply_validators(schema, filter_field_decorator_info_by_field(this_field_validators, name), name) schema = apply_validators( - schema, - filter_field_decorator_info_by_field(this_field_validators, name), - name, - ) - schema = apply_validators( - schema, - filter_field_decorator_info_by_field( - decorators.field_validators.values(), name - ), - name, + schema, filter_field_decorator_info_by_field(decorators.field_validators.values(), name), name ) # the default validator needs to go outside of any other validators @@ -1339,33 +1159,22 @@ def set_discriminator(schema: CoreSchema) -> CoreSchema: schema = wrap_default(field_info, schema) schema = self._apply_field_serializers( - schema, - filter_field_decorator_info_by_field( - decorators.field_serializers.values(), name - ), - ) - self._apply_field_title_generator_to_field_info( - self._config_wrapper, field_info, name + schema, filter_field_decorator_info_by_field(decorators.field_serializers.values(), name) ) + self._apply_field_title_generator_to_field_info(self._config_wrapper, field_info, name) json_schema_updates = { - "title": field_info.title, - "description": field_info.description, - "deprecated": bool(field_info.deprecated) - or field_info.deprecated == "" - or None, - "examples": to_jsonable_python(field_info.examples), - } - json_schema_updates = { - k: v for k, v in json_schema_updates.items() if v is not None + 'title': field_info.title, + 'description': field_info.description, + 'deprecated': bool(field_info.deprecated) or field_info.deprecated == '' or None, + 'examples': to_jsonable_python(field_info.examples), } + json_schema_updates = {k: v for k, v in json_schema_updates.items() if v is not None} json_schema_extra = field_info.json_schema_extra metadata = build_metadata_dict( - js_annotation_functions=[ - get_json_schema_update_func(json_schema_updates, json_schema_extra) - ] + js_annotation_functions=[get_json_schema_update_func(json_schema_updates, json_schema_extra)] ) alias_generator = self._config_wrapper.alias_generator @@ -1402,7 +1211,7 @@ def _union_schema(self, union_type: Any) -> core_schema.CoreSchema: else: choices_with_tags: list[CoreSchema | tuple[CoreSchema, str]] = [] for choice in choices: - tag = choice.get("metadata", {}).get(_core_utils.TAGGED_UNION_TAG_KEY) + tag = choice.get('metadata', {}).get(_core_utils.TAGGED_UNION_TAG_KEY) if tag is not None: choices_with_tags.append((choice, tag)) else: @@ -1427,13 +1236,11 @@ def _type_alias_type_schema( typevars_map = get_standard_typevars_map(obj) with self._types_namespace_stack.push(origin): - annotation = _typing_extra.eval_type_lenient( - annotation, self._types_namespace - ) + annotation = _typing_extra.eval_type_lenient(annotation, self._types_namespace) annotation = replace_types(annotation, typevars_map) schema = self.generate_schema(annotation) - assert schema["type"] != "definitions" - schema["ref"] = ref # type: ignore + assert schema['type'] != 'definitions' + schema['ref'] = ref # type: ignore self.defs.definitions[ref] = schema return core_schema.definition_reference_schema(ref) @@ -1443,18 +1250,14 @@ def _literal_schema(self, literal_type: Any) -> CoreSchema: assert expected, f'literal "expected" cannot be empty, obj={literal_type}' schema = core_schema.literal_schema(expected) - if self._config_wrapper.use_enum_values and any( - isinstance(v, Enum) for v in expected - ): + if self._config_wrapper.use_enum_values and any(isinstance(v, Enum) for v in expected): schema = core_schema.no_info_after_validator_function( lambda v: v.value if isinstance(v, Enum) else v, schema ) return schema - def _typed_dict_schema( - self, typed_dict_cls: Any, origin: Any - ) -> core_schema.CoreSchema: + def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.CoreSchema: """Generate schema for a TypedDict. It is not possible to track required/optional keys in TypedDict without __required_keys__ @@ -1472,9 +1275,7 @@ def _typed_dict_schema( """ from ..fields import FieldInfo - with self.model_type_stack.push(typed_dict_cls), self.defs.get_schema_or_ref( - typed_dict_cls - ) as ( + with self.model_type_stack.push(typed_dict_cls), self.defs.get_schema_or_ref(typed_dict_cls) as ( typed_dict_ref, maybe_schema, ): @@ -1485,22 +1286,18 @@ def _typed_dict_schema( if origin is not None: typed_dict_cls = origin - if not _SUPPORTS_TYPEDDICT and type(typed_dict_cls).__module__ == "typing": + if not _SUPPORTS_TYPEDDICT and type(typed_dict_cls).__module__ == 'typing': raise PydanticUserError( - "Please use `typing_extensions.TypedDict` instead of `typing.TypedDict` on Python < 3.12.", - code="typed-dict-version", + 'Please use `typing_extensions.TypedDict` instead of `typing.TypedDict` on Python < 3.12.', + code='typed-dict-version', ) try: - config: ConfigDict | None = get_attribute_from_bases( - typed_dict_cls, "__pydantic_config__" - ) + config: ConfigDict | None = get_attribute_from_bases(typed_dict_cls, '__pydantic_config__') except AttributeError: config = None - with self._config_wrapper_stack.push( - config - ), self._types_namespace_stack.push(typed_dict_cls): + with self._config_wrapper_stack.push(config), self._types_namespace_stack.push(typed_dict_cls): core_config = self._config_wrapper.core_config(typed_dict_cls) self = self._current_generate_schema @@ -1512,9 +1309,7 @@ def _typed_dict_schema( decorators = DecoratorInfos.build(typed_dict_cls) if self._config_wrapper.use_attribute_docstrings: - field_docstrings = extract_docstrings_from_cls( - typed_dict_cls, use_inspect=True - ) + field_docstrings = extract_docstrings_from_cls(typed_dict_cls, use_inspect=True) else: field_docstrings = None @@ -1544,22 +1339,14 @@ def _typed_dict_schema( and field_name in field_docstrings ): field_info.description = field_docstrings[field_name] - self._apply_field_title_generator_to_field_info( - self._config_wrapper, field_info, field_name - ) + self._apply_field_title_generator_to_field_info(self._config_wrapper, field_info, field_name) fields[field_name] = self._generate_td_field_schema( field_name, field_info, decorators, required=required ) - title = self._get_model_title_from_config( - typed_dict_cls, ConfigWrapper(config) - ) + title = self._get_model_title_from_config(typed_dict_cls, ConfigWrapper(config)) metadata = build_metadata_dict( - js_functions=[ - partial( - modify_model_json_schema, cls=typed_dict_cls, title=title - ) - ], + js_functions=[partial(modify_model_json_schema, cls=typed_dict_cls, title=title)], typed_dict_cls=typed_dict_cls, ) td_schema = core_schema.typed_dict_schema( @@ -1573,22 +1360,14 @@ def _typed_dict_schema( config=core_config, ) - schema = self._apply_model_serializers( - td_schema, decorators.model_serializers.values() - ) - schema = apply_model_validators( - schema, decorators.model_validators.values(), "all" - ) + schema = self._apply_model_serializers(td_schema, decorators.model_serializers.values()) + schema = apply_model_validators(schema, decorators.model_validators.values(), 'all') self.defs.definitions[typed_dict_ref] = schema return core_schema.definition_reference_schema(typed_dict_ref) - def _namedtuple_schema( - self, namedtuple_cls: Any, origin: Any - ) -> core_schema.CoreSchema: + def _namedtuple_schema(self, namedtuple_cls: Any, origin: Any) -> core_schema.CoreSchema: """Generate schema for a NamedTuple.""" - with self.model_type_stack.push(namedtuple_cls), self.defs.get_schema_or_ref( - namedtuple_cls - ) as ( + with self.model_type_stack.push(namedtuple_cls), self.defs.get_schema_or_ref(namedtuple_cls) as ( namedtuple_ref, maybe_schema, ): @@ -1614,28 +1393,20 @@ def _namedtuple_schema( arguments_schema = core_schema.arguments_schema( [ self._generate_parameter_schema( - field_name, - annotation, - default=namedtuple_cls._field_defaults.get( - field_name, Parameter.empty - ), + field_name, annotation, default=namedtuple_cls._field_defaults.get(field_name, Parameter.empty) ) for field_name, annotation in annotations.items() ], metadata=build_metadata_dict(js_prefer_positional_arguments=True), ) - return core_schema.call_schema( - arguments_schema, namedtuple_cls, ref=namedtuple_ref - ) + return core_schema.call_schema(arguments_schema, namedtuple_cls, ref=namedtuple_ref) def _generate_parameter_schema( self, name: str, annotation: type[Any], default: Any = Parameter.empty, - mode: ( - Literal["positional_only", "positional_or_keyword", "keyword_only"] | None - ) = None, + mode: Literal['positional_only', 'positional_or_keyword', 'keyword_only'] | None = None, ) -> core_schema.ArgumentsParameter: """Prepare a ArgumentsParameter to represent a field in a namedtuple or function signature.""" from ..fields import FieldInfo @@ -1644,9 +1415,7 @@ def _generate_parameter_schema( field = FieldInfo.from_annotation(annotation) else: field = FieldInfo.from_annotated_attribute(annotation, default) - assert ( - field.annotation is not None - ), "field.annotation should not be None when generating a schema" + assert field.annotation is not None, 'field.annotation should not be None when generating a schema' source_type, annotations = field.annotation, field.metadata with self.field_name_stack.push(name): schema = self._apply_annotations(source_type, annotations) @@ -1656,18 +1425,15 @@ def _generate_parameter_schema( parameter_schema = core_schema.arguments_parameter(name, schema) if mode is not None: - parameter_schema["mode"] = mode + parameter_schema['mode'] = mode if field.alias is not None: - parameter_schema["alias"] = field.alias + parameter_schema['alias'] = field.alias else: alias_generator = self._config_wrapper.alias_generator - if ( - isinstance(alias_generator, AliasGenerator) - and alias_generator.alias is not None - ): - parameter_schema["alias"] = alias_generator.alias(name) + if isinstance(alias_generator, AliasGenerator) and alias_generator.alias is not None: + parameter_schema['alias'] = alias_generator.alias(name) elif isinstance(alias_generator, Callable): - parameter_schema["alias"] = alias_generator(name) + parameter_schema['alias'] = alias_generator(name) return parameter_schema def _tuple_schema(self, tuple_type: Any) -> core_schema.CoreSchema: @@ -1683,42 +1449,34 @@ def _tuple_schema(self, tuple_type: Any) -> core_schema.CoreSchema: # This is only true for <3.11, on Python 3.11+ `typing.Tuple[()]` gives `params=()` if not params: if tuple_type in TUPLE_TYPES: - return core_schema.tuple_schema( - [core_schema.any_schema()], variadic_item_index=0 - ) + return core_schema.tuple_schema([core_schema.any_schema()], variadic_item_index=0) else: # special case for `tuple[()]` which means `tuple[]` - an empty tuple return core_schema.tuple_schema([]) elif params[-1] is Ellipsis: if len(params) == 2: - return core_schema.tuple_schema( - [self.generate_schema(params[0])], variadic_item_index=0 - ) + return core_schema.tuple_schema([self.generate_schema(params[0])], variadic_item_index=0) else: # TODO: something like https://github.com/pydantic/pydantic/issues/5952 - raise ValueError("Variable tuples can only have one type") + raise ValueError('Variable tuples can only have one type') elif len(params) == 1 and params[0] == (): # special case for `Tuple[()]` which means `Tuple[]` - an empty tuple # NOTE: This conditional can be removed when we drop support for Python 3.10. return core_schema.tuple_schema([]) else: - return core_schema.tuple_schema( - [self.generate_schema(param) for param in params] - ) + return core_schema.tuple_schema([self.generate_schema(param) for param in params]) def _type_schema(self) -> core_schema.CoreSchema: return core_schema.custom_error_schema( core_schema.is_instance_schema(type), - custom_error_type="is_type", - custom_error_message="Input should be a type", + custom_error_type='is_type', + custom_error_message='Input should be a type', ) def _union_is_subclass_schema(self, union_type: Any) -> core_schema.CoreSchema: """Generate schema for `Type[Union[X, ...]]`.""" args = self._get_args_resolving_forward_refs(union_type, required=True) - return core_schema.union_schema( - [self.generate_schema(typing.Type[args]) for args in args] - ) + return core_schema.union_schema([self.generate_schema(typing.Type[args]) for args in args]) def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema: """Generate schema for a Type, e.g. `Type[int]`.""" @@ -1732,10 +1490,7 @@ def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema: return core_schema.is_subclass_schema(type_param.__bound__) elif type_param.__constraints__: return core_schema.union_schema( - [ - self.generate_schema(typing.Type[c]) - for c in type_param.__constraints__ - ] + [self.generate_schema(typing.Type[c]) for c in type_param.__constraints__] ) else: return self._type_schema() @@ -1752,28 +1507,19 @@ def _sequence_schema(self, sequence_type: Any) -> core_schema.CoreSchema: item_type_schema = self.generate_schema(item_type) list_schema = core_schema.list_schema(item_type_schema) - python_schema = core_schema.is_instance_schema( - typing.Sequence, cls_repr="Sequence" - ) + python_schema = core_schema.is_instance_schema(typing.Sequence, cls_repr='Sequence') if item_type != Any: from ._validators import sequence_validator python_schema = core_schema.chain_schema( - [ - python_schema, - core_schema.no_info_wrap_validator_function( - sequence_validator, list_schema - ), - ], + [python_schema, core_schema.no_info_wrap_validator_function(sequence_validator, list_schema)], ) serialization = core_schema.wrap_serializer_function_ser_schema( serialize_sequence_via_list, schema=item_type_schema, info_arg=True ) return core_schema.json_or_python_schema( - json_schema=list_schema, - python_schema=python_schema, - serialization=serialization, + json_schema=list_schema, python_schema=python_schema, serialization=serialization ) def _iterable_schema(self, type_: Any) -> core_schema.GeneratorSchema: @@ -1785,20 +1531,14 @@ def _iterable_schema(self, type_: Any) -> core_schema.GeneratorSchema: def _pattern_schema(self, pattern_type: Any) -> core_schema.CoreSchema: from . import _validators - metadata = build_metadata_dict( - js_functions=[lambda _1, _2: {"type": "string", "format": "regex"}] - ) + metadata = build_metadata_dict(js_functions=[lambda _1, _2: {'type': 'string', 'format': 'regex'}]) ser = core_schema.plain_serializer_function_ser_schema( - attrgetter("pattern"), - when_used="json", - return_schema=core_schema.str_schema(), + attrgetter('pattern'), when_used='json', return_schema=core_schema.str_schema() ) if pattern_type == typing.Pattern or pattern_type == re.Pattern: # bare type return core_schema.no_info_plain_validator_function( - _validators.pattern_either_validator, - serialization=ser, - metadata=metadata, + _validators.pattern_either_validator, serialization=ser, metadata=metadata ) param = self._get_args_resolving_forward_refs( @@ -1811,29 +1551,23 @@ def _pattern_schema(self, pattern_type: Any) -> core_schema.CoreSchema: ) elif param is bytes: return core_schema.no_info_plain_validator_function( - _validators.pattern_bytes_validator, - serialization=ser, - metadata=metadata, + _validators.pattern_bytes_validator, serialization=ser, metadata=metadata ) else: - raise PydanticSchemaGenerationError( - f"Unable to generate pydantic-core schema for {pattern_type!r}." - ) + raise PydanticSchemaGenerationError(f'Unable to generate pydantic-core schema for {pattern_type!r}.') def _hashable_schema(self) -> core_schema.CoreSchema: return core_schema.custom_error_schema( core_schema.is_instance_schema(collections.abc.Hashable), - custom_error_type="is_hashable", - custom_error_message="Input should be hashable", + custom_error_type='is_hashable', + custom_error_message='Input should be hashable', ) def _dataclass_schema( self, dataclass: type[StandardDataclass], origin: type[StandardDataclass] | None ) -> core_schema.CoreSchema: """Generate schema for a dataclass.""" - with self.model_type_stack.push(dataclass), self.defs.get_schema_or_ref( - dataclass - ) as ( + with self.model_type_stack.push(dataclass), self.defs.get_schema_or_ref(dataclass) as ( dataclass_ref, maybe_schema, ): @@ -1848,18 +1582,14 @@ def _dataclass_schema( # Pushing a namespace prioritises items already in the stack, so iterate though the MRO forwards for dataclass_base in dataclass.__mro__: if dataclasses.is_dataclass(dataclass_base): - dataclass_bases_stack.enter_context( - self._types_namespace_stack.push(dataclass_base) - ) + dataclass_bases_stack.enter_context(self._types_namespace_stack.push(dataclass_base)) # Pushing a config overwrites the previous config, so iterate though the MRO backwards config = None for dataclass_base in reversed(dataclass.__mro__): if dataclasses.is_dataclass(dataclass_base): - config = getattr(dataclass_base, "__pydantic_config__", None) - dataclass_bases_stack.enter_context( - self._config_wrapper_stack.push(config) - ) + config = getattr(dataclass_base, '__pydantic_config__', None) + dataclass_bases_stack.enter_context(self._config_wrapper_stack.push(config)) core_config = self._config_wrapper.core_config(dataclass) @@ -1871,9 +1601,7 @@ def _dataclass_schema( fields = deepcopy(dataclass.__pydantic_fields__) if typevars_map: for field in fields.values(): - field.apply_typevars_map( - typevars_map, self._types_namespace - ) + field.apply_typevars_map(typevars_map, self._types_namespace) else: fields = collect_dataclass_fields( dataclass, @@ -1882,30 +1610,25 @@ def _dataclass_schema( ) # disallow combination of init=False on a dataclass field and extra='allow' on a dataclass - if self._config_wrapper_stack.tail.extra == "allow": + if self._config_wrapper_stack.tail.extra == 'allow': # disallow combination of init=False on a dataclass field and extra='allow' on a dataclass for field_name, field in fields.items(): if field.init is False: raise PydanticUserError( f'Field {field_name} has `init=False` and dataclass has config setting `extra="allow"`. ' - f"This combination is not allowed.", - code="dataclass-init-false-extra-allow", + f'This combination is not allowed.', + code='dataclass-init-false-extra-allow', ) - decorators = dataclass.__dict__.get( - "__pydantic_decorators__" - ) or DecoratorInfos.build(dataclass) + decorators = dataclass.__dict__.get('__pydantic_decorators__') or DecoratorInfos.build(dataclass) # Move kw_only=False args to the start of the list, as this is how vanilla dataclasses work. # Note that when kw_only is missing or None, it is treated as equivalent to kw_only=True args = sorted( - ( - self._generate_dc_field_schema(k, v, decorators) - for k, v in fields.items() - ), - key=lambda a: a.get("kw_only") is not False, + (self._generate_dc_field_schema(k, v, decorators) for k, v in fields.items()), + key=lambda a: a.get('kw_only') is not False, ) - has_post_init = hasattr(dataclass, "__post_init__") - has_slots = hasattr(dataclass, "__slots__") + has_post_init = hasattr(dataclass, '__post_init__') + has_slots = hasattr(dataclass, '__slots__') args_schema = core_schema.dataclass_args_schema( dataclass.__name__, @@ -1917,22 +1640,14 @@ def _dataclass_schema( collect_init_only=has_post_init, ) - inner_schema = apply_validators( - args_schema, decorators.root_validators.values(), None - ) + inner_schema = apply_validators(args_schema, decorators.root_validators.values(), None) model_validators = decorators.model_validators.values() - inner_schema = apply_model_validators( - inner_schema, model_validators, "inner" - ) + inner_schema = apply_model_validators(inner_schema, model_validators, 'inner') - title = self._get_model_title_from_config( - dataclass, ConfigWrapper(config) - ) + title = self._get_model_title_from_config(dataclass, ConfigWrapper(config)) metadata = build_metadata_dict( - js_functions=[ - partial(modify_model_json_schema, cls=dataclass, title=title) - ] + js_functions=[partial(modify_model_json_schema, cls=dataclass, title=title)] ) dc_schema = core_schema.dataclass_schema( @@ -1945,16 +1660,14 @@ def _dataclass_schema( config=core_config, metadata=metadata, ) - schema = self._apply_model_serializers( - dc_schema, decorators.model_serializers.values() - ) - schema = apply_model_validators(schema, model_validators, "outer") + schema = self._apply_model_serializers(dc_schema, decorators.model_serializers.values()) + schema = apply_model_validators(schema, model_validators, 'outer') self.defs.definitions[dataclass_ref] = schema return core_schema.definition_reference_schema(dataclass_ref) # Type checkers seem to assume ExitStack may suppress exceptions and therefore # control flow can exit the `with` block without returning. - assert False, "Unreachable" + assert False, 'Unreachable' def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSchema: """Generate schema for a Callable. @@ -1963,17 +1676,12 @@ def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSche """ sig = signature(function) - type_hints = _typing_extra.get_function_type_hints( - function, types_namespace=self._types_namespace - ) + type_hints = _typing_extra.get_function_type_hints(function, types_namespace=self._types_namespace) - mode_lookup: dict[ - _ParameterKind, - Literal["positional_only", "positional_or_keyword", "keyword_only"], - ] = { - Parameter.POSITIONAL_ONLY: "positional_only", - Parameter.POSITIONAL_OR_KEYWORD: "positional_or_keyword", - Parameter.KEYWORD_ONLY: "keyword_only", + mode_lookup: dict[_ParameterKind, Literal['positional_only', 'positional_or_keyword', 'keyword_only']] = { + Parameter.POSITIONAL_ONLY: 'positional_only', + Parameter.POSITIONAL_OR_KEYWORD: 'positional_or_keyword', + Parameter.KEYWORD_ONLY: 'keyword_only', } arguments_list: list[core_schema.ArgumentsParameter] = [] @@ -1988,9 +1696,7 @@ def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSche parameter_mode = mode_lookup.get(p.kind) if parameter_mode is not None: - arg_schema = self._generate_parameter_schema( - name, annotation, p.default, parameter_mode - ) + arg_schema = self._generate_parameter_schema(name, annotation, p.default, parameter_mode) arguments_list.append(arg_schema) elif p.kind == Parameter.VAR_POSITIONAL: var_args_schema = self.generate_schema(annotation) @@ -2001,7 +1707,7 @@ def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSche return_schema: core_schema.CoreSchema | None = None config_wrapper = self._config_wrapper if config_wrapper.validate_return: - return_hint = type_hints.get("return") + return_hint = type_hints.get('return') if return_hint is not None: return_schema = self.generate_schema(return_hint) @@ -2016,9 +1722,7 @@ def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSche return_schema=return_schema, ) - def _unsubstituted_typevar_schema( - self, typevar: typing.TypeVar - ) -> core_schema.CoreSchema: + def _unsubstituted_typevar_schema(self, typevar: typing.TypeVar) -> core_schema.CoreSchema: assert isinstance(typevar, typing.TypeVar) bound = typevar.__bound__ @@ -2028,11 +1732,11 @@ def _unsubstituted_typevar_schema( typevar_has_default = typevar.has_default() # type: ignore except AttributeError: # could still have a default if it's an old version of typing_extensions.TypeVar - typevar_has_default = getattr(typevar, "__default__", None) is not None + typevar_has_default = getattr(typevar, '__default__', None) is not None if (bound is not None) + (len(constraints) != 0) + typevar_has_default > 1: raise NotImplementedError( - "Pydantic does not support mixing more than one of TypeVar bounds, constraints and defaults" + 'Pydantic does not support mixing more than one of TypeVar bounds, constraints and defaults' ) if typevar_has_default: @@ -2041,7 +1745,7 @@ def _unsubstituted_typevar_schema( return self._union_schema(typing.Union[constraints]) # type: ignore elif bound: schema = self.generate_schema(bound) - schema["serialization"] = core_schema.wrap_serializer_function_ser_schema( + schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( lambda x, h: h(x), schema=core_schema.any_schema() ) return schema @@ -2054,16 +1758,14 @@ def _computed_field_schema( field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]], ) -> core_schema.ComputedField: try: - return_type = _decorators.get_function_return_type( - d.func, d.info.return_type, self._types_namespace - ) + return_type = _decorators.get_function_return_type(d.func, d.info.return_type, self._types_namespace) except NameError as e: raise PydanticUndefinedAnnotation.from_name_error(e) from e if return_type is PydanticUndefined: raise PydanticUserError( - "Computed field is missing return type annotation or specifying `return_type`" - " to the `@computed_field` decorator (e.g. `@computed_field(return_type=int|str)`)", - code="model-field-missing-annotation", + 'Computed field is missing return type annotation or specifying `return_type`' + ' to the `@computed_field` decorator (e.g. `@computed_field(return_type=int|str)`)', + code='model-field-missing-annotation', ) return_type = replace_types(return_type, self._typevars_map) @@ -2074,44 +1776,36 @@ def _computed_field_schema( # Apply serializers to computed field if there exist return_type_schema = self._apply_field_serializers( return_type_schema, - filter_field_decorator_info_by_field( - field_serializers.values(), d.cls_var_name - ), + filter_field_decorator_info_by_field(field_serializers.values(), d.cls_var_name), computed_field=True, ) alias_generator = self._config_wrapper.alias_generator if alias_generator is not None: self._apply_alias_generator_to_computed_field_info( - alias_generator=alias_generator, - computed_field_info=d.info, - computed_field_name=d.cls_var_name, + alias_generator=alias_generator, computed_field_info=d.info, computed_field_name=d.cls_var_name ) - self._apply_field_title_generator_to_field_info( - self._config_wrapper, d.info, d.cls_var_name - ) + self._apply_field_title_generator_to_field_info(self._config_wrapper, d.info, d.cls_var_name) - def set_computed_field_metadata( - schema: CoreSchemaOrField, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: + def set_computed_field_metadata(schema: CoreSchemaOrField, handler: GetJsonSchemaHandler) -> JsonSchemaValue: json_schema = handler(schema) - json_schema["readOnly"] = True + json_schema['readOnly'] = True title = d.info.title if title is not None: - json_schema["title"] = title + json_schema['title'] = title description = d.info.description if description is not None: - json_schema["description"] = description + json_schema['description'] = description - if d.info.deprecated or d.info.deprecated == "": - json_schema["deprecated"] = True + if d.info.deprecated or d.info.deprecated == '': + json_schema['deprecated'] = True examples = d.info.examples if examples is not None: - json_schema["examples"] = to_jsonable_python(examples) + json_schema['examples'] = to_jsonable_python(examples) json_schema_extra = d.info.json_schema_extra if json_schema_extra is not None: @@ -2119,14 +1813,9 @@ def set_computed_field_metadata( return json_schema - metadata = build_metadata_dict( - js_annotation_functions=[set_computed_field_metadata] - ) + metadata = build_metadata_dict(js_annotation_functions=[set_computed_field_metadata]) return core_schema.computed_field( - d.cls_var_name, - return_schema=return_type_schema, - alias=d.info.alias, - metadata=metadata, + d.cls_var_name, return_schema=return_type_schema, alias=d.info.alias, metadata=metadata ) def _annotated_schema(self, annotated_type: Any) -> core_schema.CoreSchema: @@ -2176,12 +1865,8 @@ def _apply_annotations( not expect `source_type` to be an `Annotated` object, it expects it to be the first argument of that (in other words, `GenerateSchema._annotated_schema` just unpacks `Annotated`, this process it). """ - annotations = list( - _known_annotated_metadata.expand_grouped_metadata(annotations) - ) - res = self._get_prepare_pydantic_annotations_for_known_type( - source_type, tuple(annotations) - ) + annotations = list(_known_annotated_metadata.expand_grouped_metadata(annotations)) + res = self._get_prepare_pydantic_annotations_for_known_type(source_type, tuple(annotations)) if res is not None: source_type, annotations = res @@ -2212,16 +1897,10 @@ def inner_handler(obj: Any) -> CoreSchema: schema = get_inner_schema(source_type) if pydantic_js_annotation_functions: metadata = CoreMetadataHandler(schema).metadata - metadata.setdefault("pydantic_js_annotation_functions", []).extend( - pydantic_js_annotation_functions - ) - return _add_custom_serialization_from_json_encoders( - self._config_wrapper.json_encoders, source_type, schema - ) + metadata.setdefault('pydantic_js_annotation_functions', []).extend(pydantic_js_annotation_functions) + return _add_custom_serialization_from_json_encoders(self._config_wrapper.json_encoders, source_type, schema) - def _apply_single_annotation( - self, schema: core_schema.CoreSchema, metadata: Any - ) -> core_schema.CoreSchema: + def _apply_single_annotation(self, schema: core_schema.CoreSchema, metadata: Any) -> core_schema.CoreSchema: from ..fields import FieldInfo if isinstance(metadata, FieldInfo): @@ -2229,39 +1908,35 @@ def _apply_single_annotation( schema = self._apply_single_annotation(schema, field_metadata) if metadata.discriminator is not None: - schema = self._apply_discriminator_to_union( - schema, metadata.discriminator - ) + schema = self._apply_discriminator_to_union(schema, metadata.discriminator) return schema - if schema["type"] == "nullable": + if schema['type'] == 'nullable': # for nullable schemas, metadata is automatically applied to the inner schema - inner = schema.get("schema", core_schema.any_schema()) + inner = schema.get('schema', core_schema.any_schema()) inner = self._apply_single_annotation(inner, metadata) if inner: - schema["schema"] = inner + schema['schema'] = inner return schema original_schema = schema - ref = schema.get("ref", None) + ref = schema.get('ref', None) if ref is not None: schema = schema.copy() - new_ref = ref + f"_{repr(metadata)}" + new_ref = ref + f'_{repr(metadata)}' if new_ref in self.defs.definitions: return self.defs.definitions[new_ref] - schema["ref"] = new_ref # type: ignore - elif schema["type"] == "definition-ref": - ref = schema["schema_ref"] + schema['ref'] = new_ref # type: ignore + elif schema['type'] == 'definition-ref': + ref = schema['schema_ref'] if ref in self.defs.definitions: schema = self.defs.definitions[ref].copy() - new_ref = ref + f"_{repr(metadata)}" + new_ref = ref + f'_{repr(metadata)}' if new_ref in self.defs.definitions: return self.defs.definitions[new_ref] - schema["ref"] = new_ref # type: ignore + schema['ref'] = new_ref # type: ignore - maybe_updated_schema = _known_annotated_metadata.apply_known_metadata( - metadata, schema.copy() - ) + maybe_updated_schema = _known_annotated_metadata.apply_known_metadata(metadata, schema.copy()) if maybe_updated_schema is not None: return maybe_updated_schema @@ -2274,22 +1949,18 @@ def _apply_single_annotation_json_schema( if isinstance(metadata, FieldInfo): for field_metadata in metadata.metadata: - schema = self._apply_single_annotation_json_schema( - schema, field_metadata - ) + schema = self._apply_single_annotation_json_schema(schema, field_metadata) json_schema_update: JsonSchemaValue = {} if metadata.title: - json_schema_update["title"] = metadata.title + json_schema_update['title'] = metadata.title if metadata.description: - json_schema_update["description"] = metadata.description + json_schema_update['description'] = metadata.description if metadata.examples: - json_schema_update["examples"] = to_jsonable_python(metadata.examples) + json_schema_update['examples'] = to_jsonable_python(metadata.examples) json_schema_extra = metadata.json_schema_extra if json_schema_update or json_schema_extra: - CoreMetadataHandler(schema).metadata.setdefault( - "pydantic_js_annotation_functions", [] - ).append( + CoreMetadataHandler(schema).metadata.setdefault('pydantic_js_annotation_functions', []).append( get_json_schema_update_func(json_schema_update, json_schema_extra) ) return schema @@ -2300,9 +1971,9 @@ def _get_wrapped_inner_schema( annotation: Any, pydantic_js_annotation_functions: list[GetJsonSchemaFunction], ) -> CallbackGetCoreSchemaHandler: - metadata_get_schema: GetCoreSchemaFunction = getattr( - annotation, "__get_pydantic_core_schema__", None - ) or (lambda source, handler: handler(source)) + metadata_get_schema: GetCoreSchemaFunction = getattr(annotation, '__get_pydantic_core_schema__', None) or ( + lambda source, handler: handler(source) + ) def new_handler(source: Any) -> core_schema.CoreSchema: schema = metadata_get_schema(source, get_inner_schema) @@ -2325,14 +1996,12 @@ def _apply_field_serializers( """Apply field serializers to a schema.""" if serializers: schema = copy(schema) - if schema["type"] == "definitions": - inner_schema = schema["schema"] - schema["schema"] = self._apply_field_serializers( - inner_schema, serializers - ) + if schema['type'] == 'definitions': + inner_schema = schema['schema'] + schema['schema'] = self._apply_field_serializers(inner_schema, serializers) return schema else: - ref = typing.cast("str|None", schema.get("ref", None)) + ref = typing.cast('str|None', schema.get('ref', None)) if ref is not None: schema = core_schema.definition_reference_schema(ref) @@ -2354,36 +2023,30 @@ def _apply_field_serializers( else: return_schema = self.generate_schema(return_type) - if serializer.info.mode == "wrap": - schema["serialization"] = ( - core_schema.wrap_serializer_function_ser_schema( - serializer.func, - is_field_serializer=is_field_serializer, - info_arg=info_arg, - return_schema=return_schema, - when_used=serializer.info.when_used, - ) + if serializer.info.mode == 'wrap': + schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( + serializer.func, + is_field_serializer=is_field_serializer, + info_arg=info_arg, + return_schema=return_schema, + when_used=serializer.info.when_used, ) else: - assert serializer.info.mode == "plain" - schema["serialization"] = ( - core_schema.plain_serializer_function_ser_schema( - serializer.func, - is_field_serializer=is_field_serializer, - info_arg=info_arg, - return_schema=return_schema, - when_used=serializer.info.when_used, - ) + assert serializer.info.mode == 'plain' + schema['serialization'] = core_schema.plain_serializer_function_ser_schema( + serializer.func, + is_field_serializer=is_field_serializer, + info_arg=info_arg, + return_schema=return_schema, + when_used=serializer.info.when_used, ) return schema def _apply_model_serializers( - self, - schema: core_schema.CoreSchema, - serializers: Iterable[Decorator[ModelSerializerDecoratorInfo]], + self, schema: core_schema.CoreSchema, serializers: Iterable[Decorator[ModelSerializerDecoratorInfo]] ) -> core_schema.CoreSchema: """Apply model serializers to a schema.""" - ref: str | None = schema.pop("ref", None) # type: ignore + ref: str | None = schema.pop('ref', None) # type: ignore if serializers: serializer = list(serializers)[-1] info_arg = inspect_model_serializer(serializer.func, serializer.info.mode) @@ -2399,14 +2062,12 @@ def _apply_model_serializers( else: return_schema = self.generate_schema(return_type) - if serializer.info.mode == "wrap": - ser_schema: core_schema.SerSchema = ( - core_schema.wrap_serializer_function_ser_schema( - serializer.func, - info_arg=info_arg, - return_schema=return_schema, - when_used=serializer.info.when_used, - ) + if serializer.info.mode == 'wrap': + ser_schema: core_schema.SerSchema = core_schema.wrap_serializer_function_ser_schema( + serializer.func, + info_arg=info_arg, + return_schema=return_schema, + when_used=serializer.info.when_used, ) else: # plain @@ -2416,56 +2077,30 @@ def _apply_model_serializers( return_schema=return_schema, when_used=serializer.info.when_used, ) - schema["serialization"] = ser_schema + schema['serialization'] = ser_schema if ref: - schema["ref"] = ref # type: ignore + schema['ref'] = ref # type: ignore return schema _VALIDATOR_F_MATCH: Mapping[ - tuple[FieldValidatorModes, Literal["no-info", "with-info"]], - Callable[ - [Callable[..., Any], core_schema.CoreSchema, str | None], core_schema.CoreSchema - ], + tuple[FieldValidatorModes, Literal['no-info', 'with-info']], + Callable[[Callable[..., Any], core_schema.CoreSchema, str | None], core_schema.CoreSchema], ] = { - ( - "before", - "no-info", - ): lambda f, schema, _: core_schema.no_info_before_validator_function(f, schema), - ( - "after", - "no-info", - ): lambda f, schema, _: core_schema.no_info_after_validator_function(f, schema), - ( - "plain", - "no-info", - ): lambda f, _1, _2: core_schema.no_info_plain_validator_function(f), - ( - "wrap", - "no-info", - ): lambda f, schema, _: core_schema.no_info_wrap_validator_function(f, schema), - ( - "before", - "with-info", - ): lambda f, schema, field_name: core_schema.with_info_before_validator_function( + ('before', 'no-info'): lambda f, schema, _: core_schema.no_info_before_validator_function(f, schema), + ('after', 'no-info'): lambda f, schema, _: core_schema.no_info_after_validator_function(f, schema), + ('plain', 'no-info'): lambda f, _1, _2: core_schema.no_info_plain_validator_function(f), + ('wrap', 'no-info'): lambda f, schema, _: core_schema.no_info_wrap_validator_function(f, schema), + ('before', 'with-info'): lambda f, schema, field_name: core_schema.with_info_before_validator_function( f, schema, field_name=field_name ), - ( - "after", - "with-info", - ): lambda f, schema, field_name: core_schema.with_info_after_validator_function( + ('after', 'with-info'): lambda f, schema, field_name: core_schema.with_info_after_validator_function( f, schema, field_name=field_name ), - ( - "plain", - "with-info", - ): lambda f, _, field_name: core_schema.with_info_plain_validator_function( + ('plain', 'with-info'): lambda f, _, field_name: core_schema.with_info_plain_validator_function( f, field_name=field_name ), - ( - "wrap", - "with-info", - ): lambda f, schema, field_name: core_schema.with_info_wrap_validator_function( + ('wrap', 'with-info'): lambda f, schema, field_name: core_schema.with_info_wrap_validator_function( f, schema, field_name=field_name ), } @@ -2473,11 +2108,9 @@ def _apply_model_serializers( def apply_validators( schema: core_schema.CoreSchema, - validators: ( - Iterable[Decorator[RootValidatorDecoratorInfo]] - | Iterable[Decorator[ValidatorDecoratorInfo]] - | Iterable[Decorator[FieldValidatorDecoratorInfo]] - ), + validators: Iterable[Decorator[RootValidatorDecoratorInfo]] + | Iterable[Decorator[ValidatorDecoratorInfo]] + | Iterable[Decorator[FieldValidatorDecoratorInfo]], field_name: str | None, ) -> core_schema.CoreSchema: """Apply validators to a schema. @@ -2492,17 +2125,13 @@ def apply_validators( """ for validator in validators: info_arg = inspect_validator(validator.func, validator.info.mode) - val_type = "with-info" if info_arg else "no-info" + val_type = 'with-info' if info_arg else 'no-info' - schema = _VALIDATOR_F_MATCH[(validator.info.mode, val_type)]( - validator.func, schema, field_name - ) + schema = _VALIDATOR_F_MATCH[(validator.info.mode, val_type)](validator.func, schema, field_name) return schema -def _validators_require_validate_default( - validators: Iterable[Decorator[ValidatorDecoratorInfo]], -) -> bool: +def _validators_require_validate_default(validators: Iterable[Decorator[ValidatorDecoratorInfo]]) -> bool: """In v1, if any of the validators for a field had `always=True`, the default value would be validated. This serves as an auxiliary function for re-implementing that logic, by looping over a provided @@ -2521,7 +2150,7 @@ def _validators_require_validate_default( def apply_model_validators( schema: core_schema.CoreSchema, validators: Iterable[Decorator[ModelValidatorDecoratorInfo]], - mode: Literal["inner", "outer", "all"], + mode: Literal['inner', 'outer', 'all'], ) -> core_schema.CoreSchema: """Apply model validators to a schema. @@ -2537,49 +2166,35 @@ def apply_model_validators( Returns: The updated schema. """ - ref: str | None = schema.pop("ref", None) # type: ignore + ref: str | None = schema.pop('ref', None) # type: ignore for validator in validators: - if mode == "inner" and validator.info.mode != "before": + if mode == 'inner' and validator.info.mode != 'before': continue - if mode == "outer" and validator.info.mode == "before": + if mode == 'outer' and validator.info.mode == 'before': continue info_arg = inspect_validator(validator.func, validator.info.mode) - if validator.info.mode == "wrap": + if validator.info.mode == 'wrap': if info_arg: - schema = core_schema.with_info_wrap_validator_function( - function=validator.func, schema=schema - ) + schema = core_schema.with_info_wrap_validator_function(function=validator.func, schema=schema) else: - schema = core_schema.no_info_wrap_validator_function( - function=validator.func, schema=schema - ) - elif validator.info.mode == "before": + schema = core_schema.no_info_wrap_validator_function(function=validator.func, schema=schema) + elif validator.info.mode == 'before': if info_arg: - schema = core_schema.with_info_before_validator_function( - function=validator.func, schema=schema - ) + schema = core_schema.with_info_before_validator_function(function=validator.func, schema=schema) else: - schema = core_schema.no_info_before_validator_function( - function=validator.func, schema=schema - ) + schema = core_schema.no_info_before_validator_function(function=validator.func, schema=schema) else: - assert validator.info.mode == "after" + assert validator.info.mode == 'after' if info_arg: - schema = core_schema.with_info_after_validator_function( - function=validator.func, schema=schema - ) + schema = core_schema.with_info_after_validator_function(function=validator.func, schema=schema) else: - schema = core_schema.no_info_after_validator_function( - function=validator.func, schema=schema - ) + schema = core_schema.no_info_after_validator_function(function=validator.func, schema=schema) if ref: - schema["ref"] = ref # type: ignore + schema['ref'] = ref # type: ignore return schema -def wrap_default( - field_info: FieldInfo, schema: core_schema.CoreSchema -) -> core_schema.CoreSchema: +def wrap_default(field_info: FieldInfo, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: """Wrap schema with default schema if default value or `default_factory` are available. Args: @@ -2591,47 +2206,39 @@ def wrap_default( """ if field_info.default_factory: return core_schema.with_default_schema( - schema, - default_factory=field_info.default_factory, - validate_default=field_info.validate_default, + schema, default_factory=field_info.default_factory, validate_default=field_info.validate_default ) elif field_info.default is not PydanticUndefined: return core_schema.with_default_schema( - schema, - default=field_info.default, - validate_default=field_info.validate_default, + schema, default=field_info.default, validate_default=field_info.validate_default ) else: return schema -def _extract_get_pydantic_json_schema( - tp: Any, schema: CoreSchema -) -> GetJsonSchemaFunction | None: +def _extract_get_pydantic_json_schema(tp: Any, schema: CoreSchema) -> GetJsonSchemaFunction | None: """Extract `__get_pydantic_json_schema__` from a type, handling the deprecated `__modify_schema__`.""" - js_modify_function = getattr(tp, "__get_pydantic_json_schema__", None) + js_modify_function = getattr(tp, '__get_pydantic_json_schema__', None) - if hasattr(tp, "__modify_schema__"): + if hasattr(tp, '__modify_schema__'): from pydantic import BaseModel # circular reference has_custom_v2_modify_js_func = ( js_modify_function is not None and BaseModel.__get_pydantic_json_schema__.__func__ # type: ignore - not in (js_modify_function, getattr(js_modify_function, "__func__", None)) + not in (js_modify_function, getattr(js_modify_function, '__func__', None)) ) if not has_custom_v2_modify_js_func: - cls_name = getattr(tp, "__name__", None) + cls_name = getattr(tp, '__name__', None) raise PydanticUserError( - f"The `__modify_schema__` method is not supported in Pydantic v2. " + f'The `__modify_schema__` method is not supported in Pydantic v2. ' f'Use `__get_pydantic_json_schema__` instead{f" in class `{cls_name}`" if cls_name else ""}.', - code="custom-json-schema", + code='custom-json-schema', ) # handle GenericAlias' but ignore Annotated which "lies" about its origin (in this case it would be `int`) - if hasattr(tp, "__origin__") and not isinstance( - tp, type(Annotated[int, "placeholder"]) - ): + if hasattr(tp, '__origin__') and not isinstance(tp, type(Annotated[int, 'placeholder'])): return _extract_get_pydantic_json_schema(tp.__origin__, schema) if js_modify_function is None: @@ -2641,8 +2248,7 @@ def _extract_get_pydantic_json_schema( def get_json_schema_update_func( - json_schema_update: JsonSchemaValue, - json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None, + json_schema_update: JsonSchemaValue, json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None ) -> GetJsonSchemaFunction: def json_schema_update_func( core_schema_or_field: CoreSchemaOrField, handler: GetJsonSchemaHandler @@ -2655,8 +2261,7 @@ def json_schema_update_func( def add_json_schema_extra( - json_schema: JsonSchemaValue, - json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None, + json_schema: JsonSchemaValue, json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None ): if isinstance(json_schema_extra, dict): json_schema.update(to_jsonable_python(json_schema_extra)) @@ -2683,12 +2288,12 @@ def _common_field( metadata: Any = None, ) -> _CommonField: return { - "schema": schema, - "validation_alias": validation_alias, - "serialization_alias": serialization_alias, - "serialization_exclude": serialization_exclude, - "frozen": frozen, - "metadata": metadata, + 'schema': schema, + 'validation_alias': validation_alias, + 'serialization_alias': serialization_alias, + 'serialization_exclude': serialization_exclude, + 'frozen': frozen, + 'metadata': metadata, } @@ -2700,9 +2305,7 @@ def __init__(self) -> None: self.definitions: dict[str, core_schema.CoreSchema] = {} @contextmanager - def get_schema_or_ref( - self, tp: Any - ) -> Iterator[tuple[str, None] | tuple[str, CoreSchema]]: + def get_schema_or_ref(self, tp: Any) -> Iterator[tuple[str, None] | tuple[str, CoreSchema]]: """Get a definition for `tp` if one exists. If a definition exists, a tuple of `(ref_string, CoreSchema)` is returned. @@ -2733,19 +2336,17 @@ def get_schema_or_ref( self.seen.discard(ref) -def resolve_original_schema( - schema: CoreSchema, definitions: dict[str, CoreSchema] -) -> CoreSchema | None: - schema_type = schema["type"] - if schema_type == "definition-ref": - return definitions.get(schema["schema_ref"]) - elif schema_type == "definitions": - return schema["schema"] +def resolve_original_schema(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> CoreSchema | None: + schema_type = schema['type'] + if schema_type == 'definition-ref': + return definitions.get(schema['schema_ref']) + elif schema_type == 'definitions': + return schema['schema'] return schema class _FieldNameStack: - __slots__ = ("_stack",) + __slots__ = ('_stack',) def __init__(self) -> None: self._stack: list[str] = [] @@ -2764,7 +2365,7 @@ def get(self) -> str | None: class _ModelTypeStack: - __slots__ = ("_stack",) + __slots__ = ('_stack',) def __init__(self) -> None: self._stack: list[type] = []