From 7da962078357914cef08680648385a3721262d1f Mon Sep 17 00:00:00 2001 From: Mihir1003 <32996071+Mihir1003@users.noreply.github.com> Date: Wed, 3 Apr 2024 04:43:12 +0530 Subject: [PATCH] [ENH] Support langchain embedding functions with chroma (#1880) *Summarize the changes made by this PR.* - New functionality - Adding a function to create a chroma langchain embedding interface. This interface acts as a bridge between the langchain embedding function and the chroma custom embedding function. - Native Langchain multimodal support: The PR adds a Passthrough data loader that lets langchain users use OpenClip and other multi-modal embedding functions from langchain with chroma without having to handle storing images themselves. *How are these changes tested?* - installing chroma as an editable package locally and passing langchain integration tests - pytest test_api.py test_client.py succeeds *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* Co-authored-by: Anton Troynikov --- chromadb/utils/data_loaders.py | 11 ++++- chromadb/utils/embedding_functions.py | 66 +++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/chromadb/utils/data_loaders.py b/chromadb/utils/data_loaders.py index 60057e0e584..82ea894aa9a 100644 --- a/chromadb/utils/data_loaders.py +++ b/chromadb/utils/data_loaders.py @@ -1,8 +1,8 @@ import importlib import multiprocessing -from typing import Optional, Sequence, List +from typing import Optional, Sequence, List, Tuple import numpy as np -from chromadb.api.types import URI, DataLoader, Image +from chromadb.api.types import URI, DataLoader, Image, URIs from concurrent.futures import ThreadPoolExecutor @@ -22,3 +22,10 @@ def _load_image(self, uri: Optional[URI]) -> Optional[Image]: def __call__(self, uris: Sequence[Optional[URI]]) -> List[Optional[Image]]: with ThreadPoolExecutor(max_workers=self._max_workers) as executor: return list(executor.map(self._load_image, uris)) + + +class ChromaLangchainPassthroughDataLoader(DataLoader[List[Optional[Image]]]): + # This is a simple pass through data loader that just returns the input data with "images" + # flag which lets the langchain embedding function know that the data is image uris + def __call__(self, uris: URIs) -> Tuple[str, URIs]: # type: ignore + return ("images", uris) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 4942ee950a6..7de90548990 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -914,6 +914,71 @@ def __init__( def __call__(self, input: Documents) -> Embeddings: return cast(Embeddings, self._model(input).numpy().tolist()) + + +def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore + try: + from langchain_core.embeddings import Embeddings as LangchainEmbeddings + except ImportError: + raise ValueError( + "The langchain_core python package is not installed. Please install it with `pip install langchain-core`" + ) + + class ChromaLangchainEmbeddingFunction( + LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore + ): + """ + This class is used as bridge between langchain embedding functions and custom chroma embedding functions. + """ + + def __init__(self, embedding_function: LangchainEmbeddings) -> None: + """ + Initialize the ChromaLangchainEmbeddingFunction + + Args: + embedding_function : The embedding function implementing Embeddings from langchain_core. + """ + self.embedding_function = embedding_function + + def embed_documents(self, documents: Documents) -> List[List[float]]: + return self.embedding_function.embed_documents(documents) # type: ignore + + def embed_query(self, query: str) -> List[float]: + return self.embedding_function.embed_query(query) # type: ignore + + def embed_image(self, uris: List[str]) -> List[List[float]]: + if hasattr(self.embedding_function, "embed_image"): + return self.embedding_function.embed_image(uris) # type: ignore + else: + raise ValueError( + "The provided embedding function does not support image embeddings." + ) + + def __call__(self, input: Documents) -> Embeddings: # type: ignore + """ + Get the embeddings for a list of texts or images. + + Args: + input (Documents | Images): A list of texts or images to get embeddings for. + Images should be provided as a list of URIs passed through the langchain data loader + + Returns: + Embeddings: The embeddings for the texts or images. + + Example: + >>> langchain_embedding = ChromaLangchainEmbeddingFunction(embedding_function=OpenAIEmbeddings(model="text-embedding-3-large")) + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = langchain_embedding(texts) + """ + # Due to langchain quirks, the dataloader returns a tuple if the input is uris of images + if input[0] == "images": + return self.embed_image(list(input[1])) # type: ignore + + return self.embed_documents(list(input)) # type: ignore + + return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn) + + class OllamaEmbeddingFunction(EmbeddingFunction[Documents]): """ This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings). @@ -969,6 +1034,7 @@ def __call__(self, input: Documents) -> Embeddings: ], ) + # List of all classes in this module _classes = [ name