diff --git a/chromadb/test/ef/test_spacy_ef.py b/chromadb/test/ef/test_spacy_ef.py new file mode 100644 index 00000000000..155fde78b68 --- /dev/null +++ b/chromadb/test/ef/test_spacy_ef.py @@ -0,0 +1,34 @@ +import pytest +import numpy +from chromadb.utils.embedding_functions import SpacyEmbeddingFunction + +input_list = ["great work by the guy", "Super man is that guy"] +model_name = "en_core_web_md" +unknown_model = "unknown_model" +spacy = pytest.importorskip("spacy", reason="spacy not installed") + + +def test_spacyembeddingfunction_isnotnone_wheninputisnotnone(): + spacy_emb_fn = SpacyEmbeddingFunction(model_name) + assert spacy_emb_fn(input_list) is not None + + +def test_spacyembddingfunction_throwserror_whenmodel_notfound(): + with pytest.raises( + ValueError, + match=r"""spacy models are not downloaded yet, please download them using `spacy download model_name`, Please checkout + for the list of models from: https://spacy.io/usage/models.""", + ): + SpacyEmbeddingFunction(unknown_model) + + +def test_spacyembddingfunction_isembedding_wheninput_islist(): + spacy_emb_fn = SpacyEmbeddingFunction(model_name) + assert type(spacy_emb_fn(input_list)) is list + + +def test_spacyembeddingfunction_returnslistoflistsofloats(): + spacy_emb_fn = SpacyEmbeddingFunction(model_name) + expected_output = spacy_emb_fn(input_list) + assert type(expected_output[0]) is list + assert type(expected_output[0][0]) is numpy.float64 diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index cc779865675..65b1ff6a6a0 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -372,6 +372,48 @@ def __call__(self, input: Documents) -> Embeddings: return cast(Embeddings, self._model.encode(texts_with_instructions).tolist()) +class SpacyEmbeddingFunction(EmbeddingFunction[Documents]): + def __init__(self, model_name: str = "en_core_web_lg"): + try: + import spacy + except ImportError: + raise ValueError( + "The spacy python package is not installed. Please install it with `pip install spacy`" + ) + self._model_name = model_name + + try: + self._nlp = spacy.load("{model}".format(model=self._model_name)) + except OSError: + raise ValueError( + """spacy models are not downloaded yet, please download them using `spacy download model_name`, Please checkout + for the list of models from: https://spacy.io/usage/models. By default the module will load en_core_web_lg + model as it optimizes accuracy and has embeddings in-built, please download and load with `en_core_web_md` + if you want to priortize efficiency over accuracy, the same logic applies for models from other languages also. + language_web_core_sm and language_web_core_trf doesn't have pre-trained embeddings.""" + ) + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + + Args: + texts (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + >>> spacy_fn = SpacyEmbeddingFunction(model_name="md") + >>> input = ["Hello, world!", "How are you?"] + >>> embeddings = spacy_fn(input) + """ + + return cast( + Embeddings, [list(self._nlp(doc).vector.astype("float")) for doc in input] + ) + + # In order to remove dependencies on sentence-transformers, which in turn depends on # pytorch and sentence-piece we have created a default ONNX embedding function that # implements the same functionality as "all-MiniLM-L6-v2" from sentence-transformers.