Skip to content

Commit

Permalink
[ENH]: Ollama embedding function (#1813)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - New functionality
	 - New Ollama embedding function (Python and JS)
	 - Example of how to run Ollama with the embedding function

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python, `yarn test` for js

## Documentation Changes

chroma-core/docs#222
  • Loading branch information
tazarov authored Mar 28, 2024
1 parent 5f3f141 commit 1d77f99
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 3 deletions.
34 changes: 34 additions & 0 deletions chromadb/test/ef/test_ollama_ef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os

import pytest
import requests
from requests import HTTPError
from requests.exceptions import ConnectionError

from chromadb.utils.embedding_functions import OllamaEmbeddingFunction


def test_ollama() -> None:
"""
To set up the Ollama server, follow instructions at: https://github.com/ollama/ollama?tab=readme-ov-file
Export the OLLAMA_SERVER_URL and OLLAMA_MODEL environment variables.
"""
if (
os.environ.get("OLLAMA_SERVER_URL") is None
or os.environ.get("OLLAMA_MODEL") is None
):
pytest.skip(
"OLLAMA_SERVER_URL or OLLAMA_MODEL environment variable not set. Skipping test."
)
try:
response = requests.get(os.environ.get("OLLAMA_SERVER_URL", ""))
# If the response was successful, no Exception will be raised
response.raise_for_status()
except (HTTPError, ConnectionError):
pytest.skip("Ollama server not running. Skipping test.")
ef = OllamaEmbeddingFunction(
model_name=os.environ.get("OLLAMA_MODEL") or "nomic-embed-text",
url=f"{os.environ.get('OLLAMA_SERVER_URL')}/embeddings",
)
embeddings = ef(["Here is an article about llamas...", "this is another article"])
assert len(embeddings) == 2
62 changes: 60 additions & 2 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
model_name: str = "all-MiniLM-L6-v2",
device: str = "cpu",
normalize_embeddings: bool = False,
**kwargs: Any
**kwargs: Any,
):
"""Initialize SentenceTransformerEmbeddingFunction.
Expand All @@ -78,7 +78,9 @@ def __init__(
raise ValueError(
"The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`"
)
self.models[model_name] = SentenceTransformer(model_name, device=device, **kwargs)
self.models[model_name] = SentenceTransformer(
model_name, device=device, **kwargs
)
self._model = self.models[model_name]
self._normalize_embeddings = normalize_embeddings

Expand Down Expand Up @@ -828,6 +830,62 @@ def __call__(self, input: Documents) -> Embeddings:
)


class OllamaEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings).
"""

def __init__(self, url: str, model_name: str) -> None:
"""
Initialize the Ollama Embedding Function.
Args:
url (str): The URL of the Ollama Server.
model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text" (see https://ollama.com/library for available models).
"""
try:
import requests
except ImportError:
raise ValueError(
"The requests python package is not installed. Please install it with `pip install requests`"
)
self._api_url = f"{url}"
self._model_name = model_name
self._session = requests.Session()

