Skip to content

Commit

Permalink
fix: fix discriminator bug in recursive type (#596)
Browse files Browse the repository at this point in the history
Fixes #567
By the way, discriminator feature will be reworked in next version,
and the current API will be deprecated
  • Loading branch information
wyfo authored Oct 18, 2023
1 parent f86ff5e commit c4755d5
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 40 deletions.
38 changes: 7 additions & 31 deletions apischema/deserialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import inspect
import re
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum
from functools import lru_cache, partial
from typing import (
Expand Down Expand Up @@ -254,7 +253,6 @@ def __init__(
aliaser: Aliaser,
coercer: Optional[Coercer],
default_conversion: DefaultConversion,
discriminator: Optional[str],
fall_back_on_default: bool,
no_copy: bool,
pass_through: CollectionOrPredicate[type],
Expand All @@ -263,7 +261,6 @@ def __init__(
self.additional_properties = additional_properties
self.aliaser = aliaser
self.coercer = coercer
self._discriminator = discriminator
self.fall_back_on_default = fall_back_on_default
self.no_copy = no_copy
self.pass_through = pass_through
Expand All @@ -287,28 +284,17 @@ def visit_not_recursive(self, tp: AnyType) -> DeserializationMethodFactory:
self.coercer,
self._conversion,
self.default_conversion,
self._discriminator,
self.fall_back_on_default,
self.no_copy,
self.pass_through,
)

@contextmanager
def _discriminate(self, discriminator: Optional[str]):
discriminator_save = self._discriminator
self._discriminator = discriminator
try:
yield
finally:
self._discriminator = discriminator_save

def discriminate(
self, discriminator: Discriminator, types: Sequence[AnyType]
) -> DeserializationMethodFactory:
mapping = {}
for key, tp in discriminator.get_mapping(types).items():
with self._discriminate(self.aliaser(discriminator.alias)):
mapping[key] = self.visit(tp)
mapping[key] = self.visit(tp)

def factory(constraints: Optional[Constraints], _) -> DeserializationMethod:
from apischema import settings
Expand Down Expand Up @@ -433,13 +419,12 @@ def object(
self, tp: Type, fields: Sequence[ObjectField]
) -> DeserializationMethodFactory:
cls = get_origin_or_type(tp)
with self._discriminate(None):
field_factories = [
self.visit_with_conv(f.type, f.deserialization).merge(
get_constraints(f.schema), f.validators
)
for f in fields
]
field_factories = [
self.visit_with_conv(f.type, f.deserialization).merge(
get_constraints(f.schema), f.validators
)
for f in fields
]

def factory(
constraints: Optional[Constraints], validators: Sequence[Validator]
Expand Down Expand Up @@ -528,11 +513,6 @@ def factory(
and not flattened_fields
and not pattern_fields
and not additional_field
and (
self._discriminator is None
or self._discriminator in all_alliases
or is_typed_dict(cls)
)
and (is_typed_dict(cls) == self.additional_properties)
and (not is_typed_dict(cls) or self.no_copy)
and not validators
Expand Down Expand Up @@ -572,7 +552,6 @@ def factory(
self.aliaser,
settings.errors.missing_property,
settings.errors.unexpected_property,
self._discriminator,
)

return self._factory(factory, dict, validation=False)
Expand Down Expand Up @@ -736,7 +715,6 @@ def deserialization_method_factory(
coercer: Optional[Coercer],
conversion: Optional[AnyConversion],
default_conversion: DefaultConversion,
discriminator: Optional[str],
fall_back_on_default: bool,
no_copy: bool,
pass_through: CollectionOrPredicate[type],
Expand All @@ -746,7 +724,6 @@ def deserialization_method_factory(
aliaser,
coercer,
default_conversion,
discriminator,
fall_back_on_default,
no_copy,
pass_through,
Expand Down Expand Up @@ -821,7 +798,6 @@ def deserialization_method(
coercer,
conversion,
opt_or(default_conversion, settings.deserialization.default_conversion),
None,
opt_or(fall_back_on_default, settings.deserialization.fall_back_on_default),
opt_or(no_copy, settings.deserialization.no_copy),
pass_through, # type: ignore
Expand Down
44 changes: 35 additions & 9 deletions apischema/deserialization/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,15 @@ class SimpleObjectMethod(DeserializationMethod):
unexpected: str

def deserialize(self, data: Any) -> Any:
discriminator: Optional[str] = None
if not isinstance(data, dict):
raise bad_type(data, dict)
if isinstance(data, Discriminated):
discriminator = data.discriminator
data = data.data
if not isinstance(data, dict):
raise bad_type(data, dict)
else:
raise bad_type(data, dict)
fields_count: int = 0
field_errors: Optional[dict] = None
for field in self.fields:
Expand All @@ -504,13 +511,20 @@ def deserialize(self, data: Any) -> Any:
field_errors = set_child_error(
field_errors, field.alias, ValidationError(self.missing)
)
has_discriminator = False
if len(data) != fields_count and not self.typed_dict:
for key in data.keys() - self.all_aliases:
field_errors = set_child_error(
field_errors, key, ValidationError(self.unexpected)
)
if key == discriminator:
has_discriminator = True
else:
field_errors = set_child_error(
field_errors, key, ValidationError(self.unexpected)
)
if field_errors:
raise ValidationError([], field_errors)
if has_discriminator:
data = data.copy()
del data[discriminator]
return self.constructor.construct(data)


Expand Down Expand Up @@ -552,7 +566,6 @@ class ObjectMethod(DeserializationMethod):
aliaser: Aliaser
missing: str
unexpected: str
discriminator: Optional[str]
aggregate_fields: bool = field(init=False)

def __post_init__(self):
Expand All @@ -563,8 +576,15 @@ def __post_init__(self):
)

def deserialize(self, data: Any) -> Any:
discriminator: Optional[str] = None
if not isinstance(data, dict):
raise bad_type(data, dict)
if isinstance(data, Discriminated):
discriminator = data.discriminator
data = data.data
if not isinstance(data, dict):
raise bad_type(data, dict)
else:
raise bad_type(data, dict)
values: dict = {}
fields_count: int = 0
errors: Optional[list] = None
Expand Down Expand Up @@ -640,7 +660,7 @@ def deserialize(self, data: Any) -> Any:
elif remain:
if not self.additional_properties:
for key in remain:
if key != self.discriminator:
if key != discriminator:
field_errors = set_child_error(
field_errors, key, ValidationError(self.unexpected)
)
Expand All @@ -650,7 +670,7 @@ def deserialize(self, data: Any) -> Any:
elif len(data) != fields_count:
if not self.additional_properties:
for key in data.keys() - self.all_aliases:
if key != self.discriminator:
if key != discriminator:
field_errors = set_child_error(
field_errors, key, ValidationError(self.unexpected)
)
Expand Down Expand Up @@ -895,6 +915,12 @@ def deserialize(self, data: Any) -> Any:
raise error


@dataclass
class Discriminated:
discriminator: str
data: Any


@dataclass
class DiscriminatorMethod(DeserializationMethod):
alias: str
Expand All @@ -919,4 +945,4 @@ def deserialize(self, data: Any):
},
)
else:
return method.deserialize(data)
return method.deserialize(Discriminated(self.alias, data))

0 comments on commit c4755d5

Please sign in to comment.