Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: extend pipeline.add_component to support stores #5261

Merged
merged 48 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
d967605
add protocol and adapt pipeline
ZanSara Jul 3, 2023
e061d3d
change API in pipeline.add_component
ZanSara Jul 3, 2023
9eb2734
adapt pipeline tests
ZanSara Jul 3, 2023
b8af512
adapt memoryretriever
ZanSara Jul 3, 2023
b0cec16
additional checks
ZanSara Jul 4, 2023
1703ef3
separate protocol and mixin
ZanSara Jul 4, 2023
7494028
review feedback & update tests
ZanSara Jul 5, 2023
85ca595
pylint
ZanSara Jul 5, 2023
c644675
Update haystack/preview/document_stores/protocols.py
ZanSara Jul 5, 2023
8b9d8ee
Update haystack/preview/document_stores/memory/document_store.py
ZanSara Jul 5, 2023
8cba631
docstring of Store
ZanSara Jul 5, 2023
68bbf1e
adapt memorydocumentstore
ZanSara Jul 5, 2023
1583d6f
fix tests
ZanSara Jul 5, 2023
0edfbf0
Merge branch 'v2-docstore-protocol' into v2-docstores-connections
ZanSara Jul 5, 2023
c518252
remove direct inheritance
ZanSara Jul 6, 2023
d54a953
pylint
ZanSara Jul 6, 2023
29cd817
Update haystack/preview/document_stores/mixins.py
ZanSara Jul 6, 2023
db0a792
Update test/preview/components/retrievers/test_memory_retriever.py
ZanSara Jul 6, 2023
a60a8b5
Update test/preview/components/retrievers/test_memory_retriever.py
ZanSara Jul 6, 2023
84b7b86
Update test/preview/components/retrievers/test_memory_retriever.py
ZanSara Jul 6, 2023
78992d3
Update test/preview/components/retrievers/test_memory_retriever.py
ZanSara Jul 6, 2023
53ee0c5
Update test/preview/components/retrievers/test_memory_retriever.py
ZanSara Jul 6, 2023
2b11bf5
test names
ZanSara Jul 6, 2023
254b120
revert suggestion
ZanSara Jul 6, 2023
2619d9c
private self._stores
ZanSara Jul 6, 2023
3e5d2f7
move asserts out
ZanSara Jul 6, 2023
7e0cc7b
remove protocols
ZanSara Jul 6, 2023
fa50bf4
review feedback
ZanSara Jul 7, 2023
34600bc
review feedback
ZanSara Jul 7, 2023
e2e807a
fix tests
ZanSara Jul 7, 2023
aeb96c0
Merge branch 'main' into v2-docstores-connections
ZanSara Jul 7, 2023
9a249b7
mypy
ZanSara Jul 7, 2023
357390f
review feedback
ZanSara Jul 7, 2023
1813804
fix tests & other details
ZanSara Jul 10, 2023
f315019
naming
ZanSara Jul 10, 2023
4cc2506
mypy
ZanSara Jul 10, 2023
84252d0
fix tests
ZanSara Jul 10, 2023
5d220af
typing
ZanSara Jul 10, 2023
9214b79
Merge branch 'main' into v2-docstores-connections
ZanSara Jul 11, 2023
6055455
partial review feedback
ZanSara Jul 12, 2023
53f624b
move .store to input dataclass
ZanSara Jul 13, 2023
3f585a6
Revert "move .store to input dataclass"
ZanSara Jul 13, 2023
a6c1b24
disable reusing components with stores
ZanSara Jul 13, 2023
0c466c6
disable sharing components with docstores
ZanSara Jul 13, 2023
e832084
Merge branch 'main' into v2-docstores-connections
ZanSara Jul 13, 2023
cf1d6dc
Update mixins.py
ZanSara Jul 13, 2023
e6f894f
black
ZanSara Jul 13, 2023
e5ccad9
upgrade canals & fix tests
ZanSara Jul 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 38 additions & 48 deletions haystack/preview/components/retrievers/memory.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,67 @@
from typing import Dict, List, Any, Optional

from haystack.preview import component, Document
from haystack.preview.document_stores import MemoryDocumentStore
from haystack.preview.document_stores import MemoryDocumentStore, StoreAwareMixin


@component
class MemoryRetriever:
class MemoryRetriever(StoreAwareMixin):
"""
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
"""

