Skip to content

Commit

Permalink
feat: extend pipeline.add_component to support stores (#5261)
Browse files Browse the repository at this point in the history
* add protocol and adapt pipeline

* change API in pipeline.add_component

* adapt pipeline tests

* adapt memoryretriever

* additional checks

* separate protocol and mixin

* review feedback & update tests

* pylint

* Update haystack/preview/document_stores/protocols.py

Co-authored-by: Silvano Cerza <[email protected]>

* Update haystack/preview/document_stores/memory/document_store.py

Co-authored-by: Silvano Cerza <[email protected]>

* docstring of Store

* adapt memorydocumentstore

* fix tests

* remove direct inheritance

* pylint

* Update haystack/preview/document_stores/mixins.py

Co-authored-by: Silvano Cerza <[email protected]>

* Update test/preview/components/retrievers/test_memory_retriever.py

Co-authored-by: Silvano Cerza <[email protected]>

* Update test/preview/components/retrievers/test_memory_retriever.py

Co-authored-by: Silvano Cerza <[email protected]>

* Update test/preview/components/retrievers/test_memory_retriever.py

Co-authored-by: Silvano Cerza <[email protected]>

* Update test/preview/components/retrievers/test_memory_retriever.py

Co-authored-by: Silvano Cerza <[email protected]>

* Update test/preview/components/retrievers/test_memory_retriever.py

Co-authored-by: Silvano Cerza <[email protected]>

* test names

* revert suggestion

* private self._stores

* move asserts out

* remove protocols

* review feedback

* review feedback

* fix tests

* mypy

* review feedback

* fix tests & other details

* naming

* mypy

* fix tests

* typing

* partial review feedback

* move .store to input dataclass

* Revert "move .store to input dataclass"

This reverts commit 53f624b.

* disable reusing components with stores

* disable sharing components with docstores

* Update mixins.py

* black

* upgrade canals & fix tests

---------

Co-authored-by: Silvano Cerza <[email protected]>
  • Loading branch information
ZanSara and silvanocerza authored Jul 17, 2023
1 parent adfabdd commit 8f3fe85
Show file tree
Hide file tree
Showing 8 changed files with 539 additions and 124 deletions.
86 changes: 38 additions & 48 deletions haystack/preview/components/retrievers/memory.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,67 @@
from typing import Dict, List, Any, Optional

from haystack.preview import component, Document
from haystack.preview.document_stores import MemoryDocumentStore
from haystack.preview.document_stores import MemoryDocumentStore, StoreAwareMixin


@component
class MemoryRetriever:
class MemoryRetriever(StoreAwareMixin):
"""
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
"""

class Input:
"""
Input data for the MemoryRetriever component.
:param query: The query string for the retriever.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the BM25 scores or not.
:param stores: A dictionary mapping document store names to instances.
"""
Needs to be connected to a MemoryDocumentStore to run.
"""

queries: List[str]
filters: Dict[str, Any]
top_k: int
scale_score: bool
stores: Dict[str, Any]
supported_stores = [MemoryDocumentStore]

class Output:
"""
Output data from the MemoryRetriever component.
@component.input
def input(self): # type: ignore
class Input:
"""
Input data for the MemoryRetriever component.
:param documents: The retrieved documents.
"""
:param query: The query string for the retriever.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the BM25 scores or not.
:param stores: A dictionary mapping document store names to instances.
"""

documents: List[List[Document]]
queries: List[str]
filters: Dict[str, Any]
top_k: int
scale_score: bool

@component.input
def input(self): # type: ignore
return MemoryRetriever.Input
return Input

@component.output
def output(self): # type: ignore
return MemoryRetriever.Output

def __init__(
self,
document_store_name: str,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = True,
):
class Output:
"""
Output data from the MemoryRetriever component.
:param documents: The retrieved documents.
"""

documents: List[List[Document]]

return Output

def __init__(self, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = True):
"""
Create a MemoryRetriever component.
:param document_store_name: The name of the MemoryDocumentStore to retrieve documents from.
:param filters: A dictionary with filters to narrow down the search space (default is None).
:param top_k: The maximum number of documents to retrieve (default is 10).
:param scale_score: Whether to scale the BM25 score or not (default is True).
:raises ValueError: If the specified top_k is not > 0.
"""
self.document_store_name = document_store_name
if top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
self.defaults = {"top_k": top_k, "scale_score": scale_score, "filters": filters or {}}

def run(self, data: Input) -> Output:
def run(self, data):
"""
Run the MemoryRetriever on the given input data.
Expand All @@ -75,19 +70,14 @@ def run(self, data: Input) -> Output:
:raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance.
"""
if self.document_store_name not in data.stores:
raise ValueError(
f"MemoryRetriever's document store '{self.document_store_name}' not found "
f"in input stores {list(data.stores.keys())}"
)
document_store = data.stores[self.document_store_name]
if not isinstance(document_store, MemoryDocumentStore):
raise ValueError("MemoryRetriever can only be used with a MemoryDocumentStore instance.")
self.store: MemoryDocumentStore

if not self.store:
raise ValueError("MemoryRetriever needs a store to run: set the store instance to the self.store attribute")
docs = []
for query in data.queries:
docs.append(
document_store.bm25_retrieval(
self.store.bm25_retrieval(
query=query, filters=data.filters, top_k=data.top_k, scale_score=data.scale_score
)
)
Expand Down
1 change: 1 addition & 0 deletions haystack/preview/document_stores/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from haystack.preview.document_stores.protocols import Store, DuplicatePolicy
from haystack.preview.document_stores.mixins import StoreAwareMixin
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore
from haystack.preview.document_stores.errors import StoreError, DuplicateDocumentError, MissingDocumentError
31 changes: 31 additions & 0 deletions haystack/preview/document_stores/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List, Optional, Type


from haystack.preview.document_stores.protocols import Store


class StoreAwareMixin:
"""
Adds the capability of a component to use a single document store from the `self.store` property.
To use this mixin you must specify which document stores to support by setting a value to `supported_stores`.
To support any document store, set it to `[Store]`.
"""

_store: Optional[Store] = None
supported_stores: List[Type[Store]] # type: ignore # (see https://github.com/python/mypy/issues/4717)

@property
def store(self) -> Optional[Store]:
return self._store

@store.setter
def store(self, store: Store):
if not isinstance(store, Store):
raise ValueError("'store' does not respect the Store Protocol.")
if not any(isinstance(store, type_) for type_ in type(self).supported_stores):
raise ValueError(
f"Store type '{type(store).__name__}' is not compatible with this component. "
f"Compatible store types: {[type_.__name__ for type_ in type(self).supported_stores]}"
)
self._store = store
3 changes: 2 additions & 1 deletion haystack/preview/document_stores/protocols.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Protocol, Optional, Dict, Any, List
from typing import Protocol, Optional, Dict, Any, List, runtime_checkable

import logging
from enum import Enum
Expand All @@ -15,6 +15,7 @@ class DuplicatePolicy(Enum):
FAIL = "fail"


@runtime_checkable
class Store(Protocol):
"""
Stores Documents to be used by the components of a Pipeline.
Expand Down
77 changes: 51 additions & 26 deletions haystack/preview/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
load_pipelines as load_canals_pipelines,
save_pipelines as save_canals_pipelines,
)
from canals.pipeline.sockets import find_input_sockets

