Skip to content

Commit

Permalink
Fix decoding of nested collections
Browse files Browse the repository at this point in the history
  • Loading branch information
dax committed Apr 27, 2022
1 parent 3264a00 commit ac614fc
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 12 deletions.
28 changes: 17 additions & 11 deletions dataclasses_json/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,10 @@ def _decode_generic(type_, value, infer_missing):
# a mapping type has `.keys()` and `.values()`
# (see collections.abc)
ks = _decode_dict_keys(k_type, value.keys(), infer_missing)
vs = _decode_items(v_type, value.values(), infer_missing)
vs = _decode_items([v_type], value.values(), infer_missing)
xs = zip(ks, vs)
else:
xs = _decode_items(type_.__args__[0], value, infer_missing)
xs = _decode_items(type_.__args__, value, infer_missing)

# get the constructor if using corresponding generic type in `typing`
# otherwise fallback on constructing using type_ itself
Expand Down Expand Up @@ -300,10 +300,10 @@ def _decode_dict_keys(key_type, xs, infer_missing):
decode_function = tuple
key_type = key_type

return map(decode_function, _decode_items(key_type, xs, infer_missing))
return map(decode_function, _decode_items([key_type], xs, infer_missing))


def _decode_items(type_arg, xs, infer_missing):
def _decode_items(types_arg, xs, infer_missing):
"""
This is a tricky situation where we need to check both the annotated
type info (which is usually a type from `typing`) and check the
Expand All @@ -312,14 +312,20 @@ def _decode_items(type_arg, xs, infer_missing):
If the type_arg is a generic we can use the annotated type, but if the
type_arg is a typevar we need to extract the reified type information
hence the check of `is_dataclass(vs)`
length of types_arg may be > 1 for tuples so that we have to iterate
while iterating on items
"""
if is_dataclass(type_arg) or is_dataclass(xs):
items = (_decode_dataclass(type_arg, x, infer_missing)
for x in xs)
elif _is_supported_generic(type_arg):
items = (_decode_generic(type_arg, x, infer_missing) for x in xs)
else:
items = xs
items = []
for i, x in enumerate(xs):
type_arg = types_arg[i] if len(types_arg) > i else types_arg[0]
if is_dataclass(type_arg) or is_dataclass(xs):
item = _decode_dataclass(type_arg, x, infer_missing)
elif _is_supported_generic(type_arg):
item = _decode_generic(type_arg, x, infer_missing)
else:
item = _support_extended_types(type_arg, x)
items.append(item)
return items


Expand Down
11 changes: 11 additions & 0 deletions tests/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,14 @@ class DataClassWithNestedOptional:
@dataclass
class DataClassWithNestedDictWithTupleKeys:
a: Dict[Tuple[int], int]


@dataclass_json
@dataclass
class DataClassWithListUuid:
xs: List[UUID]

@dataclass_json
@dataclass
class DataClassWithListNestedTupleUuid:
xs: List[Tuple[UUID, int]]
24 changes: 23 additions & 1 deletion tests/test_collections.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from uuid import uuid4
from collections import deque

from tests.entities import (DataClassIntImmutableDefault,
Expand All @@ -8,7 +9,8 @@
DataClassWithListStr, DataClassWithMyCollection,
DataClassWithOptional, DataClassWithOptionalStr,
DataClassWithSet, DataClassWithTuple,
DataClassWithUnionIntNone, MyCollection)
DataClassWithUnionIntNone, MyCollection,
DataClassWithListUuid, DataClassWithListNestedTupleUuid)


class TestEncoder:
Expand Down Expand Up @@ -62,6 +64,16 @@ def test_mutable_default_list(self):
def test_mutable_default_dict(self):
assert DataClassMutableDefaultDict().to_json() == '{"xs": {}}'

def test_list_uuid(self):
uuid = uuid4()
assert (DataClassWithListUuid([uuid]).to_json()
== '{"xs": ["' + str(uuid) + '"]}')

def test_list_nested_tuple_uuid(self):
uuid = uuid4()
assert (DataClassWithListNestedTupleUuid([(uuid, 1)]).to_json()
== '{"xs": [["' + str(uuid) + '", 1]]}')


class TestDecoder:
def test_list(self):
Expand Down Expand Up @@ -131,3 +143,13 @@ def test_mutable_default_dict(self):
== DataClassMutableDefaultDict())
assert (DataClassMutableDefaultDict.from_json('{}', infer_missing=True)
== DataClassMutableDefaultDict())

def test_list_uuid(self):
uuid = uuid4()
assert (DataClassWithListUuid.from_json('{"xs": ["' + str(uuid) + '"]}')
== DataClassWithListUuid([uuid]))

def test_list_nested_tuple_uuid(self):
uuid = uuid4()
assert (DataClassWithListNestedTupleUuid.from_json('{"xs": [["' + str(uuid) + '", 1]]}')
== DataClassWithListNestedTupleUuid([(uuid, 1)]))

0 comments on commit ac614fc

Please sign in to comment.