Skip to content

Commit

Permalink
feat: make DocumentWriter return the actual number of documents wri…
Browse files Browse the repository at this point in the history
…tten (#6366)

* make DocumentWriter return the actual number of documents written

* add/improve tests
  • Loading branch information
anakin87 authored Nov 21, 2023
1 parent ec35580 commit 4569022
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
4 changes: 2 additions & 2 deletions haystack/preview/components/writers/document_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,5 @@ def run(self, documents: List[Document], policy: Optional[DuplicatePolicy] = Non
if policy is None:
policy = self.policy

self.document_store.write_documents(documents=documents, policy=policy)
return {"documents_written": len(documents)}
documents_written = self.document_store.write_documents(documents=documents, policy=policy)
return {"documents_written": documents_written}
26 changes: 20 additions & 6 deletions test/preview/components/writers/test_document_writer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from unittest.mock import MagicMock

import pytest

from haystack.preview import Document, DeserializationError
from haystack.preview.testing.factory import document_store_class
from haystack.preview.components.writers.document_writer import DocumentWriter
from haystack.preview.document_stores import DuplicatePolicy
from haystack.preview.document_stores.in_memory import InMemoryDocumentStore


class TestDocumentWriter:
Expand Down Expand Up @@ -81,12 +80,27 @@ def test_from_dict_nonexisting_docstore(self):

@pytest.mark.unit
def test_run(self):
mocked_document_store = MagicMock()
writer = DocumentWriter(mocked_document_store)
document_store = InMemoryDocumentStore()
writer = DocumentWriter(document_store)
documents = [
Document(content="This is the text of a document."),
Document(content="This is the text of another document."),
]

writer.run(documents=documents)
mocked_document_store.write_documents.assert_called_once_with(documents=documents, policy=DuplicatePolicy.FAIL)
result = writer.run(documents=documents)
assert result["documents_written"] == 2

@pytest.mark.unit
def test_run_skip_policy(self):
document_store = InMemoryDocumentStore()
writer = DocumentWriter(document_store, policy=DuplicatePolicy.SKIP)
documents = [
Document(content="This is the text of a document."),
Document(content="This is the text of another document."),
]

result = writer.run(documents=documents)
assert result["documents_written"] == 2

result = writer.run(documents=documents)
assert result["documents_written"] == 0

0 comments on commit 4569022

Please sign in to comment.