Skip to content

Commit

Permalink
[BUG] fix bug when receive embedding from vertex api
Browse files Browse the repository at this point in the history
1. change the way to check key in response.
  • Loading branch information
Albert Li committed Apr 12, 2024
1 parent 3dd01d2 commit ff8827a
Showing 1 changed file with 32 additions and 20 deletions.
52 changes: 32 additions & 20 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"):

def __call__(self, input: Documents) -> Embeddings:
return cast(
Embeddings, self._model.encode(list(input), convert_to_numpy=True).tolist()
Embeddings, self._model.encode(
list(input), convert_to_numpy=True).tolist()
) # noqa E501


Expand Down Expand Up @@ -222,7 +223,8 @@ def __call__(self, input: Documents) -> Embeddings:

# Return just the embeddings
return cast(
Embeddings, [result["embedding"] for result in sorted_embeddings]
Embeddings, [result["embedding"]
for result in sorted_embeddings]
)


Expand Down Expand Up @@ -431,7 +433,8 @@ def __init__(self, preferred_providers: Optional[List[str]] = None) -> None:
reraise=True,
stop=stop_after_attempt(3),
wait=wait_random(min=1, max=3),
retry=retry_if_exception(lambda e: "does not match expected SHA256" in str(e)),
retry=retry_if_exception(
lambda e: "does not match expected SHA256" in str(e)),
)
def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None:
resp = requests.get(url, stream=True)
Expand Down Expand Up @@ -466,7 +469,7 @@ def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray:
self.model = cast(self.ort.InferenceSession, self.model)
all_embeddings = []
for i in range(0, len(documents), batch_size):
batch = documents[i : i + batch_size]
batch = documents[i: i + batch_size]
encoded = [self.tokenizer.encode(d) for d in batch]
input_ids = np.array([e.ids for e in encoded])
attention_mask = np.array([e.attention_mask for e in encoded])
Expand Down Expand Up @@ -525,7 +528,8 @@ def model(self) -> "InferenceSession":
so.log_severity_level = 3

return self.ort.InferenceSession(
os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"),
os.path.join(self.DOWNLOAD_PATH,
self.EXTRACTED_FOLDER_NAME, "model.onnx"),
# Since 1.9 onnyx runtime requires providers to be specified when there are multiple available - https://onnxruntime.ai/docs/api/python/api_summary.html
# This is probably not ideal but will improve DX as no exceptions will be raised in multi-provider envs
providers=self._preferred_providers,
Expand All @@ -546,7 +550,8 @@ def _download_model_if_not_exists(self) -> None:
"tokenizer.json",
"vocab.txt",
]
extracted_folder = os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME)
extracted_folder = os.path.join(
self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME)
onnx_files_exist = True
for f in onnx_files:
if not os.path.exists(os.path.join(extracted_folder, f)):
Expand All @@ -563,7 +568,8 @@ def _download_model_if_not_exists(self) -> None:
):
self._download(
url=self.MODEL_DOWNLOAD_URL,
fname=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
fname=os.path.join(self.DOWNLOAD_PATH,
self.ARCHIVE_FILENAME),
)
with tarfile.open(
name=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
Expand Down Expand Up @@ -676,8 +682,9 @@ def __call__(self, input: Documents) -> Embeddings:
self._api_url, json={"instances": [{"content": text}]}
).json()

if "predictions" in response:
for prediction in response["predictions"]:
predictions = response.get("predictions")
if predictions:
for prediction in predictions:
embeddings.append(prediction["embeddings"]["values"])

return embeddings
Expand Down Expand Up @@ -745,7 +752,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:

class RoboflowEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]):
def __init__(
self, api_key: str = "", api_url = "https://infer.roboflow.com"
self, api_key: str = "", api_url="https://infer.roboflow.com"
) -> None:
"""
Create a RoboflowEmbeddingFunction.
Expand All @@ -758,7 +765,7 @@ def __init__(
api_key = os.environ.get("ROBOFLOW_API_KEY")

self._api_url = api_url
self._api_key = api_key
self._api_key = api_key

try:
self._PILImage = importlib.import_module("PIL.Image")
Expand All @@ -776,7 +783,8 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:

buffer = BytesIO()
image.save(buffer, format="JPEG")
base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
base64_image = base64.b64encode(
buffer.getvalue()).decode("utf-8")

infer_clip_payload = {
"image": {
Expand All @@ -793,7 +801,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
result = res.json()['embeddings']

embeddings.append(result[0])

elif is_document(item):
infer_clip_payload = {
"text": input,
Expand All @@ -810,7 +818,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:

return embeddings


class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
Expand Down Expand Up @@ -897,7 +905,8 @@ def __call__(self, input: Documents) -> Embeddings:
"""
# Call HuggingFace Embedding Server API for each document
return cast(
Embeddings, self._session.post(self._api_url, json={"inputs": input}).json()
Embeddings, self._session.post(
self._api_url, json={"inputs": input}).json()
)


Expand All @@ -910,7 +919,8 @@ def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore
)

class ChromaLangchainEmbeddingFunction(
LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore
# type: ignore
LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]]
):
"""
This class is used as bridge between langchain embedding functions and custom chroma embedding functions.
Expand All @@ -926,14 +936,16 @@ def __init__(self, embedding_function: LangchainEmbeddings) -> None:
self.embedding_function = embedding_function

def embed_documents(self, documents: Documents) -> List[List[float]]:
return self.embedding_function.embed_documents(documents) # type: ignore
# type: ignore
return self.embedding_function.embed_documents(documents)

def embed_query(self, query: str) -> List[float]:
return self.embedding_function.embed_query(query) # type: ignore

def embed_image(self, uris: List[str]) -> List[List[float]]:
if hasattr(self.embedding_function, "embed_image"):
return self.embedding_function.embed_image(uris) # type: ignore
# type: ignore
return self.embedding_function.embed_image(uris)
else:
raise ValueError(
"The provided embedding function does not support image embeddings."
Expand Down Expand Up @@ -963,7 +975,7 @@ def __call__(self, input: Documents) -> Embeddings: # type: ignore

return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn)


class OllamaEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings).
Expand Down Expand Up @@ -1019,7 +1031,7 @@ def __call__(self, input: Documents) -> Embeddings:
],
)


# List of all classes in this module
_classes = [
name
Expand Down

0 comments on commit ff8827a

Please sign in to comment.