diff --git a/haystack/preview/components/retrievers/memory.py b/haystack/preview/components/retrievers/memory.py index 1e3171b53a..63de5b0690 100644 --- a/haystack/preview/components/retrievers/memory.py +++ b/haystack/preview/components/retrievers/memory.py @@ -1,72 +1,67 @@ from typing import Dict, List, Any, Optional from haystack.preview import component, Document -from haystack.preview.document_stores import MemoryDocumentStore +from haystack.preview.document_stores import MemoryDocumentStore, StoreAwareMixin @component -class MemoryRetriever: +class MemoryRetriever(StoreAwareMixin): """ A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm. - """ - - class Input: - """ - Input data for the MemoryRetriever component. - :param query: The query string for the retriever. - :param filters: A dictionary with filters to narrow down the search space. - :param top_k: The maximum number of documents to return. - :param scale_score: Whether to scale the BM25 scores or not. - :param stores: A dictionary mapping document store names to instances. - """ + Needs to be connected to a MemoryDocumentStore to run. + """ - queries: List[str] - filters: Dict[str, Any] - top_k: int - scale_score: bool - stores: Dict[str, Any] + supported_stores = [MemoryDocumentStore] - class Output: - """ - Output data from the MemoryRetriever component. + @component.input + def input(self): # type: ignore + class Input: + """ + Input data for the MemoryRetriever component. - :param documents: The retrieved documents. - """ + :param query: The query string for the retriever. + :param filters: A dictionary with filters to narrow down the search space. + :param top_k: The maximum number of documents to return. + :param scale_score: Whether to scale the BM25 scores or not. + :param stores: A dictionary mapping document store names to instances. + """ - documents: List[List[Document]] + queries: List[str] + filters: Dict[str, Any] + top_k: int + scale_score: bool - @component.input - def input(self): # type: ignore - return MemoryRetriever.Input + return Input @component.output def output(self): # type: ignore - return MemoryRetriever.Output - - def __init__( - self, - document_store_name: str, - filters: Optional[Dict[str, Any]] = None, - top_k: int = 10, - scale_score: bool = True, - ): + class Output: + """ + Output data from the MemoryRetriever component. + + :param documents: The retrieved documents. + """ + + documents: List[List[Document]] + + return Output + + def __init__(self, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = True): """ Create a MemoryRetriever component. - :param document_store_name: The name of the MemoryDocumentStore to retrieve documents from. :param filters: A dictionary with filters to narrow down the search space (default is None). :param top_k: The maximum number of documents to retrieve (default is 10). :param scale_score: Whether to scale the BM25 score or not (default is True). :raises ValueError: If the specified top_k is not > 0. """ - self.document_store_name = document_store_name if top_k <= 0: raise ValueError(f"top_k must be > 0, but got {top_k}") self.defaults = {"top_k": top_k, "scale_score": scale_score, "filters": filters or {}} - def run(self, data: Input) -> Output: + def run(self, data): """ Run the MemoryRetriever on the given input data. @@ -75,19 +70,14 @@ def run(self, data: Input) -> Output: :raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance. """ - if self.document_store_name not in data.stores: - raise ValueError( - f"MemoryRetriever's document store '{self.document_store_name}' not found " - f"in input stores {list(data.stores.keys())}" - ) - document_store = data.stores[self.document_store_name] - if not isinstance(document_store, MemoryDocumentStore): - raise ValueError("MemoryRetriever can only be used with a MemoryDocumentStore instance.") + self.store: MemoryDocumentStore + if not self.store: + raise ValueError("MemoryRetriever needs a store to run: set the store instance to the self.store attribute") docs = [] for query in data.queries: docs.append( - document_store.bm25_retrieval( + self.store.bm25_retrieval( query=query, filters=data.filters, top_k=data.top_k, scale_score=data.scale_score ) ) diff --git a/haystack/preview/document_stores/__init__.py b/haystack/preview/document_stores/__init__.py index 19ba0ecd2c..834fb553b2 100644 --- a/haystack/preview/document_stores/__init__.py +++ b/haystack/preview/document_stores/__init__.py @@ -1,3 +1,4 @@ from haystack.preview.document_stores.protocols import Store, DuplicatePolicy +from haystack.preview.document_stores.mixins import StoreAwareMixin from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore from haystack.preview.document_stores.errors import StoreError, DuplicateDocumentError, MissingDocumentError diff --git a/haystack/preview/document_stores/mixins.py b/haystack/preview/document_stores/mixins.py new file mode 100644 index 0000000000..93c7956fc8 --- /dev/null +++ b/haystack/preview/document_stores/mixins.py @@ -0,0 +1,31 @@ +from typing import List, Optional, Type + + +from haystack.preview.document_stores.protocols import Store + + +class StoreAwareMixin: + """ + Adds the capability of a component to use a single document store from the `self.store` property. + + To use this mixin you must specify which document stores to support by setting a value to `supported_stores`. + To support any document store, set it to `[Store]`. + """ + + _store: Optional[Store] = None + supported_stores: List[Type[Store]] # type: ignore # (see https://github.com/python/mypy/issues/4717) + + @property + def store(self) -> Optional[Store]: + return self._store + + @store.setter + def store(self, store: Store): + if not isinstance(store, Store): + raise ValueError("'store' does not respect the Store Protocol.") + if not any(isinstance(store, type_) for type_ in type(self).supported_stores): + raise ValueError( + f"Store type '{type(store).__name__}' is not compatible with this component. " + f"Compatible store types: {[type_.__name__ for type_ in type(self).supported_stores]}" + ) + self._store = store diff --git a/haystack/preview/document_stores/protocols.py b/haystack/preview/document_stores/protocols.py index 1c269351fb..ecd734a0e4 100644 --- a/haystack/preview/document_stores/protocols.py +++ b/haystack/preview/document_stores/protocols.py @@ -1,4 +1,4 @@ -from typing import Protocol, Optional, Dict, Any, List +from typing import Protocol, Optional, Dict, Any, List, runtime_checkable import logging from enum import Enum @@ -15,6 +15,7 @@ class DuplicatePolicy(Enum): FAIL = "fail" +@runtime_checkable class Store(Protocol): """ Stores Documents to be used by the components of a Pipeline. diff --git a/haystack/preview/pipeline.py b/haystack/preview/pipeline.py index c2cd75e33e..f7c30d9d69 100644 --- a/haystack/preview/pipeline.py +++ b/haystack/preview/pipeline.py @@ -8,9 +8,13 @@ load_pipelines as load_canals_pipelines, save_pipelines as save_canals_pipelines, ) -from canals.pipeline.sockets import find_input_sockets from haystack.preview.document_stores.protocols import Store +from haystack.preview.document_stores.mixins import StoreAwareMixin + + +class NotAStoreError(PipelineError): + pass class NoSuchStoreError(PipelineError): @@ -24,7 +28,7 @@ class Pipeline(CanalsPipeline): def __init__(self): super().__init__() - self.stores: Dict[str, Store] = {} + self._stores: Dict[str, Store] = {} def add_store(self, name: str, store: Store) -> None: """ @@ -34,7 +38,12 @@ def add_store(self, name: str, store: Store) -> None: :param store: the store object. :returns: None """ - self.stores[name] = store + if not isinstance(store, Store): + raise NotAStoreError( + f"This object ({store}) does not respect the Store Protocol, " + "so it can't be added to the pipeline with Pipeline.add_store()." + ) + self._stores[name] = store def list_stores(self) -> List[str]: """ @@ -42,7 +51,7 @@ def list_stores(self) -> List[str]: :returns: a dictionary with all the stores attached to this Pipeline. """ - return list(self.stores.keys()) + return list(self._stores.keys()) def get_store(self, name: str) -> Store: """ @@ -52,33 +61,49 @@ def get_store(self, name: str) -> Store: :returns: the store """ try: - return self.stores[name] + return self._stores[name] except KeyError as e: raise NoSuchStoreError(f"No store named '{name}' is connected to this pipeline.") from e - def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]: + def add_component(self, name: str, instance: Any, store: Optional[str] = None) -> None: """ - Wrapper on top of Canals Pipeline.run(). Adds the `stores` parameter to all nodes. - - :params data: the inputs to give to the input components of the Pipeline. - :params parameters: a dictionary with all the parameters of all the components, namespaced by component. - :params debug: whether to collect and return debug information. - :returns A dictionary with the outputs of the output components of the Pipeline. + Make this component available to the pipeline. Components are not connected to anything by default: + use `Pipeline.connect()` to connect components together. + + Component names must be unique, but component instances can be reused if needed. + + If `store` has a value, the pipeline will also connect this component to the requested document store. + Note that only components that inherit from StoreAwareMixin can be connected to stores. + + :param name: the name of the component. + :param instance: the component instance. + :param store: the store this component needs access to, if any. + :raises ValueError: if: + - a component with the same name already exists + - a component requiring a store didn't receive it + - a component that didn't expect a store received it + :raises PipelineValidationError: if the given instance is not a component + :raises NoSuchStoreError: if the given store name is not known to the pipeline """ - # Get all nodes in this pipelines instance - for node_name in self.graph.nodes: - # Get node inputs - node = self.graph.nodes[node_name]["instance"] - input_params = find_input_sockets(node) - - # If the node needs a store, adds the list of stores to its default inputs - if "stores" in input_params: - if not hasattr(node, "defaults"): - setattr(node, "defaults", {}) - node.defaults["stores"] = self.stores - - # Run the pipeline - return super().run(data=data, debug=debug) + if isinstance(instance, StoreAwareMixin): + if not store: + raise ValueError(f"Component '{name}' needs a store.") + + if store not in self._stores: + raise NoSuchStoreError( + f"Store named '{store}' not found. " + f"Add it with 'pipeline.add_store('{store}', )'." + ) + + if instance.store: + raise ValueError("Reusing components with stores is not supported (yet). Create a separate instance.") + + instance.store = self._stores[store] + + elif store: + raise ValueError(f"Component '{name}' doesn't support stores.") + + super().add_component(name, instance) def load_pipelines(path: Path, _reader: Optional[Callable[..., Any]] = None): diff --git a/pyproject.toml b/pyproject.toml index a9ca80f621..89efaad47c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ dependencies = [ "jsonschema", # Preview - "canals>=0.3,<0.4", + "canals==0.3.2", # Agent events "events", diff --git a/test/preview/components/retrievers/test_memory_retriever.py b/test/preview/components/retrievers/test_memory_retriever.py index 727b10e471..31a800c21d 100644 --- a/test/preview/components/retrievers/test_memory_retriever.py +++ b/test/preview/components/retrievers/test_memory_retriever.py @@ -1,14 +1,16 @@ -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional import pytest from haystack.preview import Pipeline from haystack.preview.components.retrievers.memory import MemoryRetriever from haystack.preview.dataclasses import Document -from haystack.preview.document_stores import MemoryDocumentStore +from haystack.preview.document_stores import Store, MemoryDocumentStore from test.preview.components.base import BaseTestComponent +from haystack.preview.document_stores.protocols import DuplicatePolicy + @pytest.fixture() def mock_docs(): @@ -21,43 +23,39 @@ def mock_docs(): ] -class Test_MemoryRetriever(BaseTestComponent): +class TestMemoryRetriever(BaseTestComponent): @pytest.mark.unit def test_save_load(self, tmp_path): - self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(document_store_name="memory"), tmp_path) + self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(), tmp_path) @pytest.mark.unit def test_save_load_with_parameters(self, tmp_path): - self.assert_can_be_saved_and_loaded_in_pipeline( - MemoryRetriever(document_store_name="memory", top_k=5, scale_score=False), tmp_path - ) + self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(top_k=5, scale_score=False), tmp_path) @pytest.mark.unit def test_init_default(self): - retriever = MemoryRetriever(document_store_name="memory") - assert retriever.document_store_name == "memory" + retriever = MemoryRetriever() assert retriever.defaults == {"filters": {}, "top_k": 10, "scale_score": True} @pytest.mark.unit def test_init_with_parameters(self): - retriever = MemoryRetriever(document_store_name="memory-test", top_k=5, scale_score=False) - assert retriever.document_store_name == "memory-test" + retriever = MemoryRetriever(top_k=5, scale_score=False) assert retriever.defaults == {"filters": {}, "top_k": 5, "scale_score": False} @pytest.mark.unit def test_init_with_invalid_top_k_parameter(self): with pytest.raises(ValueError, match="top_k must be > 0, but got -2"): - MemoryRetriever(document_store_name="memory-test", top_k=-2, scale_score=False) + MemoryRetriever(top_k=-2, scale_score=False) @pytest.mark.unit def test_valid_run(self, mock_docs): top_k = 5 ds = MemoryDocumentStore() ds.write_documents(mock_docs) - mr = MemoryRetriever(document_store_name="memory", top_k=top_k) - result: MemoryRetriever.Output = mr.run( - data=MemoryRetriever.Input(queries=["PHP", "Java"], stores={"memory": ds}) - ) + + retriever = MemoryRetriever(top_k=top_k) + retriever.store = ds + result = retriever.run(data=retriever.input(queries=["PHP", "Java"])) assert getattr(result, "documents") assert len(result.documents) == 2 @@ -67,26 +65,42 @@ def test_valid_run(self, mock_docs): assert result.documents[1][0].content == "Java is a popular programming language" @pytest.mark.unit - def test_invalid_run_wrong_store_name(self): - # Test invalid run with wrong store name - ds = MemoryDocumentStore() - mr = MemoryRetriever(document_store_name="memory") - with pytest.raises(ValueError, match=r"MemoryRetriever's document store 'memory' not found"): - invalid_input_data = MemoryRetriever.Input( - queries=["test"], top_k=10, scale_score=True, stores={"invalid_store": ds} - ) - mr.run(invalid_input_data) + def test_invalid_run_no_store(self): + retriever = MemoryRetriever() + with pytest.raises( + ValueError, match="MemoryRetriever needs a store to run: set the store instance to the self.store attribute" + ): + retriever.run(retriever.input(queries=["test"])) + + @pytest.mark.unit + def test_invalid_run_not_a_store(self): + class MockStore: + ... + + retriever = MemoryRetriever() + with pytest.raises(ValueError, match="does not respect the Store Protocol"): + retriever.store = MockStore() @pytest.mark.unit def test_invalid_run_wrong_store_type(self): - # Test invalid run with wrong store type - ds = MemoryDocumentStore() - mr = MemoryRetriever(document_store_name="memory") - with pytest.raises(ValueError, match=r"MemoryRetriever can only be used with a MemoryDocumentStore instance."): - invalid_input_data = MemoryRetriever.Input( - queries=["test"], top_k=10, scale_score=True, stores={"memory": "not a MemoryDocumentStore"} - ) - mr.run(invalid_input_data) + class MockStore: + def count_documents(self) -> int: + return 0 + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + return [] + + def write_documents( + self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL + ) -> None: + return None + + def delete_documents(self, document_ids: List[str]) -> None: + return None + + retriever = MemoryRetriever() + with pytest.raises(ValueError, match="is not compatible with this component"): + retriever.store = MockStore() @pytest.mark.integration @pytest.mark.parametrize( @@ -99,12 +113,12 @@ def test_invalid_run_wrong_store_type(self): def test_run_with_pipeline(self, mock_docs, query: str, query_result: str): ds = MemoryDocumentStore() ds.write_documents(mock_docs) - mr = MemoryRetriever(document_store_name="memory") + retriever = MemoryRetriever() pipeline = Pipeline() - pipeline.add_component("retriever", mr) pipeline.add_store("memory", ds) - result: Dict[str, Any] = pipeline.run(data={"retriever": MemoryRetriever.Input(queries=[query])}) + pipeline.add_component("retriever", retriever, store="memory") + result: Dict[str, Any] = pipeline.run(data={"retriever": retriever.input(queries=[query])}) assert result assert "retriever" in result @@ -124,12 +138,12 @@ def test_run_with_pipeline(self, mock_docs, query: str, query_result: str): def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int): ds = MemoryDocumentStore() ds.write_documents(mock_docs) - mr = MemoryRetriever(document_store_name="memory") + retriever = MemoryRetriever() pipeline = Pipeline() - pipeline.add_component("retriever", mr) pipeline.add_store("memory", ds) - result: Dict[str, Any] = pipeline.run(data={"retriever": MemoryRetriever.Input(queries=[query], top_k=top_k)}) + pipeline.add_component("retriever", retriever, store="memory") + result: Dict[str, Any] = pipeline.run(data={"retriever": retriever.input(queries=[query], top_k=top_k)}) assert result assert "retriever" in result diff --git a/test/preview/pipeline/test_pipeline.py b/test/preview/pipeline/test_pipeline.py index 2a4be6291d..dfdc866417 100644 --- a/test/preview/pipeline/test_pipeline.py +++ b/test/preview/pipeline/test_pipeline.py @@ -1,16 +1,49 @@ -from typing import Dict, Any +from typing import Any, Optional, Dict, List import pytest -from haystack.preview import Pipeline, component, NoSuchStoreError +from haystack.preview import Pipeline, component, NoSuchStoreError, Document +from haystack.preview.pipeline import NotAStoreError +from haystack.preview.document_stores import StoreAwareMixin, DuplicatePolicy, Store +# Note: we're using a real class instead of a mock because mocks don't play too well with protocols. class MockStore: - ... + def count_documents(self) -> int: + return 0 + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + return [] + + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None: + return None + + def delete_documents(self, document_ids: List[str]) -> None: + return None + + +@pytest.mark.unit +def test_add_store(): + store_1 = MockStore() + store_2 = MockStore() + pipe = Pipeline() + + pipe.add_store(name="first_store", store=store_1) + pipe.add_store(name="second_store", store=store_2) + assert pipe._stores.get("first_store") == store_1 + assert pipe._stores.get("second_store") == store_2 + + +@pytest.mark.unit +def test_add_store_wrong_object(): + pipe = Pipeline() + + with pytest.raises(NotAStoreError, match="does not respect the Store Protocol"): + pipe.add_store(name="store", store="I'm surely not a Store object!") @pytest.mark.unit -def test_pipeline_store_api(): +def test_list_stores(): store_1 = MockStore() store_2 = MockStore() pipe = Pipeline() @@ -20,22 +53,115 @@ def test_pipeline_store_api(): assert pipe.list_stores() == ["first_store", "second_store"] + +@pytest.mark.unit +def test_get_store(): + store_1 = MockStore() + store_2 = MockStore() + pipe = Pipeline() + + pipe.add_store(name="first_store", store=store_1) + pipe.add_store(name="second_store", store=store_2) + assert pipe.get_store("first_store") == store_1 assert pipe.get_store("second_store") == store_2 + + +@pytest.mark.unit +def test_get_store_wrong_name(): + store_1 = MockStore() + pipe = Pipeline() + + with pytest.raises(NoSuchStoreError): + pipe.get_store("first_store") + + pipe.add_store(name="first_store", store=store_1) + assert pipe.get_store("first_store") == store_1 + with pytest.raises(NoSuchStoreError): pipe.get_store("third_store") @pytest.mark.unit -def test_pipeline_stores_in_params(): +def test_add_component_store_aware_component_receives_one_docstore(): + store_1 = MockStore() + store_2 = MockStore() + + @component + class MockComponent(StoreAwareMixin): + supported_stores = [Store] + + class Input: + value: int + + class Output: + value: int + + @component.input + def input(self): + return MockComponent.Input + + @component.output + def output(self): + return MockComponent.Output + + def run(self, data: Input) -> Output: + return MockComponent.Output(value=data.value) + + mock = MockComponent() + pipe = Pipeline() + pipe.add_store(name="first_store", store=store_1) + pipe.add_store(name="second_store", store=store_2) + pipe.add_component("component", mock, store="first_store") + assert mock.store == store_1 + assert pipe.run(data={"component": MockComponent.Input(value=1)}) == {"component": MockComponent.Output(value=1)} + + +@pytest.mark.unit +def test_add_component_store_aware_component_receives_no_docstore(): + store_1 = MockStore() + store_2 = MockStore() + + @component + class MockComponent(StoreAwareMixin): + supported_stores = [Store] + + class Input: + value: int + + class Output: + value: int + + @component.input + def input(self): + return MockComponent.Input + + @component.output + def output(self): + return MockComponent.Output + + def run(self, data: Input) -> Output: + return MockComponent.Output(value=data.value) + + pipe = Pipeline() + pipe.add_store(name="first_store", store=store_1) + pipe.add_store(name="second_store", store=store_2) + + with pytest.raises(ValueError, match="Component 'component' needs a store."): + pipe.add_component("component", MockComponent()) + + +@pytest.mark.unit +def test_non_store_aware_component_receives_one_docstore(): store_1 = MockStore() store_2 = MockStore() @component class MockComponent: + supported_stores = [Store] + class Input: value: int - stores: Dict[str, Any] class Output: value: int @@ -49,13 +175,240 @@ def output(self): return MockComponent.Output def run(self, data: Input) -> Output: - assert data.stores == {"first_store": store_1, "second_store": store_2} return MockComponent.Output(value=data.value) pipe = Pipeline() - pipe.add_component("component", MockComponent()) + pipe.add_store(name="first_store", store=store_1) + pipe.add_store(name="second_store", store=store_2) + + with pytest.raises(ValueError, match="Component 'component' doesn't support stores."): + pipe.add_component("component", MockComponent(), store="first_store") + + +@pytest.mark.unit +def test_add_component_store_aware_component_receives_wrong_docstore_name(): + store_1 = MockStore() + store_2 = MockStore() + + @component + class MockComponent(StoreAwareMixin): + supported_stores = [Store] + + class Input: + value: int + + class Output: + value: int + + @component.input + def input(self): + return MockComponent.Input + + @component.output + def output(self): + return MockComponent.Output + def run(self, data: Input) -> Output: + return MockComponent.Output(value=data.value) + + pipe = Pipeline() pipe.add_store(name="first_store", store=store_1) pipe.add_store(name="second_store", store=store_2) - assert pipe.run(data={"component": MockComponent.Input(value=1)}) == {"component": MockComponent.Output(value=1)} + with pytest.raises(NoSuchStoreError, match="Store named 'wrong_store' not found."): + pipe.add_component("component", MockComponent(), store="wrong_store") + + +@pytest.mark.unit +def test_add_component_store_aware_component_receives_correct_docstore_type(): + store_1 = MockStore() + store_2 = MockStore() + + @component + class MockComponent(StoreAwareMixin): + supported_stores = [MockStore] + + class Input: + value: int + + class Output: + value: int + + @component.input + def input(self): + return MockComponent.Input + + @component.output + def output(self): + return MockComponent.Output + + def run(self, data: Input) -> Output: + return MockComponent.Output(value=data.value) + + mock = MockComponent() + pipe = Pipeline() + pipe.add_store(name="first_store", store=store_1) + pipe.add_store(name="second_store", store=store_2) + + pipe.add_component("component", mock, store="second_store") + assert mock.store == store_2 + + +@pytest.mark.unit +def test_add_component_store_aware_component_is_reused(): + store_1 = MockStore() + store_2 = MockStore() + + @component + class MockComponent(StoreAwareMixin): + supported_stores = [MockStore] + + class Input: + value: int + + class Output: + value: int + + @component.input + def input(self): + return MockComponent.Input + + @component.output + def output(self): + return MockComponent.Output + + def run(self, data: Input) -> Output: + return MockComponent.Output(value=data.value) + + mock = MockComponent() + pipe = Pipeline() + pipe.add_store(name="first_store", store=store_1) + pipe.add_store(name="second_store", store=store_2) + + pipe.add_component("component", mock, store="second_store") + + with pytest.raises(ValueError, match="Reusing components with stores is not supported"): + pipe.add_component("component2", mock, store="second_store") + + with pytest.raises(ValueError, match="Reusing components with stores is not supported"): + pipe.add_component("component2", mock, store="first_store") + + assert mock.store == store_2 + + +@pytest.mark.unit +def test_add_component_store_aware_component_receives_subclass_of_correct_docstore_type(): + class MockStoreSubclass(MockStore): + ... + + store_1 = MockStoreSubclass() + store_2 = MockStore() + + @component + class MockComponent(StoreAwareMixin): + supported_stores = [MockStore] + + class Input: + value: int + + class Output: + value: int + + @component.input + def input(self): + return MockComponent.Input + + @component.output + def output(self): + return MockComponent.Output + + def run(self, data: Input) -> Output: + return MockComponent.Output(value=data.value) + + mock = MockComponent() + mock2 = MockComponent() + pipe = Pipeline() + pipe.add_store(name="first_store", store=store_1) + pipe.add_store(name="second_store", store=store_2) + + pipe.add_component("component", mock, store="first_store") + assert mock.store == store_1 + pipe.add_component("component2", mock2, store="second_store") + assert mock2.store == store_2 + + +@pytest.mark.unit +def test_add_component_store_aware_component_does_not_check_supported_stores(): + class SomethingElse: + ... + + @component + class MockComponent(StoreAwareMixin): + supported_stores = [SomethingElse] + + class Input: + value: int + + class Output: + value: int + + @component.input + def input(self): + return MockComponent.Input + + @component.output + def output(self): + return MockComponent.Output + + def run(self, data: Input) -> Output: + return MockComponent.Output(value=data.value) + + MockComponent() + + +@pytest.mark.unit +def test_add_component_store_aware_component_receives_wrong_docstore_type(): + store_1 = MockStore() + store_2 = MockStore() + + class MockStore2: + def count_documents(self) -> int: + return 0 + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + return [] + + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None: + return None + + def delete_documents(self, document_ids: List[str]) -> None: + return None + + @component + class MockComponent(StoreAwareMixin): + supported_stores = [MockStore2] + + class Input: + value: int + + class Output: + value: int + + @component.input + def input(self): + return MockComponent.Input + + @component.output + def output(self): + return MockComponent.Output + + def run(self, data: Input) -> Output: + return MockComponent.Output(value=data.value) + + mock = MockComponent() + pipe = Pipeline() + pipe.add_store(name="first_store", store=store_1) + pipe.add_store(name="second_store", store=store_2) + + with pytest.raises(ValueError, match="is not compatible with this component"): + pipe.add_component("component", mock, store="second_store")