diff --git a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py new file mode 100644 index 0000000000..993bcd7bab --- /dev/null +++ b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py @@ -0,0 +1,101 @@ +from typing import List, Optional, Union + +from haystack.preview import component +from haystack.preview import Document +from haystack.preview.embedding_backends.sentence_transformers_backend import ( + _SentenceTransformersEmbeddingBackendFactory, +) + + +@component +class SentenceTransformersDocumentEmbedder: + """ + A component for computing Document embeddings using Sentence Transformers models. + The embedding of each Document is stored in the `embedding` field of the Document. + """ + + def __init__( + self, + model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2", + device: Optional[str] = None, + use_auth_token: Union[bool, str, None] = None, + batch_size: int = 32, + progress_bar: bool = True, + normalize_embeddings: bool = False, + metadata_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + ): + """ + Create a SentenceTransformersDocumentEmbedder component. + + :param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-mpnet-base-v2'``. + :param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used. + :param use_auth_token: The API token used to download private models from Hugging Face. + If this parameter is set to `True`, then the token generated when running + `transformers-cli login` (stored in ~/.huggingface) will be used. + :param batch_size: Number of strings to encode at once. + :param progress_bar: If true, displays progress bar during embedding. + :param normalize_embeddings: If set to true, returned vectors will have length 1. + :param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document content. + :param embedding_separator: Separator used to concatenate the meta fields to the Document content. + """ + + self.model_name_or_path = model_name_or_path + # TODO: remove device parameter and use Haystack's device management once migrated + self.device = device + self.use_auth_token = use_auth_token + self.batch_size = batch_size + self.progress_bar = progress_bar + self.normalize_embeddings = normalize_embeddings + self.metadata_fields_to_embed = metadata_fields_to_embed or [] + self.embedding_separator = embedding_separator + + def warm_up(self): + """ + Load the embedding backend. + """ + if not hasattr(self, "embedding_backend"): + self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( + model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.use_auth_token + ) + + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]): + """ + Embed a list of Documents. + The embedding of each Document is stored in the `embedding` field of the Document. + """ + if not isinstance(documents, list) or not isinstance(documents[0], Document): + raise TypeError( + "SentenceTransformersDocumentEmbedder expects a list of Documents as input." + "In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder." + ) + if not hasattr(self, "embedding_backend"): + raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.") + + # TODO: once non textual Documents are properly supported, we should also prepare them for embedding here + + texts_to_embed = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.metadata[key]) + for key in self.metadata_fields_to_embed + if key in doc.metadata and doc.metadata[key] + ] + text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content]) + texts_to_embed.append(text_to_embed) + + embeddings = self.embedding_backend.embed( + texts_to_embed, + batch_size=self.batch_size, + show_progress_bar=self.progress_bar, + normalize_embeddings=self.normalize_embeddings, + ) + + documents_with_embeddings = [] + for doc, emb in zip(documents, embeddings): + doc_as_dict = doc.to_dict() + doc_as_dict["embedding"] = emb + documents_with_embeddings.append(Document.from_dict(doc_as_dict)) + + return {"documents": documents_with_embeddings} diff --git a/releasenotes/notes/add-sentence-transformers-document-embedder-f1e8612b8eaf9b7f.yaml b/releasenotes/notes/add-sentence-transformers-document-embedder-f1e8612b8eaf9b7f.yaml new file mode 100644 index 0000000000..15689050b0 --- /dev/null +++ b/releasenotes/notes/add-sentence-transformers-document-embedder-f1e8612b8eaf9b7f.yaml @@ -0,0 +1,5 @@ +--- +preview: + - | + Add Sentence Transformers Document Embedder. + It computes embeddings of Documents. The embedding of each Document is stored in the `embedding` field of the Document. diff --git a/test/preview/components/embedders/__init__.py b/test/preview/components/embedders/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py new file mode 100644 index 0000000000..4bde42025d --- /dev/null +++ b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py @@ -0,0 +1,120 @@ +from unittest.mock import patch, MagicMock +import pytest +import numpy as np + +from haystack.preview import Document +from haystack.preview.components.embedders.sentence_transformers_document_embedder import ( + SentenceTransformersDocumentEmbedder, +) + + +class TestSentenceTransformersDocumentEmbedder: + @pytest.mark.unit + def test_init_default(self): + embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") + assert embedder.model_name_or_path == "model" + assert embedder.device is None + assert embedder.use_auth_token is None + assert embedder.batch_size == 32 + assert embedder.progress_bar is True + assert embedder.normalize_embeddings is False + + @pytest.mark.unit + def test_init_with_parameters(self): + embedder = SentenceTransformersDocumentEmbedder( + model_name_or_path="model", + device="cpu", + use_auth_token=True, + batch_size=64, + progress_bar=False, + normalize_embeddings=True, + ) + assert embedder.model_name_or_path == "model" + assert embedder.device == "cpu" + assert embedder.use_auth_token is True + assert embedder.batch_size == 64 + assert embedder.progress_bar is False + assert embedder.normalize_embeddings is True + + @pytest.mark.unit + @patch( + "haystack.preview.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" + ) + def test_warmup(self, mocked_factory): + embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") + mocked_factory.get_embedding_backend.assert_not_called() + embedder.warm_up() + mocked_factory.get_embedding_backend.assert_called_once_with( + model_name_or_path="model", device=None, use_auth_token=None + ) + + @pytest.mark.unit + @patch( + "haystack.preview.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" + ) + def test_warmup_doesnt_reload(self, mocked_factory): + embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") + mocked_factory.get_embedding_backend.assert_not_called() + embedder.warm_up() + embedder.warm_up() + mocked_factory.get_embedding_backend.assert_called_once() + + @pytest.mark.unit + def test_run(self): + embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") + embedder.embedding_backend = MagicMock() + embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() + + documents = [Document(content=f"document number {i}") for i in range(5)] + + result = embedder.run(documents=documents) + + assert isinstance(result["documents"], list) + assert len(result["documents"]) == len(documents) + for doc in result["documents"]: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert isinstance(doc.embedding[0], float) + + @pytest.mark.unit + def test_run_wrong_input_format(self): + embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") + + string_input = "text" + list_integers_input = [1, 2, 3] + + with pytest.raises( + TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input" + ): + embedder.run(documents=string_input) + + with pytest.raises( + TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input" + ): + embedder.run(documents=list_integers_input) + + @pytest.mark.unit + def test_embed_metadata(self): + embedder = SentenceTransformersDocumentEmbedder( + model_name_or_path="model", metadata_fields_to_embed=["meta_field"], embedding_separator="\n" + ) + embedder.embedding_backend = MagicMock() + + documents = [ + Document(content=f"document number {i}", metadata={"meta_field": f"meta_value {i}"}) for i in range(5) + ] + + embedder.run(documents=documents) + + embedder.embedding_backend.embed.assert_called_once_with( + [ + "meta_value 0\ndocument number 0", + "meta_value 1\ndocument number 1", + "meta_value 2\ndocument number 2", + "meta_value 3\ndocument number 3", + "meta_value 4\ndocument number 4", + ], + batch_size=32, + show_progress_bar=True, + normalize_embeddings=False, + )