Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH]: Add mlx embedding function #2275 #2295

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 243 additions & 1 deletion chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
import tarfile
import requests
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Union, Tuple , cast
import numpy as np
import numpy.typing as npt
import importlib
Expand Down Expand Up @@ -1017,6 +1017,248 @@ def __call__(self, input: Documents) -> Embeddings:
)




class MlXEmbeddingFunction(EmbeddingFunction):

def __load_model(
self,
bert_model: str,
weights_path: str,
) -> Tuple["MlXEmbeddingFunction.__load_model.Bert", "MlXEmbeddingFunction.__load_model.Tokenizer"]:
try :
import mlx.core as mx
import mlx.nn as nn
except ImportError :
raise ValueError(
"The mlx python package is not installed. Please install it with `pip install mlx`"
)

try:
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
except ImportError:
raise ValueError(
"The transformers python package is not installed. Please install it with `pip install transformers`"
)



class TransformerEncoderLayer(nn.Module):
"""
A transformer encoder layer with (the original BERT) post-normalization.
"""

def __init__(
self,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
layer_norm_eps: float = 1e-12,
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.linear1 = nn.Linear(dims, mlp_dims)
self.linear2 = nn.Linear(mlp_dims, dims)
self.gelu = nn.GELU()

def __call__(self, x, mask):
attention_out = self.attention(x, x, x, mask)
add_and_norm = self.ln1(x + attention_out)

ff = self.linear1(add_and_norm)
ff_gelu = self.gelu(ff)
ff_out = self.linear2(ff_gelu)
x = self.ln2(ff_out + add_and_norm)

return x


class TransformerEncoder(nn.Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims)
for i in range(num_layers)
]

def __call__(self, x, mask):
for layer in self.layers:
x = layer(x, mask)

return x


class BertEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.token_type_embeddings = nn.Embedding(
config.type_vocab_size, config.hidden_size
)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size
)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def __call__(
self, input_ids: mx.array, token_type_ids: mx.array = None
) -> mx.array:
words = self.word_embeddings(input_ids)
position = self.position_embeddings(
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
)

if token_type_ids is None:
# If token_type_ids is not provided, default to zeros
token_type_ids = mx.zeros_like(input_ids)

token_types = self.token_type_embeddings(token_type_ids)

embeddings = position + words + token_types
return self.norm(embeddings)

class Bert(nn.Module):
def __init__(self, config):
super().__init__()
self.embeddings = BertEmbeddings(config)
self.encoder = TransformerEncoder(
num_layers=config.num_hidden_layers,
dims=config.hidden_size,
num_heads=config.num_attention_heads,
mlp_dims=config.intermediate_size,
)
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)

def __call__(
self,
input_ids: mx.array,
token_type_ids: mx.array = None,
attention_mask: mx.array = None,
) -> Tuple[mx.array, mx.array]:
x = self.embeddings(input_ids, token_type_ids)

if attention_mask is not None:
# convert 0's to -infs, 1's to 0's, and make it broadcastable
attention_mask = mx.log(attention_mask)
attention_mask = mx.expand_dims(attention_mask, (1, 2))

y = self.encoder(x, attention_mask)
return y, mx.tanh(self.pooler(y[:, 0]))

if not Path(weights_path).exists():
raise ValueError(f"No model weights found in {weights_path}")

config = AutoConfig.from_pretrained(bert_model)

# create and update the model
model = Bert(config)
model.load_weights(weights_path)

tokenizer = AutoTokenizer.from_pretrained(bert_model)

return model, tokenizer

@staticmethod
def convert( bert_model:str, mlx_model: str) -> None:
try :
from transformers import AutoModel
except ImportError:
raise ValueError(
"The transformers python package is not installed. Please install it with `pip install transformers`"
)
def replace_key(key: str) -> str:
key = key.replace(".layer.", ".layers.")
key = key.replace(".self.key.", ".key_proj.")
key = key.replace(".self.query.", ".query_proj.")
key = key.replace(".self.value.", ".value_proj.")
key = key.replace(".attention.output.dense.", ".attention.out_proj.")
key = key.replace(".attention.output.LayerNorm.", ".ln1.")
key = key.replace(".output.LayerNorm.", ".ln2.")
key = key.replace(".intermediate.dense.", ".linear1.")
key = key.replace(".output.dense.", ".linear2.")
key = key.replace(".LayerNorm.", ".norm.")
key = key.replace("pooler.dense.", "pooler.")
return key

model = AutoModel.from_pretrained(bert_model)
# save the tensors
tensors = {
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()

}
np.savez(mlx_model, **tensors)

def __init__(
self,
bert_model: str,
weights_path: Optional[str] = None,
) -> None:
"""
Use local mlx model to get embeddings for a list of texts.

Args:
bert_model (str): The path to the BERT model.
weights_path (str): The path to the model weights.


code is based on https://github.com/ml-explore/mlx-examples/tree/main/bert
"""


model_path = Path(bert_model)
npz_path = weights_path
if weights_path is None:
npz_path = model_path.joinpath("model.npz")
if npz_path.exists():

print("Embeddings already exists. Skipping conversion.")
npz_path = npz_path.as_posix()
else :
print("Converting BERT model to MLX model")
npz_path = npz_path.as_posix()
self.convert(bert_model,npz_path)


self.model, self.tokenizer = self.__load_model(bert_model, npz_path)


def __call__(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.

Args:
input (Documents): A list of texts to get embeddings for.

Returns:
Embeddings: The embeddings for the texts.

Example:
>>> MlX_ef = MlXEmbeddingFunction(bert_model="bert-base-uncased", weights_path="bert-base-uncased/model.npz")
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = Mlx_ef(texts)


code is based on https://github.com/ml-explore/mlx-examples/tree/main/bert
"""
try :
import mlx.core as mx
except ImportError:
raise ImportError("mlx is not installed. Please install it with `pip install mlx`")

tokens = self.tokenizer(input, return_tensors="np", padding=True)

tokens = {key: mx.array(v) for key, v in tokens.items()}

_ , pooled = self.model(**tokens)

return pooled.tolist()


# List of all classes in this module
_classes = [
name
Expand Down
1 change: 1 addition & 0 deletions docs/docs.trychroma.com/pages/integrations/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Chroma provides lightweight wrappers around popular embedding providers, making
| [Jina AI](/integrations/jinaai) | ✅ | ✅ |
| [Roboflow](/integrations/roboflow) | ✅ | ➖ |
| [Ollama Embeddings](/integrations/ollama) | ✅ | ✅ |
| [MlX Embedding](/integrations/mlx) | ✅ | ➖ |


***
Expand Down
24 changes: 24 additions & 0 deletions docs/docs.trychroma.com/pages/integrations/mlx.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
---
title: MLX Embeddings
---
Chroma provides a convenient wrapper around the [MLX](https://github.com/mlfoundations/openllama) framework to run embedding models using the BERT architecture. The code is available in the [MLX examples repo](https://github.com/ml-explore/mlx-examples/).

To use the MLXEmbeddingFunction, you need to provide the model folder path and model weights path.

A BERT model from hf needs to be converted to mlx format for it to usable. To convert the model please vist this [repo](https://github.com/ml-explore/mlx-examples/tree/main/bert).


{% tabs group="code-lang" %}
{% tab label="Python" %}

```python
import chromadb.utils.embedding_functions as embedding_functions

mlx_ef = embedding_functions.MLXEmbeddingFunction(bert_model="bert-base-uncased", weights_path="bert-base-uncased/model.npz")
texts = ["Hello, world!", "How are you?"]

embeddings = mlx_ef(texts)
```

{% /tab %}
{% /tabs %}