Skip to content

Commit

Permalink
MemoryEmbeddingRetriever (2.0) (#5726)
Browse files Browse the repository at this point in the history
* MemoryDocumentStore - Embedding retrieval draft

* add release notes

* fix mypy

* better comment

* improve return_embeddings handling

* MemoryEmbeddingRetriever - first draft

* address PR comments

* release note

* update docstrings

* update docstrings

* incorporated feeback

* add return_embedding to __init__

* rm leftover docstring

---------

Co-authored-by: Daria Fokina <[email protected]>
  • Loading branch information
anakin87 and dfokina authored Sep 8, 2023
1 parent d860a5c commit 2edf85f
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 45 deletions.
4 changes: 2 additions & 2 deletions haystack/preview/components/retrievers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from haystack.preview.components.retrievers.memory import MemoryRetriever
from haystack.preview.components.retrievers.memory import MemoryBM25Retriever, MemoryEmbeddingRetriever

__all__ = ["MemoryRetriever"]
__all__ = ["MemoryBM25Retriever", "MemoryEmbeddingRetriever"]
131 changes: 123 additions & 8 deletions haystack/preview/components/retrievers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@component
class MemoryRetriever:
class MemoryBM25Retriever:
"""
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
Expand All @@ -20,12 +20,12 @@ def __init__(
scale_score: bool = True,
):
"""
Create a MemoryRetriever component.
Create a MemoryBM25Retriever component.
:param document_store: An instance of MemoryDocumentStore.
:param filters: A dictionary with filters to narrow down the search space (default is None).
:param top_k: The maximum number of documents to retrieve (default is 10).
:param scale_score: Whether to scale the BM25 score or not (default is True).
:param filters: A dictionary with filters to narrow down the search space. Default is None.
:param top_k: The maximum number of documents to retrieve. Default is 10.
:param scale_score: Whether to scale the BM25 score or not. Default is True.
:raises ValueError: If the specified top_k is not > 0.
"""
Expand All @@ -51,7 +51,7 @@ def to_dict(self) -> Dict[str, Any]:
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryRetriever":
def from_dict(cls, data: Dict[str, Any]) -> "MemoryBM25Retriever":
"""
Deserialize this component from a dictionary.
"""
Expand All @@ -77,13 +77,12 @@ def run(
scale_score: Optional[bool] = None,
):
"""
Run the MemoryRetriever on the given input data.
Run the MemoryBM25Retriever on the given input data.
:param query: The query string for the retriever.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the BM25 scores or not.
:param document_stores: A dictionary mapping DocumentStore names to instances.
:return: The retrieved documents.
:raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance.
Expand All @@ -101,3 +100,119 @@ def run(
self.document_store.bm25_retrieval(query=query, filters=filters, top_k=top_k, scale_score=scale_score)
)
return {"documents": docs}


@component
class MemoryEmbeddingRetriever:
"""
A component for retrieving documents from a MemoryDocumentStore using a vector similarity metric.
Needs to be connected to a MemoryDocumentStore to run.
"""

def __init__(
self,
document_store: MemoryDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = True,
return_embedding: bool = False,
):
"""
Create a MemoryEmbeddingRetriever component.
:param document_store: An instance of MemoryDocumentStore.
:param filters: A dictionary with filters to narrow down the search space. Default is None.
:param top_k: The maximum number of documents to retrieve. Default is 10.
:param scale_score: Whether to scale the scores of the retrieved documents or not. Default is True.
:param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False.
:raises ValueError: If the specified top_k is not > 0.
"""
if not isinstance(document_store, MemoryDocumentStore):
raise ValueError("document_store must be an instance of MemoryDocumentStore")

self.document_store = document_store

if top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")

self.filters = filters
self.top_k = top_k
self.scale_score = scale_score
self.return_embedding = return_embedding

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
docstore = self.document_store.to_dict()
return default_to_dict(
self,
document_store=docstore,
filters=self.filters,
top_k=self.top_k,
scale_score=self.scale_score,
return_embedding=self.return_embedding,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryBM25Retriever":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
if "document_store" not in init_params:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if init_params["document_store"]["type"] not in document_store.registry:
raise DeserializationError(f"DocumentStore type '{init_params['document_store']['type']}' not found")

docstore_class = document_store.registry[init_params["document_store"]["type"]]
docstore = docstore_class.from_dict(init_params["document_store"])
data["init_parameters"]["document_store"] = docstore
return default_from_dict(cls, data)

@component.output_types(documents=List[List[Document]])
def run(
self,
queries_embeddings: List[List[float]],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
scale_score: Optional[bool] = None,
return_embedding: Optional[bool] = None,
):
"""
Run the MemoryEmbeddingRetriever on the given input data.
:param queries_embeddings: Embeddings of the queries.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the scores of the retrieved documents or not.
:param return_embedding: Whether to return the embedding of the retrieved Documents.
:return: The retrieved documents.
:raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance.
"""
if filters is None:
filters = self.filters
if top_k is None:
top_k = self.top_k
if scale_score is None:
scale_score = self.scale_score
if return_embedding is None:
return_embedding = self.return_embedding

docs = []
for query_embedding in queries_embeddings:
docs.append(
self.document_store.embedding_retrieval(
query_embedding=query_embedding,
filters=filters,
top_k=top_k,
scale_score=scale_score,
return_embedding=return_embedding,
)
)
return {"documents": docs}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
preview:
- |
Rename `MemoryRetriever` to `MemoryBM25Retriever`
Add `MemoryEmbeddingRetriever`, which takes as input a query embedding and
retrieves the most relevant Documents from a `MemoryDocumentStore`.
Loading

0 comments on commit 2edf85f

Please sign in to comment.