Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Jan 23, 2024
1 parent cc62733 commit 051e39b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pgvector.psycopg import register_vector

from .filters import _build_where_clause, _normalize_filters
from .filters import _build_where_clause

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -162,8 +162,6 @@ def _execute_sql(
cursor = cursor or self._cursor

try:
print("***QUERY: " + sql_query.as_string(cursor))
print("***PARAMS: " + str(params))
result = cursor.execute(sql_query, params)
except Error as e:
self._connection.rollback()
Expand Down Expand Up @@ -329,8 +327,6 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D

sql_insert += SQL(" RETURNING id")

print("sql_insert", sql_insert.as_string(self._cursor))

try:
self._cursor.executemany(sql_insert, db_documents, returning=True)
except IntegrityError as ie:
Expand All @@ -339,7 +335,10 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
except Error as e:
self._connection.rollback()
sql_query_str = sql_insert.as_string(self._cursor)
error_msg = f"Could not write documents to PgvectorDocumentStore. \nSQL query: {sql_query_str} \nParameters: {db_documents}"
error_msg = (
f"Could not write documents to PgvectorDocumentStore. \n"
f"SQL query: {sql_query_str} \nParameters: {db_documents}"
)
raise DocumentStoreError(error_msg) from e

# get the number of the inserted documents, inspired by psycopg3 docs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
#
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
from itertools import chain
from typing import Any, Dict

from haystack.errors import FilterError
from pandas import DataFrame
from psycopg.sql import SQL
from psycopg.types.json import Jsonb
from itertools import chain


def _build_where_clause(filters: Dict[str, Any], cursor) -> str:
def _build_where_clause(filters: Dict[str, Any]) -> str:
normalized_filters = _normalize_filters(filters)
print("normalized_filters", normalized_filters)

sql_query, params = normalized_filters
if isinstance(params, list):
Expand Down Expand Up @@ -51,9 +50,9 @@ def _build_where_clause(filters: Dict[str, Any], cursor) -> str:
# params = (params,)

actual_params = ()
for i, param in enumerate(params):
for param in enumerate(params):
if param != "no_value":
actual_params = actual_params + (param,)
actual_params = (*actual_params, param)

return where_clause, actual_params

Expand Down Expand Up @@ -85,35 +84,27 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
query_parts = []
values = []
for c in conditions:
print("c0", c[0])

query_parts.append(c[0])
values.append(c[1])

# values = list(chain.from_iterable(values))

print("query_parts", query_parts)
# if isinstance(query_parts[0], list):
# query_parts = list(chain.from_iterable(query_parts))
# print("chained", query_parts)
sql_query_parts = [SQL(q) if isinstance(q, str) else q for q in query_parts]
if isinstance(values[0], list):
values = list(chain.from_iterable(values))
values = [list(chain.from_iterable(values))]

if operator == "AND":
sql_query = SQL("(") + SQL(" AND ").join(sql_query_parts)+ SQL(")")
sql_query = SQL("(") + SQL(" AND ").join(sql_query_parts) + SQL(")")

elif operator == "OR":
sql_query = SQL("(") + SQL(" OR ").join(sql_query_parts)+ SQL(")")
sql_query = SQL("(") + SQL(" OR ").join(sql_query_parts) + SQL(")")

elif operator == "NOT":
joined_query_parts = SQL(" AND ").join(sql_query_parts)
sql_query = SQL("NOT (") + joined_query_parts + SQL(")")

else:
msg = f"Unknown logical operator '{operator}'"
raise FilterError(msg)

return sql_query, values


Expand Down
17 changes: 12 additions & 5 deletions integrations/pgvector/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

class TestFilters(FilterDocumentsTest):
def assert_documents_are_equal(self, received: List[Document], expected: List[Document]):
print("received", received)
print("expected", expected)
"""
This overrides the default assert_documents_are_equal from FilterDocumentsTest.
It is needed because the embeddings are not exactly the same when they are retrieved from Postgres.
"""

assert len(received) == len(expected)
received.sort(key=lambda x: x.id)
expected.sort(key=lambda x: x.id)
Expand Down Expand Up @@ -48,6 +51,10 @@ def test_complex_filter(self, document_store, filterable_docs):

self.assert_documents_are_equal(
result,
[d for d in filterable_docs if
(d.meta.get("number") == 100 and d.meta.get("chapter") == "intro")
or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion")])
[
d
for d in filterable_docs
if (d.meta.get("number") == 100 and d.meta.get("chapter") == "intro")
or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion")
],
)

0 comments on commit 051e39b

Please sign in to comment.