Skip to content

Commit

Permalink
feat: add universal sentence encoder embedding function
Browse files Browse the repository at this point in the history
  • Loading branch information
csbasil authored and atroyn committed Apr 3, 2024
1 parent 193988d commit e7730dc
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,9 +743,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:


class RoboflowEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]):
def __init__(
self, api_key: str = "", api_url = "https://infer.roboflow.com"
) -> None:
def __init__(self, api_key: str = "", api_url="https://infer.roboflow.com") -> None:
"""
Create a RoboflowEmbeddingFunction.
Expand All @@ -757,7 +755,7 @@ def __init__(
api_key = os.environ.get("ROBOFLOW_API_KEY")

self._api_url = api_url
self._api_key = api_key
self._api_key = api_key

try:
self._PILImage = importlib.import_module("PIL.Image")
Expand Down Expand Up @@ -789,10 +787,10 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
json=infer_clip_payload,
)

result = res.json()['embeddings']
result = res.json()["embeddings"]

embeddings.append(result[0])

elif is_document(item):
infer_clip_payload = {
"text": input,
Expand All @@ -803,13 +801,13 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
json=infer_clip_payload,
)

result = res.json()['embeddings']
result = res.json()["embeddings"]

embeddings.append(result[0])

return embeddings


class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
Expand Down Expand Up @@ -900,6 +898,22 @@ def __call__(self, input: Documents) -> Embeddings:
)


class UniversalSentenceEncoderEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self, model_name: str = "https://tfhub.dev/google/universal-sentence-encoder/4"
):
try:
import tensorflow_hub as hub
except ImportError:
raise ValueError(
"The tensorflow_hub python package is not installed. Please install it with `pip install tensorflow_hub`"
)
self._model = hub.load(model_name)

def __call__(self, input: Documents) -> Embeddings:
return cast(Embeddings, self._model(input).numpy().tolist())


def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore
try:
from langchain_core.embeddings import Embeddings as LangchainEmbeddings
Expand Down Expand Up @@ -962,7 +976,7 @@ def __call__(self, input: Documents) -> Embeddings: # type: ignore

return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn)


class OllamaEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings).
Expand Down Expand Up @@ -1018,7 +1032,7 @@ def __call__(self, input: Documents) -> Embeddings:
],
)


# List of all classes in this module
_classes = [
name
Expand Down

0 comments on commit e7730dc

Please sign in to comment.