Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Oct 24, 2024
1 parent 2f3304c commit c028438
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 24 deletions.
27 changes: 12 additions & 15 deletions haystack/components/converters/docx.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class DOCXMetadata:
version: str


class TableFormat(Enum):
class DOCXTableFormat(Enum):
"""
Supported formats for storing DOCX tabular data in a Document.
"""
Expand All @@ -75,11 +75,11 @@ def __str__(self):
return self.value

@staticmethod
def from_str(string: str) -> "TableFormat":
def from_str(string: str) -> "DOCXTableFormat":
"""
Convert a string to a TableFormat enum.
Convert a string to a DOCXTableFormat enum.
"""
enum_map = {e.value: e for e in TableFormat}
enum_map = {e.value: e for e in DOCXTableFormat}
table_format = enum_map.get(string.lower())
if table_format is None:
msg = f"Unknown table format '{string}'. Supported formats are: {list(enum_map.keys())}"
Expand All @@ -97,25 +97,25 @@ class DOCXToDocument:
Usage example:
```python
from haystack.components.converters.docx import DOCXToDocument, TableFormat
from haystack.components.converters.docx import DOCXToDocument, DOCXTableFormat
converter = DOCXToDocument(table_format=TableFormat.CSV)
converter = DOCXToDocument(table_format=DOCXTableFormat.CSV)
results = converter.run(sources=["sample.docx"], meta={"date_added": datetime.now().isoformat()})
documents = results["documents"]
print(documents[0].content)
# 'This is a text from the DOCX file.'
```
"""

def __init__(self, table_format: Union[str, TableFormat] = TableFormat.CSV):
def __init__(self, table_format: Union[str, DOCXTableFormat] = DOCXTableFormat.CSV):
"""
Create a DOCXToDocument component.
:param table_format: The format for table output. Can be either TableFormat.MARKDOWN,
TableFormat.CSV, "markdown", or "csv". Defaults to TableFormat.CSV.
:param table_format: The format for table output. Can be either DOCXTableFormat.MARKDOWN,
DOCXTableFormat.CSV, "markdown", or "csv". Defaults to DOCXTableFormat.CSV.
"""
docx_import.check()
self.table_format = TableFormat.from_str(table_format) if isinstance(table_format, str) else table_format
self.table_format = DOCXTableFormat.from_str(table_format) if isinstance(table_format, str) else table_format

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -136,10 +136,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "DOCXToDocument":
:returns:
The deserialized component.
"""
# Convert the table_format string back to enum before passing to the constructor
if "init_parameters" in data and "table_format" in data["init_parameters"]:
data["init_parameters"]["table_format"] = TableFormat.from_str(data["init_parameters"]["table_format"])

data["init_parameters"]["table_format"] = DOCXTableFormat.from_str(data["init_parameters"]["table_format"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down Expand Up @@ -213,7 +210,7 @@ def _extract_elements(self, document: "DocxDocument") -> List[str]:
table = docx.table.Table(element, document)
table_str = (
self._table_to_markdown(table)
if self.table_format == TableFormat.MARKDOWN
if self.table_format == DOCXTableFormat.MARKDOWN
else self._table_to_csv(table)
)
elements.append(table_str)
Expand Down
40 changes: 31 additions & 9 deletions test/components/converters/test_docx_file_to_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import csv
from io import StringIO

from haystack import Document
from haystack.components.converters.docx import DOCXMetadata, DOCXToDocument, TableFormat
from haystack import Document, Pipeline
from haystack.components.converters.docx import DOCXMetadata, DOCXToDocument, DOCXTableFormat
from haystack.dataclasses import ByteStream


Expand All @@ -21,7 +21,7 @@ def test_init(self, docx_converter):
def test_init_with_string(self):
converter = DOCXToDocument(table_format="markdown")
assert isinstance(converter, DOCXToDocument)
assert converter.table_format == TableFormat.MARKDOWN
assert converter.table_format == DOCXTableFormat.MARKDOWN

def test_init_with_invalid_string(self):
with pytest.raises(ValueError, match="Unknown table format 'invalid_format'"):
Expand Down Expand Up @@ -50,32 +50,35 @@ def test_to_dict_custom_parameters(self):
"init_parameters": {"table_format": "csv"},
}

converter = DOCXToDocument(table_format=TableFormat.MARKDOWN)
converter = DOCXToDocument(table_format=DOCXTableFormat.MARKDOWN)
data = converter.to_dict()
assert data == {
"type": "haystack.components.converters.docx.DOCXToDocument",
"init_parameters": {"table_format": "markdown"},
}

converter = DOCXToDocument(table_format=TableFormat.CSV)
converter = DOCXToDocument(table_format=DOCXTableFormat.CSV)
data = converter.to_dict()
assert data == {
"type": "haystack.components.converters.docx.DOCXToDocument",
"init_parameters": {"table_format": "csv"},
}

def test_from_dict(self):
data = {"type": "haystack.components.converters.docx.DOCXToDocument", "init_parameters": {}}
data = {
"type": "haystack.components.converters.docx.DOCXToDocument",
"init_parameters": {"table_format": "csv"},
}
converter = DOCXToDocument.from_dict(data)
assert converter.table_format == TableFormat.CSV
assert converter.table_format == DOCXTableFormat.CSV

def test_from_dict_custom_parameters(self):
data = {
"type": "haystack.components.converters.docx.DOCXToDocument",
"init_parameters": {"table_format": "markdown"},
}
converter = DOCXToDocument.from_dict(data)
assert converter.table_format == TableFormat.MARKDOWN
assert converter.table_format == DOCXTableFormat.MARKDOWN

def test_from_dict_invalid_table_format(self):
data = {
Expand All @@ -85,6 +88,25 @@ def test_from_dict_invalid_table_format(self):
with pytest.raises(ValueError, match="Unknown table format 'invalid_format'"):
DOCXToDocument.from_dict(data)

def test_pipeline_serde(self):
pipeline = Pipeline()
converter = DOCXToDocument(table_format=DOCXTableFormat.MARKDOWN)
pipeline.add_component("converter", converter)
assert pipeline.to_dict() == {
"components": {
"converter": {
"init_parameters": {"table_format": "markdown"},
"type": "haystack.components.converters.docx.DOCXToDocument",
}
},
"connections": [],
"max_runs_per_component": 100,
"metadata": {},
}

new_pipeline = Pipeline.from_dict(pipeline.to_dict())
assert new_pipeline == pipeline

def test_run(self, test_files_path, docx_converter):
"""
Test if the component runs correctly
Expand Down Expand Up @@ -120,7 +142,7 @@ def test_run_with_table(self, test_files_path):
"""
Test if the component runs correctly
"""
docx_converter = DOCXToDocument(table_format=TableFormat.MARKDOWN)
docx_converter = DOCXToDocument(table_format=DOCXTableFormat.MARKDOWN)
paths = [test_files_path / "docx" / "sample_docx.docx"]
output = docx_converter.run(sources=paths)
docs = output["documents"]
Expand Down

0 comments on commit c028438

Please sign in to comment.