diff --git a/HISTORY.md b/HISTORY.md index 0ef6f853..ed16136d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -19,6 +19,8 @@ Our backwards-compatibility policy can be found [here](https://github.com/python ([#432](https://github.com/python-attrs/cattrs/issues/432) [#472](https://github.com/python-attrs/cattrs/pull/472)) - The default union handler now properly takes renamed fields into account. ([#472](https://github.com/python-attrs/cattrs/pull/472)) +- The default union handler now also handles dataclasses. + ([#](https://github.com/python-attrs/cattrs/pull/)) - Add support for [PEP 695](https://peps.python.org/pep-0695/) type aliases. ([#452](https://github.com/python-attrs/cattrs/pull/452)) - The `include_subclasses` strategy now fetches the member hooks from the converter (making use of converter defaults) if overrides are not provided, instead of generating new hooks with no overrides. diff --git a/docs/unions.md b/docs/unions.md index c564b9ec..c4c43d65 100644 --- a/docs/unions.md +++ b/docs/unions.md @@ -1,9 +1,9 @@ # Handling Unions -_cattrs_ is able to handle simple unions of _attrs_ classes [automatically](#default-union-strategy). +_cattrs_ is able to handle simple unions of _attrs_ classes and dataclasses [automatically](#default-union-strategy). More complex cases require converter customization (since there are many ways of handling unions). -_cattrs_ also comes with a number of strategies to help handle unions: +_cattrs_ also comes with a number of optional strategies to help handle unions: - [tagged unions strategy](strategies.md#tagged-unions-strategy) mentioned below - [union passthrough strategy](strategies.md#union-passthrough), which is preapplied to all the [preconfigured](preconf.md) converters @@ -12,10 +12,10 @@ _cattrs_ also comes with a number of strategies to help handle unions: For convenience, _cattrs_ includes a default union structuring strategy which is a little more opinionated. -Given a union of several _attrs_ classes, the default union strategy will attempt to handle it in several ways. +Given a union of several _attrs_ classes and/or dataclasses, the default union strategy will attempt to handle it in several ways. First, it will look for `Literal` fields. -If all members of the union contain a literal field, _cattrs_ will generate a disambiguation function based on the field. +If _all members_ of the union contain a literal field, _cattrs_ will generate a disambiguation function based on the field. ```python from typing import Literal @@ -68,6 +68,10 @@ The field `field_with_default` will not be considered since it has a default val Literals can now be potentially used to disambiguate. ``` +```{versionchanged} 24.1.0 +Dataclasses are now supported in addition to _attrs_ classes. +``` + ## Unstructuring Unions with Extra Metadata ```{note} diff --git a/src/cattrs/_compat.py b/src/cattrs/_compat.py index 5a3118ff..ee042c86 100644 --- a/src/cattrs/_compat.py +++ b/src/cattrs/_compat.py @@ -2,7 +2,7 @@ from collections import deque from collections.abc import MutableSet as AbcMutableSet from collections.abc import Set as AbcSet -from dataclasses import MISSING, is_dataclass +from dataclasses import MISSING, Field, is_dataclass from dataclasses import fields as dataclass_fields from typing import AbstractSet as TypingAbstractSet from typing import ( @@ -18,6 +18,7 @@ Protocol, Tuple, Type, + Union, get_args, get_origin, get_type_hints, @@ -31,9 +32,11 @@ from attrs import NOTHING, Attribute, Factory, resolve_types from attrs import fields as attrs_fields +from attrs import fields_dict as attrs_fields_dict __all__ = [ "adapted_fields", + "fields_dict", "ExceptionGroup", "ExtensionsTypedDict", "get_type_alias_base", @@ -119,6 +122,13 @@ def fields(type): raise Exception("Not an attrs or dataclass class.") from None +def fields_dict(type) -> Dict[str, Union[Attribute, Field]]: + """Return the fields_dict for attrs and dataclasses.""" + if is_dataclass(type): + return {f.name: f for f in dataclass_fields(type)} + return attrs_fields_dict(type) + + def adapted_fields(cl) -> List[Attribute]: """Return the attrs format of `fields()` for attrs and dataclasses.""" if is_dataclass(cl): diff --git a/src/cattrs/disambiguators.py b/src/cattrs/disambiguators.py index ad145f65..3a1e4391 100644 --- a/src/cattrs/disambiguators.py +++ b/src/cattrs/disambiguators.py @@ -2,13 +2,23 @@ from __future__ import annotations from collections import defaultdict +from dataclasses import MISSING from functools import reduce from operator import or_ from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, Union -from attrs import NOTHING, Attribute, AttrsInstance, fields, fields_dict - -from ._compat import NoneType, get_args, get_origin, has, is_literal, is_union_type +from attrs import NOTHING, Attribute, AttrsInstance + +from ._compat import ( + NoneType, + adapted_fields, + fields_dict, + get_args, + get_origin, + has, + is_literal, + is_union_type, +) from .gen import AttributeOverride if TYPE_CHECKING: @@ -31,13 +41,16 @@ def create_default_dis_func( overrides: dict[str, AttributeOverride] | Literal["from_converter"] = "from_converter", ) -> Callable[[Mapping[Any, Any]], type[Any] | None]: - """Given attrs classes, generate a disambiguation function. + """Given attrs classes or dataclasses, generate a disambiguation function. The function is based on unique fields without defaults or unique values. :param use_literals: Whether to try using fields annotated as literals for disambiguation. :param overrides: Attribute overrides to apply. + + .. versionchanged:: 24.1.0 + Dataclasses are now supported. """ if len(classes) < 2: raise ValueError("At least two classes required.") @@ -55,7 +68,11 @@ def create_default_dis_func( # (... TODO: a single fallback is OK) # - it must always be enumerated cls_candidates = [ - {at.name for at in fields(get_origin(cl) or cl) if is_literal(at.type)} + { + at.name + for at in adapted_fields(get_origin(cl) or cl) + if is_literal(at.type) + } for cl in classes ] @@ -128,10 +145,10 @@ def dis_func(data: Mapping[Any, Any]) -> type | None: uniq = cl_reqs - other_reqs # We want a unique attribute with no default. - cl_fields = fields(get_origin(cl) or cl) + cl_fields = fields_dict(get_origin(cl) or cl) for maybe_renamed_attr_name in uniq: orig_name = back_map[maybe_renamed_attr_name] - if getattr(cl_fields, orig_name).default is NOTHING: + if cl_fields[orig_name].default in (NOTHING, MISSING): break else: if fallback is None: @@ -173,13 +190,13 @@ def _overriden_name(at: Attribute, override: AttributeOverride | None) -> str: def _usable_attribute_names( - cl: type[AttrsInstance], overrides: dict[str, AttributeOverride] + cl: type[Any], overrides: dict[str, AttributeOverride] ) -> tuple[set[str], dict[str, str]]: """Return renamed fields and a mapping to original field names.""" res = set() mapping = {} - for at in fields(get_origin(cl) or cl): + for at in adapted_fields(get_origin(cl) or cl): res.add(n := _overriden_name(at, overrides.get(at.name))) mapping[n] = at.name diff --git a/tests/test_disambiguators.py b/tests/test_disambiguators.py index 508586cf..d9fc8d72 100644 --- a/tests/test_disambiguators.py +++ b/tests/test_disambiguators.py @@ -1,4 +1,5 @@ """Tests for auto-disambiguators.""" +from dataclasses import dataclass from functools import partial from typing import Literal, Union @@ -7,11 +8,7 @@ from hypothesis import HealthCheck, assume, given, settings from cattrs import Converter -from cattrs.disambiguators import ( - create_default_dis_func, - create_uniq_field_dis_func, - is_supported_union, -) +from cattrs.disambiguators import create_default_dis_func, is_supported_union from cattrs.gen import make_dict_structure_fn, override from .untyped import simple_classes @@ -27,7 +24,7 @@ class A: with pytest.raises(ValueError): # Can't generate for only one class. - create_uniq_field_dis_func(c, A) + create_default_dis_func(c, A) with pytest.raises(ValueError): create_default_dis_func(c, A) @@ -38,7 +35,7 @@ class B: with pytest.raises(TypeError): # No fields on either class. - create_uniq_field_dis_func(c, A, B) + create_default_dis_func(c, A, B) @define class C: @@ -50,7 +47,7 @@ class D: with pytest.raises(TypeError): # No unique fields on either class. - create_uniq_field_dis_func(c, C, D) + create_default_dis_func(c, C, D) with pytest.raises(TypeError): # No discriminator candidates @@ -66,7 +63,7 @@ class F: with pytest.raises(TypeError): # no usable non-default attributes - create_uniq_field_dis_func(c, E, F) + create_default_dis_func(c, E, F) @define class G: @@ -93,7 +90,7 @@ def test_fallback(cl_and_vals): class A: pass - fn = create_uniq_field_dis_func(c, A, cl) + fn = create_default_dis_func(c, A, cl) assert fn({}) is A assert fn(asdict(cl(*vals, **kwargs))) is cl @@ -124,7 +121,7 @@ def test_disambiguation(cl_and_vals_a, cl_and_vals_b): for attr_name in req_b - req_a: assume(getattr(fields(cl_b), attr_name).default is NOTHING) - fn = create_uniq_field_dis_func(c, cl_a, cl_b) + fn = create_default_dis_func(c, cl_a, cl_b) assert fn(asdict(cl_a(*vals_a, **kwargs_a))) is cl_a @@ -271,3 +268,33 @@ class B: assert converter.structure({"a": 1}, Union[A, B]) == A(1) assert converter.structure({"b": 1}, Union[A, B]) == B(1) + + +def test_dataclasses(converter): + """The default strategy works for dataclasses too.""" + + @define + class A: + a: int + + @dataclass + class B: + b: int + + assert converter.structure({"a": 1}, Union[A, B]) == A(1) + assert converter.structure({"b": 1}, Union[A, B]) == B(1) + + +def test_dataclasses_literals(converter): + """The default strategy works for dataclasses too.""" + + @define + class A: + a: Literal["a"] = "a" + + @dataclass + class B: + b: Literal["b"] + + assert converter.structure({"a": "a"}, Union[A, B]) == A() + assert converter.structure({"b": "b"}, Union[A, B]) == B("b")