class Input:
"""
Input data for the MemoryRetriever component.

:param query: The query string for the retriever.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the BM25 scores or not.
:param stores: A dictionary mapping document store names to instances.
"""
Needs to be connected to a MemoryDocumentStore to run.
"""

queries: List[str]
filters: Dict[str, Any]
top_k: int
scale_score: bool
stores: Dict[str, Any]
supported_stores = [MemoryDocumentStore]

class Output:
"""
Output data from the MemoryRetriever component.
@component.input
def input(self): # type: ignore
class Input:
"""
Input data for the MemoryRetriever component.

:param documents: The retrieved documents.
"""
:param query: The query string for the retriever.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the BM25 scores or not.
:param stores: A dictionary mapping document store names to instances.
"""

documents: List[List[Document]]
queries: List[str]
filters: Dict[str, Any]
top_k: int
scale_score: bool

@component.input
def input(self): # type: ignore
return MemoryRetriever.Input
return Input

@component.output
def output(self): # type: ignore
return MemoryRetriever.Output

def __init__(
self,
document_store_name: str,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = True,
):
class Output:
"""
Output data from the MemoryRetriever component.

:param documents: The retrieved documents.
"""

documents: List[List[Document]]

return Output

def __init__(self, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = True):
"""
Create a MemoryRetriever component.

:param document_store_name: The name of the MemoryDocumentStore to retrieve documents from.
:param filters: A dictionary with filters to narrow down the search space (default is None).
:param top_k: The maximum number of documents to retrieve (default is 10).
:param scale_score: Whether to scale the BM25 score or not (default is True).

:raises ValueError: If the specified top_k is not > 0.
"""
self.document_store_name = document_store_name
if top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
self.defaults = {"top_k": top_k, "scale_score": scale_score, "filters": filters or {}}

def run(self, data: Input) -> Output:
def run(self, data):
"""
Run the MemoryRetriever on the given input data.

Expand All @@ -75,19 +70,14 @@ def run(self, data: Input) -> Output:

:raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance.
"""
if self.document_store_name not in data.stores:
raise ValueError(
f"MemoryRetriever's document store '{self.document_store_name}' not found "
f"in input stores {list(data.stores.keys())}"
)
document_store = data.stores[self.document_store_name]
if not isinstance(document_store, MemoryDocumentStore):
raise ValueError("MemoryRetriever can only be used with a MemoryDocumentStore instance.")
self.store: MemoryDocumentStore

if not self.store:
raise ValueError("MemoryRetriever needs a store to run: set the store instance to the self.store attribute")
docs = []
for query in data.queries:
docs.append(
document_store.bm25_retrieval(
self.store.bm25_retrieval(
query=query, filters=data.filters, top_k=data.top_k, scale_score=data.scale_score
)
)
Expand Down
1 change: 1 addition & 0 deletions haystack/preview/document_stores/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from haystack.preview.document_stores.protocols import Store, DuplicatePolicy
from haystack.preview.document_stores.mixins import StoreAwareMixin
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore
from haystack.preview.document_stores.errors import StoreError, DuplicateDocumentError, MissingDocumentError
31 changes: 31 additions & 0 deletions haystack/preview/document_stores/mixins.py
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List, Optional, Type


from haystack.preview.document_stores.protocols import Store


class StoreAwareMixin:
"""
Adds the capability of a component to use a single document store from the `self.store` property.

To use this mixin you must specify which document stores to support by setting a value to `supported_stores`.
To support any document store, set it to `[Store]`.
"""

_store: Optional[Store] = None
supported_stores: List[Type[Store]] # type: ignore # (see https://github.com/python/mypy/issues/4717)

ZanSara marked this conversation as resolved.
Show resolved Hide resolved
@property
def store(self) -> Optional[Store]:
return self._store

@store.setter
def store(self, store: Store):
if not isinstance(store, Store):
raise ValueError("'store' does not respect the Store Protocol.")
if not any(isinstance(store, type_) for type_ in type(self).supported_stores):
raise ValueError(
f"Store type '{type(store).__name__}' is not compatible with this component. "
f"Compatible store types: {[type_.__name__ for type_ in type(self).supported_stores]}"
)
self._store = store
3 changes: 2 additions & 1 deletion haystack/preview/document_stores/protocols.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Protocol, Optional, Dict, Any, List
from typing import Protocol, Optional, Dict, Any, List, runtime_checkable

import logging
from enum import Enum
Expand All @@ -15,6 +15,7 @@ class DuplicatePolicy(Enum):
FAIL = "fail"


@runtime_checkable
class Store(Protocol):
"""
Stores Documents to be used by the components of a Pipeline.
Expand Down
77 changes: 51 additions & 26 deletions haystack/preview/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
load_pipelines as load_canals_pipelines,
save_pipelines as save_canals_pipelines,
)
from canals.pipeline.sockets import find_input_sockets

