diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py index 5d7a6fc8edc8a..6fd12db4322a6 100644 --- a/libs/community/langchain_community/vectorstores/azuresearch.py +++ b/libs/community/langchain_community/vectorstores/azuresearch.py @@ -428,6 +428,12 @@ def add_texts( logger.debug("Nothing to insert, skipping.") return [] + # when `keys` are not passed in and there is `ids` in kwargs, use those instead + # base class expects `ids` passed in rather than `keys` + # https://github.com/langchain-ai/langchain/blob/4cdaca67dc51dba887289f56c6fead3c1a52f97d/libs/core/langchain_core/vectorstores/base.py#L65 + if (not keys) and ("ids" in kwargs) and (len(kwargs["ids"]) == len(embeddings)): + keys = kwargs["ids"] + return self.add_embeddings(zip(texts, embeddings), metadatas, keys=keys) async def aadd_texts( @@ -452,6 +458,12 @@ async def aadd_texts( logger.debug("Nothing to insert, skipping.") return [] + # when `keys` are not passed in and there is `ids` in kwargs, use those instead + # base class expects `ids` passed in rather than `keys` + # https://github.com/langchain-ai/langchain/blob/4cdaca67dc51dba887289f56c6fead3c1a52f97d/libs/core/langchain_core/vectorstores/base.py#L65 + if (not keys) and ("ids" in kwargs) and (len(kwargs["ids"]) == len(embeddings)): + keys = kwargs["ids"] + return await self.aadd_embeddings(zip(texts, embeddings), metadatas, keys=keys) def add_embeddings( @@ -468,9 +480,13 @@ def add_embeddings( data = [] for i, (text, embedding) in enumerate(text_embeddings): # Use provided key otherwise use default key - key = keys[i] if keys else str(uuid.uuid4()) - # Encoding key for Azure Search valid characters - key = base64.urlsafe_b64encode(bytes(key, "utf-8")).decode("ascii") + if keys: + key = keys[i] + else: + key = str(uuid.uuid4()) + # Encoding key for Azure Search valid characters + key = base64.urlsafe_b64encode(bytes(key, "utf-8")).decode("ascii") + metadata = metadatas[i] if metadatas else {} # Add data to index # Additional metadata to fields mapping diff --git a/libs/community/tests/unit_tests/vectorstores/test_azure_search.py b/libs/community/tests/unit_tests/vectorstores/test_azure_search.py index a06fbfd151b0b..255105f32894b 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_azure_search.py +++ b/libs/community/tests/unit_tests/vectorstores/test_azure_search.py @@ -188,3 +188,40 @@ def mock_create_index() -> None: ) assert vector_store.client is not None assert vector_store.client._api_version == "test" + + +@pytest.mark.requires("azure.search.documents") +def test_ids_used_correctly() -> None: + """Check whether vector store uses the document ids when provided with them.""" + from azure.search.documents import SearchClient + from azure.search.documents.indexes import SearchIndexClient + from langchain_core.documents import Document + + class Response: + def __init__(self) -> None: + self.succeeded: bool = True + + def mock_upload_documents(self, documents: List[object]) -> List[Response]: # type: ignore[no-untyped-def] + # assume all documents uploaded successfuly + response = [Response() for _ in documents] + return response + + documents = [ + Document( + page_content="page zero Lorem Ipsum", + metadata={"source": "document.pdf", "page": 0, "id": "ID-document-1"}, + ), + Document( + page_content="page one Lorem Ipsum", + metadata={"source": "document.pdf", "page": 1, "id": "ID-document-2"}, + ), + ] + ids_provided = [i.metadata.get("id") for i in documents] + + with patch.object( + SearchClient, "upload_documents", mock_upload_documents + ), patch.object(SearchIndexClient, "get_index", mock_default_index): + vector_store = create_vector_store() + ids_used_at_upload = vector_store.add_documents(documents, ids=ids_provided) + assert len(ids_provided) == len(ids_used_at_upload) + assert ids_provided == ids_used_at_upload