From a12d2476873fb45ae8f2b4f2e6772d6e2e34fa83 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 14 Aug 2023 18:04:39 +0200 Subject: [PATCH 01/21] first draft --- .../preview/embedding_backends/__init__.py | 0 .../sentence_transformers_backend.py | 40 +++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 haystack/preview/embedding_backends/__init__.py create mode 100644 haystack/preview/embedding_backends/sentence_transformers_backend.py diff --git a/haystack/preview/embedding_backends/__init__.py b/haystack/preview/embedding_backends/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/haystack/preview/embedding_backends/sentence_transformers_backend.py b/haystack/preview/embedding_backends/sentence_transformers_backend.py new file mode 100644 index 0000000000..ef272f7512 --- /dev/null +++ b/haystack/preview/embedding_backends/sentence_transformers_backend.py @@ -0,0 +1,40 @@ +from typing import List, Optional, Union, Dict +import hashlib +import numpy as np + +from haystack.lazy_imports import LazyImport + +with LazyImport(message="Run 'pip install farm-haystack[inference]'") as sentence_transformers_import: + from sentence_transformers import SentenceTransformer + + +class SentenceTransformersEmbeddingBackend: + """ + Singleton class to manage SentenceTransformers embeddings. + """ + + _instances: Dict[str, "SentenceTransformersEmbeddingBackend"] = {} + + def __new__(cls, *args, **kwargs): + args_kwargs_str = str(args) + str(kwargs) + instance_id = hashlib.md5(args_kwargs_str.encode()) + + if instance_id in cls._instances: + return cls._instances[instance_id] + + instance = super().__new__(cls) + cls._instances[instance_id] = instance + return instance + + def __init__( + self, model_name_or_path: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None + ): + sentence_transformers_import.check() + if not hasattr(self, "model"): + self.model = SentenceTransformer( + model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token + ) + + def embed(self, data: List[str], *inference_params) -> np.ndarray: + embedding = self.model.encode(data, *inference_params) + return embedding From 0ba2f52a2d12728a0911ee8888bd4d2ebf0a2864 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 17 Aug 2023 18:13:05 +0200 Subject: [PATCH 02/21] incorporate feedback --- .../sentence_transformers_backend.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/haystack/preview/embedding_backends/sentence_transformers_backend.py b/haystack/preview/embedding_backends/sentence_transformers_backend.py index ef272f7512..16b7cbb58d 100644 --- a/haystack/preview/embedding_backends/sentence_transformers_backend.py +++ b/haystack/preview/embedding_backends/sentence_transformers_backend.py @@ -2,23 +2,24 @@ import hashlib import numpy as np -from haystack.lazy_imports import LazyImport +from haystack.preview.lazy_imports import LazyImport with LazyImport(message="Run 'pip install farm-haystack[inference]'") as sentence_transformers_import: from sentence_transformers import SentenceTransformer -class SentenceTransformersEmbeddingBackend: +class _SentenceTransformersEmbeddingBackend: """ Singleton class to manage SentenceTransformers embeddings. """ - _instances: Dict[str, "SentenceTransformersEmbeddingBackend"] = {} - - def __new__(cls, *args, **kwargs): - args_kwargs_str = str(args) + str(kwargs) - instance_id = hashlib.md5(args_kwargs_str.encode()) + _instances: Dict[str, "_SentenceTransformersEmbeddingBackend"] = {} + def __new__( + cls, model_name_or_path: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None + ): + args_str = f"{model_name_or_path}{device}{use_auth_token}" + instance_id = hashlib.md5(args_str.encode()).hexdigest() if instance_id in cls._instances: return cls._instances[instance_id] @@ -35,6 +36,6 @@ def __init__( model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token ) - def embed(self, data: List[str], *inference_params) -> np.ndarray: - embedding = self.model.encode(data, *inference_params) + def embed(self, data: List[str], **kwargs) -> np.ndarray: + embedding = self.model.encode(data, **kwargs) return embedding From 83996d38ea01999805053950fe57c959f7df5557 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 18 Aug 2023 13:13:09 +0200 Subject: [PATCH 03/21] some unit tests --- .../test_sentence_transformers.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 test/preview/embedding_backends/test_sentence_transformers.py diff --git a/test/preview/embedding_backends/test_sentence_transformers.py b/test/preview/embedding_backends/test_sentence_transformers.py new file mode 100644 index 0000000000..5f2ddc6b16 --- /dev/null +++ b/test/preview/embedding_backends/test_sentence_transformers.py @@ -0,0 +1,36 @@ +from unittest.mock import Mock, patch +import pytest +from haystack.preview.embedding_backends.sentence_transformers_backend import _SentenceTransformersEmbeddingBackend +import numpy as np + + +@pytest.mark.unit +@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") +def test_singleton_behavior(mock_sentence_transformer): + embedding_backend = _SentenceTransformersEmbeddingBackend(model_name_or_path="my_model", device="cpu") + same_embedding_backend = _SentenceTransformersEmbeddingBackend("my_model", "cpu") + another_embedding_backend = _SentenceTransformersEmbeddingBackend(model_name_or_path="another_model", device="cpu") + + assert same_embedding_backend is embedding_backend + assert another_embedding_backend is not embedding_backend + + +@pytest.mark.unit +@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") +def test_model_initialization(mock_sentence_transformer): + _SentenceTransformersEmbeddingBackend(model_name_or_path="model", device="cpu") + mock_sentence_transformer.assert_called_once_with(model_name_or_path="model", device="cpu", use_auth_token=None) + + +@pytest.mark.unit +@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") +def test_embedding_function_with_kwargs(mock_sentence_transformer): + embedding_backend = _SentenceTransformersEmbeddingBackend(model_name_or_path="model") + fake_embeddings = np.array([[0.1, 0.2], [0.3, 0.4]]) + embedding_backend.model.encode.return_value = fake_embeddings + + data = ["sentence1", "sentence2"] + result = embedding_backend.embed(data=data, normalize_embeddings=True) + + embedding_backend.model.encode.assert_called_once_with(data, normalize_embeddings=True) + np.testing.assert_array_equal(result, fake_embeddings) From 0a80333d18d3625feaa7bdb63fffc82b7759f4cf Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 18 Aug 2023 13:22:50 +0200 Subject: [PATCH 04/21] release notes --- ...rs-embedding-backend-69bd9410ede08c8f.yaml | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 releasenotes/notes/sentence-transformers-embedding-backend-69bd9410ede08c8f.yaml diff --git a/releasenotes/notes/sentence-transformers-embedding-backend-69bd9410ede08c8f.yaml b/releasenotes/notes/sentence-transformers-embedding-backend-69bd9410ede08c8f.yaml new file mode 100644 index 0000000000..649daf2a04 --- /dev/null +++ b/releasenotes/notes/sentence-transformers-embedding-backend-69bd9410ede08c8f.yaml @@ -0,0 +1,33 @@ +--- +prelude: > + Replace this text with content to appear at the top of the section for this + release. This is equivalent to the "Highlights" section we used before. + The prelude might repeat some details that are also present in other notes + from the same release, that's ok. Not every release note requires a prelude, + use it only to describe major features or notable changes. +upgrade: + - | + List upgrade notes here, or remove this section. + Upgrade notes should be rare: only list known/potential breaking changes, + or major changes that require user action before the upgrade. + Notes here must include steps that users can follow to 1. know if they're + affected and 2. handle the change gracefully on their end. +features: + - | + List new features here, or remove this section. +enhancements: + - | + List new behavior that is too small to be + considered a new feature, or remove this section. +issues: + - | + List known issues here, or remove this section. For example, if some change is experimental or known to not work in some cases, it should be mentioned here. +deprecations: + - | + List deprecations notes here, or remove this section. Deprecations should not be used for something that is removed in the release, use upgrade section instead. Deprecation should allow time for users to make necessary changes for the removal to happen in a future release. +security: + - | + Add security notes here, or remove this section. +fixes: + - | + Add normal bug fixes here, or remove this section. From d76e91c54cb3f834c1969f7ffa3fd1c0132d8199 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 18 Aug 2023 13:24:06 +0200 Subject: [PATCH 05/21] real release notes --- ...rs-embedding-backend-69bd9410ede08c8f.yaml | 34 ++----------------- 1 file changed, 3 insertions(+), 31 deletions(-) diff --git a/releasenotes/notes/sentence-transformers-embedding-backend-69bd9410ede08c8f.yaml b/releasenotes/notes/sentence-transformers-embedding-backend-69bd9410ede08c8f.yaml index 649daf2a04..b5e647fa7e 100644 --- a/releasenotes/notes/sentence-transformers-embedding-backend-69bd9410ede08c8f.yaml +++ b/releasenotes/notes/sentence-transformers-embedding-backend-69bd9410ede08c8f.yaml @@ -1,33 +1,5 @@ --- -prelude: > - Replace this text with content to appear at the top of the section for this - release. This is equivalent to the "Highlights" section we used before. - The prelude might repeat some details that are also present in other notes - from the same release, that's ok. Not every release note requires a prelude, - use it only to describe major features or notable changes. -upgrade: +preview: - | - List upgrade notes here, or remove this section. - Upgrade notes should be rare: only list known/potential breaking changes, - or major changes that require user action before the upgrade. - Notes here must include steps that users can follow to 1. know if they're - affected and 2. handle the change gracefully on their end. -features: - - | - List new features here, or remove this section. -enhancements: - - | - List new behavior that is too small to be - considered a new feature, or remove this section. -issues: - - | - List known issues here, or remove this section. For example, if some change is experimental or known to not work in some cases, it should be mentioned here. -deprecations: - - | - List deprecations notes here, or remove this section. Deprecations should not be used for something that is removed in the release, use upgrade section instead. Deprecation should allow time for users to make necessary changes for the removal to happen in a future release. -security: - - | - Add security notes here, or remove this section. -fixes: - - | - Add normal bug fixes here, or remove this section. + Add Sentence Transformers Embedding Backend. + It will be used by Embedder components and is responsible for computing embeddings. From dc302bc13d4cfbaa69a9ff140b631d1b85834920 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 21 Aug 2023 10:51:06 +0200 Subject: [PATCH 06/21] refactored to use a factory class --- .../sentence_transformers_backend.py | 43 +++++++++++-------- .../test_sentence_transformers.py | 20 ++++++--- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/haystack/preview/embedding_backends/sentence_transformers_backend.py b/haystack/preview/embedding_backends/sentence_transformers_backend.py index 16b7cbb58d..389ac4c102 100644 --- a/haystack/preview/embedding_backends/sentence_transformers_backend.py +++ b/haystack/preview/embedding_backends/sentence_transformers_backend.py @@ -8,34 +8,43 @@ from sentence_transformers import SentenceTransformer -class _SentenceTransformersEmbeddingBackend: +class SentenceTransformersEmbeddingBackendFactory: """ - Singleton class to manage SentenceTransformers embeddings. + Factory class to create instances of Sentence Transformers embedding backends. """ _instances: Dict[str, "_SentenceTransformersEmbeddingBackend"] = {} - def __new__( - cls, model_name_or_path: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None + @staticmethod + def get_embedding_backend( + model_name_or_path: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None ): - args_str = f"{model_name_or_path}{device}{use_auth_token}" - instance_id = hashlib.md5(args_str.encode()).hexdigest() - if instance_id in cls._instances: - return cls._instances[instance_id] + args_string = f"{model_name_or_path}{device}{use_auth_token}" + embedding_backend_id = hashlib.md5(args_string.encode()).hexdigest() + + if embedding_backend_id in SentenceTransformersEmbeddingBackendFactory._instances: + return SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] + + embedding_backend = _SentenceTransformersEmbeddingBackend( + model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token + ) + SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend + return embedding_backend + - instance = super().__new__(cls) - cls._instances[instance_id] = instance - return instance +class _SentenceTransformersEmbeddingBackend: + """ + Class to manage SentenceTransformers embeddings. + """ def __init__( self, model_name_or_path: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None ): sentence_transformers_import.check() - if not hasattr(self, "model"): - self.model = SentenceTransformer( - model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token - ) + self.model = SentenceTransformer( + model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token + ) - def embed(self, data: List[str], **kwargs) -> np.ndarray: + def embed(self, data: List[str], **kwargs) -> List[np.ndarray]: embedding = self.model.encode(data, **kwargs) - return embedding + return list(embedding) diff --git a/test/preview/embedding_backends/test_sentence_transformers.py b/test/preview/embedding_backends/test_sentence_transformers.py index 5f2ddc6b16..385e61622f 100644 --- a/test/preview/embedding_backends/test_sentence_transformers.py +++ b/test/preview/embedding_backends/test_sentence_transformers.py @@ -1,15 +1,21 @@ from unittest.mock import Mock, patch import pytest -from haystack.preview.embedding_backends.sentence_transformers_backend import _SentenceTransformersEmbeddingBackend +from haystack.preview.embedding_backends.sentence_transformers_backend import ( + SentenceTransformersEmbeddingBackendFactory, +) import numpy as np @pytest.mark.unit @patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") def test_singleton_behavior(mock_sentence_transformer): - embedding_backend = _SentenceTransformersEmbeddingBackend(model_name_or_path="my_model", device="cpu") - same_embedding_backend = _SentenceTransformersEmbeddingBackend("my_model", "cpu") - another_embedding_backend = _SentenceTransformersEmbeddingBackend(model_name_or_path="another_model", device="cpu") + embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( + model_name_or_path="my_model", device="cpu" + ) + same_embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend("my_model", "cpu") + another_embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( + model_name_or_path="another_model", device="cpu" + ) assert same_embedding_backend is embedding_backend assert another_embedding_backend is not embedding_backend @@ -18,15 +24,15 @@ def test_singleton_behavior(mock_sentence_transformer): @pytest.mark.unit @patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") def test_model_initialization(mock_sentence_transformer): - _SentenceTransformersEmbeddingBackend(model_name_or_path="model", device="cpu") + SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model", device="cpu") mock_sentence_transformer.assert_called_once_with(model_name_or_path="model", device="cpu", use_auth_token=None) @pytest.mark.unit @patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") def test_embedding_function_with_kwargs(mock_sentence_transformer): - embedding_backend = _SentenceTransformersEmbeddingBackend(model_name_or_path="model") - fake_embeddings = np.array([[0.1, 0.2], [0.3, 0.4]]) + embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model") + fake_embeddings = [np.array([0.1, 0.2]), np.array([0.3, 0.4])] embedding_backend.model.encode.return_value = fake_embeddings data = ["sentence1", "sentence2"] From 021813f4dcdf791026469153b2a94394b9dcca2d Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 21 Aug 2023 12:02:49 +0200 Subject: [PATCH 07/21] allow forcing fresh instances --- .../sentence_transformers_backend.py | 11 +++++++++-- .../test_sentence_transformers.py | 15 ++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/haystack/preview/embedding_backends/sentence_transformers_backend.py b/haystack/preview/embedding_backends/sentence_transformers_backend.py index 389ac4c102..db82beec7a 100644 --- a/haystack/preview/embedding_backends/sentence_transformers_backend.py +++ b/haystack/preview/embedding_backends/sentence_transformers_backend.py @@ -17,14 +17,21 @@ class SentenceTransformersEmbeddingBackendFactory: @staticmethod def get_embedding_backend( - model_name_or_path: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None + model_name_or_path: str, + device: Optional[str] = None, + use_auth_token: Union[bool, str, None] = None, + force_fresh_instance: bool = False, ): + if force_fresh_instance is True: + return _SentenceTransformersEmbeddingBackend( + model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token + ) + args_string = f"{model_name_or_path}{device}{use_auth_token}" embedding_backend_id = hashlib.md5(args_string.encode()).hexdigest() if embedding_backend_id in SentenceTransformersEmbeddingBackendFactory._instances: return SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] - embedding_backend = _SentenceTransformersEmbeddingBackend( model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token ) diff --git a/test/preview/embedding_backends/test_sentence_transformers.py b/test/preview/embedding_backends/test_sentence_transformers.py index 385e61622f..73b96064b7 100644 --- a/test/preview/embedding_backends/test_sentence_transformers.py +++ b/test/preview/embedding_backends/test_sentence_transformers.py @@ -8,7 +8,7 @@ @pytest.mark.unit @patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") -def test_singleton_behavior(mock_sentence_transformer): +def test_factory_behavior(mock_sentence_transformer): embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( model_name_or_path="my_model", device="cpu" ) @@ -21,6 +21,19 @@ def test_singleton_behavior(mock_sentence_transformer): assert another_embedding_backend is not embedding_backend +@pytest.mark.unit +@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") +def test_factory_force_fresh_instance(mock_sentence_transformer): + embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( + model_name_or_path="my_model", device="cpu" + ) + fresh_embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( + model_name_or_path="my_model", device="cpu", force_fresh_instance=True + ) + + assert fresh_embedding_backend is not embedding_backend + + @pytest.mark.unit @patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") def test_model_initialization(mock_sentence_transformer): From 6de808102fbb46495cfd2628bbfadc96cf815658 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 21 Aug 2023 16:07:35 +0200 Subject: [PATCH 08/21] first draft --- ...sentence_transformers_document_embedder.py | 83 +++++++++++++++++++ ...rs-document-embedder-f1e8612b8eaf9b7f.yaml | 5 ++ test/preview/components/embedders/__init__.py | 0 ...sentence_transformers_document_embedder.py | 81 ++++++++++++++++++ 4 files changed, 169 insertions(+) create mode 100644 haystack/preview/components/embedders/sentence_transformers_document_embedder.py create mode 100644 releasenotes/notes/add-sentence-transformers-document-embedder-f1e8612b8eaf9b7f.yaml create mode 100644 test/preview/components/embedders/__init__.py create mode 100644 test/preview/components/embedders/test_sentence_transformers_document_embedder.py diff --git a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py new file mode 100644 index 0000000000..2cfe706af7 --- /dev/null +++ b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py @@ -0,0 +1,83 @@ +from typing import List, Optional, Union + +from haystack.preview import component +from haystack.preview import Document +from haystack.preview.embedding_backends.sentence_transformers_backend import ( + SentenceTransformersEmbeddingBackendFactory, +) + + +@component +class SentenceTransformersDocumentEmbedder: + """ + A component for computing Document embeddings using Sentence Transformers models. + The embedding of each Document is stored in the `embedding` field of the Document. + """ + + def __init__( + self, + model_name_or_path: str, + device: Optional[str] = None, + use_auth_token: Union[bool, str, None] = None, + batch_size: int = 32, + progress_bar: bool = True, + normalize_embeddings: bool = False, + ): + """ + Create a SentenceTransformersDocumentEmbedder component. + + :param model_name_or_path: Local path or name of model in Hugging Face's model hub such as ``'sentence-transformers/all-MiniLM-L6-v2'``. + :param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used. + :param use_auth_token: The API token used to download private models from Hugging Face. + If this parameter is set to `True`, then the token generated when running + `transformers-cli login` (stored in ~/.huggingface) will be used. + :param batch_size: Number of strings to encode at once. + :param progress_bar: If true displays progress bar during embedding. + :param normalize_embeddings: If set to true, returned vectors will have length 1. + """ + + self.model_name_or_path = model_name_or_path + # TODO: remove device parameter and use Haystack's device management once migrated + self.device = device + self.use_auth_token = use_auth_token + self.batch_size = batch_size + self.progress_bar = progress_bar + self.normalize_embeddings = normalize_embeddings + + def warm_up(self): + """ + Load the embedding backend. + """ + if not hasattr(self, "embedding_backend"): + self.embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( + model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.use_auth_token + ) + + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]): + """ + Embed a list of Documents. + The embedding of each Document is stored in the `embedding` field of the Document. + """ + self.warm_up() + + # TODO: once non textual Documents are properly supported, we should also prepare them for embedding here + + # TODO: we should find a proper strategy for supporting the embedding of meta fields, also supporting templates + # E.g.: This article talks about {{doc.meta["company"]}}, it was published on {{doc.meta["publication_date"]}}. Here is the article's content: {{doc.content}} + texts_to_embed = [doc.content for doc in documents] + + embeddings = self.embedding_backend.embed( + texts_to_embed, + batch_size=self.batch_size, + show_progress_bar=self.progress_bar, + normalize_embeddings=self.normalize_embeddings, + ) + + documents_with_embeddings = [] + for doc, emb in zip(documents, embeddings): + doc_as_dict = doc.to_dict() + doc_as_dict["embedding"] = emb + documents_with_embeddings.append(Document.from_dict(doc_as_dict)) + + return {"documents": documents_with_embeddings} diff --git a/releasenotes/notes/add-sentence-transformers-document-embedder-f1e8612b8eaf9b7f.yaml b/releasenotes/notes/add-sentence-transformers-document-embedder-f1e8612b8eaf9b7f.yaml new file mode 100644 index 0000000000..15689050b0 --- /dev/null +++ b/releasenotes/notes/add-sentence-transformers-document-embedder-f1e8612b8eaf9b7f.yaml @@ -0,0 +1,5 @@ +--- +preview: + - | + Add Sentence Transformers Document Embedder. + It computes embeddings of Documents. The embedding of each Document is stored in the `embedding` field of the Document. diff --git a/test/preview/components/embedders/__init__.py b/test/preview/components/embedders/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py new file mode 100644 index 0000000000..54a436c6ff --- /dev/null +++ b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py @@ -0,0 +1,81 @@ +from unittest.mock import patch, MagicMock +import pytest + +from haystack.preview import Document +from haystack.preview.components.embedders.sentence_transformers_document_embedder import ( + SentenceTransformersDocumentEmbedder, +) + +from test.preview.components.base import BaseTestComponent + +import numpy as np + + +class TestSentenceTransformersDocumentEmbedder(BaseTestComponent): + # TODO: We're going to rework these tests when we'll remove BaseTestComponent. + + @pytest.mark.unit + def test_init_default(self): + embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") + assert embedder.model_name_or_path == "model" + assert embedder.device is None + assert embedder.use_auth_token is None + assert embedder.batch_size == 32 + assert embedder.progress_bar is True + assert embedder.normalize_embeddings is False + + @pytest.mark.unit + def test_init_with_parameters(self): + embedder = SentenceTransformersDocumentEmbedder( + model_name_or_path="model", + device="cpu", + use_auth_token=True, + batch_size=64, + progress_bar=False, + normalize_embeddings=True, + ) + assert embedder.model_name_or_path == "model" + assert embedder.device == "cpu" + assert embedder.use_auth_token is True + assert embedder.batch_size == 64 + assert embedder.progress_bar is False + assert embedder.normalize_embeddings is True + + @pytest.mark.unit + @patch( + "haystack.preview.components.embedders.sentence_transformers_document_embedder.SentenceTransformersEmbeddingBackendFactory" + ) + def test_warmup(self, mocked_factory): + embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") + mocked_factory.get_embedding_backend.assert_not_called() + embedder.warm_up() + mocked_factory.get_embedding_backend.assert_called_once_with( + model_name_or_path="model", device=None, use_auth_token=None + ) + + @pytest.mark.unit + @patch( + "haystack.preview.components.embedders.sentence_transformers_document_embedder.SentenceTransformersEmbeddingBackendFactory" + ) + def test_warmup_doesnt_reload(self, mocked_factory): + embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") + mocked_factory.get_embedding_backend.assert_not_called() + embedder.warm_up() + embedder.warm_up() + mocked_factory.get_embedding_backend.assert_called_once() + + @pytest.mark.unit + def test_run(self): + embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") + embedder.embedding_backend = MagicMock() + embedder.embedding_backend.embed = lambda x, **kwargs: list(np.random.rand(len(x), 16)) + + documents = [Document(content=f"document number {i}") for i in range(5)] + + result = embedder.run(documents=documents) + + assert isinstance(result["documents"], list) + assert len(result["documents"]) == len(documents) + for doc in result["documents"]: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, np.ndarray) From 307089386710fc6e3b417bed0a0d3da457097a8f Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Mon, 21 Aug 2023 22:56:37 +0200 Subject: [PATCH 09/21] Update haystack/preview/embedding_backends/sentence_transformers_backend.py Co-authored-by: Daria Fokina --- .../preview/embedding_backends/sentence_transformers_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/preview/embedding_backends/sentence_transformers_backend.py b/haystack/preview/embedding_backends/sentence_transformers_backend.py index db82beec7a..8bef5b4e0b 100644 --- a/haystack/preview/embedding_backends/sentence_transformers_backend.py +++ b/haystack/preview/embedding_backends/sentence_transformers_backend.py @@ -41,7 +41,7 @@ def get_embedding_backend( class _SentenceTransformersEmbeddingBackend: """ - Class to manage SentenceTransformers embeddings. + Class to manage Sentence Transformers embeddings. """ def __init__( From bcba5bd80315962297a24d3028633d8b063024c1 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 23 Aug 2023 11:57:10 +0200 Subject: [PATCH 10/21] simplify implementation and tests --- .../sentence_transformers_backend.py | 14 ++----------- .../test_sentence_transformers.py | 21 ++----------------- 2 files changed, 4 insertions(+), 31 deletions(-) diff --git a/haystack/preview/embedding_backends/sentence_transformers_backend.py b/haystack/preview/embedding_backends/sentence_transformers_backend.py index 8bef5b4e0b..7fd679cd8a 100644 --- a/haystack/preview/embedding_backends/sentence_transformers_backend.py +++ b/haystack/preview/embedding_backends/sentence_transformers_backend.py @@ -1,5 +1,4 @@ from typing import List, Optional, Union, Dict -import hashlib import numpy as np from haystack.preview.lazy_imports import LazyImport @@ -17,18 +16,9 @@ class SentenceTransformersEmbeddingBackendFactory: @staticmethod def get_embedding_backend( - model_name_or_path: str, - device: Optional[str] = None, - use_auth_token: Union[bool, str, None] = None, - force_fresh_instance: bool = False, + model_name_or_path: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None ): - if force_fresh_instance is True: - return _SentenceTransformersEmbeddingBackend( - model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token - ) - - args_string = f"{model_name_or_path}{device}{use_auth_token}" - embedding_backend_id = hashlib.md5(args_string.encode()).hexdigest() + embedding_backend_id = f"{model_name_or_path}{device}{use_auth_token}" if embedding_backend_id in SentenceTransformersEmbeddingBackendFactory._instances: return SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] diff --git a/test/preview/embedding_backends/test_sentence_transformers.py b/test/preview/embedding_backends/test_sentence_transformers.py index 73b96064b7..3a373c127a 100644 --- a/test/preview/embedding_backends/test_sentence_transformers.py +++ b/test/preview/embedding_backends/test_sentence_transformers.py @@ -1,9 +1,8 @@ -from unittest.mock import Mock, patch +from unittest.mock import patch import pytest from haystack.preview.embedding_backends.sentence_transformers_backend import ( SentenceTransformersEmbeddingBackendFactory, ) -import numpy as np @pytest.mark.unit @@ -21,19 +20,6 @@ def test_factory_behavior(mock_sentence_transformer): assert another_embedding_backend is not embedding_backend -@pytest.mark.unit -@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") -def test_factory_force_fresh_instance(mock_sentence_transformer): - embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( - model_name_or_path="my_model", device="cpu" - ) - fresh_embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( - model_name_or_path="my_model", device="cpu", force_fresh_instance=True - ) - - assert fresh_embedding_backend is not embedding_backend - - @pytest.mark.unit @patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") def test_model_initialization(mock_sentence_transformer): @@ -45,11 +31,8 @@ def test_model_initialization(mock_sentence_transformer): @patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") def test_embedding_function_with_kwargs(mock_sentence_transformer): embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model") - fake_embeddings = [np.array([0.1, 0.2]), np.array([0.3, 0.4])] - embedding_backend.model.encode.return_value = fake_embeddings data = ["sentence1", "sentence2"] - result = embedding_backend.embed(data=data, normalize_embeddings=True) + embedding_backend.embed(data=data, normalize_embeddings=True) embedding_backend.model.encode.assert_called_once_with(data, normalize_embeddings=True) - np.testing.assert_array_equal(result, fake_embeddings) From e2cc862413db2c4d8d0a1062d03b124d3e089833 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 23 Aug 2023 17:04:12 +0200 Subject: [PATCH 11/21] add embed_meta_fields implementation --- .../sentence_transformers_document_embedder.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py index 2cfe706af7..b4dc7dd4d3 100644 --- a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py @@ -22,6 +22,8 @@ def __init__( batch_size: int = 32, progress_bar: bool = True, normalize_embeddings: bool = False, + embed_meta_fields: Optional[List[str]] = None, + embed_separator: str = "\n", ): """ Create a SentenceTransformersDocumentEmbedder component. @@ -34,6 +36,8 @@ def __init__( :param batch_size: Number of strings to encode at once. :param progress_bar: If true displays progress bar during embedding. :param normalize_embeddings: If set to true, returned vectors will have length 1. + :param embed_meta_fields: List of meta fields that should be embedded along with the Document content. + :param embed_separator: Separator used to concatenate the meta fields to the Document content. """ self.model_name_or_path = model_name_or_path @@ -43,6 +47,8 @@ def __init__( self.batch_size = batch_size self.progress_bar = progress_bar self.normalize_embeddings = normalize_embeddings + self.embed_meta_fields = embed_meta_fields or [] + self.embed_separator = embed_separator def warm_up(self): """ @@ -63,9 +69,13 @@ def run(self, documents: List[Document]): # TODO: once non textual Documents are properly supported, we should also prepare them for embedding here - # TODO: we should find a proper strategy for supporting the embedding of meta fields, also supporting templates - # E.g.: This article talks about {{doc.meta["company"]}}, it was published on {{doc.meta["publication_date"]}}. Here is the article's content: {{doc.content}} - texts_to_embed = [doc.content for doc in documents] + texts_to_embed = [] + for doc in documents: + meta_values_to_embed = [ + doc.metadata[key] for key in self.embed_meta_fields if key in doc.metadata and doc.metadata[key] + ] + text_to_embed = self.embed_separator.join(meta_values_to_embed + [doc.content]) + texts_to_embed.append(text_to_embed) embeddings = self.embedding_backend.embed( texts_to_embed, From 7107ff8f142d3e71d0bb1246475c67d1a092b2d0 Mon Sep 17 00:00:00 2001 From: Daria Fokina Date: Thu, 24 Aug 2023 11:56:10 +0200 Subject: [PATCH 12/21] lg update --- .../embedders/sentence_transformers_document_embedder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py index b4dc7dd4d3..17f2ee8f3c 100644 --- a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py @@ -28,13 +28,13 @@ def __init__( """ Create a SentenceTransformersDocumentEmbedder component. - :param model_name_or_path: Local path or name of model in Hugging Face's model hub such as ``'sentence-transformers/all-MiniLM-L6-v2'``. + :param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-MiniLM-L6-v2'``. :param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used. :param use_auth_token: The API token used to download private models from Hugging Face. If this parameter is set to `True`, then the token generated when running `transformers-cli login` (stored in ~/.huggingface) will be used. :param batch_size: Number of strings to encode at once. - :param progress_bar: If true displays progress bar during embedding. + :param progress_bar: If true, displays progress bar during embedding. :param normalize_embeddings: If set to true, returned vectors will have length 1. :param embed_meta_fields: List of meta fields that should be embedded along with the Document content. :param embed_separator: Separator used to concatenate the meta fields to the Document content. From 31233b2fb4ad9348f613ab8393a3efba0654fc8b Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 25 Aug 2023 11:34:42 +0200 Subject: [PATCH 13/21] improve meta data embedding; tests --- ...sentence_transformers_document_embedder.py | 21 +++++---- ...sentence_transformers_document_embedder.py | 44 +++++++++++++++++++ 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py index 17f2ee8f3c..49e38fa174 100644 --- a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py @@ -22,8 +22,8 @@ def __init__( batch_size: int = 32, progress_bar: bool = True, normalize_embeddings: bool = False, - embed_meta_fields: Optional[List[str]] = None, - embed_separator: str = "\n", + metadata_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", ): """ Create a SentenceTransformersDocumentEmbedder component. @@ -36,8 +36,8 @@ def __init__( :param batch_size: Number of strings to encode at once. :param progress_bar: If true, displays progress bar during embedding. :param normalize_embeddings: If set to true, returned vectors will have length 1. - :param embed_meta_fields: List of meta fields that should be embedded along with the Document content. - :param embed_separator: Separator used to concatenate the meta fields to the Document content. + :param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document content. + :param embedding_separator: Separator used to concatenate the meta fields to the Document content. """ self.model_name_or_path = model_name_or_path @@ -47,8 +47,8 @@ def __init__( self.batch_size = batch_size self.progress_bar = progress_bar self.normalize_embeddings = normalize_embeddings - self.embed_meta_fields = embed_meta_fields or [] - self.embed_separator = embed_separator + self.metadata_fields_to_embed = metadata_fields_to_embed or [] + self.embedding_separator = embedding_separator def warm_up(self): """ @@ -65,6 +65,11 @@ def run(self, documents: List[Document]): Embed a list of Documents. The embedding of each Document is stored in the `embedding` field of the Document. """ + if not isinstance(documents, list) or not isinstance(documents[0], Document): + raise ValueError( + "SentenceTransformersDocumentEmbedder expects a list of Documents as input." + "In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder." + ) self.warm_up() # TODO: once non textual Documents are properly supported, we should also prepare them for embedding here @@ -72,9 +77,9 @@ def run(self, documents: List[Document]): texts_to_embed = [] for doc in documents: meta_values_to_embed = [ - doc.metadata[key] for key in self.embed_meta_fields if key in doc.metadata and doc.metadata[key] + doc.metadata[key] for key in self.metadata_fields_to_embed if key in doc.metadata and doc.metadata[key] ] - text_to_embed = self.embed_separator.join(meta_values_to_embed + [doc.content]) + text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content]) texts_to_embed.append(text_to_embed) embeddings = self.embedding_backend.embed( diff --git a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py index 54a436c6ff..aff0039a43 100644 --- a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py @@ -79,3 +79,47 @@ def test_run(self): for doc in result["documents"]: assert isinstance(doc, Document) assert isinstance(doc.embedding, np.ndarray) + + @pytest.mark.unit + def test_run_wrong_input_format(self): + embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") + + string_input = "text" + list_integers_input = [1, 2, 3] + + with pytest.raises( + ValueError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input" + ): + embedder.run(documents=string_input) + + with pytest.raises( + ValueError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input" + ): + embedder.run(documents=list_integers_input) + + @pytest.mark.unit + def test_embed_metadata(self): + embedder = SentenceTransformersDocumentEmbedder( + model_name_or_path="model", metadata_fields_to_embed=["meta_field"], embedding_separator="\n" + ) + embedder.embedding_backend = MagicMock() + # embedder.embedding_backend.embed = lambda x, **kwargs: list(np.random.rand(len(x), 16)) + + documents = [ + Document(content=f"document number {i}", metadata={"meta_field": f"meta_value {i}"}) for i in range(5) + ] + + embedder.run(documents=documents) + + embedder.embedding_backend.embed.assert_called_once_with( + [ + "meta_value 0\ndocument number 0", + "meta_value 1\ndocument number 1", + "meta_value 2\ndocument number 2", + "meta_value 3\ndocument number 3", + "meta_value 4\ndocument number 4", + ], + batch_size=32, + show_progress_bar=True, + normalize_embeddings=False, + ) From 059b3514280cd1d0dce5f5b4af8c6379fee1b7c0 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 25 Aug 2023 11:41:21 +0200 Subject: [PATCH 14/21] support non-string metadata --- .../embedders/sentence_transformers_document_embedder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py index 49e38fa174..c9f57b9b27 100644 --- a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py @@ -77,7 +77,9 @@ def run(self, documents: List[Document]): texts_to_embed = [] for doc in documents: meta_values_to_embed = [ - doc.metadata[key] for key in self.metadata_fields_to_embed if key in doc.metadata and doc.metadata[key] + str(doc.metadata[key]) + for key in self.metadata_fields_to_embed + if key in doc.metadata and doc.metadata[key] ] text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content]) texts_to_embed.append(text_to_embed) From d2cdaeb91fbebd33e6529f4e2c205943a609810d Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 25 Aug 2023 11:51:46 +0200 Subject: [PATCH 15/21] make factory private --- .../sentence_transformers_backend.py | 8 ++++---- .../embedding_backends/test_sentence_transformers.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/haystack/preview/embedding_backends/sentence_transformers_backend.py b/haystack/preview/embedding_backends/sentence_transformers_backend.py index 7fd679cd8a..7918211c5c 100644 --- a/haystack/preview/embedding_backends/sentence_transformers_backend.py +++ b/haystack/preview/embedding_backends/sentence_transformers_backend.py @@ -7,7 +7,7 @@ from sentence_transformers import SentenceTransformer -class SentenceTransformersEmbeddingBackendFactory: +class _SentenceTransformersEmbeddingBackendFactory: """ Factory class to create instances of Sentence Transformers embedding backends. """ @@ -20,12 +20,12 @@ def get_embedding_backend( ): embedding_backend_id = f"{model_name_or_path}{device}{use_auth_token}" - if embedding_backend_id in SentenceTransformersEmbeddingBackendFactory._instances: - return SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] + if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances: + return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] embedding_backend = _SentenceTransformersEmbeddingBackend( model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token ) - SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend + _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend diff --git a/test/preview/embedding_backends/test_sentence_transformers.py b/test/preview/embedding_backends/test_sentence_transformers.py index 3a373c127a..00b3b0ef0e 100644 --- a/test/preview/embedding_backends/test_sentence_transformers.py +++ b/test/preview/embedding_backends/test_sentence_transformers.py @@ -1,18 +1,18 @@ from unittest.mock import patch import pytest from haystack.preview.embedding_backends.sentence_transformers_backend import ( - SentenceTransformersEmbeddingBackendFactory, + _SentenceTransformersEmbeddingBackendFactory, ) @pytest.mark.unit @patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") def test_factory_behavior(mock_sentence_transformer): - embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( + embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( model_name_or_path="my_model", device="cpu" ) - same_embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend("my_model", "cpu") - another_embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( + same_embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend("my_model", "cpu") + another_embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( model_name_or_path="another_model", device="cpu" ) @@ -23,14 +23,14 @@ def test_factory_behavior(mock_sentence_transformer): @pytest.mark.unit @patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") def test_model_initialization(mock_sentence_transformer): - SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model", device="cpu") + _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model", device="cpu") mock_sentence_transformer.assert_called_once_with(model_name_or_path="model", device="cpu", use_auth_token=None) @pytest.mark.unit @patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") def test_embedding_function_with_kwargs(mock_sentence_transformer): - embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model") + embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model") data = ["sentence1", "sentence2"] embedding_backend.embed(data=data, normalize_embeddings=True) From ccfeb56e10ea752bc7996915e8edb95dea9e311e Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 25 Aug 2023 16:55:24 +0200 Subject: [PATCH 16/21] change return type; improve tests --- .../embedding_backends/sentence_transformers_backend.py | 4 ++-- .../embedding_backends/test_sentence_transformers.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/haystack/preview/embedding_backends/sentence_transformers_backend.py b/haystack/preview/embedding_backends/sentence_transformers_backend.py index 7918211c5c..e0d68853e6 100644 --- a/haystack/preview/embedding_backends/sentence_transformers_backend.py +++ b/haystack/preview/embedding_backends/sentence_transformers_backend.py @@ -42,6 +42,6 @@ def __init__( model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token ) - def embed(self, data: List[str], **kwargs) -> List[np.ndarray]: + def embed(self, data: List[str], **kwargs) -> List[List[float]]: embedding = self.model.encode(data, **kwargs) - return list(embedding) + return embedding diff --git a/test/preview/embedding_backends/test_sentence_transformers.py b/test/preview/embedding_backends/test_sentence_transformers.py index 00b3b0ef0e..f9f98d0a02 100644 --- a/test/preview/embedding_backends/test_sentence_transformers.py +++ b/test/preview/embedding_backends/test_sentence_transformers.py @@ -23,8 +23,12 @@ def test_factory_behavior(mock_sentence_transformer): @pytest.mark.unit @patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") def test_model_initialization(mock_sentence_transformer): - _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model", device="cpu") - mock_sentence_transformer.assert_called_once_with(model_name_or_path="model", device="cpu", use_auth_token=None) + _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( + model_name_or_path="model", device="cpu", use_auth_token="my_token" + ) + mock_sentence_transformer.assert_called_once_with( + model_name_or_path="model", device="cpu", use_auth_token="my_token" + ) @pytest.mark.unit From f999fc83588978b4fd0f0b36bc29b9e6009cb305 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 25 Aug 2023 17:42:48 +0200 Subject: [PATCH 17/21] warm_up not called in run --- .../sentence_transformers_document_embedder.py | 13 +++++++------ ...test_sentence_transformers_document_embedder.py | 14 +++++++------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py index c9f57b9b27..993bcd7bab 100644 --- a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py @@ -3,7 +3,7 @@ from haystack.preview import component from haystack.preview import Document from haystack.preview.embedding_backends.sentence_transformers_backend import ( - SentenceTransformersEmbeddingBackendFactory, + _SentenceTransformersEmbeddingBackendFactory, ) @@ -16,7 +16,7 @@ class SentenceTransformersDocumentEmbedder: def __init__( self, - model_name_or_path: str, + model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2", device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None, batch_size: int = 32, @@ -28,7 +28,7 @@ def __init__( """ Create a SentenceTransformersDocumentEmbedder component. - :param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-MiniLM-L6-v2'``. + :param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-mpnet-base-v2'``. :param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used. :param use_auth_token: The API token used to download private models from Hugging Face. If this parameter is set to `True`, then the token generated when running @@ -55,7 +55,7 @@ def warm_up(self): Load the embedding backend. """ if not hasattr(self, "embedding_backend"): - self.embedding_backend = SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( + self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.use_auth_token ) @@ -66,11 +66,12 @@ def run(self, documents: List[Document]): The embedding of each Document is stored in the `embedding` field of the Document. """ if not isinstance(documents, list) or not isinstance(documents[0], Document): - raise ValueError( + raise TypeError( "SentenceTransformersDocumentEmbedder expects a list of Documents as input." "In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder." ) - self.warm_up() + if not hasattr(self, "embedding_backend"): + raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.") # TODO: once non textual Documents are properly supported, we should also prepare them for embedding here diff --git a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py index aff0039a43..da82e52fcf 100644 --- a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py @@ -43,7 +43,7 @@ def test_init_with_parameters(self): @pytest.mark.unit @patch( - "haystack.preview.components.embedders.sentence_transformers_document_embedder.SentenceTransformersEmbeddingBackendFactory" + "haystack.preview.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" ) def test_warmup(self, mocked_factory): embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") @@ -55,7 +55,7 @@ def test_warmup(self, mocked_factory): @pytest.mark.unit @patch( - "haystack.preview.components.embedders.sentence_transformers_document_embedder.SentenceTransformersEmbeddingBackendFactory" + "haystack.preview.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" ) def test_warmup_doesnt_reload(self, mocked_factory): embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") @@ -68,7 +68,7 @@ def test_warmup_doesnt_reload(self, mocked_factory): def test_run(self): embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") embedder.embedding_backend = MagicMock() - embedder.embedding_backend.embed = lambda x, **kwargs: list(np.random.rand(len(x), 16)) + embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() documents = [Document(content=f"document number {i}") for i in range(5)] @@ -78,7 +78,8 @@ def test_run(self): assert len(result["documents"]) == len(documents) for doc in result["documents"]: assert isinstance(doc, Document) - assert isinstance(doc.embedding, np.ndarray) + assert isinstance(doc.embedding, list) + assert isinstance(doc.embedding[0], float) @pytest.mark.unit def test_run_wrong_input_format(self): @@ -88,12 +89,12 @@ def test_run_wrong_input_format(self): list_integers_input = [1, 2, 3] with pytest.raises( - ValueError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input" + TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input" ): embedder.run(documents=string_input) with pytest.raises( - ValueError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input" + TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input" ): embedder.run(documents=list_integers_input) @@ -103,7 +104,6 @@ def test_embed_metadata(self): model_name_or_path="model", metadata_fields_to_embed=["meta_field"], embedding_separator="\n" ) embedder.embedding_backend = MagicMock() - # embedder.embedding_backend.embed = lambda x, **kwargs: list(np.random.rand(len(x), 16)) documents = [ Document(content=f"document number {i}", metadata={"meta_field": f"meta_value {i}"}) for i in range(5) From 44959b5646fe2872ccdc1d9331c24693f5e35cf3 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 25 Aug 2023 17:50:58 +0200 Subject: [PATCH 18/21] fix typing --- .../embedding_backends/sentence_transformers_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack/preview/embedding_backends/sentence_transformers_backend.py b/haystack/preview/embedding_backends/sentence_transformers_backend.py index e0d68853e6..4038035ad5 100644 --- a/haystack/preview/embedding_backends/sentence_transformers_backend.py +++ b/haystack/preview/embedding_backends/sentence_transformers_backend.py @@ -43,5 +43,5 @@ def __init__( ) def embed(self, data: List[str], **kwargs) -> List[List[float]]: - embedding = self.model.encode(data, **kwargs) - return embedding + embeddings = self.model.encode(data, **kwargs).tolist() + return embeddings From 9079e91499535560ea131cbfac0c006ed65b76db Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 25 Aug 2023 18:02:29 +0200 Subject: [PATCH 19/21] rm unused import --- .../preview/embedding_backends/sentence_transformers_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/haystack/preview/embedding_backends/sentence_transformers_backend.py b/haystack/preview/embedding_backends/sentence_transformers_backend.py index 4038035ad5..c04169ead8 100644 --- a/haystack/preview/embedding_backends/sentence_transformers_backend.py +++ b/haystack/preview/embedding_backends/sentence_transformers_backend.py @@ -1,5 +1,4 @@ from typing import List, Optional, Union, Dict -import numpy as np from haystack.preview.lazy_imports import LazyImport From 9af62ac6beb71393b4b4e3f9cc43d3e6dd5d7e08 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Mon, 28 Aug 2023 12:03:54 +0200 Subject: [PATCH 20/21] Remove base test class --- .../test_sentence_transformers_document_embedder.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py index da82e52fcf..129014db03 100644 --- a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py @@ -1,18 +1,14 @@ from unittest.mock import patch, MagicMock import pytest +import numpy as np from haystack.preview import Document from haystack.preview.components.embedders.sentence_transformers_document_embedder import ( SentenceTransformersDocumentEmbedder, ) -from test.preview.components.base import BaseTestComponent - -import numpy as np - -class TestSentenceTransformersDocumentEmbedder(BaseTestComponent): - # TODO: We're going to rework these tests when we'll remove BaseTestComponent. +class TestSentenceTransformersDocumentEmbedder: @pytest.mark.unit def test_init_default(self): From d87f8c5e0897476908a58d01bf8f62ab0e0429f8 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Mon, 28 Aug 2023 12:36:47 +0200 Subject: [PATCH 21/21] black --- .../embedders/test_sentence_transformers_document_embedder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py index 129014db03..4bde42025d 100644 --- a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py @@ -9,7 +9,6 @@ class TestSentenceTransformersDocumentEmbedder: - @pytest.mark.unit def test_init_default(self): embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")