Skip to content

Commit

Permalink
Qdrant metadata payload keys (#13001)
Browse files Browse the repository at this point in the history
- **Description:** In Qdrant allows to input list of keys as the
content_payload_key to retrieve multiple fields (the generated document
will contain the dictionary {field: value} in a string),
- **Issue:** Previously we were able to retrieve only one field from the
vector database when making a search
  - **Dependencies:** 
  - **Tag maintainer:** 
  - **Twitter handle:** @jb_dlb

---------

Co-authored-by: Jean Baptiste De La Broise <[email protected]>
  • Loading branch information
JeanBaptiste-dlb and jbdlb authored Dec 6, 2023
1 parent ad6dfb6 commit 38813d7
Showing 1 changed file with 78 additions and 28 deletions.
106 changes: 78 additions & 28 deletions libs/langchain/langchain/vectorstores/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,17 @@ class Qdrant(VectorStore):
qdrant = Qdrant(client, collection_name, embedding_function)
"""

CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
CONTENT_KEY = ["page_content"]
METADATA_KEY = ["metadata"]
VECTOR_NAME = None

def __init__(
self,
client: Any,
collection_name: str,
embeddings: Optional[Embeddings] = None,
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
content_payload_key: Union[list, str] = CONTENT_KEY,
metadata_payload_key: Union[list, str] = METADATA_KEY,
distance_strategy: str = "COSINE",
vector_name: Optional[str] = VECTOR_NAME,
embedding_function: Optional[Callable] = None, # deprecated
Expand All @@ -112,6 +112,12 @@ def __init__(
f"got {type(client)}"
)

if isinstance(content_payload_key, str): # Ensuring Backward compatibility
content_payload_key = [content_payload_key]

if isinstance(metadata_payload_key, str): # Ensuring Backward compatibility
metadata_payload_key = [metadata_payload_key]

if embeddings is None and embedding_function is None:
raise ValueError(
"`embeddings` value can't be None. Pass `Embeddings` instance."
Expand All @@ -127,8 +133,14 @@ def __init__(
self._embeddings_function = embedding_function
self.client: qdrant_client.QdrantClient = client
self.collection_name = collection_name
self.content_payload_key = content_payload_key or self.CONTENT_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
self.content_payload_key = (
content_payload_key if content_payload_key is not None else self.CONTENT_KEY
)
self.metadata_payload_key = (
metadata_payload_key
if metadata_payload_key is not None
else self.METADATA_KEY
)
self.vector_name = vector_name or self.VECTOR_NAME

if embedding_function is not None:
Expand Down Expand Up @@ -1178,8 +1190,8 @@ def from_texts(
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
content_payload_key: List[str] = CONTENT_KEY,
metadata_payload_key: List[str] = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None,
Expand Down Expand Up @@ -1354,8 +1366,8 @@ async def afrom_texts(
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
content_payload_key: List[str] = CONTENT_KEY,
metadata_payload_key: List[str] = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None,
Expand Down Expand Up @@ -1527,8 +1539,8 @@ def construct_instance(
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
content_payload_key: List[str] = CONTENT_KEY,
metadata_payload_key: List[str] = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
Expand Down Expand Up @@ -1691,8 +1703,8 @@ async def aconstruct_instance(
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
content_payload_key: List[str] = CONTENT_KEY,
metadata_payload_key: List[str] = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
Expand Down Expand Up @@ -1888,11 +1900,11 @@ def _similarity_search_with_relevance_scores(

@classmethod
def _build_payloads(
cls,
cls: Type[Qdrant],
texts: Iterable[str],
metadatas: Optional[List[dict]],
content_payload_key: str,
metadata_payload_key: str,
content_payload_key: list[str],
metadata_payload_key: list[str],
) -> List[dict]:
payloads = []
for i, text in enumerate(texts):
Expand All @@ -1913,29 +1925,67 @@ def _build_payloads(

@classmethod
def _document_from_scored_point(
cls,
cls: Type[Qdrant],
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
content_payload_key: list[str],
metadata_payload_key: list[str],
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata=scored_point.payload.get(metadata_payload_key) or {},
payload = scored_point.payload
return Qdrant._document_from_payload(
payload=payload,
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
)

@classmethod
def _document_from_scored_point_grpc(
cls,
cls: Type[Qdrant],
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
content_payload_key: list[str],
metadata_payload_key: list[str],
) -> Document:
from qdrant_client.conversions.conversion import grpc_to_payload

payload = grpc_to_payload(scored_point.payload)
return Qdrant._document_from_payload(
payload=payload,
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
)

@classmethod
def _document_from_payload(
cls: Type[Qdrant],
payload: Any,
content_payload_key: list[str],
metadata_payload_key: list[str],
) -> Document:
if len(content_payload_key) == 1:
content = payload.get(
content_payload_key
) # Ensuring backward compatibility
elif len(content_payload_key) > 1:
content = {
content_key: payload.get(content_key)
for content_key in content_payload_key
}
content = str(content) # Ensuring str type output
else:
content = ""
if len(metadata_payload_key) == 1:
metadata = payload.get(
metadata_payload_key
) # Ensuring backward compatibility
elif len(metadata_payload_key) > 1:
metadata = {
metadata_key: payload.get(metadata_key)
for metadata_key in metadata_payload_key
}
else:
metadata = {}
return Document(
page_content=payload[content_payload_key],
metadata=payload.get(metadata_payload_key) or {},
page_content=content,
metadata=metadata,
)

def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
Expand Down

0 comments on commit 38813d7

Please sign in to comment.