From d17472ed8aad90e6e4dee8a668424d0ce8b89f10 Mon Sep 17 00:00:00 2001 From: Drew Kim Date: Tue, 17 Sep 2024 11:46:23 -0700 Subject: [PATCH] [BUG] Test that query result shapes are correct in invariants (#2807) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Discovered that there is an issue with `PersistentLocalHnswSegment` where `query_vectors` can return results with lengths different than `n_results`. This PR implements a test both displays and catches this issue. - Fixes off by one issue in hnsw/BF merge logic in both single node and distributed ## Test plan *How are these changes tested?* - Added a test for the breaking case that fails on main, fix shows test passes - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust --------- Co-authored-by: hammadb --- .../impl/vector/local_persistent_hnsw.py | 4 +- chromadb/test/property/invariants.py | 18 +++++-- chromadb/test/property/test_persist.py | 50 +++++++++++++++++++ .../execution/operators/merge_knn_results.rs | 2 +- 4 files changed, 69 insertions(+), 5 deletions(-) diff --git a/chromadb/segment/impl/vector/local_persistent_hnsw.py b/chromadb/segment/impl/vector/local_persistent_hnsw.py index aeae534c796..9741da0fe94 100644 --- a/chromadb/segment/impl/vector/local_persistent_hnsw.py +++ b/chromadb/segment/impl/vector/local_persistent_hnsw.py @@ -413,6 +413,8 @@ def query_vectors( # Overquery by updated and deleted elements layered on the index because they may # hide the real nearest neighbors in the hnsw index hnsw_k = k + self._curr_batch.update_count + self._curr_batch.delete_count + # self._id_to_label contains the ids of the elements in the hnsw index + # so its length is the number of elements in the hnsw index if hnsw_k > len(self._id_to_label): hnsw_k = len(self._id_to_label) hnsw_query = VectorQuery( @@ -472,7 +474,7 @@ def query_vectors( if remaining > 0 and hnsw_pointer < len(curr_hnsw_result): for i in range( hnsw_pointer, - min(len(curr_hnsw_result), hnsw_pointer + remaining + 1), + min(len(curr_hnsw_result), hnsw_pointer + remaining), ): id = curr_hnsw_result[i]["id"] if not self._brute_force_index.has_id(id): diff --git a/chromadb/test/property/invariants.py b/chromadb/test/property/invariants.py index 51c4d52016f..d7adb8ff673 100644 --- a/chromadb/test/property/invariants.py +++ b/chromadb/test/property/invariants.py @@ -4,9 +4,7 @@ from chromadb.db.base import get_sql from chromadb.db.impl.sqlite import SqliteDB from time import sleep - import psutil - from chromadb.test.property.strategies import NormalizedRecordSet, RecordSet from typing import Callable, Optional, Tuple, Union, List, TypeVar, cast from typing_extensions import Literal @@ -261,10 +259,14 @@ def ann_accuracy( include=["embeddings", "documents", "metadatas", "distances"], # type: ignore[list-item] ) + _query_results_are_correct_shape(query_results, n_results) + + # Assert fields are not None for type checking + assert query_results["ids"] is not None assert query_results["distances"] is not None + assert query_results["embeddings"] is not None assert query_results["documents"] is not None assert query_results["metadatas"] is not None - assert query_results["embeddings"] is not None # Dict of ids to indices id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])} @@ -324,6 +326,16 @@ def ann_accuracy( assert np.allclose(np.sort(distance_result), distance_result) +def _query_results_are_correct_shape( + query_results: types.QueryResult, n_results: int +) -> None: + for result_type in ["distances", "embeddings", "documents", "metadatas"]: + assert query_results[result_type] is not None # type: ignore[literal-required] + assert all( + len(result) == n_results for result in query_results[result_type] # type: ignore[literal-required] + ) + + def _total_embedding_queue_log_size(sqlite: SqliteDB) -> int: t = Table("embeddings_queue") q = sqlite.querybuilder().from_(t) diff --git a/chromadb/test/property/test_persist.py b/chromadb/test/property/test_persist.py index 267692ce485..92f65b27714 100644 --- a/chromadb/test/property/test_persist.py +++ b/chromadb/test/property/test_persist.py @@ -30,6 +30,7 @@ import tempfile from chromadb.api.client import Client as ClientCreator from chromadb.utils.embedding_functions import DefaultEmbeddingFunction +import numpy as np CreatePersistAPI = Callable[[], ServerAPI] @@ -308,6 +309,55 @@ def test_persist_embeddings_state( ) # type: ignore +def test_delete_less_than_k( + caplog: pytest.LogCaptureFixture, settings: Settings +) -> None: + client = chromadb.Client(settings) + state = PersistEmbeddingsStateMachine(settings=settings, client=client) + state.initialize( + collection=strategies.Collection( + name="A00", + metadata={ + "hnsw:construction_ef": 128, + "hnsw:search_ef": 128, + "hnsw:M": 128, + "hnsw:sync_threshold": 3, + "hnsw:batch_size": 3, + }, + embedding_function=None, + id=UUID("2d3eddc7-2314-45f4-a951-47a9a8e099d2"), + dimension=2, + dtype=np.float16, + known_metadata_keys={}, + known_document_keywords=[], + has_documents=False, + has_embeddings=True, + ) + ) + state.ann_accuracy() + state.count() + state.fields_match() + state.log_size_below_max() + state.no_duplicates() + (embedding_ids_0,) = state.add_embeddings(record_set={"ids": ["0"], "embeddings": [[0.09765625, 0.430419921875]], "metadatas": [None], "documents": None}) # type: ignore + state.ann_accuracy() + # recall: 1.0, missing 0 out of 1, accuracy threshold 1e-06 + state.count() + state.fields_match() + state.log_size_below_max() + state.no_duplicates() + embedding_ids_1, embedding_ids_2 = state.add_embeddings(record_set={"ids": ["1", "2"], "embeddings": [[0.20556640625, 0.08978271484375], [-0.1527099609375, 0.291748046875]], "metadatas": [None, None], "documents": None}) # type: ignore + state.ann_accuracy() + # recall: 1.0, missing 0 out of 3, accuracy threshold 1e-06 + state.count() + state.fields_match() + state.log_size_below_max() + state.no_duplicates() + state.delete_by_ids(ids=[embedding_ids_2]) + state.ann_accuracy() + state.teardown() + + # Ideally this scenario would be exercised by Hypothesis, but most runs don't seem to trigger this particular state. def test_delete_add_after_persist(settings: Settings) -> None: client = chromadb.Client(settings) diff --git a/rust/worker/src/execution/operators/merge_knn_results.rs b/rust/worker/src/execution/operators/merge_knn_results.rs index 27e7a587d6d..b328de9530c 100644 --- a/rust/worker/src/execution/operators/merge_knn_results.rs +++ b/rust/worker/src/execution/operators/merge_knn_results.rs @@ -205,7 +205,7 @@ fn merge_results( let mut brute_force_index = 0; // TODO: This doesn't have to clone the user IDs, but it's easier for now - while (result_user_ids.len() <= k) + while (result_user_ids.len() < k) && (hnsw_index < hnsw_result_user_ids.len() || brute_force_index < brute_force_result_user_ids.len()) {