Skip to content

Commit

Permalink
Propagate exclude to marshmallow
Browse files Browse the repository at this point in the history
* Propagate exclude function as metadata on schema field
* Add a post-dump hook to exclude fields as indicated

Signed-off-by: Aidan Jensen <[email protected]>
  • Loading branch information
artificial-aidan committed Aug 7, 2024
1 parent dc63902 commit a071506
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
29 changes: 28 additions & 1 deletion dataclasses_json/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from typing_inspect import is_union_type # type: ignore

from marshmallow import fields, Schema, post_load # type: ignore
from marshmallow import fields, Schema, post_dump, post_load # type: ignore
from marshmallow.exceptions import ValidationError # type: ignore

from dataclasses_json.core import (_is_supported_generic, _decode_dataclass,
Expand Down Expand Up @@ -326,6 +326,11 @@ def schema(cls, mixin, infer_missing):
if metadata.letter_case is not None:
options['data_key'] = metadata.letter_case(field.name)

if metadata.exclude:
options.setdefault("metadata", {}).setdefault("dataclasses_json", {})[
"exclude"
] = metadata.exclude

t = build_type(type_, options, mixin, field, cls)
if field.metadata.get('dataclasses_json', {}).get('decoder'):
# If the field defines a custom decoder, it should completely replace the Marshmallow field's conversion
Expand Down Expand Up @@ -368,6 +373,27 @@ def dumps(self, *args, **kwargs):

return Schema.dumps(self, *args, **kwargs)

@post_dump(pass_many=True)
def _exclude_fields(self, data, many, **kwargs):
def exclude(k, v) -> bool:
if f := self.fields.get(k):
if f.metadata.get("dataclasses_json", {}).get(
"exclude", lambda x: False
)(v):
return True
return False

if many:
for i, _data in enumerate(data):
for k in list(_data.keys()):
if exclude(k, _data[k]):
del _data[k]
else:
for k in list(data.keys()):
if exclude(k, data[k]):
del data[k]
return data

def dump(self, obj, *, many=None):
many = self.many if many is None else bool(many)
dumped = Schema.dump(self, obj, many=many)
Expand All @@ -392,6 +418,7 @@ def dump(self, obj, *, many=None):
(Schema,),
{'Meta': Meta,
f'make_{cls.__name__.lower()}': make_instance,
f"_exclude_fields": _exclude_fields,
'dumps': dumps,
'dump': dump,
**schema_})
Expand Down
19 changes: 19 additions & 0 deletions tests/test_exclude.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,22 @@ def test_custom_action_excluded():
dclass = EncodeCustom(public_field="public", sensitive_field="secret")
encoded = dclass.to_dict()
assert "sensitive_field" not in encoded


def test_marshmallow_exclude_dump():
dclass = EncodeExclude(public_field="public", private_field="private")
encoded = EncodeExclude.schema().dump(dclass)
assert "public_field" in encoded
assert "private_field" not in encoded


def test_marshmallow_exclude_dumps():
dclass = EncodeExclude(public_field="public", private_field="private")
encoded = EncodeExclude.schema().dumps(dclass)
assert "public_field" in encoded
assert "private_field" not in encoded

def test_schema():
dclass = EncodeExclude(public_field="public", private_field="private")
schema = EncodeExclude.schema()
print(schema)

0 comments on commit a071506

Please sign in to comment.