Skip to content

Commit

Permalink
making progress on index
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Jan 19, 2024
1 parent fe63ac9 commit 68be57a
Showing 1 changed file with 34 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from haystack.document_stores.types import DuplicatePolicy
from sqlalchemy import create_engine, delete, text
from sqlalchemy.dialects.postgresql import BYTEA, JSON, TEXT, VARCHAR, insert
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import DeclarativeBase, Session, mapped_column
from sqlalchemy.schema import Index
Expand All @@ -29,11 +30,11 @@

class _AbstractDBDocument(DeclarativeBase):
# __abstract__ = True means that this class does not correspond to a table in the database
# this allows setting dinamically the table name
# this allows setting dinamically the table name and the embedding dimension
__abstract__ = True

id = mapped_column(VARCHAR(64), primary_key=True)
embedding = mapped_column(Vector(None), nullable=True)
embedding = mapped_column(Vector(), nullable=True)
content = mapped_column(TEXT, nullable=True)
dataframe = mapped_column(JSON, nullable=True)
blob = mapped_column(BYTEA, nullable=True)
Expand All @@ -42,8 +43,12 @@ class _AbstractDBDocument(DeclarativeBase):
meta = mapped_column(JSON, nullable=True)


def _get_db_document(table_name):
return type("DBDocument", (_AbstractDBDocument,), {"__tablename__": table_name})
def _get_db_document(table_name, embedding_dimension):
return type(
"DBDocument",
(_AbstractDBDocument,),
{"__tablename__": table_name, "embedding": mapped_column(Vector(embedding_dimension), nullable=True)},
)


class PgvectorDocumentStore:
Expand All @@ -52,22 +57,23 @@ def __init__(
*,
connection_string: str,
table_name: str = "haystack_documents",
embedding_dimension: int = 768,
embedding_similarity_function: Literal[
"cosine_distance", "max_inner_product", "l2_distance"
] = "cosine_distance",
recreate_table: bool = False,
search_strategy: Literal["exact_nearest_neighbor", "hnsw"] = "exact_nearest_neighbor",
hnsw_recreate_index_if_exists: bool = False,
hnsw_index_creation_kwargs: Optional[Dict[str, Any]] = None,
hnsw_ef_search: Optional[int] = None,
):
engine = create_engine(connection_string)
with engine.connect() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.commit()
self._session = Session(engine)

self._DBDocument = _get_db_document(table_name)
self._session.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
self._session.commit()

self._session = Session(engine)
self._DBDocument = _get_db_document(table_name, embedding_dimension)

if recreate_table:
self._DBDocument.__table__.drop(engine, checkfirst=True)
Expand All @@ -78,6 +84,9 @@ def __init__(
hnsw_index_creation_kwargs = hnsw_index_creation_kwargs or {}

if search_strategy == "hnsw":
if hnsw_ef_search:
self._session.execute(text("SET hnsw.ef_search = :ef_search"), ef_search=hnsw_ef_search)

effective_hnsw_index_creation_kwargs = {}
for key, value in hnsw_index_creation_kwargs.items():
if key in HNSW_INDEX_CREATION_VALID_KWARGS:
Expand All @@ -89,17 +98,30 @@ def __init__(
HNSW_INDEX_CREATION_VALID_KWARGS,
)

inspector = Inspector.from_engine(engine)
exists_index = next(
(index for index in inspector.get_indexes(table_name=table_name) if index["name"] == "hnsw_index"), None
)

if exists_index:
if not hnsw_recreate_index_if_exists:
logger.warning(
"HNSW index already exists and won't be recreated. "
"If you want to recreate it, set hnsw_recreate_index=True"
)
return
self._session.execute(text("DROP INDEX hnsw_index"))
self._session.commit()

index = Index(
"hnsw_index",
self._DBDocument.embedding,
postgresql_using="hnsw",
postgresql_with=effective_hnsw_index_creation_kwargs,
postgresql_ops={"embedding": SIMILARITY_FUNCTION_TO_POSTGRESQL_OPS[embedding_similarity_function]},
)
index.create(engine)

if hnsw_ef_search:
conn.execute(text("SET hnsw.ef_search = :ef_search"), ef_search=hnsw_ef_search)
index.create(engine)

def count_documents(self) -> int:
"""
Expand Down

0 comments on commit 68be57a

Please sign in to comment.