Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update vector_store_component.py to add keyword based search alongside semantic search #1603

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions private_gpt/components/vector_store/vector_store_component.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import logging
import typing

from injector import inject, singleton
from llama_index import VectorStoreIndex
from llama_index.indices.vector_store import VectorIndexRetriever
from llama_index.vector_stores.types import VectorStore

from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.paths import local_data_path
from private_gpt.settings.settings import Settings
from langchain.retrievers import BM25Retriever, EnsembleRetriever
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't should use langchain.retrievers

import typing
from typing import List, Union

logger = logging.getLogger(__name__)


@typing.no_type_check
def _chromadb_doc_id_metadata_filter(
context_filter: ContextFilter | None,
Expand All @@ -36,6 +35,8 @@ def _chromadb_doc_id_metadata_filter(
@singleton
class VectorStoreComponent:
vector_store: VectorStore
keyword_retriever: BM25Retriever | None = None
ensemble_retriever: EnsembleRetriever | None = None

@inject
def __init__(self, settings: Settings) -> None:
Expand Down Expand Up @@ -97,22 +98,46 @@ def __init__(self, settings: Settings) -> None:
f"Vectorstore database {settings.vectorstore.database} not supported"
)

@staticmethod
# Check if there are documents to retrieve from and handle the case where the list is empty
documents = [] # Replace this with your actual documents list if available
if documents:
self.keyword_retriever = BM25Retriever.from_documents(documents)
else:
self.keyword_retriever = None # Handle the case where there are no documents

# Initialize the ensemble retriever only if keyword_retriever is not None
if self.keyword_retriever:
self.ensemble_retriever = EnsembleRetriever(
retrievers=[self.vector_store, self.keyword_retriever],
retriever_weights={
'vector_store': 0.5,
'keyword_retriever': 0.5
},
search_kwargs={'search_type': 'mmr'}
)
else:
self.ensemble_retriever = None


def get_retriever(
self,
index: VectorStoreIndex,
context_filter: ContextFilter | None = None,
similarity_top_k: int = 2,
) -> VectorIndexRetriever:
# This way we support qdrant (using doc_ids) and chroma (using where clause)
return VectorIndexRetriever(
index=index,
similarity_top_k=similarity_top_k,
doc_ids=context_filter.docs_ids if context_filter else None,
vector_store_kwargs={
"where": _chromadb_doc_id_metadata_filter(context_filter)
},
)
use_keyword_retriever: bool = False
) -> Union[VectorIndexRetriever, EnsembleRetriever]:
if use_keyword_retriever and self.ensemble_retriever:
return self.ensemble_retriever
else:
return VectorIndexRetriever(
index=index,
similarity_top_k=similarity_top_k,
doc_ids=context_filter.docs_ids if context_filter else None,
vector_store_kwargs={
'where': _chromadb_doc_id_metadata_filter(context_filter)
}
)

def close(self) -> None:
if hasattr(self.vector_store.client, "close"):
if hasattr(self.vector_store, 'client') and hasattr(self.vector_store.client, "close"):
self.vector_store.client.close()