Skip to content

Commit

Permalink
docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Jan 19, 2024
1 parent 442ab3c commit e42a7e7
Showing 1 changed file with 86 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import logging
from typing import Any, Dict, List, Literal, Optional

import psycopg
from haystack.dataclasses.document import ByteStream, Document
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from psycopg import Error, connect
from psycopg.abc import Query
from psycopg.cursor import Cursor
from psycopg.rows import dict_row
from psycopg.sql import SQL, Identifier
from psycopg.sql import Literal as SQLLiteral
Expand Down Expand Up @@ -52,9 +54,36 @@ def __init__(
recreate_table: bool = False,
search_strategy: Literal["exact_nearest_neighbor", "hnsw"] = "exact_nearest_neighbor",
hnsw_recreate_index_if_exists: bool = False,
hnsw_index_creation_kwargs: Optional[Dict[str, Any]] = None,
hnsw_index_creation_kwargs: Optional[Dict[str, int]] = None,
hnsw_ef_search: Optional[int] = None,
):
"""
Creates a new PgvectorDocumentStore instance.
It is meant to be connected to a PostgreSQL database with the pgvector extension installed.
A specific table to store Haystack documents will be created if it doesn't exist yet.
:param connection_string: The connection string to use to connect to the PostgreSQL database.
e.g. "postgresql://USER:PASSWORD@HOST:PORT/DB_NAME"
:param table_name: The name of the table to use to store Haystack documents. Defaults to "haystack_documents".
:param embedding_dimension: The dimension of the embedding. Defaults to 768.
:param embedding_similarity_function: The similarity function to use when searching for similar embeddings.
Defaults to "cosine_distance". Set it to one of the following values:
:type embedding_similarity_function: Literal["cosine_distance", "max_inner_product", "l2_distance"]
:param recreate_table: Whether to recreate the table if it already exists. Defaults to False.
:param search_strategy: The search strategy to use when searching for similar embeddings.
Defaults to "exact_nearest_neighbor". "hnsw" is an approximate nearest neighbor search strategy,
which trades off some accuracy for speed; it is recommended for large numbers of documents.
:type search_strategy: Literal["exact_nearest_neighbor", "hnsw"]
:param hnsw_recreate_index_if_exists: Whether to recreate the HNSW index if it already exists.
Defaults to False. Only used if search_strategy is set to "hnsw".
:param hnsw_index_creation_kwargs: Additional keyword arguments to pass to the HNSW index creation.
Only used if search_strategy is set to "hnsw". You can find the list of valid arguments in the
pgvector documentation: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
:param hnsw_ef_search: The ef_search parameter to use at query time. Only used if search_strategy is set to
"hnsw". You can find more information about this parameter in the pgvector documentation:
https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
"""

