diff --git a/haystack/preview/__init__.py b/haystack/preview/__init__.py index 260da4e027..36f7de744f 100644 --- a/haystack/preview/__init__.py +++ b/haystack/preview/__init__.py @@ -1,2 +1,4 @@ from canals import component, Pipeline +from canals.serialization import default_from_dict, default_to_dict +from canals.errors import DeserializationError from haystack.preview.dataclasses import * diff --git a/haystack/preview/components/audio/whisper_local.py b/haystack/preview/components/audio/whisper_local.py index bfa281aa07..35fdef229a 100644 --- a/haystack/preview/components/audio/whisper_local.py +++ b/haystack/preview/components/audio/whisper_local.py @@ -6,7 +6,7 @@ import torch import whisper -from haystack.preview import component, Document +from haystack.preview import component, Document, default_to_dict, default_from_dict logger = logging.getLogger(__name__) @@ -55,6 +55,21 @@ def warm_up(self) -> None: if not self._model: self._model = whisper.load_model(self.model_name, device=self.device) + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict( + self, model_name_or_path=self.model_name, device=str(self.device), whisper_params=self.whisper_params + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "LocalWhisperTranscriber": + """ + Deserialize this component from a dictionary. + """ + return default_from_dict(cls, data) + @component.output_types(documents=List[Document]) def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None): """ diff --git a/haystack/preview/components/audio/whisper_remote.py b/haystack/preview/components/audio/whisper_remote.py index 6752ce419d..f771b59249 100644 --- a/haystack/preview/components/audio/whisper_remote.py +++ b/haystack/preview/components/audio/whisper_remote.py @@ -6,7 +6,7 @@ from pathlib import Path from haystack.preview.utils import request_with_retry -from haystack.preview import component, Document +from haystack.preview import component, Document, default_to_dict, default_from_dict logger = logging.getLogger(__name__) @@ -49,17 +49,29 @@ def __init__( if not api_key: raise ValueError("API key is None.") + self.model_name = model_name self.api_key = api_key self.api_base = api_base self.whisper_params = whisper_params or {} - self.model_name = model_name - self.init_parameters = { - "api_key": self.api_key, - "model_name": self.model_name, - "api_base": self.api_base, - "whisper_params": self.whisper_params, - } + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict( + self, + model_name=self.model_name, + api_key=self.api_key, + api_base=self.api_base, + whisper_params=self.whisper_params, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "RemoteWhisperTranscriber": + """ + Deserialize this component from a dictionary. + """ + return default_from_dict(cls, data) @component.output_types(documents=List[Document]) def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None): diff --git a/haystack/preview/components/classifiers/file_classifier.py b/haystack/preview/components/classifiers/file_classifier.py index 994f31c062..86ce63e667 100644 --- a/haystack/preview/components/classifiers/file_classifier.py +++ b/haystack/preview/components/classifiers/file_classifier.py @@ -2,9 +2,9 @@ import mimetypes from collections import defaultdict from pathlib import Path -from typing import List, Union, Optional +from typing import List, Union, Optional, Dict, Any -from haystack.preview import component +from haystack.preview import component, default_from_dict, default_to_dict logger = logging.getLogger(__name__) @@ -38,12 +38,22 @@ def __init__(self, mime_types: List[str]): f"Unknown mime type: '{mime_type}'. Ensure you passed a list of strings in the 'mime_types' parameter" ) - # save the init parameters for serialization - self.init_parameters = {"mime_types": mime_types} - component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types}) self.mime_types = mime_types + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict(self, mime_types=self.mime_types) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "FileExtensionClassifier": + """ + Deserialize this component from a dictionary. + """ + return default_from_dict(cls, data) + def run(self, paths: List[Union[str, Path]]): """ Run the FileExtensionClassifier. diff --git a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py index 993bcd7bab..07b5368f66 100644 --- a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py @@ -1,7 +1,6 @@ -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict, Any -from haystack.preview import component -from haystack.preview import Document +from haystack.preview import component, Document, default_to_dict, default_from_dict from haystack.preview.embedding_backends.sentence_transformers_backend import ( _SentenceTransformersEmbeddingBackendFactory, ) @@ -42,7 +41,7 @@ def __init__( self.model_name_or_path = model_name_or_path # TODO: remove device parameter and use Haystack's device management once migrated - self.device = device + self.device = device or "cpu" self.use_auth_token = use_auth_token self.batch_size = batch_size self.progress_bar = progress_bar @@ -50,6 +49,29 @@ def __init__( self.metadata_fields_to_embed = metadata_fields_to_embed or [] self.embedding_separator = embedding_separator + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict( + self, + model_name_or_path=self.model_name_or_path, + device=self.device, + use_auth_token=self.use_auth_token, + batch_size=self.batch_size, + progress_bar=self.progress_bar, + normalize_embeddings=self.normalize_embeddings, + metadata_fields_to_embed=self.metadata_fields_to_embed, + embedding_separator=self.embedding_separator, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder": + """ + Deserialize this component from a dictionary. + """ + return default_from_dict(cls, data) + def warm_up(self): """ Load the embedding backend. diff --git a/haystack/preview/components/embedders/sentence_transformers_text_embedder.py b/haystack/preview/components/embedders/sentence_transformers_text_embedder.py index b2352da77c..08ced5a8ce 100644 --- a/haystack/preview/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/preview/components/embedders/sentence_transformers_text_embedder.py @@ -1,6 +1,6 @@ -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict, Any -from haystack.preview import component +from haystack.preview import component, default_to_dict, default_from_dict from haystack.preview.embedding_backends.sentence_transformers_backend import ( _SentenceTransformersEmbeddingBackendFactory, ) @@ -40,7 +40,7 @@ def __init__( self.model_name_or_path = model_name_or_path # TODO: remove device parameter and use Haystack's device management once migrated - self.device = device + self.device = device or "cpu" self.use_auth_token = use_auth_token self.prefix = prefix self.suffix = suffix @@ -48,6 +48,29 @@ def __init__( self.progress_bar = progress_bar self.normalize_embeddings = normalize_embeddings + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict( + self, + model_name_or_path=self.model_name_or_path, + device=self.device, + use_auth_token=self.use_auth_token, + prefix=self.prefix, + suffix=self.suffix, + batch_size=self.batch_size, + progress_bar=self.progress_bar, + normalize_embeddings=self.normalize_embeddings, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder": + """ + Deserialize this component from a dictionary. + """ + return default_from_dict(cls, data) + def warm_up(self): """ Load the embedding backend. diff --git a/haystack/preview/components/file_converters/txt.py b/haystack/preview/components/file_converters/txt.py index 508ff83ff1..b62d98087b 100644 --- a/haystack/preview/components/file_converters/txt.py +++ b/haystack/preview/components/file_converters/txt.py @@ -1,12 +1,12 @@ import logging from pathlib import Path -from typing import Optional, List, Union, Dict +from typing import Optional, List, Union, Dict, Any from canals.errors import PipelineRuntimeError from tqdm import tqdm from haystack.preview.lazy_imports import LazyImport -from haystack.preview import Document, component +from haystack.preview import Document, component, default_to_dict, default_from_dict with LazyImport("Run 'pip install farm-haystack[preprocessing]'") as langdetect_import: import langdetect @@ -61,6 +61,27 @@ def __init__( self.id_hash_keys = id_hash_keys or [] self.progress_bar = progress_bar + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict( + self, + encoding=self.encoding, + remove_numeric_tables=self.remove_numeric_tables, + numeric_row_threshold=self.numeric_row_threshold, + valid_languages=self.valid_languages, + id_hash_keys=self.id_hash_keys, + progress_bar=self.progress_bar, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TextFileToDocument": + """ + Deserialize this component from a dictionary. + """ + return default_from_dict(cls, data) + @component.output_types(documents=List[Document]) def run( self, diff --git a/haystack/preview/components/retrievers/memory.py b/haystack/preview/components/retrievers/memory.py index e34e244625..e1792057c1 100644 --- a/haystack/preview/components/retrievers/memory.py +++ b/haystack/preview/components/retrievers/memory.py @@ -1,7 +1,7 @@ from typing import Dict, List, Any, Optional -from haystack.preview import component, Document -from haystack.preview.document_stores import MemoryDocumentStore +from haystack.preview import component, Document, default_to_dict, default_from_dict, DeserializationError +from haystack.preview.document_stores import MemoryDocumentStore, document_store @component @@ -41,6 +41,33 @@ def __init__( self.top_k = top_k self.scale_score = scale_score + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + docstore = self.document_store.to_dict() + return default_to_dict( + self, document_store=docstore, filters=self.filters, top_k=self.top_k, scale_score=self.scale_score + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MemoryRetriever": + """ + Deserialize this component from a dictionary. + """ + init_params = data.get("init_parameters", {}) + if "document_store" not in init_params: + raise DeserializationError("Missing 'document_store' in serialization data") + if "type" not in init_params["document_store"]: + raise DeserializationError("Missing 'type' in document store's serialization data") + if init_params["document_store"]["type"] not in document_store.registry: + raise DeserializationError(f"DocumentStore type '{init_params['document_store']['type']}' not found") + + docstore_class = document_store.registry[init_params["document_store"]["type"]] + docstore = docstore_class.from_dict(init_params["document_store"]) + data["init_parameters"]["document_store"] = docstore + return default_from_dict(cls, data) + @component.output_types(documents=List[List[Document]]) def run( self, diff --git a/haystack/preview/components/writers/document_writer.py b/haystack/preview/components/writers/document_writer.py index b59642fe5c..aef0d823b3 100644 --- a/haystack/preview/components/writers/document_writer.py +++ b/haystack/preview/components/writers/document_writer.py @@ -1,7 +1,7 @@ -from typing import List, Optional +from typing import List, Optional, Dict, Any -from haystack.preview import component, Document -from haystack.preview.document_stores import DocumentStore, DuplicatePolicy +from haystack.preview import component, Document, default_from_dict, default_to_dict, DeserializationError +from haystack.preview.document_stores import DocumentStore, DuplicatePolicy, document_store @component @@ -19,6 +19,31 @@ def __init__(self, document_store: DocumentStore, policy: DuplicatePolicy = Dupl self.document_store = document_store self.policy = policy + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict(self, document_store=self.document_store.to_dict(), policy=self.policy.name) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DocumentWriter": + """ + Deserialize this component from a dictionary. + """ + init_params = data.get("init_parameters", {}) + if "document_store" not in init_params: + raise DeserializationError("Missing 'document_store' in serialization data") + if "type" not in init_params["document_store"]: + raise DeserializationError("Missing 'type' in document store's serialization data") + if init_params["document_store"]["type"] not in document_store.registry: + raise DeserializationError(f"DocumentStore of type '{init_params['document_store']['type']}' not found.") + docstore_class = document_store.registry[init_params["document_store"]["type"]] + docstore = docstore_class.from_dict(init_params["document_store"]) + + data["init_parameters"]["document_store"] = docstore + data["init_parameters"]["policy"] = DuplicatePolicy[data["init_parameters"]["policy"]] + return default_from_dict(cls, data) + def run(self, documents: List[Document], policy: Optional[DuplicatePolicy] = None): """ Run DocumentWriter on the given input data. diff --git a/haystack/preview/document_stores/decorator.py b/haystack/preview/document_stores/decorator.py index 30283d4089..01e78a42b3 100644 --- a/haystack/preview/document_stores/decorator.py +++ b/haystack/preview/document_stores/decorator.py @@ -1,9 +1,5 @@ -from typing import Dict, Any, Type import logging -from haystack.preview.document_stores.protocols import DocumentStore -from haystack.preview.document_stores.errors import DocumentStoreDeserializationError - logger = logging.getLogger(__name__) @@ -40,30 +36,3 @@ def __call__(self, cls=None): document_store = _DocumentStore() - - -def default_document_store_to_dict(store_: DocumentStore) -> Dict[str, Any]: - """ - Default DocumentStore serializer. - Serializes a DocumentStore to a dictionary. - """ - return { - "hash": id(store_), - "type": store_.__class__.__name__, - "init_parameters": getattr(store_, "init_parameters", {}), - } - - -def default_document_store_from_dict(cls: Type[DocumentStore], data: Dict[str, Any]) -> DocumentStore: - """ - Default DocumentStore deserializer. - The "type" field in `data` must match the class that is being deserialized into. - """ - init_params = data.get("init_parameters", {}) - if "type" not in data: - raise DocumentStoreDeserializationError("Missing 'type' in DocumentStore serialization data") - if data["type"] != cls.__name__: - raise DocumentStoreDeserializationError( - f"DocumentStore '{data['type']}' can't be deserialized as '{cls.__name__}'" - ) - return cls(**init_params) diff --git a/haystack/preview/document_stores/errors.py b/haystack/preview/document_stores/errors.py index 0500ae2b6a..85830be5cf 100644 --- a/haystack/preview/document_stores/errors.py +++ b/haystack/preview/document_stores/errors.py @@ -12,7 +12,3 @@ class DuplicateDocumentError(DocumentStoreError): class MissingDocumentError(DocumentStoreError): pass - - -class DocumentStoreDeserializationError(DocumentStoreError): - pass diff --git a/haystack/preview/document_stores/memory/document_store.py b/haystack/preview/document_stores/memory/document_store.py index 87a928674c..af190de95d 100644 --- a/haystack/preview/document_stores/memory/document_store.py +++ b/haystack/preview/document_stores/memory/document_store.py @@ -7,11 +7,8 @@ import rank_bm25 from tqdm.auto import tqdm -from haystack.preview.document_stores.decorator import ( - document_store, - default_document_store_to_dict, - default_document_store_from_dict, -) +from haystack.preview import default_from_dict, default_to_dict +from haystack.preview.document_stores.decorator import document_store from haystack.preview.dataclasses import Document from haystack.preview.document_stores.protocols import DuplicatePolicy, DocumentStore from haystack.preview.document_stores.memory._filters import match @@ -44,6 +41,7 @@ def __init__( Initializes the DocumentStore. """ self.storage: Dict[str, Document] = {} + self._bm25_tokenization_regex = bm25_tokenization_regex self.tokenizer = re.compile(bm25_tokenization_regex).findall algorithm_class = getattr(rank_bm25, bm25_algorithm) if algorithm_class is None: @@ -51,25 +49,23 @@ def __init__( self.bm25_algorithm = algorithm_class self.bm25_parameters = bm25_parameters or {} - # Used to convert this instance to a dictionary for serialization - self.init_parameters = { - "bm25_tokenization_regex": bm25_tokenization_regex, - "bm25_algorithm": bm25_algorithm, - "bm25_parameters": self.bm25_parameters, - } - def to_dict(self) -> Dict[str, Any]: """ Serializes this store to a dictionary. """ - return default_document_store_to_dict(self) + return default_to_dict( + self, + bm25_tokenization_regex=self._bm25_tokenization_regex, + bm25_algorithm=self.bm25_algorithm.__name__, + bm25_parameters=self.bm25_parameters, + ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "DocumentStore": """ Deserializes the store from a dictionary. """ - return default_document_store_from_dict(cls, data) + return default_from_dict(cls, data) def count_documents(self) -> int: """ diff --git a/haystack/preview/testing/factory.py b/haystack/preview/testing/factory.py index 7494823d5b..52bfab18dc 100644 --- a/haystack/preview/testing/factory.py +++ b/haystack/preview/testing/factory.py @@ -1,5 +1,6 @@ from typing import Any, Dict, Optional, Tuple, Type, List, Union +from haystack.preview import default_to_dict, default_from_dict from haystack.preview.dataclasses import Document from haystack.preview.document_stores import document_store, DocumentStore, DuplicatePolicy @@ -96,11 +97,16 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D def delete_documents(self, document_ids: List[str]) -> None: return + def to_dict(self) -> Dict[str, Any]: + return default_to_dict(self) + fields = { "count_documents": count_documents, "filter_documents": filter_documents, "write_documents": write_documents, "delete_documents": delete_documents, + "to_dict": to_dict, + "from_dict": classmethod(default_from_dict), } if extra_fields is not None: diff --git a/pyproject.toml b/pyproject.toml index 093af53bac..9f8d08c2db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ dependencies = [ "jsonschema", # Preview - "canals==0.5.0", + "canals==0.7.0", # Agent events "events", diff --git a/releasenotes/notes/default-to-from-dict-7f7d89b6c36e2ab8.yaml b/releasenotes/notes/default-to-from-dict-7f7d89b6c36e2ab8.yaml new file mode 100644 index 0000000000..3592483606 --- /dev/null +++ b/releasenotes/notes/default-to-from-dict-7f7d89b6c36e2ab8.yaml @@ -0,0 +1,4 @@ +--- +preview: + - Migrate all components to Canals==0.7.0 + - Add serialization and deserialization methods for all Haystack components diff --git a/test/preview/components/audio/test_whisper_local.py b/test/preview/components/audio/test_whisper_local.py index 6dfcef8c1e..6132c2cd32 100644 --- a/test/preview/components/audio/test_whisper_local.py +++ b/test/preview/components/audio/test_whisper_local.py @@ -26,6 +26,47 @@ def test_init_wrong_model(self): with pytest.raises(ValueError, match="Model name 'whisper-1' not recognized"): LocalWhisperTranscriber(model_name_or_path="whisper-1") + @pytest.mark.unit + def test_to_dict(self): + transcriber = LocalWhisperTranscriber() + data = transcriber.to_dict() + assert data == { + "type": "LocalWhisperTranscriber", + "init_parameters": {"model_name_or_path": "large", "device": "cpu", "whisper_params": {}}, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + transcriber = LocalWhisperTranscriber( + model_name_or_path="tiny", + device="cuda", + whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]}, + ) + data = transcriber.to_dict() + assert data == { + "type": "LocalWhisperTranscriber", + "init_parameters": { + "model_name_or_path": "tiny", + "device": "cuda", + "whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}, + }, + } + + @pytest.mark.unit + def test_from_dict(self): + data = { + "type": "LocalWhisperTranscriber", + "init_parameters": { + "model_name_or_path": "tiny", + "device": "cuda", + "whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}, + }, + } + transcriber = LocalWhisperTranscriber.from_dict(data) + assert transcriber.model_name == "tiny" + assert transcriber.device == torch.device("cuda") + assert transcriber.whisper_params == {"return_segments": True, "temperature": [0.1, 0.6, 0.8]} + @pytest.mark.unit def test_warmup(self): with patch("haystack.preview.components.audio.whisper_local.whisper") as mocked_whisper: diff --git a/test/preview/components/audio/test_whisper_remote.py b/test/preview/components/audio/test_whisper_remote.py index d3d84a3ad7..e9288fa683 100644 --- a/test/preview/components/audio/test_whisper_remote.py +++ b/test/preview/components/audio/test_whisper_remote.py @@ -1,3 +1,4 @@ +from typing import Literal from unittest.mock import MagicMock, patch import pytest @@ -24,6 +25,56 @@ def test_init_no_key(self): with pytest.raises(ValueError, match="API key is None"): RemoteWhisperTranscriber(api_key=None) + @pytest.mark.unit + def test_to_dict(self): + transcriber = RemoteWhisperTranscriber(api_key="test") + data = transcriber.to_dict() + assert data == { + "type": "RemoteWhisperTranscriber", + "init_parameters": { + "model_name": "whisper-1", + "api_key": "test", + "api_base": "https://api.openai.com/v1", + "whisper_params": {}, + }, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + transcriber = RemoteWhisperTranscriber( + api_key="test", + model_name="whisper-1", + api_base="https://my.api.base/something_else/v3", + whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]}, + ) + data = transcriber.to_dict() + assert data == { + "type": "RemoteWhisperTranscriber", + "init_parameters": { + "model_name": "whisper-1", + "api_key": "test", + "api_base": "https://my.api.base/something_else/v3", + "whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}, + }, + } + + @pytest.mark.unit + def test_from_dict(self): + data = { + "type": "RemoteWhisperTranscriber", + "init_parameters": { + "model_name": "whisper-1", + "api_key": "test", + "api_base": "https://my.api.base/something_else/v3", + "whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}, + }, + } + transcriber = RemoteWhisperTranscriber.from_dict(data) + assert transcriber.model_name == "whisper-1" + assert transcriber.api_key == "test" + assert transcriber.api_base == "https://my.api.base/something_else/v3" + assert transcriber.whisper_params == {"return_segments": True, "temperature": [0.1, 0.6, 0.8]} + @pytest.mark.unit def test_run_with_path(self, preview_samples_path): mock_response = MagicMock() diff --git a/test/preview/components/classifiers/test_file_classifier.py b/test/preview/components/classifiers/test_file_classifier.py index 7a45171690..833b138480 100644 --- a/test/preview/components/classifiers/test_file_classifier.py +++ b/test/preview/components/classifiers/test_file_classifier.py @@ -10,6 +10,24 @@ reason="Can't run on Windows Github CI, need access to registry to get mime types", ) class TestFileExtensionClassifier: + @pytest.mark.unit + def test_to_dict(self): + component = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]) + data = component.to_dict() + assert data == { + "type": "FileExtensionClassifier", + "init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]}, + } + + @pytest.mark.unit + def test_from_dict(self): + data = { + "type": "FileExtensionClassifier", + "init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]}, + } + component = FileExtensionClassifier.from_dict(data) + assert component.mime_types == ["text/plain", "audio/x-wav", "image/jpeg"] + @pytest.mark.unit def test_run(self, preview_samples_path): """ diff --git a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py index 4bde42025d..1fd782d2b9 100644 --- a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py @@ -13,28 +13,104 @@ class TestSentenceTransformersDocumentEmbedder: def test_init_default(self): embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") assert embedder.model_name_or_path == "model" - assert embedder.device is None + assert embedder.device == "cpu" assert embedder.use_auth_token is None assert embedder.batch_size == 32 assert embedder.progress_bar is True assert embedder.normalize_embeddings is False + assert embedder.metadata_fields_to_embed == [] + assert embedder.embedding_separator == "\n" @pytest.mark.unit def test_init_with_parameters(self): embedder = SentenceTransformersDocumentEmbedder( model_name_or_path="model", - device="cpu", + device="cuda", use_auth_token=True, batch_size=64, progress_bar=False, normalize_embeddings=True, + metadata_fields_to_embed=["test_field"], + embedding_separator=" | ", ) assert embedder.model_name_or_path == "model" - assert embedder.device == "cpu" + assert embedder.device == "cuda" assert embedder.use_auth_token is True assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.normalize_embeddings is True + assert embedder.metadata_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == " | " + + @pytest.mark.unit + def test_to_dict(self): + component = SentenceTransformersDocumentEmbedder(model_name_or_path="model") + data = component.to_dict() + assert data == { + "type": "SentenceTransformersDocumentEmbedder", + "init_parameters": { + "model_name_or_path": "model", + "device": "cpu", + "use_auth_token": None, + "batch_size": 32, + "progress_bar": True, + "normalize_embeddings": False, + "embedding_separator": "\n", + "metadata_fields_to_embed": [], + }, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + component = SentenceTransformersDocumentEmbedder( + model_name_or_path="model", + device="cuda", + use_auth_token="the-token", + batch_size=64, + progress_bar=False, + normalize_embeddings=True, + metadata_fields_to_embed=["meta_field"], + embedding_separator=" - ", + ) + data = component.to_dict() + assert data == { + "type": "SentenceTransformersDocumentEmbedder", + "init_parameters": { + "model_name_or_path": "model", + "device": "cuda", + "use_auth_token": "the-token", + "batch_size": 64, + "progress_bar": False, + "normalize_embeddings": True, + "embedding_separator": " - ", + "metadata_fields_to_embed": ["meta_field"], + }, + } + + @pytest.mark.unit + def test_from_dict(self): + data = { + "type": "SentenceTransformersDocumentEmbedder", + "init_parameters": { + "model_name_or_path": "model", + "device": "cuda", + "use_auth_token": "the-token", + "batch_size": 64, + "progress_bar": False, + "normalize_embeddings": False, + "embedding_separator": " - ", + "metadata_fields_to_embed": ["meta_field"], + }, + } + component = SentenceTransformersDocumentEmbedder.from_dict(data) + assert component.model_name_or_path == "model" + assert component.device == "cuda" + assert component.use_auth_token == "the-token" + assert component.batch_size == 64 + assert component.progress_bar is False + assert component.normalize_embeddings is False + assert component.metadata_fields_to_embed == ["meta_field"] + assert component.embedding_separator == " - " @pytest.mark.unit @patch( @@ -45,7 +121,7 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name_or_path="model", device=None, use_auth_token=None + model_name_or_path="model", device="cpu", use_auth_token=None ) @pytest.mark.unit diff --git a/test/preview/components/embedders/test_sentence_transformers_text_embedder.py b/test/preview/components/embedders/test_sentence_transformers_text_embedder.py index 771b687cc1..9aa20696a2 100644 --- a/test/preview/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/preview/components/embedders/test_sentence_transformers_text_embedder.py @@ -11,7 +11,7 @@ class TestSentenceTransformersTextEmbedder: def test_init_default(self): embedder = SentenceTransformersTextEmbedder(model_name_or_path="model") assert embedder.model_name_or_path == "model" - assert embedder.device is None + assert embedder.device == "cpu" assert embedder.use_auth_token is None assert embedder.prefix == "" assert embedder.suffix == "" @@ -23,7 +23,7 @@ def test_init_default(self): def test_init_with_parameters(self): embedder = SentenceTransformersTextEmbedder( model_name_or_path="model", - device="cpu", + device="cuda", use_auth_token=True, prefix="prefix", suffix="suffix", @@ -32,7 +32,7 @@ def test_init_with_parameters(self): normalize_embeddings=True, ) assert embedder.model_name_or_path == "model" - assert embedder.device == "cpu" + assert embedder.device == "cuda" assert embedder.use_auth_token is True assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" @@ -40,6 +40,76 @@ def test_init_with_parameters(self): assert embedder.progress_bar is False assert embedder.normalize_embeddings is True + @pytest.mark.unit + def test_to_dict(self): + component = SentenceTransformersTextEmbedder(model_name_or_path="model") + data = component.to_dict() + assert data == { + "type": "SentenceTransformersTextEmbedder", + "init_parameters": { + "model_name_or_path": "model", + "device": "cpu", + "use_auth_token": None, + "prefix": "", + "suffix": "", + "batch_size": 32, + "progress_bar": True, + "normalize_embeddings": False, + }, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + component = SentenceTransformersTextEmbedder( + model_name_or_path="model", + device="cuda", + use_auth_token=True, + prefix="prefix", + suffix="suffix", + batch_size=64, + progress_bar=False, + normalize_embeddings=True, + ) + data = component.to_dict() + assert data == { + "type": "SentenceTransformersTextEmbedder", + "init_parameters": { + "model_name_or_path": "model", + "device": "cuda", + "use_auth_token": True, + "prefix": "prefix", + "suffix": "suffix", + "batch_size": 64, + "progress_bar": False, + "normalize_embeddings": True, + }, + } + + @pytest.mark.unit + def test_from_dict(self): + data = { + "type": "SentenceTransformersTextEmbedder", + "init_parameters": { + "model_name_or_path": "model", + "device": "cuda", + "use_auth_token": True, + "prefix": "prefix", + "suffix": "suffix", + "batch_size": 64, + "progress_bar": False, + "normalize_embeddings": True, + }, + } + component = SentenceTransformersTextEmbedder.from_dict(data) + assert component.model_name_or_path == "model" + assert component.device == "cuda" + assert component.use_auth_token is True + assert component.prefix == "prefix" + assert component.suffix == "suffix" + assert component.batch_size == 64 + assert component.progress_bar is False + assert component.normalize_embeddings is True + @pytest.mark.unit @patch( "haystack.preview.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory" @@ -49,7 +119,7 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name_or_path="model", device=None, use_auth_token=None + model_name_or_path="model", device="cpu", use_auth_token=None ) @pytest.mark.unit diff --git a/test/preview/components/file_converters/test_textfile_to_document.py b/test/preview/components/file_converters/test_textfile_to_document.py index 08e9d99f90..f84f46cece 100644 --- a/test/preview/components/file_converters/test_textfile_to_document.py +++ b/test/preview/components/file_converters/test_textfile_to_document.py @@ -11,6 +11,66 @@ class TestTextfileToDocument: + @pytest.mark.unit + def test_to_dict(self): + component = TextFileToDocument() + data = component.to_dict() + assert data == { + "type": "TextFileToDocument", + "init_parameters": { + "encoding": "utf-8", + "remove_numeric_tables": False, + "numeric_row_threshold": 0.4, + "valid_languages": [], + "id_hash_keys": [], + "progress_bar": True, + }, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + component = TextFileToDocument( + encoding="latin-1", + remove_numeric_tables=True, + numeric_row_threshold=0.7, + valid_languages=["en", "de"], + id_hash_keys=["name"], + progress_bar=False, + ) + data = component.to_dict() + assert data == { + "type": "TextFileToDocument", + "init_parameters": { + "encoding": "latin-1", + "remove_numeric_tables": True, + "numeric_row_threshold": 0.7, + "valid_languages": ["en", "de"], + "id_hash_keys": ["name"], + "progress_bar": False, + }, + } + + @pytest.mark.unit + def test_from_dict(self): + data = { + "type": "TextFileToDocument", + "init_parameters": { + "encoding": "latin-1", + "remove_numeric_tables": True, + "numeric_row_threshold": 0.7, + "valid_languages": ["en", "de"], + "id_hash_keys": ["name"], + "progress_bar": False, + }, + } + component = TextFileToDocument.from_dict(data) + assert component.encoding == "latin-1" + assert component.remove_numeric_tables + assert component.numeric_row_threshold == 0.7 + assert component.valid_languages == ["en", "de"] + assert component.id_hash_keys == ["name"] + assert not component.progress_bar + @pytest.mark.unit def test_run(self, preview_samples_path): """ diff --git a/test/preview/components/retrievers/test_memory_retriever.py b/test/preview/components/retrievers/test_memory_retriever.py index 3dd4721e27..11752711a9 100644 --- a/test/preview/components/retrievers/test_memory_retriever.py +++ b/test/preview/components/retrievers/test_memory_retriever.py @@ -1,8 +1,9 @@ from typing import Dict, Any +from unittest.mock import MagicMock, patch import pytest -from haystack.preview import Pipeline +from haystack.preview import Pipeline, DeserializationError from haystack.preview.testing.factory import document_store_class from haystack.preview.components.retrievers.memory import MemoryRetriever from haystack.preview.dataclasses import Document @@ -40,6 +41,81 @@ def test_init_with_invalid_top_k_parameter(self): with pytest.raises(ValueError, match="top_k must be > 0, but got -2"): MemoryRetriever(MemoryDocumentStore(), top_k=-2, scale_score=False) + @pytest.mark.unit + def test_to_dict(self): + MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,)) + document_store = MyFakeStore() + document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}} + component = MemoryRetriever(document_store=document_store) + + data = component.to_dict() + assert data == { + "type": "MemoryRetriever", + "init_parameters": { + "document_store": {"type": "MyFakeStore", "init_parameters": {}}, + "filters": None, + "top_k": 10, + "scale_score": True, + }, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,)) + document_store = MyFakeStore() + document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}} + component = MemoryRetriever( + document_store=document_store, filters={"name": "test.txt"}, top_k=5, scale_score=False + ) + data = component.to_dict() + assert data == { + "type": "MemoryRetriever", + "init_parameters": { + "document_store": {"type": "MyFakeStore", "init_parameters": {}}, + "filters": {"name": "test.txt"}, + "top_k": 5, + "scale_score": False, + }, + } + + @pytest.mark.unit + def test_from_dict(self): + document_store_class("MyFakeStore", bases=(MemoryDocumentStore,)) + data = { + "type": "MemoryRetriever", + "init_parameters": { + "document_store": {"type": "MyFakeStore", "init_parameters": {}}, + "filters": {"name": "test.txt"}, + "top_k": 5, + }, + } + component = MemoryRetriever.from_dict(data) + assert isinstance(component.document_store, MemoryDocumentStore) + assert component.filters == {"name": "test.txt"} + assert component.top_k == 5 + assert component.scale_score + + @pytest.mark.unit + def test_from_dict_without_docstore(self): + data = {"type": "MemoryRetriever", "init_parameters": {}} + with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"): + MemoryRetriever.from_dict(data) + + @pytest.mark.unit + def test_from_dict_without_docstore_type(self): + data = {"type": "MemoryRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}} + with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): + MemoryRetriever.from_dict(data) + + @pytest.mark.unit + def test_from_dict_nonexisting_docstore(self): + data = { + "type": "MemoryRetriever", + "init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}}, + } + with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"): + MemoryRetriever.from_dict(data) + @pytest.mark.unit def test_valid_run(self, mock_docs): top_k = 5 diff --git a/test/preview/components/writers/document_writer.py b/test/preview/components/writers/document_writer.py deleted file mode 100644 index 11927f5e31..0000000000 --- a/test/preview/components/writers/document_writer.py +++ /dev/null @@ -1,21 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from haystack.preview import Document -from haystack.preview.components.writers.document_writer import DocumentWriter -from haystack.preview.document_stores import DuplicatePolicy - - -class TestDocumentWriter: - @pytest.mark.unit - def test_run(self): - mocked_document_store = MagicMock() - writer = DocumentWriter(mocked_document_store) - documents = [ - Document(content="This is the text of a document."), - Document(content="This is the text of another document."), - ] - - writer.run(documents=documents) - mocked_document_store.write_documents.assert_called_once_with(documents=documents, policy=DuplicatePolicy.FAIL) diff --git a/test/preview/components/writers/test_document_writer.py b/test/preview/components/writers/test_document_writer.py new file mode 100644 index 0000000000..66931ef0c8 --- /dev/null +++ b/test/preview/components/writers/test_document_writer.py @@ -0,0 +1,83 @@ +from unittest.mock import MagicMock + +import pytest + +from haystack.preview import Document, DeserializationError +from haystack.preview.testing.factory import document_store_class +from haystack.preview.components.writers.document_writer import DocumentWriter +from haystack.preview.document_stores import DuplicatePolicy + + +class TestDocumentWriter: + @pytest.mark.unit + def test_to_dict(self): + mocked_docstore_class = document_store_class("MockedDocumentStore") + component = DocumentWriter(document_store=mocked_docstore_class()) + data = component.to_dict() + assert data == { + "type": "DocumentWriter", + "init_parameters": { + "document_store": {"type": "MockedDocumentStore", "init_parameters": {}}, + "policy": "FAIL", + }, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + mocked_docstore_class = document_store_class("MockedDocumentStore") + component = DocumentWriter(document_store=mocked_docstore_class(), policy=DuplicatePolicy.SKIP) + data = component.to_dict() + assert data == { + "type": "DocumentWriter", + "init_parameters": { + "document_store": {"type": "MockedDocumentStore", "init_parameters": {}}, + "policy": "SKIP", + }, + } + + @pytest.mark.unit + def test_from_dict(self): + mocked_docstore_class = document_store_class("MockedDocumentStore") + data = { + "type": "DocumentWriter", + "init_parameters": { + "document_store": {"type": "MockedDocumentStore", "init_parameters": {}}, + "policy": "SKIP", + }, + } + component = DocumentWriter.from_dict(data) + assert isinstance(component.document_store, mocked_docstore_class) + assert component.policy == DuplicatePolicy.SKIP + + @pytest.mark.unit + def test_from_dict_without_docstore(self): + data = {"type": "DocumentWriter", "init_parameters": {}} + with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"): + DocumentWriter.from_dict(data) + + @pytest.mark.unit + def test_from_dict_without_docstore_type(self): + data = {"type": "DocumentWriter", "init_parameters": {"document_store": {"init_parameters": {}}}} + with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): + DocumentWriter.from_dict(data) + + @pytest.mark.unit + def test_from_dict_nonexisting_docstore(self): + data = { + "type": "DocumentWriter", + "init_parameters": {"document_store": {"type": "NonexistingDocumentStore", "init_parameters": {}}}, + } + with pytest.raises(DeserializationError, match="DocumentStore of type 'NonexistingDocumentStore' not found."): + DocumentWriter.from_dict(data) + + @pytest.mark.unit + def test_run(self): + mocked_document_store = MagicMock() + writer = DocumentWriter(mocked_document_store) + documents = [ + Document(content="This is the text of a document."), + Document(content="This is the text of another document."), + ] + + writer.run(documents=documents) + mocked_document_store.write_documents.assert_called_once_with(documents=documents, policy=DuplicatePolicy.FAIL) diff --git a/test/preview/document_stores/test_decorator.py b/test/preview/document_stores/test_decorator.py deleted file mode 100644 index 23f4eaa809..0000000000 --- a/test/preview/document_stores/test_decorator.py +++ /dev/null @@ -1,61 +0,0 @@ -from unittest.mock import Mock - -import pytest - -from haystack.preview.testing.factory import document_store_class -from haystack.preview.document_stores.decorator import default_document_store_to_dict, default_document_store_from_dict -from haystack.preview.document_stores.errors import DocumentStoreDeserializationError - - -@pytest.mark.unit -def test_default_store_to_dict(): - MyStore = document_store_class("MyStore") - comp = MyStore() - res = default_document_store_to_dict(comp) - assert res == {"hash": id(comp), "type": "MyStore", "init_parameters": {}} - - -@pytest.mark.unit -def test_default_store_to_dict_with_custom_init_parameters(): - extra_fields = {"init_parameters": {"custom_param": True}} - MyStore = document_store_class("MyStore", extra_fields=extra_fields) - comp = MyStore() - res = default_document_store_to_dict(comp) - assert res == {"hash": id(comp), "type": "MyStore", "init_parameters": {"custom_param": True}} - - -@pytest.mark.unit -def test_default_store_from_dict(): - MyStore = document_store_class("MyStore") - comp = default_document_store_from_dict(MyStore, {"type": "MyStore"}) - assert isinstance(comp, MyStore) - - -@pytest.mark.unit -def test_default_store_from_dict_with_custom_init_parameters(): - def store_init(self, custom_param: int): - self.custom_param = custom_param - - extra_fields = {"__init__": store_init} - MyStore = document_store_class("MyStore", extra_fields=extra_fields) - comp = default_document_store_from_dict(MyStore, {"type": "MyStore", "init_parameters": {"custom_param": 100}}) - assert isinstance(comp, MyStore) - assert comp.custom_param == 100 - - -@pytest.mark.unit -def test_default_store_from_dict_without_type(): - with pytest.raises(DocumentStoreDeserializationError, match="Missing 'type' in DocumentStore serialization data"): - default_document_store_from_dict(Mock, {}) - - -@pytest.mark.unit -def test_default_store_from_dict_unregistered_store(request): - # We use the test function name as store name to make sure it's not registered. - # Since the registry is global we risk to have a store with the same name registered in another test. - store_name = request.node.name - - with pytest.raises( - DocumentStoreDeserializationError, match=f"DocumentStore '{store_name}' can't be deserialized as 'Mock'" - ): - default_document_store_from_dict(Mock, {"type": store_name}) diff --git a/test/preview/document_stores/test_memory.py b/test/preview/document_stores/test_memory.py index d6bb85b82b..75adf972d7 100644 --- a/test/preview/document_stores/test_memory.py +++ b/test/preview/document_stores/test_memory.py @@ -24,7 +24,6 @@ def test_to_dict(self): store = MemoryDocumentStore() data = store.to_dict() assert data == { - "hash": id(store), "type": "MemoryDocumentStore", "init_parameters": { "bm25_tokenization_regex": r"(?u)\b\w\w+\b", @@ -40,7 +39,6 @@ def test_to_dict_with_custom_init_parameters(self): ) data = store.to_dict() assert data == { - "hash": id(store), "type": "MemoryDocumentStore", "init_parameters": { "bm25_tokenization_regex": "custom_regex", diff --git a/test/preview/testing/test_factory.py b/test/preview/testing/test_factory.py index 6cea7bd444..80ccf33031 100644 --- a/test/preview/testing/test_factory.py +++ b/test/preview/testing/test_factory.py @@ -13,6 +13,15 @@ def test_document_store_class_default(): assert store.filter_documents() == [] assert store.write_documents([]) is None assert store.delete_documents([]) is None + assert store.to_dict() == {"type": "MyStore", "init_parameters": {}} + + +@pytest.mark.unit +def test_document_store_from_dict(): + MyStore = document_store_class("MyStore") + + store = MyStore.from_dict({"type": "MyStore", "init_parameters": {}}) + assert isinstance(store, MyStore) @pytest.mark.unit