diff --git a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py index 3a2685fac7..07b5368f66 100644 --- a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py @@ -1,7 +1,6 @@ from typing import List, Optional, Union, Dict, Any -from haystack.preview import component -from haystack.preview import Document +from haystack.preview import component, Document, default_to_dict, default_from_dict from haystack.preview.embedding_backends.sentence_transformers_backend import ( _SentenceTransformersEmbeddingBackendFactory, ) @@ -42,7 +41,7 @@ def __init__( self.model_name_or_path = model_name_or_path # TODO: remove device parameter and use Haystack's device management once migrated - self.device = device + self.device = device or "cpu" self.use_auth_token = use_auth_token self.batch_size = batch_size self.progress_bar = progress_bar @@ -54,14 +53,24 @@ def to_dict(self) -> Dict[str, Any]: """ Serialize this component to a dictionary. """ - # return default_to_dict(self, ...) + return default_to_dict( + self, + model_name_or_path=self.model_name_or_path, + device=self.device, + use_auth_token=self.use_auth_token, + batch_size=self.batch_size, + progress_bar=self.progress_bar, + normalize_embeddings=self.normalize_embeddings, + metadata_fields_to_embed=self.metadata_fields_to_embed, + embedding_separator=self.embedding_separator, + ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder": """ Deserialize this component from a dictionary. """ - # return default_from_dict(cls, data) + return default_from_dict(cls, data) def warm_up(self): """ diff --git a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py index 4bde42025d..c18ec928d4 100644 --- a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py @@ -18,6 +18,8 @@ def test_init_default(self): assert embedder.batch_size == 32 assert embedder.progress_bar is True assert embedder.normalize_embeddings is False + assert embedder.metadata_fields_to_embed == [] + assert embedder.embedding_separator == "\n" @pytest.mark.unit def test_init_with_parameters(self): @@ -28,6 +30,8 @@ def test_init_with_parameters(self): batch_size=64, progress_bar=False, normalize_embeddings=True, + metadata_fields_to_embed=["test_field"], + embedding_separator=" | ", ) assert embedder.model_name_or_path == "model" assert embedder.device == "cpu" @@ -35,6 +39,78 @@ def test_init_with_parameters(self): assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.normalize_embeddings is True + assert embedder.metadata_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == " | " + + @pytest.mark.unit + def test_to_dict(self): + component = SentenceTransformersDocumentEmbedder(model_name_or_path="model") + data = component.to_dict() + assert data == { + "type": "SentenceTransformersDocumentEmbedder", + "init_parameters": { + "model_name_or_path": "model", + "device": None, + "use_auth_token": None, + "batch_size": 32, + "progress_bar": True, + "normalize_embeddings": False, + "embedding_separator": "\n", + "metadata_fields_to_embed": [], + }, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + component = SentenceTransformersDocumentEmbedder( + model_name_or_path="model", + device="cpu", + use_auth_token="the-token", + batch_size=64, + progress_bar=False, + normalize_embeddings=True, + metadata_fields_to_embed=["meta_field"], + embedding_separator=" - ", + ) + data = component.to_dict() + assert data == { + "type": "SentenceTransformersDocumentEmbedder", + "init_parameters": { + "model_name_or_path": "model", + "device": "cpu", + "use_auth_token": "the-token", + "batch_size": 64, + "progress_bar": False, + "normalize_embeddings": True, + "embedding_separator": " - ", + "metadata_fields_to_embed": ["meta_field"], + }, + } + + @pytest.mark.unit + def test_from_dict(self): + data = { + "type": "SentenceTransformersDocumentEmbedder", + "init_parameters": { + "model_name_or_path": "model", + "device": "cpu", + "use_auth_token": "the-token", + "batch_size": 64, + "progress_bar": False, + "normalize_embeddings": False, + "embedding_separator": " - ", + "metadata_fields_to_embed": ["meta_field"], + }, + } + component = SentenceTransformersDocumentEmbedder.from_dict(data) + assert component.model_name_or_path == "model" + assert component.device == "cpu" + assert component.use_auth_token == "the-token" + assert component.batch_size == 64 + assert component.progress_bar is False + assert component.normalize_embeddings is False + assert component.metadata_fields_to_embed == ["meta_field"] + assert component.embedding_separator == " - " @pytest.mark.unit @patch(