Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] CloudflareWorkersAIEmbeddingFunction #1271

Closed
72 changes: 72 additions & 0 deletions chromadb/test/ef/test_cloudflare_ef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os

import pytest

from chromadb.utils.embedding_functions.cloudflare_workers_ai_embedding_function import (
CloudflareWorkersAIEmbeddingFunction,
)


@pytest.mark.skipif(
"CF_API_TOKEN" not in os.environ,
reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.",
)
def test_cf_ef_token_and_account() -> None:
ef = CloudflareWorkersAIEmbeddingFunction(
api_token=os.environ.get("CF_API_TOKEN", ""),
account_id=os.environ.get("CF_ACCOUNT_ID"),
)
embeddings = ef(["test doc"])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


@pytest.mark.skipif(
"CF_API_TOKEN" not in os.environ,
reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.",
)
def test_cf_ef_gateway() -> None:
ef = CloudflareWorkersAIEmbeddingFunction(
api_token=os.environ.get("CF_API_TOKEN", ""),
gateway_url=os.environ.get("CF_GATEWAY_ENDPOINT"),
)
embeddings = ef(["test doc"])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


@pytest.mark.skipif(
"CF_API_TOKEN" not in os.environ,
reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.",
)
def test_cf_ef_large_batch() -> None:
ef = CloudflareWorkersAIEmbeddingFunction(api_token="dummy", account_id="dummy")
with pytest.raises(ValueError, match="Batch too large"):
ef(["test doc"] * 101)


@pytest.mark.skipif(
"CF_API_TOKEN" not in os.environ,
reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.",
)
def test_cf_ef_missing_account_or_gateway() -> None:
with pytest.raises(
ValueError, match="Please provide either an account_id or a gateway_url"
):
CloudflareWorkersAIEmbeddingFunction(api_token="dummy")


@pytest.mark.skipif(
"CF_API_TOKEN" not in os.environ,
reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.",
)
def test_cf_ef_with_account_or_gateway() -> None:
with pytest.raises(
ValueError,
match="Please provide either an account_id or a gateway_url, not both",
):
CloudflareWorkersAIEmbeddingFunction(
api_token="dummy", account_id="dummy", gateway_url="dummy"
)
1 change: 1 addition & 0 deletions chromadb/test/ef/test_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_get_builtins_holds() -> None:
"SentenceTransformerEmbeddingFunction",
"Text2VecEmbeddingFunction",
"ChromaLangchainEmbeddingFunction",
"CloudflareWorkersAIEmbeddingFunction",
}

assert expected_builtins == embedding_functions.get_builtins()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging
from typing import Optional, Dict, cast

import httpx

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings

logger = logging.getLogger(__name__)


class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]):
# Follow API Quickstart for Cloudflare Workers AI
# https://developers.cloudflare.com/workers-ai/
# Information about the text embedding modules in Google Vertex AI
# https://developers.cloudflare.com/workers-ai/models/embedding/
def __init__(
self,
api_token: str,
account_id: Optional[str] = None,
model_name: Optional[str] = "@cf/baai/bge-base-en-v1.5",
gateway_url: Optional[
str
] = None, # use Cloudflare AI Gateway instead of the usual endpoint
# right now endpoint schema supports up to 100 docs at a time
# https://developers.cloudflare.com/workers-ai/models/bge-small-en-v1.5/#api-schema (Input JSON Schema)
max_batch_size: Optional[int] = 100,
headers: Optional[Dict[str, str]] = None,
):
if not gateway_url and not account_id:
raise ValueError("Please provide either an account_id or a gateway_url.")
if gateway_url and account_id:
raise ValueError(
"Please provide either an account_id or a gateway_url, not both."
)
if gateway_url is not None and not gateway_url.endswith("/"):
gateway_url += "/"
self._api_url = (
f"{gateway_url}{model_name}"
if gateway_url is not None
else f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}"
)
self._session = httpx.Client()
self._session.headers.update(headers or {})
self._session.headers.update({"Authorization": f"Bearer {api_token}"})
self._max_batch_size = max_batch_size

def __call__(self, texts: Documents) -> Embeddings:
# Endpoint accepts up to 100 items at a time. We'll reject anything larger.
# It would be up to the user to split the input into smaller batches.
if self._max_batch_size and len(texts) > self._max_batch_size:
raise ValueError(
f"Batch too large {len(texts)} > {self._max_batch_size} (maximum batch size)."
)

print("URI", self._api_url)

response = self._session.post(f"{self._api_url}", json={"text": texts})
response.raise_for_status()
_json = response.json()
if "result" in _json and "data" in _json["result"]:
return cast(Embeddings, _json["result"]["data"])
else:
raise ValueError(f"Error calling Cloudflare Workers AI: {response.text}")
82 changes: 82 additions & 0 deletions clients/js/src/embeddings/CloudflareWorkersAIEmbeddingFunction.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import { IEmbeddingFunction } from "./IEmbeddingFunction";

export class CloudflareWorkersAIEmbeddingFunction
implements IEmbeddingFunction
{
private apiUrl: string;
private headers: { [key: string]: string };
private maxBatchSize: number;

constructor({
apiToken,
model,
accountId,
gatewayUrl,
maxBatchSize,
headers,
}: {
apiToken: string;
model?: string;
accountId?: string;
gatewayUrl?: string;
maxBatchSize?: number;
headers?: { [key: string]: string };
}) {
model = model || "@cf/baai/bge-base-en-v1.5";
this.maxBatchSize = maxBatchSize || 100;
if (accountId === undefined && gatewayUrl === undefined) {
throw new Error("Please provide either an accountId or a gatewayUrl.");
}
if (accountId !== undefined && gatewayUrl !== undefined) {
throw new Error(
"Please provide either an accountId or a gatewayUrl, not both.",
);
}
if (gatewayUrl !== undefined && !gatewayUrl.endsWith("/")) {
gatewayUrl += "/" + model;
}
this.apiUrl =
gatewayUrl ||
`https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`;
this.headers = headers || {};
this.headers["Authorization"] = `Bearer ${apiToken}`;
}

public async generate(texts: string[]) {
if (texts.length === 0) {
return [];
}
if (texts.length > this.maxBatchSize) {
throw new Error(
`Batch too large ${texts.length} > ${this.maxBatchSize} (maximum batch size).`,
);
}
try {
const response = await fetch(this.apiUrl, {
method: "POST",
headers: this.headers,
body: JSON.stringify({
text: texts,
}),
});

const data = (await response.json()) as {
success?: boolean;
messages: any[];
errors?: any[];
result: { shape: any[]; data: number[][] };
};
if (data.success === false) {
throw new Error(`${JSON.stringify(data.errors)}`);
}
return data.result.data;
} catch (error) {
console.error(error);
if (error instanceof Error) {
throw new Error(`Error calling CF API: ${error}`);
} else {
throw new Error(`Error calling CF API: ${error}`);
}
}
}
}
1 change: 1 addition & 0 deletions clients/js/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export { HuggingFaceEmbeddingServerFunction } from "./embeddings/HuggingFaceEmbe
export { JinaEmbeddingFunction } from "./embeddings/JinaEmbeddingFunction";
export { GoogleGenerativeAiEmbeddingFunction } from "./embeddings/GoogleGeminiEmbeddingFunction";
export { OllamaEmbeddingFunction } from "./embeddings/OllamaEmbeddingFunction";
export { CloudflareWorkersAIEmbeddingFunction } from "./embeddings/CloudflareWorkersAIEmbeddingFunction";

export {
IncludeEnum,
Expand Down
99 changes: 99 additions & 0 deletions clients/js/test/embeddings/cloudflare.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import { expect, test } from "@jest/globals";
import { DOCUMENTS } from "../data";
import { CloudflareWorkersAIEmbeddingFunction } from "../../src";

if (!process.env.CF_API_TOKEN) {
test.skip("it should generate Cloudflare embeddings with apiToken and AccountId", async () => {});
} else {
test("it should generate Cloudflare embeddings with apiToken and AccountId", async () => {
const embedder = new CloudflareWorkersAIEmbeddingFunction({
apiToken: process.env.CF_API_TOKEN as string,
accountId: process.env.CF_ACCOUNT_ID,
});
const embeddings = await embedder.generate(DOCUMENTS);
expect(embeddings).toBeDefined();
expect(embeddings.length).toBe(DOCUMENTS.length);
});
}

if (!process.env.CF_API_TOKEN) {
test.skip("it should generate Cloudflare embeddings with apiToken and AccountId and model", async () => {});
} else {
test("it should generate Cloudflare embeddings with apiToken and AccountId and model", async () => {
const embedder = new CloudflareWorkersAIEmbeddingFunction({
apiToken: process.env.CF_API_TOKEN as string,
accountId: process.env.CF_ACCOUNT_ID,
model: "@cf/baai/bge-small-en-v1.5",
});
const embeddings = await embedder.generate(DOCUMENTS);
expect(embeddings).toBeDefined();
expect(embeddings.length).toBe(DOCUMENTS.length);
});
}

if (!process.env.CF_API_TOKEN) {
test.skip("it should generate Cloudflare embeddings with apiToken and gateway", async () => {});
} else {
test("it should generate Cloudflare embeddings with apiToken and gateway", async () => {
const embedder = new CloudflareWorkersAIEmbeddingFunction({
apiToken: process.env.CF_API_TOKEN as string,
gatewayUrl: process.env.CF_GATEWAY_ENDPOINT,
});
const embeddings = await embedder.generate(DOCUMENTS);
expect(embeddings).toBeDefined();
expect(embeddings.length).toBe(DOCUMENTS.length);
});
}

if (!process.env.CF_API_TOKEN) {
test.skip("it should fail when batch too large", async () => {});
} else {
test("it should fail when batch too large", async () => {
const embedder = new CloudflareWorkersAIEmbeddingFunction({
apiToken: process.env.CF_API_TOKEN as string,
gatewayUrl: process.env.CF_GATEWAY_ENDPOINT,
});
const largeBatch = Array(100)
.fill([...DOCUMENTS])
.flat();
try {
await embedder.generate(largeBatch);
} catch (e: any) {
expect(e.message).toMatch("Batch too large");
}
});
}

if (!process.env.CF_API_TOKEN) {
test.skip("it should fail when gateway endpoint and account id are both provided", async () => {});
} else {
test("it should fail when gateway endpoint and account id are both provided", async () => {
try {
new CloudflareWorkersAIEmbeddingFunction({
apiToken: process.env.CF_API_TOKEN as string,
accountId: process.env.CF_ACCOUNT_ID,
gatewayUrl: process.env.CF_GATEWAY_ENDPOINT,
});
} catch (e: any) {
expect(e.message).toMatch(
"Please provide either an accountId or a gatewayUrl, not both.",
);
}
});
}

if (!process.env.CF_API_TOKEN) {
test.skip("it should fail when neither gateway endpoint nor account id are provided", async () => {});
} else {
test("it should fail when neither gateway endpoint nor account id are provided", async () => {
try {
new CloudflareWorkersAIEmbeddingFunction({
apiToken: process.env.CF_API_TOKEN as string,
});
} catch (e: any) {
expect(e.message).toMatch(
"Please provide either an accountId or a gatewayUrl.",
);
}
});
}
17 changes: 9 additions & 8 deletions docs/docs.trychroma.com/pages/guides/embeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ Chroma provides lightweight wrappers around popular embedding providers, making
{% special_table %}
{% /special_table %}

| | Python | JS |
|--------------|-----------|---------------|
| [OpenAI](/integrations/openai) | ✅ | ✅ |
| [Google Generative AI](/integrations/google-gemini) | ✅ | ✅ |
| [Cohere](/integrations/cohere) | ✅ | ✅ |
| [Hugging Face](/integrations/hugging-face) | ✅ | ➖ |
| [Instructor](/integrations/instructor) | ✅ | ➖ |
| | Python | JS |
|--------------------------------------------------------------------|-----------|---------------|
| [OpenAI](/integrations/openai) | ✅ | ✅ |
| [Google Generative AI](/integrations/google-gemini) | ✅ | ✅ |
| [Cohere](/integrations/cohere) | ✅ | ✅ |
| [Hugging Face](/integrations/hugging-face) | ✅ | ➖ |
| [Instructor](/integrations/instructor) | ✅ | ➖ |
| [Hugging Face Embedding Server](/integrations/hugging-face-server) | ✅ | ✅ |
| [Jina AI](/integrations/jinaai) | ✅ | ✅ |
| [Jina AI](/integrations/jinaai) | ✅ | ✅ |
| [Cloudflare Workers AI](/integrations/cloudflare) | ✅ | ✅ |

We welcome pull requests to add new Embedding Functions to the community.

Expand Down
1 change: 1 addition & 0 deletions docs/docs.trychroma.com/pages/integrations/_sidenav.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export const items = [
{ href: '/integrations/jinaai', children: 'JinaAI' },
{ href: '/integrations/roboflow', children: 'Roboflow' },
{ href: '/integrations/ollama', children: 'Ollama Embeddings' },
{ href: '/integrations/cloudflare', children: 'Cloudflare Workers AI Embeddings' },
]
},
{
Expand Down
Loading
Loading