Skip to content

Commit

Permalink
Fix Document init when passing non existing fields (#6286)
Browse files Browse the repository at this point in the history
* Fix Document init when passing non existing fields

* Update releasenotes/notes/fix-document-init-09c1cbb14202be7d.yaml

Co-authored-by: Massimiliano Pippi <[email protected]>

* Fix linting

---------

Co-authored-by: Massimiliano Pippi <[email protected]>
  • Loading branch information
silvanocerza and masci authored Nov 13, 2023
1 parent bf637e9 commit 8e7ce20
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 57 deletions.
24 changes: 10 additions & 14 deletions haystack/preview/dataclasses/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import hashlib
import logging
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, List, Optional, Type, cast
from typing import Any, Dict, List, Optional

import numpy
import pandas
Expand Down Expand Up @@ -40,18 +40,6 @@ def __call__(cls, *args, **kwargs):
if "id_hash_keys" in kwargs:
del kwargs["id_hash_keys"]

if kwargs.get("meta") is None:
# This must be a flattened Document, so we treat all keys that are not
# Document fields as metadata.
meta = {}
field_names = [f.name for f in fields(cast(Type[Document], cls))]
keys = list(kwargs.keys()) # get a list of the keys as we'll modify the dict in the loop
for key in keys:
if key in field_names:
continue
meta[key] = kwargs.pop(key)
kwargs["meta"] = meta

return super().__call__(*args, **kwargs)


Expand Down Expand Up @@ -149,7 +137,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "Document":
data["dataframe"] = pandas.read_json(io.StringIO(dataframe))
if blob := data.get("blob"):
data["blob"] = ByteStream(data=bytes(blob["data"]), mime_type=blob["mime_type"])
return cls(**data)
# Unflatten metadata if it was flattened
meta = {}
legacy_fields = ["content_type", "id_hash_keys"]
field_names = legacy_fields + [f.name for f in fields(cls)]
for key in list(data.keys()):
if key not in field_names:
meta[key] = data.pop(key)

return cls(**data, meta=meta)

@property
def content_type(self):
Expand Down
4 changes: 4 additions & 0 deletions releasenotes/notes/fix-document-init-09c1cbb14202be7d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
preview:
- |
Make Document's constructor fail when is passed fields that are not present in the dataclass. An exception is made for "content_type" and "id_hash_keys": they are accepted in order to keep backward compatibility.
52 changes: 9 additions & 43 deletions test/preview/dataclasses/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def test_init():
assert doc.embedding == None


@pytest.mark.unit
def test_init_with_wrong_parameters():
with pytest.raises(TypeError):
Document(text="")


@pytest.mark.unit
def test_init_with_parameters():
blob_data = b"some bytes"
Expand Down Expand Up @@ -80,15 +86,14 @@ def test_init_with_legacy_fields():


@pytest.mark.unit
def test_init_with_legacy_field_and_flat_meta():
def test_init_with_legacy_field():
doc = Document(
content="test text",
content_type="text", # type: ignore
id_hash_keys=["content"], # type: ignore
score=0.812,
embedding=[0.1, 0.2, 0.3],
date="10-10-2023", # type: ignore
type="article", # type: ignore
meta={"date": "10-10-2023", "type": "article"},
)
assert doc.id == "a2c0321b34430cc675294611e55529fceb56140ca3202f1c59a43a8cecac1f43"
assert doc.content == "test text"
Expand All @@ -98,44 +103,6 @@ def test_init_with_legacy_field_and_flat_meta():
assert doc.embedding == [0.1, 0.2, 0.3]


@pytest.mark.unit
def test_init_with_flat_meta():
blob_data = b"some bytes"
doc = Document(
content="test text",
dataframe=pd.DataFrame([0]),
blob=ByteStream(data=blob_data, mime_type="text/markdown"),
score=0.812,
embedding=[0.1, 0.2, 0.3],
date="10-10-2023", # type: ignore
type="article", # type: ignore
)
assert doc.id == "c6212ad7bb513c572367e11dd12fd671911a1a5499e3d31e4fe3bda7e87c0641"
assert doc.content == "test text"
assert doc.dataframe is not None
assert doc.dataframe.equals(pd.DataFrame([0]))
assert doc.blob.data == blob_data
assert doc.blob.mime_type == "text/markdown"
assert doc.meta == {"date": "10-10-2023", "type": "article"}
assert doc.score == 0.812
assert doc.embedding == [0.1, 0.2, 0.3]


@pytest.mark.unit
def test_init_with_flat_and_non_flat_meta():
with pytest.raises(TypeError):
Document(
content="test text",
dataframe=pd.DataFrame([0]),
blob=ByteStream(data=b"some bytes", mime_type="text/markdown"),
score=0.812,
meta={"test": 10},
embedding=[0.1, 0.2, 0.3],
date="10-10-2023", # type: ignore
type="article", # type: ignore
)


@pytest.mark.unit
def test_basic_equality_type_mismatch():
doc = Document(content="test text")
Expand Down Expand Up @@ -286,8 +253,7 @@ def test_from_dict_with_legacy_field_and_flat_meta():
id_hash_keys=["content"], # type: ignore
score=0.812,
embedding=[0.1, 0.2, 0.3],
date="10-10-2023", # type: ignore
type="article", # type: ignore
meta={"date": "10-10-2023", "type": "article"},
)


Expand Down

0 comments on commit 8e7ce20

Please sign in to comment.