Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Dec 6, 2023
1 parent 7b4f473 commit 80d52c0
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 2 deletions.
4 changes: 2 additions & 2 deletions libs/langchain/langchain/retrievers/multi_vector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Optional, Any
from typing import Any, List, Optional

from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field, validator
Expand Down Expand Up @@ -38,7 +38,7 @@ class MultiVectorRetriever(BaseRetriever):

@validator("docstore", pre=True, always=True)
def shim_docstore(
cls, docstore: Optional[BaseStore[str, Document]], values: any
cls, docstore: Optional[BaseStore[str, Document]], values: Any
) -> BaseStore[str, Document]:
base_store = values.get("base_store")
if base_store is not None:
Expand Down
30 changes: 30 additions & 0 deletions libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Any, List

from langchain_core.documents import Document

from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore


class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
res = self.store.get(query)
if res is None:
return []
return [res]


def test_multi_vector_retriever() -> None:
vectorstore = InMemoryVectorstoreWithSearch()
retriever = MultiVectorRetriever(
vectorstore=vectorstore, docstore=InMemoryStore(), doc_id="doc_id"
)
documents = [Document(page_content="test document", metadata={"doc_id": "1"})]
retriever.vectorstore.add_documents(documents, ids=["1"])
retriever.docstore.mset(list(zip(["1"], documents)))
results = retriever.invoke("1")
assert len(results) > 0
assert results[0].page_content == "test document"
37 changes: 37 additions & 0 deletions libs/langchain/tests/unit_tests/retrievers/test_parent_document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any, List, Sequence

from langchain_core.documents import Document

from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import InMemoryStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore


class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
res = self.store.get(query)
if res is None:
return []
return [res]

def add_documents(self, documents: Sequence) -> None:
return super().add_documents(documents, ids=["1"])


def test_parent_document_retriever() -> None:
vectorstore = InMemoryVectorstoreWithSearch()
store = InMemoryStore()
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
documents = [Document(page_content="test document", metadata={"doc_id": "1"})]
retriever = ParentDocumentRetriever(
vectorstore=vectorstore,
docstore=store,
child_splitter=child_splitter,
)
retriever.add_documents(documents, ids=["1"])
results = retriever.invoke("1")
assert len(results) > 0
assert results[0].page_content == "test document"

0 comments on commit 80d52c0

Please sign in to comment.