diff --git a/README.md b/README.md index 26b51310..009f2dda 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ Table of contents * [`dialect` config option](#dialect-config-option) * [`orjson_options`](#orjson_options-config-option) * [`discriminator` config option](#discriminator-config-option) + * [`lazy_compilation` config option](#lazy_compilation-config-option) * [Passing field values as is](#passing-field-values-as-is) * [Extending existing types](#extending-existing-types) * [Dialects](#dialects) @@ -297,6 +298,8 @@ field types on every call of parsing or building at runtime. These specific parsers and builders are presented by the corresponding `from_*` and `to_*` methods. They are compiled during import time (or at runtime in some cases) and are set as attributes to your dataclasses. +To minimize the import time, you can explicitly enable +[lazy compilation](#lazy_compilation-config-option). Benchmark -------------------------------------------------------------------------------- @@ -1344,6 +1347,23 @@ assert MyClass({1: 2}).to_json() == {"1": 2} This option is described in the [Class level discriminator](#class-level-discriminator) section. +#### `lazy_compilation` config option + +By using this option, the compilation of the `from_*` and `to_*` methods will +be deferred until they are called first time. This will reduce the import time +and, in certain instances, may enhance the speed of deserialization +by leveraging the data that is accessible after the class has been created. + +> **Warning** +> +> If you need to save a reference to `from_*` or `to_*` method, you should +> do it after the method is compiled. To be safe, you can always use lambda +> function: +> ```python +> from_dict = lambda x: MyModel.from_dict(x) +> to_dict = lambda x: x.to_dict() +> ``` + ### Passing field values as is In some cases it's needed to pass a field value as is without any changes diff --git a/mashumaro/config.py b/mashumaro/config.py index 3555f687..63272497 100644 --- a/mashumaro/config.py +++ b/mashumaro/config.py @@ -49,3 +49,4 @@ class BaseConfig: orjson_options: Optional[int] = 0 json_schema: Dict[str, Any] = {} discriminator: Optional[Discriminator] = None + lazy_compilation: bool = False diff --git a/mashumaro/core/meta/code/builder.py b/mashumaro/core/meta/code/builder.py index d726f786..e2a571dd 100644 --- a/mashumaro/core/meta/code/builder.py +++ b/mashumaro/core/meta/code/builder.py @@ -5,9 +5,15 @@ from contextlib import contextmanager # noinspection PyProtectedMember -from dataclasses import _FIELDS, MISSING, Field, is_dataclass # type: ignore +from dataclasses import _FIELDS # type: ignore +from dataclasses import MISSING, Field, is_dataclass from functools import lru_cache +try: + from dataclasses import KW_ONLY # type: ignore +except ImportError: + KW_ONLY = object() # type: ignore + import typing_extensions from mashumaro.config import ( @@ -32,7 +38,6 @@ is_dialect_subclass, is_init_var, is_literal, - is_named_tuple, is_optional, is_type_var_any, resolve_type_params, @@ -131,7 +136,7 @@ def __get_field_types( name = get_name_error_name(e) raise UnresolvedTypeReferenceError(self.cls, name) from None for fname, ftype in field_type_hints.items(): - if is_class_var(ftype) or is_init_var(ftype): + if is_class_var(ftype) or is_init_var(ftype) or ftype is KW_ONLY: continue if recursive or fname in self.annotations: fields[fname] = ftype @@ -261,8 +266,31 @@ def get_declared_hook(self, method_name: str) -> typing.Any: if cls is not None and not is_dataclass_dict_mixin(cls): return cls.__dict__[method_name] + def _add_unpack_method_lines_lazy(self, method_name: str) -> None: + if self.default_dialect is not None: + self.add_type_modules(self.default_dialect) + self.add_line( + f"CodeBuilder(" + f"cls," + f"first_method='{method_name}'," + f"allow_postponed_evaluation=False," + f"format_name='{self.format_name}'," + f"decoder={type_name(self.decoder)}," # type: ignore + f"default_dialect={type_name(self.default_dialect)}" + f").add_unpack_method()" + ) + unpacker_args = [ + "d", + self.get_unpack_method_flags(pass_decoder=True), + ] + unpacker_args_s = ", ".join(filter(None, unpacker_args)) + self.add_line(f"return cls.{method_name}({unpacker_args_s})") + def _add_unpack_method_lines(self, method_name: str) -> None: config = self.get_config() + if config.lazy_compilation and self.allow_postponed_evaluation: + self._add_unpack_method_lines_lazy(method_name) + return try: field_types = self.get_field_types(include_extras=True) except UnresolvedTypeReferenceError: @@ -271,24 +299,7 @@ def _add_unpack_method_lines(self, method_name: str) -> None: or not config.allow_postponed_evaluation ): raise - if self.default_dialect is not None: - self.add_type_modules(self.default_dialect) - self.add_line( - f"CodeBuilder(" - f"cls," - f"first_method='{method_name}'," - f"allow_postponed_evaluation=False," - f"format_name='{self.format_name}'," - f"decoder={type_name(self.decoder)}," # type: ignore - f"default_dialect={type_name(self.default_dialect)}" - f").add_unpack_method()" - ) - unpacker_args = [ - "d", - self.get_unpack_method_flags(pass_decoder=True), - ] - unpacker_args_s = ", ".join(filter(None, unpacker_args)) - self.add_line(f"return cls.{method_name}({unpacker_args_s})") + self._add_unpack_method_lines_lazy(method_name) else: if self.decoder is not None: self.add_line("d = decoder(d)") @@ -324,16 +335,36 @@ def _add_unpack_method_lines(self, method_name: str) -> None: ) else: self.add_line(f"d = cls.{__PRE_DESERIALIZE__}(d)") + post_deserialize = self.get_declared_hook(__POST_DESERIALIZE__) + if post_deserialize: + if not isinstance(post_deserialize, classmethod): + raise BadHookSignature( + f"`{__POST_DESERIALIZE__}` must be a class method " + f"with Callable[[{type_name(self.cls)}], " + f"{type_name(self.cls)}] signature" + ) filtered_fields = [] + kwargs_only = post_deserialize is not None + pos_args = [] + kw_args = [] + can_be_kwargs = False for fname, ftype in field_types.items(): field = self.dataclass_fields.get(fname) # type: ignore # https://github.com/python/mypy/issues/1362 if field and not field.init: continue + if self.get_field_default(fname) is MISSING: + if field and not getattr(field, "kw_only", True): + pos_args.append(fname) + else: + kw_args.append(fname) + else: + can_be_kwargs = True filtered_fields.append((fname, ftype)) if filtered_fields: with self.indent("try:"): - self.add_line("kwargs = {}") + if kwargs_only or can_be_kwargs: + self.add_line("kwargs = {}") for fname, ftype in filtered_fields: self.add_type_modules(ftype) metadata = self.metadatas.get(fname, {}) @@ -341,7 +372,11 @@ def _add_unpack_method_lines(self, method_name: str) -> None: if alias is None: alias = config.aliases.get(fname) self._unpack_method_set_value( - fname, ftype, metadata, alias + fname, + ftype, + metadata, + alias=alias, + kwargs_only=kwargs_only, ) with self.indent("except TypeError:"): with self.indent("if not isinstance(d, dict):"): @@ -354,20 +389,17 @@ def _add_unpack_method_lines(self, method_name: str) -> None: self.add_line("raise") else: self.add_line("kwargs = {}") - post_deserialize = self.get_declared_hook(__POST_DESERIALIZE__) if post_deserialize: - if not isinstance(post_deserialize, classmethod): - raise BadHookSignature( - f"`{__POST_DESERIALIZE__}` must be a class method " - f"with Callable[[{type_name(self.cls)}], " - f"{type_name(self.cls)}] signature" - ) - else: - self.add_line( - f"return cls.{__POST_DESERIALIZE__}(cls(**kwargs))" - ) + self.add_line( + f"return cls.{__POST_DESERIALIZE__}(cls(**kwargs))" + ) else: - self.add_line("return cls(**kwargs)") + args = [f"__{f}" for f in pos_args] + for kw_arg in kw_args: + args.append(f"{kw_arg}=__{kw_arg}") + if can_be_kwargs: + args.append("**kwargs") + self.add_line(f"return cls({', '.join(args)})") def _add_unpack_method_with_dialect_lines(self, method_name: str) -> None: if self.decoder is not None: @@ -439,75 +471,123 @@ def _unpack_method_set_value( fname: str, ftype: typing.Type, metadata: typing.Mapping, + *, alias: typing.Optional[str] = None, + kwargs_only: bool = False, ) -> None: - with self.indent("try:"): - could_be_none = False - if is_named_tuple(ftype): - self.add_line(f"value = d['{alias or fname}']") - packed_value = "value" - else: - packed_value = f"d['{alias or fname}']" - could_be_none = ( - ftype in (typing.Any, type(None), None) - or is_type_var_any(self._get_real_type(fname, ftype)) - or is_optional( - ftype, self.get_field_resolved_type_params(fname) - ) - or self.get_field_default(fname) is None - ) - if could_be_none: - self.add_line(f"value = {packed_value}") - packed_value = "value" - unpacked_value = UnpackerRegistry.get( - ValueSpec( - type=ftype, - expression=packed_value, - builder=self, - field_ctx=FieldContext( - name=fname, - metadata=metadata, - ), - could_be_none=False if could_be_none else True, - ) - ) - if could_be_none: - with self.indent("if value is not None:"): - self.add_line(f"kwargs['{fname}'] = {unpacked_value}") - with self.indent("else:"): - self.add_line(f"kwargs['{fname}'] = None") - else: - self.add_line(f"kwargs['{fname}'] = {unpacked_value}") - with self.indent("except KeyError as e:"): - field_type = type_name( - ftype, - resolved_type_params=self.get_field_resolved_type_params( - fname + default = self.get_field_default(fname) + has_default = default is not MISSING + field_type = type_name( + ftype, + resolved_type_params=self.get_field_resolved_type_params(fname), + ) + could_be_none = ( + ftype in (typing.Any, type(None), None) + or is_type_var_any(self._get_real_type(fname, ftype)) + or is_optional(ftype, self.get_field_resolved_type_params(fname)) + or default is None + ) + unpacked_value = UnpackerRegistry.get( + ValueSpec( + type=ftype, + expression="value", + builder=self, + field_ctx=FieldContext( + name=fname, + metadata=metadata, ), + could_be_none=False if could_be_none else True, ) - if self.get_field_default(fname) is MISSING: - with self.indent("if e.__traceback__.tb_next is None:"): - self.add_line( - f"raise MissingField('{fname}',{field_type},cls) " - f"from None" - ) - with self.indent("else:"): - self.add_line( - f"raise InvalidFieldValue(" - f"'{fname}',{field_type},{packed_value},cls)" + ) + if unpacked_value != "value": + self.add_line(f"value = d.get('{alias or fname}', MISSING)") + packed_value = "value" + elif has_default: + self.add_line(f"value = d.get('{alias or fname}', MISSING)") + packed_value = "value" + else: + self.add_line(f"__{fname} = d.get('{alias or fname}', MISSING)") + packed_value = f"__{fname}" + unpacked_value = packed_value + if not has_default: + with self.indent(f"if {packed_value} is MISSING:"): + self.add_line( + f"raise MissingField('{fname}',{field_type},cls) " + f"from None" + ) + if packed_value != unpacked_value: + if could_be_none: + with self.indent(f"if {packed_value} is not None:"): + self.__unpack_try_set_value( + fname, + field_type, + unpacked_value, + kwargs_only, + has_default, + ) + with self.indent("else:"): + self.__unpack_set_value( + fname, "None", kwargs_only or has_default + ) + else: + self.__unpack_try_set_value( + fname, + field_type, + unpacked_value, + kwargs_only, + has_default, ) - else: - with self.indent("if e.__traceback__.tb_next is not None:"): - self.add_line( - f"raise InvalidFieldValue(" - f"'{fname}',{field_type},{packed_value},cls)" + else: + with self.indent(f"if {packed_value} is not MISSING:"): + if could_be_none: + with self.indent(f"if {packed_value} is not None:"): + self.__unpack_try_set_value( + fname, + field_type, + unpacked_value, + kwargs_only, + has_default, + ) + if default is not None: + with self.indent("else:"): + self.__unpack_set_value( + fname, "None", kwargs_only or has_default + ) + else: + self.__unpack_try_set_value( + fname, + field_type, + unpacked_value, + kwargs_only, + has_default, ) - with self.indent("except Exception:"): + + def __unpack_try_set_value( + self, + field_name: str, + field_type_name: str, + unpacked_value: str, + kwargs_only: bool, + has_default: bool, + ) -> None: + with self.indent("try:"): + self.__unpack_set_value( + field_name, unpacked_value, kwargs_only or has_default + ) + with self.indent("except:"): self.add_line( f"raise InvalidFieldValue(" - f"'{fname}',{field_type},{packed_value},cls)" + f"'{field_name}',{field_type_name},value,cls)" ) + def __unpack_set_value( + self, fname: str, unpacked_value: str, kwargs_only: bool + ) -> None: + if kwargs_only: + self.add_line(f"kwargs['{fname}'] = {unpacked_value}") + else: + self.add_line(f"__{fname} = {unpacked_value}") + @lru_cache() @typing.no_type_check def get_config( @@ -690,8 +770,28 @@ def get_pack_method_name( method_name += f"_{hash_type_args(type_args)}" return method_name + def _add_pack_method_lines_lazy(self, method_name: str) -> None: + if self.default_dialect is not None: + self.add_type_modules(self.default_dialect) + self.add_line( + f"CodeBuilder(" + f"self.__class__," + f"first_method='{method_name}'," + f"allow_postponed_evaluation=False," + f"format_name='{self.format_name}'," + f"encoder={type_name(self.encoder)}," + f"encoder_kwargs={self._get_encoder_kwargs()}," + f"default_dialect={type_name(self.default_dialect)}" + f").add_pack_method()" + ) + packer_args = self.get_pack_method_flags(pass_encoder=True) + self.add_line(f"return self.{method_name}({packer_args})") + def _add_pack_method_lines(self, method_name: str) -> None: config = self.get_config() + if config.lazy_compilation and self.allow_postponed_evaluation: + self._add_pack_method_lines_lazy(method_name) + return try: field_types = self.get_field_types(include_extras=True) except UnresolvedTypeReferenceError: @@ -700,21 +800,7 @@ def _add_pack_method_lines(self, method_name: str) -> None: or not config.allow_postponed_evaluation ): raise - if self.default_dialect is not None: - self.add_type_modules(self.default_dialect) - self.add_line( - f"CodeBuilder(" - f"self.__class__," - f"first_method='{method_name}'," - f"allow_postponed_evaluation=False," - f"format_name='{self.format_name}'," - f"encoder={type_name(self.encoder)}," - f"encoder_kwargs={self._get_encoder_kwargs()}," - f"default_dialect={type_name(self.default_dialect)}" - f").add_pack_method()" - ) - packer_args = self.get_pack_method_flags(pass_encoder=True) - self.add_line(f"return self.{method_name}({packer_args})") + self._add_pack_method_lines_lazy(method_name) else: pre_serialize = self.get_declared_hook(__PRE_SERIALIZE__) if pre_serialize: diff --git a/mashumaro/core/meta/types/pack.py b/mashumaro/core/meta/types/pack.py index 3440fe1c..fdaaa67f 100644 --- a/mashumaro/core/meta/types/pack.py +++ b/mashumaro/core/meta/types/pack.py @@ -661,7 +661,7 @@ def inner_expr( ) ) - if issubclass(spec.origin_type, typing.ByteString): + if issubclass(spec.origin_type, typing.ByteString): # type: ignore spec.builder.ensure_object_imported(encodebytes) return f"encodebytes({spec.expression}).decode()" elif issubclass(spec.origin_type, str): diff --git a/mashumaro/core/meta/types/unpack.py b/mashumaro/core/meta/types/unpack.py index 579bdfb7..674d8a98 100644 --- a/mashumaro/core/meta/types/unpack.py +++ b/mashumaro/core/meta/types/unpack.py @@ -604,7 +604,9 @@ def unpack_dataclass_dict_mixin_subclass( ), ) ) - return f"{type_name(spec.origin_type)}.{method_name}({method_args})" + cls_alias = clean_id(type_name(spec.origin_type)) + spec.builder.ensure_object_imported(spec.origin_type, cls_alias) + return f"{cls_alias}.{method_name}({method_args})" @register @@ -1043,7 +1045,7 @@ def inner_expr( ) ) - if issubclass(spec.origin_type, typing.ByteString): + if issubclass(spec.origin_type, typing.ByteString): # type: ignore if spec.origin_type is bytes: spec.builder.ensure_object_imported(decodebytes) return f"decodebytes({spec.expression}.encode())" diff --git a/mashumaro/jsonschema/schema.py b/mashumaro/jsonschema/schema.py index 10b658a1..90fa19f3 100644 --- a/mashumaro/jsonschema/schema.py +++ b/mashumaro/jsonschema/schema.py @@ -657,7 +657,7 @@ def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]: args = get_args(instance.type) - if issubclass(instance.origin_type, typing.ByteString): + if issubclass(instance.origin_type, typing.ByteString): # type: ignore return JSONSchema( type=JSONSchemaInstanceType.STRING, format=JSONSchemaInstanceFormatExtension.BASE64, diff --git a/tests/test_common.py b/tests/test_common.py index 059ed78b..017267c0 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,8 +1,11 @@ -from dataclasses import dataclass +import dataclasses +from dataclasses import dataclass, field import msgpack import pytest +from mashumaro.config import BaseConfig +from mashumaro.core.const import PY_310_MIN from mashumaro.mixins.dict import DataClassDictMixin from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.mixins.msgpack import DataClassMessagePackMixin @@ -55,6 +58,63 @@ class EntityBWrapperMessagePack(DataClassMessagePackMixin): entity2wrapper: EntityB2WrapperMessagePack +if PY_310_MIN: + + @dataclass(kw_only=True) + class DataClassKwOnly1(DataClassDictMixin): + x: int + y: int + + @dataclass + class DataClassKwOnly2(DataClassDictMixin): + x: int = field(kw_only=True) + y: int + + @dataclass(kw_only=True) + class DataClassKwOnly3(DataClassDictMixin): + x: int + y: int = field(kw_only=False) + + @dataclass + class DataClassKwOnly4(DataClassDictMixin): + x: int + _: dataclasses.KW_ONLY + y: int + + @dataclass(kw_only=True) + class LazyDataClassKwOnly1(DataClassDictMixin): + x: int + y: int + + class Config(BaseConfig): + lazy_compilation = True + + @dataclass + class LazyDataClassKwOnly2(DataClassDictMixin): + x: int = field(kw_only=True) + y: int + + class Config(BaseConfig): + lazy_compilation = True + + @dataclass(kw_only=True) + class LazyDataClassKwOnly3(DataClassDictMixin): + x: int + y: int = field(kw_only=False) + + class Config(BaseConfig): + lazy_compilation = True + + @dataclass + class LazyDataClassKwOnly4(DataClassDictMixin): + x: int + _: dataclasses.KW_ONLY + y: int + + class Config(BaseConfig): + lazy_compilation = True + + def test_slots(): @dataclass class RegularDataClass: @@ -126,3 +186,21 @@ def test_compiled_mixin_with_inheritance_2(): ) assert wrapper.to_msgpack() == data assert EntityBWrapperMessagePack.from_msgpack(data) == wrapper + + +@pytest.mark.skipif(not PY_310_MIN, reason="requires python 3.10+") +def test_kw_only_dataclasses(): + data = {"x": "1", "y": "2"} + for cls in ( + DataClassKwOnly1, + DataClassKwOnly2, + DataClassKwOnly3, + DataClassKwOnly4, + LazyDataClassKwOnly1, + LazyDataClassKwOnly2, + LazyDataClassKwOnly3, + LazyDataClassKwOnly4, + ): + obj = cls.from_dict(data) + assert obj.x == 1 + assert obj.y == 2 diff --git a/tests/test_config.py b/tests/test_config.py index 05e23c90..859d7823 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -19,6 +19,14 @@ ) +@dataclass +class LazyCompilationDataClass(DataClassDictMixin): + x: int + + class Config(BaseConfig): + lazy_compilation = True + + def test_debug_true_option(mocker): mocked_print = mocker.patch("builtins.print") @@ -230,3 +238,9 @@ class Config(BaseConfig): assert DataClass().to_dict() == {} assert DataClass().to_dict(omit_none=True) == {} assert DataClass().to_dict(omit_none=False) == {"x": None} + + +def test_lazy_compilation(): + obj = LazyCompilationDataClass(42) + assert LazyCompilationDataClass.from_dict({"x": "42"}) == obj + assert obj.to_dict() == {"x": 42} diff --git a/tests/test_meta.py b/tests/test_meta.py index 82aa5acd..acc34d08 100644 --- a/tests/test_meta.py +++ b/tests/test_meta.py @@ -37,6 +37,7 @@ is_generic, is_init_var, is_literal, + is_named_tuple, is_new_type, is_optional, is_self, @@ -65,8 +66,10 @@ MyGenericList, MyIntEnum, MyIntFlag, + MyNamedTuple, MyNativeStrEnum, MyStrEnum, + MyUntypedNamedTuple, T, TAny, TInt, @@ -510,6 +513,12 @@ def test_not_non_type_arg(): assert not_none_type_arg((NoneType,)) is None +def test_is_named_tuple(): + assert is_named_tuple(MyNamedTuple) + assert is_named_tuple(MyUntypedNamedTuple) + assert not is_named_tuple(object()) + + def test_is_new_type(): assert is_new_type(typing.NewType("MyNewType", int)) assert not is_new_type(int)