diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index 9e677924f..951c5b5b9 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -4,10 +4,12 @@ import logging from typing import Any, Dict, List, Literal, Optional -import psycopg from haystack.dataclasses.document import ByteStream, Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy +from psycopg import Error, connect +from psycopg.abc import Query +from psycopg.cursor import Cursor from psycopg.rows import dict_row from psycopg.sql import SQL, Identifier from psycopg.sql import Literal as SQLLiteral @@ -52,9 +54,36 @@ def __init__( recreate_table: bool = False, search_strategy: Literal["exact_nearest_neighbor", "hnsw"] = "exact_nearest_neighbor", hnsw_recreate_index_if_exists: bool = False, - hnsw_index_creation_kwargs: Optional[Dict[str, Any]] = None, + hnsw_index_creation_kwargs: Optional[Dict[str, int]] = None, hnsw_ef_search: Optional[int] = None, ): + """ + Creates a new PgvectorDocumentStore instance. + It is meant to be connected to a PostgreSQL database with the pgvector extension installed. + A specific table to store Haystack documents will be created if it doesn't exist yet. + + :param connection_string: The connection string to use to connect to the PostgreSQL database. + e.g. "postgresql://USER:PASSWORD@HOST:PORT/DB_NAME" + :param table_name: The name of the table to use to store Haystack documents. Defaults to "haystack_documents". + :param embedding_dimension: The dimension of the embedding. Defaults to 768. + :param embedding_similarity_function: The similarity function to use when searching for similar embeddings. + Defaults to "cosine_distance". Set it to one of the following values: + :type embedding_similarity_function: Literal["cosine_distance", "max_inner_product", "l2_distance"] + :param recreate_table: Whether to recreate the table if it already exists. Defaults to False. + :param search_strategy: The search strategy to use when searching for similar embeddings. + Defaults to "exact_nearest_neighbor". "hnsw" is an approximate nearest neighbor search strategy, + which trades off some accuracy for speed; it is recommended for large numbers of documents. + :type search_strategy: Literal["exact_nearest_neighbor", "hnsw"] + :param hnsw_recreate_index_if_exists: Whether to recreate the HNSW index if it already exists. + Defaults to False. Only used if search_strategy is set to "hnsw". + :param hnsw_index_creation_kwargs: Additional keyword arguments to pass to the HNSW index creation. + Only used if search_strategy is set to "hnsw". You can find the list of valid arguments in the + pgvector documentation: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw + :param hnsw_ef_search: The ef_search parameter to use at query time. Only used if search_strategy is set to + "hnsw". You can find more information about this parameter in the pgvector documentation: + https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw + """ + self.connection_string = connection_string self.table_name = table_name self.embedding_dimension = embedding_dimension @@ -65,9 +94,11 @@ def __init__( self.hnsw_index_creation_kwargs = hnsw_index_creation_kwargs or {} self.hnsw_ef_search = hnsw_ef_search - connection = psycopg.connect(connection_string) + connection = connect(connection_string) connection.autocommit = True self._connection = connection + + # we create a generic cursor and another one that returns dictionaries self._cursor = connection.cursor() self._dict_cursor = connection.cursor(row_factory=dict_row) @@ -81,18 +112,33 @@ def __init__( if search_strategy == "hnsw": self._handle_hnsw() - def _execute_sql(self, sql, params: Optional[tuple] = None, error_msg="", cursor=None): + def _execute_sql( + self, sql_query: Query, params: Optional[tuple] = None, error_msg: str = "", cursor: Optional[Cursor] = None + ): + """ + Internal method to execute SQL statements and handle exceptions. + + :param sql_query: The SQL query to execute. + :param params: The parameters to pass to the SQL query. + :param error_msg: The error message to use if an exception is raised. + :param cursor: The cursor to use to execute the SQL query. Defaults to self._cursor. + """ + params = params or () cursor = cursor or self._cursor try: - result = cursor.execute(sql, params) - except psycopg.Error as e: + result = cursor.execute(sql_query, params) + except Error as e: self._connection.rollback() raise DocumentStoreError(error_msg) from e return result def _create_table_if_not_exists(self): + """ + Creates the table to store Haystack documents if it doesn't exist yet. + """ + table_structure_str = ", ".join( f"{col[0]} {col[1]} {col[2]}" if len(col) == 3 else f"{col[0]} {col[1]}" # noqa: PLR2004 for col in TABLE_DEFINITION @@ -105,11 +151,20 @@ def _create_table_if_not_exists(self): self._execute_sql(create_sql, error_msg="Could not create table in PgvectorDocumentStore") def delete_table(self): + """ + Deletes the table used to store Haystack documents. + """ + delete_sql = SQL("DROP TABLE IF EXISTS {}").format(Identifier(self.table_name)) self._execute_sql(delete_sql, error_msg="Could not delete table in PgvectorDocumentStore") def _handle_hnsw(self): + """ + Internal method to handle the HNSW index creation. + It also sets the hnsw.ef_search parameter for queries if it is specified. + """ + if self.hnsw_ef_search: sql_set_hnsw_ef_search = SQL("SET hnsw.ef_search = {hnsw_ef_search}").format( hnsw_ef_search=SQLLiteral(self.hnsw_ef_search) @@ -137,6 +192,10 @@ def _handle_hnsw(self): self._create_hnsw_index() def _create_hnsw_index(self): + """ + Internal method to create the HNSW index. + """ + pg_ops = SIMILARITY_FUNCTION_TO_POSTGRESQL_OPS[self.embedding_similarity_function] effective_hnsw_index_creation_kwargs = { key: value @@ -163,6 +222,7 @@ def count_documents(self) -> int: """ Returns how many documents are present in the document store. """ + sql_count = SQL("SELECT COUNT(*) FROM {}").format(Identifier(self.table_name)) count = self._execute_sql(sql_count, error_msg="Could not count documents in PgvectorDocumentStore").fetchone()[ @@ -171,14 +231,15 @@ def count_documents(self) -> int: return count def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: # noqa: ARG002 + # TODO: implement filters sql_get_docs = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) - self._execute_sql( + result = self._execute_sql( sql_get_docs, error_msg="Could not filter documents from PgvectorDocumentStore", cursor=self._dict_cursor ) # Fetch all the records - records = self._dict_cursor.fetchall() + records = result.fetchall() docs = self._from_pg_to_haystack_documents(records) return docs @@ -192,6 +253,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D and the policy is set to DuplicatePolicy.FAIL (or not specified). :return: The number of documents written to the document store. """ + if len(documents) > 0: if not isinstance(documents[0], Document): msg = "param 'documents' must contain a list of objects of type Document" @@ -222,7 +284,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D try: self._cursor.executemany(insert_statement, db_documents, returning=True) - except psycopg.Error as e: + except Error as e: self._connection.rollback() raise DuplicateDocumentError from e @@ -238,6 +300,11 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D return written_docs def _from_haystack_to_pg_documents(self, documents: List[Document]) -> List[Dict[str, Any]]: + """ + Internal method to convert a list of Haystack Documents to a list of dictionaries that can be used to insert + documents into the PgvectorDocumentStore. + """ + db_documents = [] for document in documents: db_document = document.to_dict(flatten=False) @@ -267,6 +334,10 @@ def _from_haystack_to_pg_documents(self, documents: List[Document]) -> List[Dict return db_documents def _from_pg_to_haystack_documents(self, documents: List[Dict[str, Any]]) -> List[Document]: + """ + Internal method to convert a list of dictionaries from pgvector to a list of Haystack Documents. + """ + haystack_documents = [] for document in documents: haystack_dict = dict(document) @@ -286,6 +357,12 @@ def _from_pg_to_haystack_documents(self, documents: List[Dict[str, Any]]) -> Lis return haystack_documents def delete_documents(self, document_ids: List[str]) -> None: + """ + Deletes all documents with a matching document_ids from the document store. + + :param document_ids: the document ids to delete + """ + if not document_ids: return