Skip to content

Commit

Permalink
New union unpack algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
Fatal1ty committed Mar 31, 2024
1 parent 4eefc4f commit 98ece87
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 27 deletions.
6 changes: 1 addition & 5 deletions benchmark/libs/mashumaro/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@
import pyperf

from benchmark.common import AbstractBenchmark
from mashumaro import field_options, pass_through
from mashumaro import field_options
from mashumaro.codecs import BasicDecoder, BasicEncoder
from mashumaro.dialect import Dialect


class DefaultDialect(Dialect):
serialize_by_alias = True
serialization_strategy = {
str: {"deserialize": str, "serialize": pass_through},
int: {"serialize": pass_through},
}


class IssueState(Enum):
Expand Down
16 changes: 10 additions & 6 deletions mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
NoneType,
ValueSpec,
clean_id,
expr_can_fail,
)
from mashumaro.core.meta.types.pack import PackerRegistry
from mashumaro.core.meta.types.unpack import (
Expand Down Expand Up @@ -1274,13 +1275,16 @@ def _try_set_value(
unpacked_value: str,
in_kwargs: bool,
) -> None:
with self.lines.indent("try:"):
if expr_can_fail(unpacked_value, "value") or False:
with self.lines.indent("try:"):
self._set_value(field_name, unpacked_value, in_kwargs)
with self.lines.indent("except:"):
self.lines.append(
"raise InvalidFieldValue("
f"'{field_name}',{field_type_name},value,cls)"
)
else:
self._set_value(field_name, unpacked_value, in_kwargs)
with self.lines.indent("except:"):
self.lines.append(
"raise InvalidFieldValue("
f"'{field_name}',{field_type_name},value,cls)"
)

def _set_value(
self, fname: str, unpacked_value: str, in_kwargs: bool = False
Expand Down
3 changes: 3 additions & 0 deletions mashumaro/core/meta/code/lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def indent(
def as_text(self) -> str:
return "\n".join(self._lines)

def len(self) -> int:
return len(self._lines)

def reset(self) -> None:
self._lines = []
self._current_indent = ""
Expand Down
14 changes: 13 additions & 1 deletion mashumaro/core/meta/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,15 @@ def build(self, spec: ValueSpec) -> str:
lines = CodeLines()
method_name = self._add_definition(spec, lines)
with lines.indent():
self._add_body(spec, lines)
body_lines = lines.branch_off()
self._add_body(spec, body_lines)
if body_lines.len() == 1:
body_line = body_lines.as_text().strip()
if body_line.startswith("return ") and not expr_can_fail(
body_line[7:], spec.expression
):
return spec.expression
lines.extend(body_lines)
self._add_setattr(spec, method_name, lines)
self._compile(spec, lines)
return self._get_call_expr(spec, method_name)
Expand Down Expand Up @@ -308,3 +316,7 @@ def clean_id(value: str) -> str:
return "_"

return _PY_VALID_ID_RE.sub("_", value)


def expr_can_fail(expr: str, value: str) -> bool:
return expr not in (value, f"str({value})", f"bool({value})")
49 changes: 37 additions & 12 deletions mashumaro/core/meta/types/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
ensure_generic_mapping,
expr_or_maybe_none,
random_hex,
expr_can_fail,
)
from mashumaro.exceptions import (
ThirdPartyModuleNotFoundError,
Expand Down Expand Up @@ -166,13 +167,27 @@ def get_method_prefix(self) -> str:
return "union"

def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
for unpacker in (
UnpackerRegistry.get(spec.copy(type=type_arg, expression="value"))
for type_arg in self.union_args
):
complex_unpackers = []
simple_unpackers = []
for type_arg in self.union_args:
unpacker = UnpackerRegistry.get(
spec.copy(type=type_arg, expression="value")
)
if expr_can_fail(unpacker, spec.expression):
if unpacker not in complex_unpackers:
complex_unpackers.append(unpacker)
elif unpacker not in simple_unpackers:
simple_unpackers.append(unpacker)

for unpacker in complex_unpackers:
with lines.indent("try:"):
lines.append(f"return {unpacker}")
lines.append("except Exception: pass")
with lines.indent("except Exception:"):
lines.append("pass")
if simple_unpackers:
lines.append(f"return {simple_unpackers[0]}")
return

field_type = spec.builder.get_type_name_identifier(
typ=spec.type,
resolved_type_params=spec.builder.get_field_resolved_type_params(
Expand Down Expand Up @@ -220,7 +235,8 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
with lines.indent("try:"):
with lines.indent(f"if {unpacker} == {literal_value!r}:"):
lines.append(f"return {literal_value!r}")
lines.append("except Exception: pass")
with lines.indent("except Exception:"):
lines.append("pass")
elif isinstance(
literal_value,
(int, str, bool, NoneType), # type: ignore
Expand Down Expand Up @@ -404,7 +420,8 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
self._add_build_variant_unpacker(
spec, lines, variant_method_name, variant_method_call
)
lines.append("except Exception: pass")
with lines.indent("except Exception:"):
lines.append("pass")
lines.append(
f"raise SuitableVariantNotFoundError({variants_type_expr}) "
"from None"
Expand Down Expand Up @@ -449,7 +466,8 @@ def _add_build_variant_unpacker(
if not self.discriminator.field:
with lines.indent("try:"):
lines.append(f"return variant.{variant_method_call}")
lines.append("except Exception: pass")
with lines.indent("except Exception:"):
lines.append("pass")
else:
spec.builder.ensure_object_imported(AttrsHolder)
attrs = f"attrs_{random_hex()}"
Expand All @@ -467,7 +485,8 @@ def _add_build_variant_unpacker(
if not self.discriminator.field:
with lines.indent("try:"):
lines.append(f"return {attrs}.{variant_method_call}")
lines.append("except Exception: pass")
with lines.indent("except Exception:"):
lines.append("pass")

def _add_register_variant_tags(
self, lines: CodeLines, variant_tagger_expr: str
Expand Down Expand Up @@ -815,8 +834,14 @@ def unpack_number(spec: ValueSpec) -> Optional[Expression]:


@register
def unpack_bool_and_none(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (bool, NoneType, None):
def unpack_bool(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type is bool:
return f"bool({spec.expression})"


@register
def unpack_none(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (NoneType, None):
return spec.expression


Expand Down Expand Up @@ -1167,7 +1192,7 @@ def inner_expr(
spec.builder.ensure_object_imported(decodebytes)
return f"bytearray(decodebytes({spec.expression}.encode()))"
elif issubclass(spec.origin_type, str):
return spec.expression
return f"str({spec.expression})"
elif ensure_generic_collection_subclass(spec, List):
return f"[{inner_expr()} for value in {spec.expression}]"
elif ensure_generic_collection_subclass(spec, typing.Deque):
Expand Down
34 changes: 31 additions & 3 deletions tests/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,45 @@ class UnionTestCase:
UnionTestCase(Union[str, List[str]], "abc", "abc"),
],
)
def test_union(test_case):
def test_union_encoding(test_case):
@dataclass
class DataClass(DataClassDictMixin):
x: test_case.type

instance = DataClass(x=test_case.loaded)
assert DataClass.from_dict({"x": test_case.dumped}) == instance
assert instance.to_dict() == {"x": test_case.dumped}


def test_union_encoding():
@pytest.mark.parametrize(
"test_case",
[
UnionTestCase(Union[str, int, bool, float], True, 1),
UnionTestCase(Union[int, str], 1, 1),
UnionTestCase(Union[str, int], 1, 1),
UnionTestCase(Union[int, str], "a", "a"),
UnionTestCase(Union[str, int], "a", "a"),
UnionTestCase(Union[Dict[int, int], List[int]], {1: 2}, {1: 2}),
UnionTestCase(Union[List[int], Dict[int, int]], {1: 2}, [1]),
UnionTestCase(Union[Dict[int, int], List[int]], [1], [1]),
UnionTestCase(Union[List[int], Dict[int, int]], [1], [1]),
UnionTestCase(Union[List[int], str], [1], [1]),
UnionTestCase(Union[str, List[int]], [1], [1]),
UnionTestCase(Union[str, List[str]], "abc", ["a", "b", "c"]),
UnionTestCase(Union[List[str], str], "abc", ["a", "b", "c"]),
],
)
def test_union_decoding(test_case):
@dataclass
class DataClass(DataClassDictMixin):
x: test_case.type

instance = DataClass(x=test_case.loaded)
assert same_types(
DataClass.from_dict({"x": test_case.dumped}).x, instance.x
)


def test_basic_types_permutations_union_encoding():
for variants in permutations((int, float, str, bool)):
for value in (1, 2.0, 3.1, "4", "5.0", True, False):
encoded = encode(value, Union[variants])
Expand Down

0 comments on commit 98ece87

Please sign in to comment.