from haystack.preview.document_stores.protocols import Store
from haystack.preview.document_stores.mixins import StoreAwareMixin


class NotAStoreError(PipelineError):
pass


class NoSuchStoreError(PipelineError):
Expand All @@ -24,7 +28,7 @@ class Pipeline(CanalsPipeline):

def __init__(self):
super().__init__()
self.stores: Dict[str, Store] = {}
self._stores: Dict[str, Store] = {}

def add_store(self, name: str, store: Store) -> None:
"""
Expand All @@ -34,15 +38,20 @@ def add_store(self, name: str, store: Store) -> None:
:param store: the store object.
:returns: None
"""
self.stores[name] = store
if not isinstance(store, Store):
raise NotAStoreError(
f"This object ({store}) does not respect the Store Protocol, "
"so it can't be added to the pipeline with Pipeline.add_store()."
)
self._stores[name] = store

def list_stores(self) -> List[str]:
"""
Returns a dictionary with all the stores that are attached to this Pipeline.
:returns: a dictionary with all the stores attached to this Pipeline.
"""
return list(self.stores.keys())
return list(self._stores.keys())

def get_store(self, name: str) -> Store:
"""
Expand All @@ -52,33 +61,49 @@ def get_store(self, name: str) -> Store:
:returns: the store
"""
try:
return self.stores[name]
return self._stores[name]
except KeyError as e:
raise NoSuchStoreError(f"No store named '{name}' is connected to this pipeline.") from e

def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
def add_component(self, name: str, instance: Any, store: Optional[str] = None) -> None:
"""
Wrapper on top of Canals Pipeline.run(). Adds the `stores` parameter to all nodes.
:params data: the inputs to give to the input components of the Pipeline.
:params parameters: a dictionary with all the parameters of all the components, namespaced by component.
:params debug: whether to collect and return debug information.
:returns A dictionary with the outputs of the output components of the Pipeline.
Make this component available to the pipeline. Components are not connected to anything by default:
use `Pipeline.connect()` to connect components together.
Component names must be unique, but component instances can be reused if needed.
If `store` has a value, the pipeline will also connect this component to the requested document store.
Note that only components that inherit from StoreAwareMixin can be connected to stores.
:param name: the name of the component.
:param instance: the component instance.
:param store: the store this component needs access to, if any.
:raises ValueError: if:
- a component with the same name already exists
- a component requiring a store didn't receive it
- a component that didn't expect a store received it
:raises PipelineValidationError: if the given instance is not a component
:raises NoSuchStoreError: if the given store name is not known to the pipeline
"""
# Get all nodes in this pipelines instance
for node_name in self.graph.nodes:
# Get node inputs
node = self.graph.nodes[node_name]["instance"]
input_params = find_input_sockets(node)

# If the node needs a store, adds the list of stores to its default inputs
if "stores" in input_params:
if not hasattr(node, "defaults"):
setattr(node, "defaults", {})
node.defaults["stores"] = self.stores

# Run the pipeline
return super().run(data=data, debug=debug)
if isinstance(instance, StoreAwareMixin):
if not store:
raise ValueError(f"Component '{name}' needs a store.")

if store not in self._stores:
raise NoSuchStoreError(
f"Store named '{store}' not found. "
f"Add it with 'pipeline.add_store('{store}', <the docstore instance>)'."
)

if instance.store:
raise ValueError("Reusing components with stores is not supported (yet). Create a separate instance.")

instance.store = self._stores[store]

elif store:
raise ValueError(f"Component '{name}' doesn't support stores.")

super().add_component(name, instance)


def load_pipelines(path: Path, _reader: Optional[Callable[..., Any]] = None):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ dependencies = [
"jsonschema",

# Preview
"canals>=0.3,<0.4",
"canals==0.3.2",

# Agent events
"events",
Expand Down
Loading

0 comments on commit 8f3fe85

Please sign in to comment.