from haystack.preview.document_stores.protocols import Store
from haystack.preview.document_stores.mixins import StoreAwareMixin


class NotAStoreError(PipelineError):
pass


class NoSuchStoreError(PipelineError):
Expand All @@ -24,7 +28,7 @@ class Pipeline(CanalsPipeline):

def __init__(self):
super().__init__()
self.stores: Dict[str, Store] = {}
self._stores: Dict[str, Store] = {}

def add_store(self, name: str, store: Store) -> None:
"""
Expand All @@ -34,15 +38,20 @@ def add_store(self, name: str, store: Store) -> None:
:param store: the store object.
:returns: None
"""
self.stores[name] = store
if not isinstance(store, Store):
raise NotAStoreError(
f"This object ({store}) does not respect the Store Protocol, "
"so it can't be added to the pipeline with Pipeline.add_store()."
)
self._stores[name] = store

def list_stores(self) -> List[str]:
"""
Returns a dictionary with all the stores that are attached to this Pipeline.

:returns: a dictionary with all the stores attached to this Pipeline.
"""
return list(self.stores.keys())
return list(self._stores.keys())

def get_store(self, name: str) -> Store:
"""
Expand All @@ -52,33 +61,49 @@ def get_store(self, name: str) -> Store:
:returns: the store
"""
try:
return self.stores[name]
return self._stores[name]
except KeyError as e:
raise NoSuchStoreError(f"No store named '{name}' is connected to this pipeline.") from e

def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
def add_component(self, name: str, instance: Any, store: Optional[str] = None) -> None:
"""
Wrapper on top of Canals Pipeline.run(). Adds the `stores` parameter to all nodes.

:params data: the inputs to give to the input components of the Pipeline.
:params parameters: a dictionary with all the parameters of all the components, namespaced by component.
:params debug: whether to collect and return debug information.
:returns A dictionary with the outputs of the output components of the Pipeline.
Make this component available to the pipeline. Components are not connected to anything by default:
use `Pipeline.connect()` to connect components together.

Component names must be unique, but component instances can be reused if needed.

If `store` has a value, the pipeline will also connect this component to the requested document store.
Note that only components that inherit from StoreAwareMixin can be connected to stores.

:param name: the name of the component.
:param instance: the component instance.
:param store: the store this component needs access to, if any.
:raises ValueError: if:
- a component with the same name already exists
- a component requiring a store didn't receive it
- a component that didn't expect a store received it
:raises PipelineValidationError: if the given instance is not a component
:raises NoSuchStoreError: if the given store name is not known to the pipeline
"""
# Get all nodes in this pipelines instance
for node_name in self.graph.nodes:
# Get node inputs
node = self.graph.nodes[node_name]["instance"]
input_params = find_input_sockets(node)

# If the node needs a store, adds the list of stores to its default inputs
if "stores" in input_params:
if not hasattr(node, "defaults"):
setattr(node, "defaults", {})
node.defaults["stores"] = self.stores

# Run the pipeline
return super().run(data=data, debug=debug)
if isinstance(instance, StoreAwareMixin):
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
if not store:
raise ValueError(f"Component '{name}' needs a store.")

if store not in self._stores:
raise NoSuchStoreError(
f"Store named '{store}' not found. "
f"Add it with 'pipeline.add_store('{store}', <the docstore instance>)'."
)

if instance.store:
raise ValueError("Reusing components with stores is not supported (yet). Create a separate instance.")

instance.store = self._stores[store]

elif store:
raise ValueError(f"Component '{name}' doesn't support stores.")

super().add_component(name, instance)


def load_pipelines(path: Path, _reader: Optional[Callable[..., Any]] = None):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ dependencies = [
"jsonschema",

# Preview
"canals>=0.3,<0.4",
"canals==0.3.2",

# Agent events
"events",
Expand Down
Loading