Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add SentenceTransformersDiversityRanker #7095

Merged
merged 11 commits into from
Mar 11, 2024
8 changes: 7 additions & 1 deletion haystack/components/rankers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker
from haystack.components.rankers.meta_field import MetaFieldRanker
from haystack.components.rankers.sentence_transformers_diversity import SentenceTransformersDiversityRanker
from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker

__all__ = ["LostInTheMiddleRanker", "MetaFieldRanker", "TransformersSimilarityRanker"]
__all__ = [
"LostInTheMiddleRanker",
"MetaFieldRanker",
"SentenceTransformersDiversityRanker",
"TransformersSimilarityRanker",
]
237 changes: 237 additions & 0 deletions haystack/components/rankers/sentence_transformers_diversity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
from typing import Any, Dict, List, Literal, Optional

from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace

logger = logging.getLogger(__name__)


with LazyImport(message="Run 'pip install \"sentence-transformers>=2.2.0\"'") as torch_and_sentence_transformers_import:
import torch
from sentence_transformers import SentenceTransformer


@component
class SentenceTransformersDiversityRanker:
"""
Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity
of the documents.

This component provides functionality to rank a list of documents based on their similarity with respect to the
query to maximize the overall diversity. It uses a pre-trained Sentence Transformers model to embed the query and
the Documents.

Usage example:
```python
from haystack import Document
from haystack.components.rankers import SentenceTransformersDiversityRanker

ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine")
ranker.warm_up()

docs = [Document(content="Paris"), Document(content="Berlin")]
query = "What is the capital of germany?"
output = ranker.run(query=query, documents=docs)
docs = output["documents"]
```
"""

def __init__(
self,
model: str = "sentence-transformers/all-MiniLM-L6-v2",
top_k: int = 10,
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
similarity: Literal["dot_product", "cosine"] = "cosine",
query_prefix: str = "",
query_suffix: str = "",
document_prefix: str = "",
document_suffix: str = "",
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
"""
Initialize a SentenceTransformersDiversityRanker.

:param model: Local path or name of the model in Hugging Face's model hub,
such as `'sentence-transformers/all-MiniLM-L6-v2'`.
:param top_k: The maximum number of Documents to return per query.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected.
:param token: The API token used to download private models from Hugging Face.
:param similarity: Similarity metric for comparing embeddings. Can be set to "dot_product" (default) or
"cosine".
:param query_prefix: A string to add to the beginning of the query text before ranking.
Can be used to prepend the text with an instruction, as required by some embedding models,
such as E5 and BGE.
:param query_suffix: A string to add to the end of the query text before ranking.
:param document_prefix: A string to add to the beginning of each Document text before ranking.
Can be used to prepend the text with an instruction, as required by some embedding models,
such as E5 and BGE.
:param document_suffix: A string to add to the end of each Document text before ranking.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""
torch_and_sentence_transformers_import.check()

self.model_name_or_path = model
if top_k is None or top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
self.top_k = top_k
self.device = ComponentDevice.resolve_device(device)
self.token = token
self.model = None
if similarity not in ["dot_product", "cosine"]:
raise ValueError(f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}.")
awinml marked this conversation as resolved.
Show resolved Hide resolved
self.similarity = similarity
self.query_prefix = query_prefix
self.document_prefix = document_prefix
self.query_suffix = query_suffix
self.document_suffix = document_suffix
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator

def warm_up(self):
"""
Initializes the component.
"""
if self.model is None:
self.model = SentenceTransformer(
model_name_or_path=self.model_name_or_path,
device=self.device.to_torch_str(),
use_auth_token=self.token.resolve_value() if self.token else None,
)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
model=self.model_name_or_path,
device=self.device.to_dict(),
token=self.token.to_dict() if self.token else None,
top_k=self.top_k,
similarity=self.similarity,
query_prefix=self.query_prefix,
document_prefix=self.document_prefix,
query_suffix=self.query_suffix,
document_suffix=self.document_suffix,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker":
"""
Deserializes the component from a dictionary.

:param data:
The dictionary to deserialize from.
:returns:
The deserialized component.
"""
serialized_device = data["init_parameters"]["device"]
data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device)

deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)

def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List[Document]:
awinml marked this conversation as resolved.
Show resolved Hide resolved
"""
Orders the given list of documents to maximize diversity.

The algorithm first calculates embeddings for each document and the query. It starts by selecting the document
that is semantically closest to the query. Then, for each remaining document, it selects the one that, on
average, is least similar to the already selected documents. This process continues until all documents are
selected, resulting in a list where each subsequent document contributes the most to the overall diversity of
the selected set.

:param query: The search query.
:param documents: The list of Document objects to be ranked.

:return: A list of documents ordered to maximize diversity.
"""

texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key]
]
text_to_embed = (
self.document_prefix
+ self.embedding_separator.join(meta_values_to_embed + [doc.content or ""])
+ self.document_suffix
)
texts_to_embed.append(text_to_embed)

# Calculate embeddings
doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined]
query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined]

# Normalize embeddings to unit length for computing cosine similarity
if self.similarity == "cosine":
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)

n = len(documents)
selected: List[int] = []

# Compute the similarity vector between the query and documents
query_doc_sim = query_embedding @ doc_embeddings.T

# Start with the document with the highest similarity to the query
selected.append(int(torch.argmax(query_doc_sim).item()))

selected_sum = doc_embeddings[selected[0]] / n

while len(selected) < n:
# Compute mean of dot products of all selected documents and all other documents
similarities = selected_sum @ doc_embeddings.T
# Mask documents that are already selected
similarities[selected] = torch.inf
# Select the document with the lowest total similarity score
index_unselected = int(torch.argmin(similarities).item())
selected.append(index_unselected)
# It's enough just to add to the selected vectors because dot product is distributive
# It's divided by n for numerical stability
selected_sum += doc_embeddings[index_unselected] / n

ranked_docs: List[Document] = [documents[i] for i in selected]

return ranked_docs

@component.output_types(documents=List[Document])
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
"""
Rank the documents based on their diversity.

:param query: The search query.
:param documents: List of Document objects to be ranker.
:param top_k: Optional. An integer to override the top_k set during initialization.

:returns: A dictionary with the following key:
- `documents`: List of Document objects that have been selected based on the diversity ranking.

:raises ValueError: If the top_k value is less than or equal to 0.
"""
if not documents:
return {"documents": []}
awinml marked this conversation as resolved.
Show resolved Hide resolved

if top_k is None:
top_k = self.top_k
elif top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
awinml marked this conversation as resolved.
Show resolved Hide resolved

if self.model is None:
raise ComponentError(
f"The component {self.__class__.__name__} wasn't warmed up. Run 'warm_up()' before calling 'run()'."
)

diversity_sorted = self._greedy_diversity_order(query=query, documents=documents)

return {"documents": diversity_sorted[:top_k]}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add `DiversityRanker`. Diversity Ranker orders documents in such a way as to maximize the overall diversity of the given documents. The ranker leverages sentence-transformer models to calculate semantic embeddings for each document and the query.
Loading