From 865563388b9e6515dc38d462a35669b09fda1925 Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Fri, 26 Apr 2024 17:50:13 +0400 Subject: [PATCH] feat: add applying filters on queries for qdrant --- models/query.py | 6 ++++-- service/router.py | 4 +++- vectordbs/base.py | 5 ++++- vectordbs/qdrant.py | 6 +++++- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/models/query.py b/models/query.py index 9289ffbc..b935e66c 100644 --- a/models/query.py +++ b/models/query.py @@ -1,10 +1,10 @@ -from typing import List, Optional - from pydantic import BaseModel +from typing import List, Optional from models.document import BaseDocumentChunk from models.ingest import EncoderConfig from models.vector_database import VectorDatabase +from qdrant_client.http.models import Filter class RequestPayload(BaseModel): @@ -15,6 +15,8 @@ class RequestPayload(BaseModel): session_id: Optional[str] = None interpreter_mode: Optional[bool] = False exclude_fields: List[str] = None + # TODO: use our own Filter model + filter: Optional[Filter] = None class ResponseData(BaseModel): diff --git a/service/router.py b/service/router.py index 9b840fce..dad43acc 100644 --- a/service/router.py +++ b/service/router.py @@ -40,7 +40,9 @@ def create_route_layer() -> RouteLayer: async def get_documents( *, vector_service: BaseVectorDatabase, payload: RequestPayload ) -> list[BaseDocumentChunk]: - chunks = await vector_service.query(input=payload.input, top_k=5) + chunks = await vector_service.query( + input=payload.input, filter=payload.filter, top_k=5 + ) # filter out documents with empty content chunks = [chunk for chunk in chunks if chunk.content.strip()] if not len(chunks): diff --git a/vectordbs/base.py b/vectordbs/base.py index 0f3203c2..b4a7b920 100644 --- a/vectordbs/base.py +++ b/vectordbs/base.py @@ -7,6 +7,7 @@ from models.delete import DeleteResponse from models.document import BaseDocumentChunk +from models.query import Filter from utils.logger import logger @@ -24,7 +25,9 @@ async def upsert(self, chunks: List[BaseDocumentChunk]): pass @abstractmethod - async def query(self, input: str, top_k: int = 25) -> List[BaseDocumentChunk]: + async def query( + self, input: str, filter: Filter, top_k: int = 25 + ) -> List[BaseDocumentChunk]: pass @abstractmethod diff --git a/vectordbs/qdrant.py b/vectordbs/qdrant.py index 4c667468..ed4f6f90 100644 --- a/vectordbs/qdrant.py +++ b/vectordbs/qdrant.py @@ -7,6 +7,7 @@ from models.delete import DeleteResponse from models.document import BaseDocumentChunk +from models.query import Filter from vectordbs.base import BaseVectorDatabase MAX_QUERY_TOP_K = 5 @@ -69,11 +70,14 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None: self.client.upsert(collection_name=self.index_name, wait=True, points=points) - async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List: + async def query( + self, input: str, filter: Filter, top_k: int = MAX_QUERY_TOP_K + ) -> List: vectors = await self._generate_vectors(input=input) search_result = self.client.search( collection_name=self.index_name, query_vector=("content", vectors[0]), + query_filter=filter, limit=top_k, with_payload=True, )