Skip to content

Commit

Permalink
feat: add applying filters on queries for qdrant
Browse files Browse the repository at this point in the history
  • Loading branch information
elisalimli committed Apr 26, 2024
1 parent 6083cb7 commit 8655633
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 5 deletions.
6 changes: 4 additions & 2 deletions models/query.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import List, Optional

from pydantic import BaseModel
from typing import List, Optional

from models.document import BaseDocumentChunk
from models.ingest import EncoderConfig
from models.vector_database import VectorDatabase
from qdrant_client.http.models import Filter


class RequestPayload(BaseModel):
Expand All @@ -15,6 +15,8 @@ class RequestPayload(BaseModel):
session_id: Optional[str] = None
interpreter_mode: Optional[bool] = False
exclude_fields: List[str] = None
# TODO: use our own Filter model
filter: Optional[Filter] = None


class ResponseData(BaseModel):
Expand Down
4 changes: 3 additions & 1 deletion service/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def create_route_layer() -> RouteLayer:
async def get_documents(
*, vector_service: BaseVectorDatabase, payload: RequestPayload
) -> list[BaseDocumentChunk]:
chunks = await vector_service.query(input=payload.input, top_k=5)
chunks = await vector_service.query(
input=payload.input, filter=payload.filter, top_k=5
)
# filter out documents with empty content
chunks = [chunk for chunk in chunks if chunk.content.strip()]
if not len(chunks):
Expand Down
5 changes: 4 additions & 1 deletion vectordbs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from models.delete import DeleteResponse
from models.document import BaseDocumentChunk
from models.query import Filter
from utils.logger import logger


Expand All @@ -24,7 +25,9 @@ async def upsert(self, chunks: List[BaseDocumentChunk]):
pass

@abstractmethod
async def query(self, input: str, top_k: int = 25) -> List[BaseDocumentChunk]:
async def query(
self, input: str, filter: Filter, top_k: int = 25
) -> List[BaseDocumentChunk]:
pass

@abstractmethod
Expand Down
6 changes: 5 additions & 1 deletion vectordbs/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from models.delete import DeleteResponse
from models.document import BaseDocumentChunk
from models.query import Filter
from vectordbs.base import BaseVectorDatabase

MAX_QUERY_TOP_K = 5
Expand Down Expand Up @@ -69,11 +70,14 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:

self.client.upsert(collection_name=self.index_name, wait=True, points=points)

async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List:
async def query(
self, input: str, filter: Filter, top_k: int = MAX_QUERY_TOP_K
) -> List:
vectors = await self._generate_vectors(input=input)
search_result = self.client.search(
collection_name=self.index_name,
query_vector=("content", vectors[0]),
query_filter=filter,
limit=top_k,
with_payload=True,
)
Expand Down

0 comments on commit 8655633

Please sign in to comment.