def __call__(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.
Args:
input (Documents): A list of texts to get embeddings for.
Returns:
Embeddings: The embeddings for the texts.
Example:
>>> ollama_ef = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="nomic-embed-text")
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = ollama_ef(texts)
"""
# Call Ollama Server API for each document
texts = input if isinstance(input, list) else [input]
embeddings = [
self._session.post(
self._api_url, json={"model": self._model_name, "prompt": text}
).json()
for text in texts
]
return cast(
Embeddings,
[
embedding["embedding"]
for embedding in embeddings
if "embedding" in embedding
],
)


# List of all classes in this module
_classes = [
name
Expand Down
34 changes: 34 additions & 0 deletions clients/js/src/embeddings/OllamaEmbeddingFunction.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import { IEmbeddingFunction } from "./IEmbeddingFunction";

export class OllamaEmbeddingFunction implements IEmbeddingFunction {
private readonly url: string;
private readonly model: string;

constructor({ url, model }: { url: string, model: string }) {
// we used to construct the client here, but we need to async import the types
// for the openai npm package, and the constructor can not be async
this.url = url;
this.model = model;
}

public async generate(texts: string[]) {
let embeddings:number[][] = [];
for (let text of texts) {
const response = await fetch(this.url, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ 'model':this.model, 'prompt': text })
});

if (!response.ok) {
throw new Error(`Failed to generate embeddings: ${response.status} (${response.statusText})`);
}
let finalResponse = await response.json();
embeddings.push(finalResponse['embedding']);
}
return embeddings;
}

}
3 changes: 2 additions & 1 deletion clients/js/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ export { ChromaClient } from "./ChromaClient";
export { AdminClient } from "./AdminClient";
export { CloudClient } from "./CloudClient";
export { Collection } from "./Collection";

export { IEmbeddingFunction } from "./embeddings/IEmbeddingFunction";
export { OpenAIEmbeddingFunction } from "./embeddings/OpenAIEmbeddingFunction";
export { CohereEmbeddingFunction } from "./embeddings/CohereEmbeddingFunction";
Expand All @@ -11,6 +10,8 @@ export { DefaultEmbeddingFunction } from "./embeddings/DefaultEmbeddingFunction"
export { HuggingFaceEmbeddingServerFunction } from "./embeddings/HuggingFaceEmbeddingServerFunction";
export { JinaEmbeddingFunction } from "./embeddings/JinaEmbeddingFunction";
export { GoogleGenerativeAiEmbeddingFunction } from "./embeddings/GoogleGeminiEmbeddingFunction";
export { OllamaEmbeddingFunction } from './embeddings/OllamaEmbeddingFunction';


export {
IncludeEnum,
Expand Down
28 changes: 28 additions & 0 deletions clients/js/test/add.collections.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { METADATAS } from "./data";
import { IncludeEnum } from "../src/types";
import { OpenAIEmbeddingFunction } from "../src/embeddings/OpenAIEmbeddingFunction";
import { CohereEmbeddingFunction } from "../src/embeddings/CohereEmbeddingFunction";
import { OllamaEmbeddingFunction } from "../src/embeddings/OllamaEmbeddingFunction";
test("it should add single embeddings to a collection", async () => {
await chroma.reset();
const collection = await chroma.createCollection({ name: "test" });
Expand Down Expand Up @@ -120,3 +121,30 @@ test("should error on empty embedding", async () => {
expect(e.message).toMatch("got empty embedding at pos");
}
});

if (!process.env.OLLAMA_SERVER_URL) {
test.skip("it should use ollama EF, OLLAMA_SERVER_URL not defined", async () => {});
} else {
test("it should use ollama EF", async () => {
await chroma.reset();
const embedder = new OllamaEmbeddingFunction({
url:
process.env.OLLAMA_SERVER_URL ||
"http://127.0.0.1:11434/api/embeddings",
model: "nomic-embed-text",
});
const collection = await chroma.createCollection({
name: "test",
embeddingFunction: embedder,
});
const embeddings = await embedder.generate(DOCUMENTS);
await collection.add({ ids: IDS, embeddings: embeddings });
const count = await collection.count();
expect(count).toBe(3);
var res = await collection.get({
ids: IDS,
include: [IncludeEnum.Embeddings],
});
expect(res.embeddings).toEqual(embeddings); // reverse because of the order of the ids
});
}
40 changes: 40 additions & 0 deletions examples/use_with/ollama.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Ollama

First let's run a local docker container with Ollama. We'll pull `nomic-embed-text` model:

```bash
docker run -d -v ./ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
docker exec -it ollama ollama run nomic-embed-text # press Ctrl+D to exit after model downloads successfully
# test it
curl http://localhost:11434/api/embeddings -d '{"model": "nomic-embed-text","prompt": "Here is an article about llamas..."}'
```

Now let's configure our OllamaEmbeddingFunction Embedding (python) function with the default Ollama endpoint:

```python
import chromadb
from chromadb.utils.embedding_functions import OllamaEmbeddingFunction

client = chromadb.PersistentClient(path="ollama")

# create EF with custom endpoint
ef = OllamaEmbeddingFunction(
model_name="nomic-embed-text",
url="http://127.0.0.1:11434/api/embeddings",
)

print(ef(["Here is an article about llamas..."]))
```

For JS users, you can use the `OllamaEmbeddingFunction` class to create embeddings:

```javascript
const {OllamaEmbeddingFunction} = require('chromadb');
const embedder = new OllamaEmbeddingFunction({
url: "http://127.0.0.1:11434/api/embeddings",
model: "llama2"
})

// use directly
const embeddings = embedder.generate(["Here is an article about llamas..."])
```

0 comments on commit 1d77f99

Please sign in to comment.