Skip to content

Commit

Permalink
some logical filters
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Dec 21, 2023
1 parent c6600e9 commit 10fe827
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
25 changes: 21 additions & 4 deletions integrations/pinecone/src/pinecone_haystack/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]:
"""
Converts Haystack filters in ElasticSearch compatible filters.
Converts Haystack filters in Pinecone compatible filters.
Reference: https://docs.pinecone.io/docs/metadata-filtering
"""
if not isinstance(filters, dict):
msg = "Filters must be a dictionary"
Expand All @@ -20,8 +21,22 @@ def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]:
return _parse_logical_condition(filters)


def _parse_logical_condition(filters: Dict[str, Any]) -> Dict[str, Any]:
return filters
def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "conditions" not in condition:
msg = f"'conditions' key missing in {condition}"
raise FilterError(msg)

operator = condition["operator"]
conditions = [_parse_comparison_condition(c) for c in condition["conditions"]]

if operator in LOGICAL_OPERATORS:
return {LOGICAL_OPERATORS[operator]: conditions}

msg = f"Unknown logical operator '{operator}'"
raise FilterError(msg)


def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -55,7 +70,7 @@ def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(value, DataFrame):
value = value.to_json()

return {"$and": [COMPARISON_OPERATORS[operator](field, value)]}
return COMPARISON_OPERATORS[operator](field, value)


def _equal(field: str, value: Any) -> Dict[str, Any]:
Expand Down Expand Up @@ -174,3 +189,5 @@ def _in(field: str, value: Any) -> Dict[str, Any]:
"in": _in,
"not in": _not_in,
}

LOGICAL_OPERATORS = {"AND": "$and", "OR": "$or"}
11 changes: 5 additions & 6 deletions integrations/pinecone/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,23 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do
The PineconeDocumentStore.filter_documents() method returns a Documents with their score set.
We don't want to compare the score, so we set it to None before comparing the documents.
"""

for doc in received:
doc.score = None
# Pinecone seems to convert strings to datetime objects (undocumented behavior)
# We convert them back to strings to compare them
if "date" in doc.meta:
doc.meta["date"] = doc.meta["date"].isoformat()

# Pinecone seems to convert integers to floats (undocumented behavior)
# We convert them back to integers to compare them
if "number" in doc.meta:
doc.meta["number"] = int(doc.meta["number"])

# let's compare the documents
# Lists comparison
assert len(received) == len(expected)
for received_doc in received:
id_ = received_doc.id
expected_doc = next(filter(lambda x: x.id == id_, expected))

received.sort(key=lambda x: x.id)
expected.sort(key=lambda x: x.id)
for received_doc, expected_doc in zip(received, expected):
assert received_doc.meta == expected_doc.meta
assert received_doc.content == expected_doc.content
if received_doc.dataframe is None:
Expand Down

0 comments on commit 10fe827

Please sign in to comment.