From 38813d7090294c0c96d4963a2a230db4fef5e37e Mon Sep 17 00:00:00 2001 From: Jean-Baptiste dlb <45666468+JeanBaptiste-dlb@users.noreply.github.com> Date: Wed, 6 Dec 2023 18:12:54 +0100 Subject: [PATCH] Qdrant metadata payload keys (#13001) - **Description:** In Qdrant allows to input list of keys as the content_payload_key to retrieve multiple fields (the generated document will contain the dictionary {field: value} in a string), - **Issue:** Previously we were able to retrieve only one field from the vector database when making a search - **Dependencies:** - **Tag maintainer:** - **Twitter handle:** @jb_dlb --------- Co-authored-by: Jean Baptiste De La Broise --- .../langchain/vectorstores/qdrant.py | 106 +++++++++++++----- 1 file changed, 78 insertions(+), 28 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/qdrant.py b/libs/langchain/langchain/vectorstores/qdrant.py index 09cba48911f78..4d6f3170c8142 100644 --- a/libs/langchain/langchain/vectorstores/qdrant.py +++ b/libs/langchain/langchain/vectorstores/qdrant.py @@ -82,8 +82,8 @@ class Qdrant(VectorStore): qdrant = Qdrant(client, collection_name, embedding_function) """ - CONTENT_KEY = "page_content" - METADATA_KEY = "metadata" + CONTENT_KEY = ["page_content"] + METADATA_KEY = ["metadata"] VECTOR_NAME = None def __init__( @@ -91,8 +91,8 @@ def __init__( client: Any, collection_name: str, embeddings: Optional[Embeddings] = None, - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, + content_payload_key: Union[list, str] = CONTENT_KEY, + metadata_payload_key: Union[list, str] = METADATA_KEY, distance_strategy: str = "COSINE", vector_name: Optional[str] = VECTOR_NAME, embedding_function: Optional[Callable] = None, # deprecated @@ -112,6 +112,12 @@ def __init__( f"got {type(client)}" ) + if isinstance(content_payload_key, str): # Ensuring Backward compatibility + content_payload_key = [content_payload_key] + + if isinstance(metadata_payload_key, str): # Ensuring Backward compatibility + metadata_payload_key = [metadata_payload_key] + if embeddings is None and embedding_function is None: raise ValueError( "`embeddings` value can't be None. Pass `Embeddings` instance." @@ -127,8 +133,14 @@ def __init__( self._embeddings_function = embedding_function self.client: qdrant_client.QdrantClient = client self.collection_name = collection_name - self.content_payload_key = content_payload_key or self.CONTENT_KEY - self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY + self.content_payload_key = ( + content_payload_key if content_payload_key is not None else self.CONTENT_KEY + ) + self.metadata_payload_key = ( + metadata_payload_key + if metadata_payload_key is not None + else self.METADATA_KEY + ) self.vector_name = vector_name or self.VECTOR_NAME if embedding_function is not None: @@ -1178,8 +1190,8 @@ def from_texts( path: Optional[str] = None, collection_name: Optional[str] = None, distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, + content_payload_key: List[str] = CONTENT_KEY, + metadata_payload_key: List[str] = METADATA_KEY, vector_name: Optional[str] = VECTOR_NAME, batch_size: int = 64, shard_number: Optional[int] = None, @@ -1354,8 +1366,8 @@ async def afrom_texts( path: Optional[str] = None, collection_name: Optional[str] = None, distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, + content_payload_key: List[str] = CONTENT_KEY, + metadata_payload_key: List[str] = METADATA_KEY, vector_name: Optional[str] = VECTOR_NAME, batch_size: int = 64, shard_number: Optional[int] = None, @@ -1527,8 +1539,8 @@ def construct_instance( path: Optional[str] = None, collection_name: Optional[str] = None, distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, + content_payload_key: List[str] = CONTENT_KEY, + metadata_payload_key: List[str] = METADATA_KEY, vector_name: Optional[str] = VECTOR_NAME, shard_number: Optional[int] = None, replication_factor: Optional[int] = None, @@ -1691,8 +1703,8 @@ async def aconstruct_instance( path: Optional[str] = None, collection_name: Optional[str] = None, distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, + content_payload_key: List[str] = CONTENT_KEY, + metadata_payload_key: List[str] = METADATA_KEY, vector_name: Optional[str] = VECTOR_NAME, shard_number: Optional[int] = None, replication_factor: Optional[int] = None, @@ -1888,11 +1900,11 @@ def _similarity_search_with_relevance_scores( @classmethod def _build_payloads( - cls, + cls: Type[Qdrant], texts: Iterable[str], metadatas: Optional[List[dict]], - content_payload_key: str, - metadata_payload_key: str, + content_payload_key: list[str], + metadata_payload_key: list[str], ) -> List[dict]: payloads = [] for i, text in enumerate(texts): @@ -1913,29 +1925,67 @@ def _build_payloads( @classmethod def _document_from_scored_point( - cls, + cls: Type[Qdrant], scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, + content_payload_key: list[str], + metadata_payload_key: list[str], ) -> Document: - return Document( - page_content=scored_point.payload.get(content_payload_key), - metadata=scored_point.payload.get(metadata_payload_key) or {}, + payload = scored_point.payload + return Qdrant._document_from_payload( + payload=payload, + content_payload_key=content_payload_key, + metadata_payload_key=metadata_payload_key, ) @classmethod def _document_from_scored_point_grpc( - cls, + cls: Type[Qdrant], scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, + content_payload_key: list[str], + metadata_payload_key: list[str], ) -> Document: from qdrant_client.conversions.conversion import grpc_to_payload payload = grpc_to_payload(scored_point.payload) + return Qdrant._document_from_payload( + payload=payload, + content_payload_key=content_payload_key, + metadata_payload_key=metadata_payload_key, + ) + + @classmethod + def _document_from_payload( + cls: Type[Qdrant], + payload: Any, + content_payload_key: list[str], + metadata_payload_key: list[str], + ) -> Document: + if len(content_payload_key) == 1: + content = payload.get( + content_payload_key + ) # Ensuring backward compatibility + elif len(content_payload_key) > 1: + content = { + content_key: payload.get(content_key) + for content_key in content_payload_key + } + content = str(content) # Ensuring str type output + else: + content = "" + if len(metadata_payload_key) == 1: + metadata = payload.get( + metadata_payload_key + ) # Ensuring backward compatibility + elif len(metadata_payload_key) > 1: + metadata = { + metadata_key: payload.get(metadata_key) + for metadata_key in metadata_payload_key + } + else: + metadata = {} return Document( - page_content=payload[content_payload_key], - metadata=payload.get(metadata_payload_key) or {}, + page_content=content, + metadata=metadata, ) def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]: