Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Support >= numpy 2.0 #2776

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast
from numpy.typing import NDArray
from packaging import version
import numpy as np
from typing_extensions import TypedDict, Protocol, runtime_checkable
from enum import Enum
Expand Down Expand Up @@ -103,8 +104,13 @@ def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents:


# Images
ImageDType = Union[np.uint, np.int_, np.float_] # type: ignore[name-defined]
Image = NDArray[ImageDType]
ImageDType = None
if version.parse(np.__version__) < version.parse("2.0.0"):
ImageDType = Union[np.uint, np.int_, np.float_] # type: ignore[attr-defined]
else:
ImageDType = Union[np.uint, np.int_, np.float64]

Image = NDArray[ImageDType] # type: ignore[valid-type]
Images = List[Image]


Expand Down
14 changes: 9 additions & 5 deletions chromadb/test/ef/test_multimodal_ef.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Generator, cast
import numpy as np
import pytest
from packaging import version
import chromadb
from chromadb.api.types import (
Embeddable,
Expand All @@ -17,7 +18,10 @@
# then hashes them to a fixed dimension.
class hashing_multimodal_ef(EmbeddingFunction[Embeddable]):
def __init__(self) -> None:
self._hef = hashing_embedding_function(dim=10, dtype=np.float_)
if version.parse(np.__version__) < version.parse("2.0.0"):
self._hef = hashing_embedding_function(dim=10, dtype=np.float_) # type: ignore[attr-defined]
else:
self._hef = hashing_embedding_function(dim=10, dtype=np.float64)

def __call__(self, input: Embeddable) -> Embeddings:
to_texts = [str(i) for i in input]
Expand Down Expand Up @@ -82,7 +86,7 @@ def test_multimodal(

# get() should return all the documents and images
# ids corresponding to images should not have documents
get_result = multimodal_collection.get(include=["documents"])
get_result = multimodal_collection.get(include=["documents"]) # type: ignore[list-item]
assert len(get_result["ids"]) == len(document_ids) + len(image_ids)
for i, id in enumerate(get_result["ids"]):
assert id in document_ids or id in image_ids
Expand Down Expand Up @@ -124,14 +128,14 @@ def test_multimodal(

# Query with images
query_result = multimodal_collection.query(
query_images=[query_image], n_results=n_query_results, include=["documents"]
query_images=[query_image], n_results=n_query_results, include=["documents"] # type: ignore[list-item]
)

assert query_result["ids"][0] == nearest_image_neighbor_ids

# Query with documents
query_result = multimodal_collection.query(
query_texts=[query_document], n_results=n_query_results, include=["documents"]
query_texts=[query_document], n_results=n_query_results, include=["documents"] # type: ignore[list-item]
)

assert query_result["ids"][0] == nearest_document_neighbor_ids
Expand All @@ -152,6 +156,6 @@ def test_multimodal_update_with_image(

multimodal_collection.update(ids=id, images=image)

get_result = multimodal_collection.get(ids=id, include=["documents"])
get_result = multimodal_collection.get(ids=id, include=["documents"]) # type: ignore[list-item]
assert get_result["documents"] is not None
assert get_result["documents"][0] is None
25 changes: 17 additions & 8 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing_extensions import TypedDict
import uuid
import numpy as np
from packaging import version
import numpy.typing as npt
import chromadb.api.types as types
import re
Expand Down Expand Up @@ -148,7 +149,12 @@ def one_or_both(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_./+"
)

float_types = [np.float16, np.float32, np.float64]
float_types = None
if version.parse(np.__version__) < version.parse("2.0.0"):
float_types = [np.float16, np.float32, np.float_] # type: ignore[attr-defined]
else:
float_types = [np.float16, np.float32, np.float64]

int_types = [np.int16, np.int32, np.int64] # TODO: handle int types


Expand Down Expand Up @@ -194,7 +200,7 @@ def create_embeddings_ndarray(
dim: int,
count: int,
dtype: npt.DTypeLike,
) -> np.typing.NDArray[Any]:
) -> npt.NDArray[Any]:
return np.random.uniform(
low=-1.0,
high=1.0,
Expand Down Expand Up @@ -295,7 +301,7 @@ def collections(
name = draw(collection_name())
metadata = draw(collection_metadata)
dimension = draw(st.integers(min_value=2, max_value=2048))
dtype = draw(st.sampled_from(float_types))
dtype = draw(st.sampled_from(float_types)) # type: ignore[arg-type]

use_persistent_hnsw_params = draw(with_persistent_hnsw_params)

Expand Down Expand Up @@ -376,7 +382,10 @@ def collections(

@st.composite
def metadata(
draw: st.DrawFn, collection: Collection, min_size=0, max_size=None
draw: st.DrawFn,
collection: Collection,
min_size: int = 0,
max_size: Optional[int] = None,
) -> Optional[types.Metadata]:
"""Strategy for generating metadata that could be a part of the given collection"""
# First draw a random dictionary.
Expand Down Expand Up @@ -429,7 +438,7 @@ def document(draw: st.DrawFn, collection: Collection) -> types.Document:

# Blacklist certain unicode characters that affect sqlite processing.
# For example, the null (/x00) character makes sqlite stop processing a string.
blacklist_categories = ("Cc", "Cs")
blacklist_categories = ("Cc", "Cs") # type: ignore[assignment]
if collection.known_document_keywords:
known_words_st = st.sampled_from(collection.known_document_keywords)
else:
Expand Down Expand Up @@ -553,7 +562,7 @@ def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where:
if not NOT_CLUSTER_ONLY:
legal_ops: List[Optional[str]] = [None, "$eq"]
else:
legal_ops: List[Optional[str]] = [None, "$eq", "$ne", "$in", "$nin"]
legal_ops: List[Optional[str]] = [None, "$eq", "$ne", "$in", "$nin"] # type: ignore[no-redef]

if not isinstance(value, str) and not isinstance(value, bool):
legal_ops.extend(["$gt", "$lt", "$lte", "$gte"])
Expand Down Expand Up @@ -605,10 +614,10 @@ def where_doc_clause(draw: st.DrawFn, collection: Collection) -> types.WhereDocu
else:
op = draw(st.sampled_from(["$contains", "$not_contains"]))

if op == "$contains":
if op == "$contains": # type: ignore[comparison-overlap]
return {"$contains": word}
else:
assert op == "$not_contains"
assert op == "$not_contains" # type: ignore[comparison-overlap]
return {"$not_contains": word}


Expand Down
11 changes: 8 additions & 3 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import logging
import hypothesis
from packaging import version
import hypothesis.strategies as st
from hypothesis import given, settings, HealthCheck
from typing import Dict, Set, cast, Union, DefaultDict, Any, List
Expand Down Expand Up @@ -51,9 +52,13 @@ def print_traces() -> None:
print(f"{key}: {value}")


dtype_shared_st: st.SearchStrategy[
Union[np.float16, np.float32, np.float64]
] = st.shared(st.sampled_from(strategies.float_types), key="dtype")
SearchStrategyType = None
if version.parse(np.__version__) < version.parse("2.0.0"):
SearchStrategyType = Union[np.float16, np.float32, np.float_] # type: ignore[attr-defined]
else:
SearchStrategyType = Union[np.float16, np.float32, np.float64]

dtype_shared_st: SearchStrategyType = st.shared(st.sampled_from(strategies.float_types), key="dtype") # type: ignore[valid-type, arg-type]

dimension_shared_st: st.SearchStrategy[int] = st.shared(
st.integers(min_value=2, max_value=2048), key="dimension"
Expand Down
2 changes: 1 addition & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
'numpy >= 1.22.5, < 2.0.0',
'numpy >= 1.22.5',
'opentelemetry-api>=1.2.0',
'opentelemetry-exporter-otlp-proto-grpc>=1.2.0',
'opentelemetry-sdk>=1.2.0',
Expand Down
2 changes: 1 addition & 1 deletion clients/python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
httpx>=0.27.0
numpy >= 1.22.5, < 2.0.0
numpy >= 1.22.5
opentelemetry-api>=1.2.0
opentelemetry-exporter-otlp-proto-grpc>=1.2.0
opentelemetry-sdk>=1.2.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
'chroma-hnswlib==0.7.6',
'fastapi >= 0.95.2',
'uvicorn[standard] >= 0.18.3',
'numpy >= 1.22.5, < 2.0.0',
'numpy >= 1.22.5',
'posthog >= 2.4.0',
'typing_extensions >= 4.5.0',
'onnxruntime >= 1.14.1',
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ httpx>=0.27.0
importlib-resources
kubernetes>=28.1.0
mmh3>=4.0.1
numpy>=1.22.5, <2.0.0
numpy>=1.22.5
onnxruntime>=1.14.1
opentelemetry-api>=1.2.0
opentelemetry-exporter-otlp-proto-grpc>=1.24.0
Expand Down
Loading