Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: introduce Store protocol (v2) #5259

Merged
merged 9 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions haystack/preview/document_stores/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from haystack.preview.document_stores.protocols import Store, DuplicatePolicy
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore
from haystack.preview.document_stores.errors import StoreError, DuplicateDocumentError, MissingDocumentError
14 changes: 7 additions & 7 deletions haystack/preview/document_stores/memory/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from tqdm.auto import tqdm

from haystack.preview.dataclasses import Document
from haystack.preview.document_stores.protocols import DuplicatePolicy
from haystack.preview.document_stores.memory._filters import match
from haystack.preview.document_stores.errors import DuplicateDocumentError, MissingDocumentError
from haystack.utils.scipy_utils import expit

logger = logging.getLogger(__name__)
DuplicatePolicy = Literal["skip", "overwrite", "fail"]

# document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to
# True (default). Scaling uses the expit function (inverse of the logit function) after applying a SCALING_FACTOR. A
Expand Down Expand Up @@ -126,17 +126,17 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
return [doc for doc in self.storage.values() if match(conditions=filters, document=doc)]
return list(self.storage.values())

def write_documents(self, documents: List[Document], duplicates: DuplicatePolicy = "fail") -> None:
def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None:
"""
Writes (or overwrites) documents into the store.

:param documents: a list of documents.
:param duplicates: documents with the same ID count as duplicates. When duplicates are met,
:param policy: documents with the same ID count as duplicates. When duplicates are met,
the store can:
- skip: keep the existing document and ignore the new one.
- overwrite: remove the old document and write the new one.
- fail: an error is raised
:raises DuplicateError: Exception trigger on duplicate document if `duplicates="fail"`
:raises DuplicateError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL`
:return: None
"""
if (
Expand All @@ -147,10 +147,10 @@ def write_documents(self, documents: List[Document], duplicates: DuplicatePolicy
raise ValueError("Please provide a list of Documents.")

for document in documents:
if document.id in self.storage.keys():
if duplicates == "fail":
if policy != DuplicatePolicy.OVERWRITE and document.id in self.storage.keys():
if policy == DuplicatePolicy.FAIL:
raise DuplicateDocumentError(f"ID '{document.id}' already exists.")
if duplicates == "skip":
if policy == DuplicatePolicy.SKIP:
logger.warning("ID '%s' already exists", document.id)
self.storage[document.id] = document

Expand Down
126 changes: 126 additions & 0 deletions haystack/preview/document_stores/protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import Protocol, Optional, Dict, Any, List

import logging
from enum import Enum

from haystack.preview.dataclasses import Document


logger = logging.getLogger(__name__)


class DuplicatePolicy(Enum):
SKIP = "skip"
OVERWRITE = "overwrite"
FAIL = "fail"


class Store(Protocol):
"""
Stores Documents to be used by the components of a Pipeline.

Classes implementing this protocol often store the documents permanently and allow specialized components to
perform retrieval on them, either by embedding, by keyword, hybrid, and so on, depending on the backend used.

In order to retrieve documents, consider using a Retriever that supports the document store implementation that
you're using.
"""

def count_documents(self) -> int:
"""
Returns the number of documents stored.
"""

def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Returns the documents that match the filters provided.

Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical operator (`"$and"`,
`"$or"`, `"$not"`), a comparison operator (`"$eq"`, `$ne`, `"$in"`, `$nin`, `"$gt"`, `"$gte"`, `"$lt"`,
`"$lte"`) or a metadata field name.

Logical operator keys take a dictionary of metadata field names and/or logical operators as value. Metadata
field names take a dictionary of comparison operators as value. Comparison operator keys take a single value or
(in case of `"$in"`) a list of values as value. If no logical operator is provided, `"$and"` is used as default
operation. If no comparison operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used
as default operation.

Example:

```python
filters = {
"$and": {
"type": {"$eq": "article"},
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": {"$in": ["economy", "politics"]},
"publisher": {"$eq": "nytimes"}
}
}
}
# or simpler using default operators
filters = {
"type": "article",
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": ["economy", "politics"],
"publisher": "nytimes"
}
}
```

To use the same logical operator multiple times on the same level, logical operators can take a list of
dictionaries as value.

Example:

```python
filters = {
"$or": [
{
"$and": {
"Type": "News Paper",
"Date": {
"$lt": "2019-01-01"
}
}
},
{
"$and": {
"Type": "Blog Post",
"Date": {
"$gte": "2019-01-01"
}
}
}
]
}
```

:param filters: the filters to apply to the document list.
:return: a list of Documents that match the given filters.
"""

def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None:
"""
Writes (or overwrites) documents into the store.

:param documents: a list of documents.
:param policy: documents with the same ID count as duplicates. When duplicates are met,
the store can:
- skip: keep the existing document and ignore the new one.
- overwrite: remove the old document and write the new one.
- fail: an error is raised
:raises DuplicateError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL`
:return: None
"""

def delete_documents(self, document_ids: List[str]) -> None:
"""
Deletes all documents with a matching document_ids from the document store.
Fails with `MissingDocumentError` if no document with this id is present in the store.

:param object_ids: the object_ids to delete
"""
8 changes: 5 additions & 3 deletions haystack/preview/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
)
from canals.pipeline.sockets import find_input_sockets

from haystack.preview.document_stores.protocols import Store


class NoSuchStoreError(PipelineError):
pass
Expand All @@ -23,9 +25,9 @@ class Pipeline(CanalsPipeline):

def __init__(self):
super().__init__()
self.stores = {}
self.stores: Dict[str, Store] = {}

def add_store(self, name: str, store: object) -> None:
def add_store(self, name: str, store: Store) -> None:
"""
Make a store available to all nodes of this pipeline.

Expand All @@ -43,7 +45,7 @@ def list_stores(self) -> List[str]:
"""
return list(self.stores.keys())

def get_store(self, name: str) -> object:
def get_store(self, name: str) -> Store:
"""
Returns the store associated with the given name.

Expand Down
Loading