self.connection_string = connection_string
self.table_name = table_name
self.embedding_dimension = embedding_dimension
Expand All @@ -65,9 +94,11 @@ def __init__(
self.hnsw_index_creation_kwargs = hnsw_index_creation_kwargs or {}
self.hnsw_ef_search = hnsw_ef_search

connection = psycopg.connect(connection_string)
connection = connect(connection_string)
connection.autocommit = True
self._connection = connection

# we create a generic cursor and another one that returns dictionaries
self._cursor = connection.cursor()
self._dict_cursor = connection.cursor(row_factory=dict_row)

Expand All @@ -81,18 +112,33 @@ def __init__(
if search_strategy == "hnsw":
self._handle_hnsw()

def _execute_sql(self, sql, params: Optional[tuple] = None, error_msg="", cursor=None):
def _execute_sql(
self, sql_query: Query, params: Optional[tuple] = None, error_msg: str = "", cursor: Optional[Cursor] = None
):
"""
Internal method to execute SQL statements and handle exceptions.
:param sql_query: The SQL query to execute.
:param params: The parameters to pass to the SQL query.
:param error_msg: The error message to use if an exception is raised.
:param cursor: The cursor to use to execute the SQL query. Defaults to self._cursor.
"""

params = params or ()
cursor = cursor or self._cursor

try:
result = cursor.execute(sql, params)
except psycopg.Error as e:
result = cursor.execute(sql_query, params)
except Error as e:
self._connection.rollback()
raise DocumentStoreError(error_msg) from e
return result

def _create_table_if_not_exists(self):
"""
Creates the table to store Haystack documents if it doesn't exist yet.
"""

table_structure_str = ", ".join(
f"{col[0]} {col[1]} {col[2]}" if len(col) == 3 else f"{col[0]} {col[1]}" # noqa: PLR2004
for col in TABLE_DEFINITION
Expand All @@ -105,11 +151,20 @@ def _create_table_if_not_exists(self):
self._execute_sql(create_sql, error_msg="Could not create table in PgvectorDocumentStore")

def delete_table(self):
"""
Deletes the table used to store Haystack documents.
"""

delete_sql = SQL("DROP TABLE IF EXISTS {}").format(Identifier(self.table_name))

self._execute_sql(delete_sql, error_msg="Could not delete table in PgvectorDocumentStore")

def _handle_hnsw(self):
"""
Internal method to handle the HNSW index creation.
It also sets the hnsw.ef_search parameter for queries if it is specified.
"""

if self.hnsw_ef_search:
sql_set_hnsw_ef_search = SQL("SET hnsw.ef_search = {hnsw_ef_search}").format(
hnsw_ef_search=SQLLiteral(self.hnsw_ef_search)
Expand Down Expand Up @@ -137,6 +192,10 @@ def _handle_hnsw(self):
self._create_hnsw_index()

def _create_hnsw_index(self):
"""
Internal method to create the HNSW index.
"""

pg_ops = SIMILARITY_FUNCTION_TO_POSTGRESQL_OPS[self.embedding_similarity_function]
effective_hnsw_index_creation_kwargs = {
key: value
Expand All @@ -163,6 +222,7 @@ def count_documents(self) -> int:
"""
Returns how many documents are present in the document store.
"""

sql_count = SQL("SELECT COUNT(*) FROM {}").format(Identifier(self.table_name))

count = self._execute_sql(sql_count, error_msg="Could not count documents in PgvectorDocumentStore").fetchone()[
Expand All @@ -171,14 +231,15 @@ def count_documents(self) -> int:
return count

def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: # noqa: ARG002
# TODO: implement filters
sql_get_docs = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name))

self._execute_sql(
result = self._execute_sql(
sql_get_docs, error_msg="Could not filter documents from PgvectorDocumentStore", cursor=self._dict_cursor
)

# Fetch all the records
records = self._dict_cursor.fetchall()
records = result.fetchall()
docs = self._from_pg_to_haystack_documents(records)
return docs

Expand All @@ -192,6 +253,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
and the policy is set to DuplicatePolicy.FAIL (or not specified).
:return: The number of documents written to the document store.
"""

if len(documents) > 0:
if not isinstance(documents[0], Document):
msg = "param 'documents' must contain a list of objects of type Document"
Expand Down Expand Up @@ -222,7 +284,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D

try:
self._cursor.executemany(insert_statement, db_documents, returning=True)
except psycopg.Error as e:
except Error as e:
self._connection.rollback()
raise DuplicateDocumentError from e

Expand All @@ -238,6 +300,11 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
return written_docs

def _from_haystack_to_pg_documents(self, documents: List[Document]) -> List[Dict[str, Any]]:
"""
Internal method to convert a list of Haystack Documents to a list of dictionaries that can be used to insert
documents into the PgvectorDocumentStore.
"""

db_documents = []
for document in documents:
db_document = document.to_dict(flatten=False)
Expand Down Expand Up @@ -267,6 +334,10 @@ def _from_haystack_to_pg_documents(self, documents: List[Document]) -> List[Dict
return db_documents

def _from_pg_to_haystack_documents(self, documents: List[Dict[str, Any]]) -> List[Document]:
"""
Internal method to convert a list of dictionaries from pgvector to a list of Haystack Documents.
"""

haystack_documents = []
for document in documents:
haystack_dict = dict(document)
Expand All @@ -286,6 +357,12 @@ def _from_pg_to_haystack_documents(self, documents: List[Dict[str, Any]]) -> Lis
return haystack_documents

def delete_documents(self, document_ids: List[str]) -> None:
"""
Deletes all documents with a matching document_ids from the document store.
:param document_ids: the document ids to delete
"""

if not document_ids:
return

Expand Down

0 comments on commit e42a7e7

Please sign in to comment.