-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Test that query result shapes are correct in invariants #2807
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,9 +4,7 @@ | |
from chromadb.db.base import get_sql | ||
from chromadb.db.impl.sqlite import SqliteDB | ||
from time import sleep | ||
|
||
import psutil | ||
|
||
from chromadb.test.property.strategies import NormalizedRecordSet, RecordSet | ||
from typing import Callable, Optional, Tuple, Union, List, TypeVar, cast | ||
from typing_extensions import Literal | ||
|
@@ -261,10 +259,14 @@ def ann_accuracy( | |
include=["embeddings", "documents", "metadatas", "distances"], # type: ignore[list-item] | ||
) | ||
|
||
_query_results_are_correct_shape(query_results, n_results) | ||
|
||
# Assert fields are not None for type checking | ||
assert query_results["ids"] is not None | ||
assert query_results["distances"] is not None | ||
assert query_results["embeddings"] is not None | ||
assert query_results["documents"] is not None | ||
assert query_results["metadatas"] is not None | ||
assert query_results["embeddings"] is not None | ||
|
||
# Dict of ids to indices | ||
id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])} | ||
|
@@ -324,6 +326,16 @@ def ann_accuracy( | |
assert np.allclose(np.sort(distance_result), distance_result) | ||
|
||
|
||
def _query_results_are_correct_shape( | ||
query_results: types.QueryResult, n_results: int | ||
) -> None: | ||
for result_type in ["distances", "embeddings", "documents", "metadatas"]: | ||
assert query_results[result_type] is not None # type: ignore[literal-required] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pushed asserts down for debuggability |
||
assert all( | ||
len(result) == n_results for result in query_results[result_type] # type: ignore[literal-required] | ||
) | ||
|
||
|
||
def _total_embedding_queue_log_size(sqlite: SqliteDB) -> int: | ||
t = Table("embeddings_queue") | ||
q = sqlite.querybuilder().from_(t) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
import tempfile | ||
from chromadb.api.client import Client as ClientCreator | ||
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction | ||
import numpy as np | ||
|
||
CreatePersistAPI = Callable[[], ServerAPI] | ||
|
||
|
@@ -308,6 +309,55 @@ def test_persist_embeddings_state( | |
) # type: ignore | ||
|
||
|
||
def test_delete_less_than_k( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. todo: rename and clean up |
||
caplog: pytest.LogCaptureFixture, settings: Settings | ||
) -> None: | ||
client = chromadb.Client(settings) | ||
state = PersistEmbeddingsStateMachine(settings=settings, client=client) | ||
state.initialize( | ||
collection=strategies.Collection( | ||
name="A00", | ||
metadata={ | ||
"hnsw:construction_ef": 128, | ||
"hnsw:search_ef": 128, | ||
"hnsw:M": 128, | ||
"hnsw:sync_threshold": 3, | ||
"hnsw:batch_size": 3, | ||
}, | ||
embedding_function=None, | ||
id=UUID("2d3eddc7-2314-45f4-a951-47a9a8e099d2"), | ||
dimension=2, | ||
dtype=np.float16, | ||
known_metadata_keys={}, | ||
known_document_keywords=[], | ||
has_documents=False, | ||
has_embeddings=True, | ||
) | ||
) | ||
state.ann_accuracy() | ||
state.count() | ||
state.fields_match() | ||
state.log_size_below_max() | ||
state.no_duplicates() | ||
(embedding_ids_0,) = state.add_embeddings(record_set={"ids": ["0"], "embeddings": [[0.09765625, 0.430419921875]], "metadatas": [None], "documents": None}) # type: ignore | ||
state.ann_accuracy() | ||
# recall: 1.0, missing 0 out of 1, accuracy threshold 1e-06 | ||
state.count() | ||
state.fields_match() | ||
state.log_size_below_max() | ||
state.no_duplicates() | ||
embedding_ids_1, embedding_ids_2 = state.add_embeddings(record_set={"ids": ["1", "2"], "embeddings": [[0.20556640625, 0.08978271484375], [-0.1527099609375, 0.291748046875]], "metadatas": [None, None], "documents": None}) # type: ignore | ||
state.ann_accuracy() | ||
# recall: 1.0, missing 0 out of 3, accuracy threshold 1e-06 | ||
state.count() | ||
state.fields_match() | ||
state.log_size_below_max() | ||
state.no_duplicates() | ||
state.delete_by_ids(ids=[embedding_ids_2]) | ||
state.ann_accuracy() | ||
state.teardown() | ||
|
||
|
||
# Ideally this scenario would be exercised by Hypothesis, but most runs don't seem to trigger this particular state. | ||
def test_delete_add_after_persist(settings: Settings) -> None: | ||
client = chromadb.Client(settings) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -205,7 +205,7 @@ fn merge_results( | |
let mut brute_force_index = 0; | ||
|
||
// TODO: This doesn't have to clone the user IDs, but it's easier for now | ||
while (result_user_ids.len() <= k) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. eww |
||
while (result_user_ids.len() < k) | ||
&& (hnsw_index < hnsw_result_user_ids.len() | ||
|| brute_force_index < brute_force_result_user_ids.len()) | ||
{ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
had to bring this back out for typechecker :(