From 9ab01964a972b47254b9aaa35a30d9a89e9a0743 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Tue, 17 Sep 2024 11:45:24 -0700 Subject: [PATCH] [CLN] Support numpy >=2.0 (#2811) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Supports numpy 2.0 now that onnxruntime supports it by removing our limit of <2.0 - Change usage of `float_` to `float64`. These are supported in both < and >= 2.0. However `float_` is only supported by < 2.0. In < 2.0 the _ types are shorthand for the 64bit wide datatypes. I verified this. This makes the changes cleaner than #2776. Screenshot 2024-09-17 at 10 15 21 AM - New functionality - None ## Test plan *How are these changes tested?* Existing tests - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes None --- chromadb/api/types.py | 2 +- chromadb/test/ef/test_multimodal_ef.py | 12 ++++++------ clients/python/requirements.txt | 2 +- pyproject.toml | 2 +- requirements.txt | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/chromadb/api/types.py b/chromadb/api/types.py index f0ffc1e6ca0..f6cab2239b8 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -103,7 +103,7 @@ def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents: # Images -ImageDType = Union[np.uint, np.int_, np.float_] # type: ignore[name-defined] +ImageDType = Union[np.uint, np.int64, np.float64] Image = NDArray[ImageDType] Images = List[Image] diff --git a/chromadb/test/ef/test_multimodal_ef.py b/chromadb/test/ef/test_multimodal_ef.py index 82f66fea33e..8953f4e495e 100644 --- a/chromadb/test/ef/test_multimodal_ef.py +++ b/chromadb/test/ef/test_multimodal_ef.py @@ -17,7 +17,7 @@ # then hashes them to a fixed dimension. class hashing_multimodal_ef(EmbeddingFunction[Embeddable]): def __init__(self) -> None: - self._hef = hashing_embedding_function(dim=10, dtype=np.float_) + self._hef = hashing_embedding_function(dim=10, dtype=np.float64) def __call__(self, input: Embeddable) -> Embeddings: to_texts = [str(i) for i in input] @@ -29,7 +29,7 @@ def __call__(self, input: Embeddable) -> Embeddings: def random_image() -> Image: - return np.random.randint(0, 255, size=(10, 10, 3), dtype=np.int32) + return np.random.randint(0, 255, size=(10, 10, 3), dtype=np.int64) def random_document() -> Document: @@ -82,7 +82,7 @@ def test_multimodal( # get() should return all the documents and images # ids corresponding to images should not have documents - get_result = multimodal_collection.get(include=["documents"]) + get_result = multimodal_collection.get(include=["documents"]) # type: ignore[list-item] assert len(get_result["ids"]) == len(document_ids) + len(image_ids) for i, id in enumerate(get_result["ids"]): assert id in document_ids or id in image_ids @@ -124,14 +124,14 @@ def test_multimodal( # Query with images query_result = multimodal_collection.query( - query_images=[query_image], n_results=n_query_results, include=["documents"] + query_images=[query_image], n_results=n_query_results, include=["documents"] # type: ignore[list-item] ) assert query_result["ids"][0] == nearest_image_neighbor_ids # Query with documents query_result = multimodal_collection.query( - query_texts=[query_document], n_results=n_query_results, include=["documents"] + query_texts=[query_document], n_results=n_query_results, include=["documents"] # type: ignore[list-item] ) assert query_result["ids"][0] == nearest_document_neighbor_ids @@ -152,6 +152,6 @@ def test_multimodal_update_with_image( multimodal_collection.update(ids=id, images=image) - get_result = multimodal_collection.get(ids=id, include=["documents"]) + get_result = multimodal_collection.get(ids=id, include=["documents"]) # type: ignore[list-item] assert get_result["documents"] is not None assert get_result["documents"][0] is None diff --git a/clients/python/requirements.txt b/clients/python/requirements.txt index 0c83eb72520..88afd9cda8a 100644 --- a/clients/python/requirements.txt +++ b/clients/python/requirements.txt @@ -1,5 +1,5 @@ httpx>=0.27.0 -numpy >= 1.22.5, < 2.0.0 +numpy >= 1.22.5 opentelemetry-api>=1.2.0 opentelemetry-exporter-otlp-proto-grpc>=1.2.0 opentelemetry-sdk>=1.2.0 diff --git a/pyproject.toml b/pyproject.toml index 21c1898b873..5ea5e43f38a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ 'chroma-hnswlib==0.7.6', 'fastapi >= 0.95.2', 'uvicorn[standard] >= 0.18.3', - 'numpy >= 1.22.5, < 2.0.0', + 'numpy >= 1.22.5', 'posthog >= 2.4.0', 'typing_extensions >= 4.5.0', 'onnxruntime >= 1.14.1', diff --git a/requirements.txt b/requirements.txt index dfd08dccd8d..b7b621faf2a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ httpx>=0.27.0 importlib-resources kubernetes>=28.1.0 mmh3>=4.0.1 -numpy>=1.22.5, <2.0.0 +numpy>=1.22.5 onnxruntime>=1.14.1 opentelemetry-api>=1.2.0 opentelemetry-exporter-otlp-proto-grpc>=1.24.0