Skip to content

Commit

Permalink
Replace ONNXMiniLM_L6_V2._init_model_and_tokenizer with tokenizer and…
Browse files Browse the repository at this point in the history
… model cached properties (#1194)

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Fixes #1193: race condition in
`ONNXMiniLM_L6_V2._init_model_and_tokenizer`

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python, `yarn test` for js
  • Loading branch information
gsakkis committed Jan 5, 2024
1 parent fca3426 commit bdec54a
Showing 1 changed file with 38 additions and 39 deletions.
77 changes: 38 additions & 39 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hashlib
import logging
from functools import cached_property

from tenacity import stop_after_attempt, wait_random, retry, retry_if_exception

Expand All @@ -18,20 +19,23 @@
import os
import tarfile
import requests
from typing import Any, Dict, List, Mapping, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Union, cast
import numpy as np
import numpy.typing as npt
import importlib
import inspect
import json
import sys
from typing import Optional

try:
from chromadb.is_thin_client import is_thin_client
except ImportError:
is_thin_client = False

if TYPE_CHECKING:
from onnxruntime import InferenceSession
from tokenizers import Tokenizer

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -361,8 +365,6 @@ class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]):
"https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz"
)
_MODEL_SHA256 = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3"
tokenizer = None
model = None

# https://github.com/python/mypy/issues/7291 mypy makes you type the constructor if
# no args
Expand Down Expand Up @@ -440,8 +442,6 @@ def _normalize(self, v: npt.NDArray) -> npt.NDArray: # type: ignore
# type: ignore
def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray:
# We need to cast to the correct type because the type checker doesn't know that init_model_and_tokenizer will set the values
self.tokenizer = cast(self.Tokenizer, self.tokenizer) # type: ignore
self.model = cast(self.ort.InferenceSession, self.model) # type: ignore
all_embeddings = []
for i in range(0, len(documents), batch_size):
batch = documents[i : i + batch_size]
Expand Down Expand Up @@ -469,46 +469,45 @@ def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray:
all_embeddings.append(embeddings)
return np.concatenate(all_embeddings)

def _init_model_and_tokenizer(self) -> None:
if self.model is None and self.tokenizer is None:
self.tokenizer = self.Tokenizer.from_file(
os.path.join(
self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json"
)
@cached_property
def tokenizer(self) -> "Tokenizer":
tokenizer = self.Tokenizer.from_file(
os.path.join(
self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json"
)
# max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128
# https://github.com/UKPLab/sentence-transformers/blob/3e1929fddef16df94f8bc6e3b10598a98f46e62d/docs/_static/html/models_en_sentence_embeddings.html#LL480
self.tokenizer.enable_truncation(max_length=256)
self.tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256)

if self._preferred_providers is None or len(self._preferred_providers) == 0:
if len(self.ort.get_available_providers()) > 0:
logger.debug(
f"WARNING: No ONNX providers provided, defaulting to available providers: "
f"{self.ort.get_available_providers()}"
)
self._preferred_providers = self.ort.get_available_providers()
elif not set(self._preferred_providers).issubset(
set(self.ort.get_available_providers())
):
raise ValueError(
f"Preferred providers must be subset of available providers: {self.ort.get_available_providers()}"
)
# max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128
# https://github.com/UKPLab/sentence-transformers/blob/3e1929fddef16df94f8bc6e3b10598a98f46e62d/docs/_static/html/models_en_sentence_embeddings.html#LL480
tokenizer.enable_truncation(max_length=256)
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256)
return tokenizer

@cached_property
def model(self) -> "InferenceSession":
if self._preferred_providers is None or len(self._preferred_providers) == 0:
if len(self.ort.get_available_providers()) > 0:
logger.debug(
f"WARNING: No ONNX providers provided, defaulting to available providers: "
f"{self.ort.get_available_providers()}"
)
self.model = self.ort.InferenceSession(
os.path.join(
self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"
),
# Since 1.9 onnyx runtime requires providers to be specified when there are multiple available - https://onnxruntime.ai/docs/api/python/api_summary.html
# This is probably not ideal but will improve DX as no exceptions will be raised in multi-provider envs
providers=self._preferred_providers,
self._preferred_providers = self.ort.get_available_providers()
elif not set(self._preferred_providers).issubset(
set(self.ort.get_available_providers())
):
raise ValueError(
f"Preferred providers must be subset of available providers: {self.ort.get_available_providers()}"
)
return self.ort.InferenceSession(
os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"),
# Since 1.9 onnyx runtime requires providers to be specified when there are multiple available - https://onnxruntime.ai/docs/api/python/api_summary.html
# This is probably not ideal but will improve DX as no exceptions will be raised in multi-provider envs
providers=self._preferred_providers,
)

def __call__(self, input: Documents) -> Embeddings:
# Only download the model when it is actually used
self._download_model_if_not_exists()
self._init_model_and_tokenizer()
res = cast(Embeddings, self._forward(input).tolist())
return res
return cast(Embeddings, self._forward(input).tolist())

def _download_model_if_not_exists(self) -> None:
onnx_files = [
Expand Down

0 comments on commit bdec54a

Please sign in to comment.