From 0a72ae614190b8a0fe09c5405b5131ba7662116a Mon Sep 17 00:00:00 2001 From: cloudrage999 <113858843+cloudrage999@users.noreply.github.com> Date: Tue, 13 Feb 2024 09:38:12 +0330 Subject: [PATCH] Update vector_store_component.py --- .../vector_store/vector_store_component.py | 57 +++++++++++++------ 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/private_gpt/components/vector_store/vector_store_component.py b/private_gpt/components/vector_store/vector_store_component.py index 09af692bf9..dfe6ab036f 100644 --- a/private_gpt/components/vector_store/vector_store_component.py +++ b/private_gpt/components/vector_store/vector_store_component.py @@ -1,19 +1,18 @@ import logging -import typing - from injector import inject, singleton from llama_index import VectorStoreIndex from llama_index.indices.vector_store import VectorIndexRetriever from llama_index.vector_stores.types import VectorStore - from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.paths import local_data_path from private_gpt.settings.settings import Settings +from langchain.retrievers import BM25Retriever, EnsembleRetriever +import typing +from typing import List, Union logger = logging.getLogger(__name__) - @typing.no_type_check def _chromadb_doc_id_metadata_filter( context_filter: ContextFilter | None, @@ -36,6 +35,8 @@ def _chromadb_doc_id_metadata_filter( @singleton class VectorStoreComponent: vector_store: VectorStore + keyword_retriever: BM25Retriever | None = None + ensemble_retriever: EnsembleRetriever | None = None @inject def __init__(self, settings: Settings) -> None: @@ -97,22 +98,46 @@ def __init__(self, settings: Settings) -> None: f"Vectorstore database {settings.vectorstore.database} not supported" ) - @staticmethod + # Check if there are documents to retrieve from and handle the case where the list is empty + documents = [] # Replace this with your actual documents list if available + if documents: + self.keyword_retriever = BM25Retriever.from_documents(documents) + else: + self.keyword_retriever = None # Handle the case where there are no documents + + # Initialize the ensemble retriever only if keyword_retriever is not None + if self.keyword_retriever: + self.ensemble_retriever = EnsembleRetriever( + retrievers=[self.vector_store, self.keyword_retriever], + retriever_weights={ + 'vector_store': 0.5, + 'keyword_retriever': 0.5 + }, + search_kwargs={'search_type': 'mmr'} + ) + else: + self.ensemble_retriever = None + + def get_retriever( + self, index: VectorStoreIndex, context_filter: ContextFilter | None = None, similarity_top_k: int = 2, - ) -> VectorIndexRetriever: - # This way we support qdrant (using doc_ids) and chroma (using where clause) - return VectorIndexRetriever( - index=index, - similarity_top_k=similarity_top_k, - doc_ids=context_filter.docs_ids if context_filter else None, - vector_store_kwargs={ - "where": _chromadb_doc_id_metadata_filter(context_filter) - }, - ) + use_keyword_retriever: bool = False + ) -> Union[VectorIndexRetriever, EnsembleRetriever]: + if use_keyword_retriever and self.ensemble_retriever: + return self.ensemble_retriever + else: + return VectorIndexRetriever( + index=index, + similarity_top_k=similarity_top_k, + doc_ids=context_filter.docs_ids if context_filter else None, + vector_store_kwargs={ + 'where': _chromadb_doc_id_metadata_filter(context_filter) + } + ) def close(self) -> None: - if hasattr(self.vector_store.client, "close"): + if hasattr(self.vector_store, 'client') and hasattr(self.vector_store.client, "close"): self.vector_store.client.close()