Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Test that query result shapes are correct in invariants #2807

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion chromadb/segment/impl/vector/local_persistent_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ def query_vectors(
# Overquery by updated and deleted elements layered on the index because they may
# hide the real nearest neighbors in the hnsw index
hnsw_k = k + self._curr_batch.update_count + self._curr_batch.delete_count
# self._id_to_label contains the ids of the elements in the hnsw index
# so its length is the number of elements in the hnsw index
if hnsw_k > len(self._id_to_label):
hnsw_k = len(self._id_to_label)
hnsw_query = VectorQuery(
Expand Down Expand Up @@ -472,7 +474,7 @@ def query_vectors(
if remaining > 0 and hnsw_pointer < len(curr_hnsw_result):
for i in range(
hnsw_pointer,
min(len(curr_hnsw_result), hnsw_pointer + remaining + 1),
min(len(curr_hnsw_result), hnsw_pointer + remaining),
):
id = curr_hnsw_result[i]["id"]
if not self._brute_force_index.has_id(id):
Expand Down
18 changes: 15 additions & 3 deletions chromadb/test/property/invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from chromadb.db.base import get_sql
from chromadb.db.impl.sqlite import SqliteDB
from time import sleep

import psutil

from chromadb.test.property.strategies import NormalizedRecordSet, RecordSet
from typing import Callable, Optional, Tuple, Union, List, TypeVar, cast
from typing_extensions import Literal
Expand Down Expand Up @@ -261,10 +259,14 @@ def ann_accuracy(
include=["embeddings", "documents", "metadatas", "distances"], # type: ignore[list-item]
)

_query_results_are_correct_shape(query_results, n_results)

# Assert fields are not None for type checking
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

had to bring this back out for typechecker :(

assert query_results["ids"] is not None
assert query_results["distances"] is not None
assert query_results["embeddings"] is not None
assert query_results["documents"] is not None
assert query_results["metadatas"] is not None
assert query_results["embeddings"] is not None

# Dict of ids to indices
id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])}
Expand Down Expand Up @@ -324,6 +326,16 @@ def ann_accuracy(
assert np.allclose(np.sort(distance_result), distance_result)


def _query_results_are_correct_shape(
query_results: types.QueryResult, n_results: int
) -> None:
for result_type in ["distances", "embeddings", "documents", "metadatas"]:
assert query_results[result_type] is not None # type: ignore[literal-required]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pushed asserts down for debuggability

assert all(
len(result) == n_results for result in query_results[result_type] # type: ignore[literal-required]
)


def _total_embedding_queue_log_size(sqlite: SqliteDB) -> int:
t = Table("embeddings_queue")
q = sqlite.querybuilder().from_(t)
Expand Down
50 changes: 50 additions & 0 deletions chromadb/test/property/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import tempfile
from chromadb.api.client import Client as ClientCreator
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
import numpy as np

CreatePersistAPI = Callable[[], ServerAPI]

Expand Down Expand Up @@ -308,6 +309,55 @@ def test_persist_embeddings_state(
) # type: ignore


def test_delete_less_than_k(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: rename and clean up

caplog: pytest.LogCaptureFixture, settings: Settings
) -> None:
client = chromadb.Client(settings)
state = PersistEmbeddingsStateMachine(settings=settings, client=client)
state.initialize(
collection=strategies.Collection(
name="A00",
metadata={
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 128,
"hnsw:sync_threshold": 3,
"hnsw:batch_size": 3,
},
embedding_function=None,
id=UUID("2d3eddc7-2314-45f4-a951-47a9a8e099d2"),
dimension=2,
dtype=np.float16,
known_metadata_keys={},
known_document_keywords=[],
has_documents=False,
has_embeddings=True,
)
)
state.ann_accuracy()
state.count()
state.fields_match()
state.log_size_below_max()
state.no_duplicates()
(embedding_ids_0,) = state.add_embeddings(record_set={"ids": ["0"], "embeddings": [[0.09765625, 0.430419921875]], "metadatas": [None], "documents": None}) # type: ignore
state.ann_accuracy()
# recall: 1.0, missing 0 out of 1, accuracy threshold 1e-06
state.count()
state.fields_match()
state.log_size_below_max()
state.no_duplicates()
embedding_ids_1, embedding_ids_2 = state.add_embeddings(record_set={"ids": ["1", "2"], "embeddings": [[0.20556640625, 0.08978271484375], [-0.1527099609375, 0.291748046875]], "metadatas": [None, None], "documents": None}) # type: ignore
state.ann_accuracy()
# recall: 1.0, missing 0 out of 3, accuracy threshold 1e-06
state.count()
state.fields_match()
state.log_size_below_max()
state.no_duplicates()
state.delete_by_ids(ids=[embedding_ids_2])
state.ann_accuracy()
state.teardown()


# Ideally this scenario would be exercised by Hypothesis, but most runs don't seem to trigger this particular state.
def test_delete_add_after_persist(settings: Settings) -> None:
client = chromadb.Client(settings)
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/src/execution/operators/merge_knn_results.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ fn merge_results(
let mut brute_force_index = 0;

// TODO: This doesn't have to clone the user IDs, but it's easier for now
while (result_user_ids.len() <= k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eww

while (result_user_ids.len() < k)
&& (hnsw_index < hnsw_result_user_ids.len()
|| brute_force_index < brute_force_result_user_ids.len())
{
Expand Down
Loading