Skip to content

Commit

Permalink
add support for numpy2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Sep 11, 2024
1 parent dc22570 commit b736380
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents:


# Images
ImageDType = Union[np.uint, np.int_, np.float_] # type: ignore[name-defined]
ImageDType = Union[np.uint, np.int_, np.float64]
Image = NDArray[ImageDType]
Images = List[Image]

Expand Down
10 changes: 5 additions & 5 deletions chromadb/test/ef/test_multimodal_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# then hashes them to a fixed dimension.
class hashing_multimodal_ef(EmbeddingFunction[Embeddable]):
def __init__(self) -> None:
self._hef = hashing_embedding_function(dim=10, dtype=np.float_)
self._hef = hashing_embedding_function(dim=10, dtype=np.float64)

def __call__(self, input: Embeddable) -> Embeddings:
to_texts = [str(i) for i in input]
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_multimodal(

# get() should return all the documents and images
# ids corresponding to images should not have documents
get_result = multimodal_collection.get(include=["documents"])
get_result = multimodal_collection.get(include=["documents"]) # type: ignore[list-item]
assert len(get_result["ids"]) == len(document_ids) + len(image_ids)
for i, id in enumerate(get_result["ids"]):
assert id in document_ids or id in image_ids
Expand Down Expand Up @@ -124,14 +124,14 @@ def test_multimodal(

# Query with images
query_result = multimodal_collection.query(
query_images=[query_image], n_results=n_query_results, include=["documents"]
query_images=[query_image], n_results=n_query_results, include=["documents"] # type: ignore[list-item]
)

assert query_result["ids"][0] == nearest_image_neighbor_ids

# Query with documents
query_result = multimodal_collection.query(
query_texts=[query_document], n_results=n_query_results, include=["documents"]
query_texts=[query_document], n_results=n_query_results, include=["documents"] # type: ignore[list-item]
)

assert query_result["ids"][0] == nearest_document_neighbor_ids
Expand All @@ -152,6 +152,6 @@ def test_multimodal_update_with_image(

multimodal_collection.update(ids=id, images=image)

get_result = multimodal_collection.get(ids=id, include=["documents"])
get_result = multimodal_collection.get(ids=id, include=["documents"]) # type: ignore[list-item]
assert get_result["documents"] is not None
assert get_result["documents"][0] is None
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ dependencies = [
'chroma-hnswlib==0.7.6',
'fastapi >= 0.95.2',
'uvicorn[standard] >= 0.18.3',
'numpy >= 1.22.5',
'numpy >= 2.0.0',
'posthog >= 2.4.0',
'typing_extensions >= 4.5.0',
'onnxruntime >= 1.14.1',
'onnxruntime >= 1.19.0',
'opentelemetry-api>=1.2.0',
'opentelemetry-exporter-otlp-proto-grpc>=1.2.0',
'opentelemetry-instrumentation-fastapi>=0.41b0',
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ httpx>=0.27.0
importlib-resources
kubernetes>=28.1.0
mmh3>=4.0.1
numpy>=1.22.5, <2.0.0
onnxruntime>=1.14.1
numpy>=2.0.0
onnxruntime>=1.19.0
opentelemetry-api>=1.2.0
opentelemetry-exporter-otlp-proto-grpc>=1.24.0
opentelemetry-instrumentation-fastapi>=0.41b0
Expand Down

0 comments on commit b736380

Please sign in to comment.