Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SentenceTransformersDocumentEmbedder #5606

Merged
merged 29 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a12d247
first draft
anakin87 Aug 14, 2023
f0a0452
Merge branch 'main' into st-embedding-backend
anakin87 Aug 17, 2023
0ba2f52
incorporate feedback
anakin87 Aug 17, 2023
83996d3
some unit tests
anakin87 Aug 18, 2023
0a80333
release notes
anakin87 Aug 18, 2023
d76e91c
real release notes
anakin87 Aug 18, 2023
76c6069
Merge branch 'main' into st-embedding-backend
anakin87 Aug 18, 2023
dc302bc
refactored to use a factory class
anakin87 Aug 21, 2023
021813f
allow forcing fresh instances
anakin87 Aug 21, 2023
a5976d3
Merge branch 'main' into st-embedding-backend
anakin87 Aug 21, 2023
6de8081
first draft
anakin87 Aug 21, 2023
3070893
Update haystack/preview/embedding_backends/sentence_transformers_back…
anakin87 Aug 21, 2023
bcba5bd
simplify implementation and tests
anakin87 Aug 23, 2023
2414f20
Merge branch 'st-embedding-backend' into st-document-embedder
anakin87 Aug 23, 2023
e2cc862
add embed_meta_fields implementation
anakin87 Aug 23, 2023
7107ff8
lg update
dfokina Aug 24, 2023
31233b2
improve meta data embedding; tests
anakin87 Aug 25, 2023
059b351
support non-string metadata
anakin87 Aug 25, 2023
d2cdaeb
make factory private
anakin87 Aug 25, 2023
ccfeb56
change return type; improve tests
anakin87 Aug 25, 2023
0a504af
Merge branch 'st-embedding-backend' into st-document-embedder
anakin87 Aug 25, 2023
f999fc8
warm_up not called in run
anakin87 Aug 25, 2023
44959b5
fix typing
anakin87 Aug 25, 2023
9079e91
rm unused import
anakin87 Aug 25, 2023
292a413
Merge branch 'main' into st-embedding-backend
anakin87 Aug 25, 2023
2e18e29
Merge branch 'st-embedding-backend' into st-document-embedder
ZanSara Aug 28, 2023
9af62ac
Remove base test class
ZanSara Aug 28, 2023
d87f8c5
black
ZanSara Aug 28, 2023
d6ab276
Merge branch 'main' into st-document-embedder
ZanSara Aug 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import List, Optional, Union

from haystack.preview import component
from haystack.preview import Document
from haystack.preview.embedding_backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)


@component
class SentenceTransformersDocumentEmbedder:
"""
A component for computing Document embeddings using Sentence Transformers models.
The embedding of each Document is stored in the `embedding` field of the Document.
"""

def __init__(
self,
model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[str] = None,
use_auth_token: Union[bool, str, None] = None,
batch_size: int = 32,
progress_bar: bool = True,
normalize_embeddings: bool = False,
metadata_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
"""
Create a SentenceTransformersDocumentEmbedder component.

:param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-mpnet-base-v2'``.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used.
:param use_auth_token: The API token used to download private models from Hugging Face.
If this parameter is set to `True`, then the token generated when running
`transformers-cli login` (stored in ~/.huggingface) will be used.
:param batch_size: Number of strings to encode at once.
:param progress_bar: If true, displays progress bar during embedding.
:param normalize_embeddings: If set to true, returned vectors will have length 1.
:param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""

self.model_name_or_path = model_name_or_path
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device
self.use_auth_token = use_auth_token
self.batch_size = batch_size
self.progress_bar = progress_bar
self.normalize_embeddings = normalize_embeddings
self.metadata_fields_to_embed = metadata_fields_to_embed or []
self.embedding_separator = embedding_separator

def warm_up(self):
"""
Load the embedding backend.
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.use_auth_token
)

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Embed a list of Documents.
The embedding of each Document is stored in the `embedding` field of the Document.
"""
if not isinstance(documents, list) or not isinstance(documents[0], Document):
raise TypeError(
"SentenceTransformersDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder."
)
if not hasattr(self, "embedding_backend"):
raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")

# TODO: once non textual Documents are properly supported, we should also prepare them for embedding here

texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.metadata[key])
for key in self.metadata_fields_to_embed
if key in doc.metadata and doc.metadata[key]
]
text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content])
texts_to_embed.append(text_to_embed)

embeddings = self.embedding_backend.embed(
texts_to_embed,
batch_size=self.batch_size,
show_progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
)

documents_with_embeddings = []
for doc, emb in zip(documents, embeddings):
doc_as_dict = doc.to_dict()
doc_as_dict["embedding"] = emb
documents_with_embeddings.append(Document.from_dict(doc_as_dict))

return {"documents": documents_with_embeddings}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
preview:
- |
Add Sentence Transformers Document Embedder.
It computes embeddings of Documents. The embedding of each Document is stored in the `embedding` field of the Document.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from unittest.mock import patch, MagicMock
import pytest
import numpy as np

from haystack.preview import Document
from haystack.preview.components.embedders.sentence_transformers_document_embedder import (
SentenceTransformersDocumentEmbedder,
)


class TestSentenceTransformersDocumentEmbedder:
@pytest.mark.unit
def test_init_default(self):
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
assert embedder.model_name_or_path == "model"
assert embedder.device is None
assert embedder.use_auth_token is None
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.normalize_embeddings is False

@pytest.mark.unit
def test_init_with_parameters(self):
embedder = SentenceTransformersDocumentEmbedder(
model_name_or_path="model",
device="cpu",
use_auth_token=True,
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
)
assert embedder.model_name_or_path == "model"
assert embedder.device == "cpu"
assert embedder.use_auth_token is True
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.normalize_embeddings is True

@pytest.mark.unit
@patch(
"haystack.preview.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_warmup(self, mocked_factory):
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name_or_path="model", device=None, use_auth_token=None
)

@pytest.mark.unit
@patch(
"haystack.preview.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_warmup_doesnt_reload(self, mocked_factory):
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once()

@pytest.mark.unit
def test_run(self):
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
embedder.embedding_backend = MagicMock()
embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist()

documents = [Document(content=f"document number {i}") for i in range(5)]

result = embedder.run(documents=documents)

assert isinstance(result["documents"], list)
assert len(result["documents"]) == len(documents)
for doc in result["documents"]:
assert isinstance(doc, Document)
assert isinstance(doc.embedding, list)
assert isinstance(doc.embedding[0], float)

@pytest.mark.unit
def test_run_wrong_input_format(self):
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")

string_input = "text"
list_integers_input = [1, 2, 3]

with pytest.raises(
TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input"
):
embedder.run(documents=string_input)

with pytest.raises(
TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input"
):
embedder.run(documents=list_integers_input)

@pytest.mark.unit
def test_embed_metadata(self):
embedder = SentenceTransformersDocumentEmbedder(
model_name_or_path="model", metadata_fields_to_embed=["meta_field"], embedding_separator="\n"
)
embedder.embedding_backend = MagicMock()

documents = [
Document(content=f"document number {i}", metadata={"meta_field": f"meta_value {i}"}) for i in range(5)
]

embedder.run(documents=documents)

embedder.embedding_backend.embed.assert_called_once_with(
[
"meta_value 0\ndocument number 0",
"meta_value 1\ndocument number 1",
"meta_value 2\ndocument number 2",
"meta_value 3\ndocument number 3",
"meta_value 4\ndocument number 4",
],
batch_size=32,
show_progress_bar=True,
normalize_embeddings=False,
)