From 8907e538627ecdb1475937342ac9416c3c135dfb Mon Sep 17 00:00:00 2001 From: Max Isom Date: Mon, 29 Jul 2024 14:11:12 -0700 Subject: [PATCH] [ENH] simplify logic for when to persist index changes (re-apply with fix) (#2545) History: - https://github.com/chroma-core/chroma/pull/2539 originally introduced some of the changes here - it introduced a bug not caught by our test suite due to limitations of our suite at the time - changes were reverted with https://github.com/chroma-core/chroma/pull/2544 - and now this PR adds a test that [was previously re-producing the bug](https://github.com/chroma-core/chroma/actions/runs/10012532835/job/27678256604?pr=2545) before it was fixed here Should not be merged until hnswlib version is bumped (bug around persisting a single item was recently fixed). --- .github/workflows/_python-tests.yml | 3 +- chromadb/segment/impl/vector/local_hnsw.py | 8 -- .../impl/vector/local_persistent_hnsw.py | 56 +++---------- chromadb/test/conftest.py | 17 +++- chromadb/test/property/invariants.py | 2 +- .../test/property/test_restart_persist.py | 83 +++++++++++++++++++ 6 files changed, 113 insertions(+), 56 deletions(-) create mode 100644 chromadb/test/property/test_restart_persist.py diff --git a/.github/workflows/_python-tests.yml b/.github/workflows/_python-tests.yml index c2ddd4973c7..f78c8f1ab92 100644 --- a/.github/workflows/_python-tests.yml +++ b/.github/workflows/_python-tests.yml @@ -30,7 +30,8 @@ jobs: "chromadb/test/property/test_cross_version_persist.py", "chromadb/test/property/test_embeddings.py", "chromadb/test/property/test_filtering.py", - "chromadb/test/property/test_persist.py"] + "chromadb/test/property/test_persist.py", + "chromadb/test/property/test_restart_persist.py"] include: - test-globs: "chromadb/test/property/test_embeddings.py" parallelized: true diff --git a/chromadb/segment/impl/vector/local_hnsw.py b/chromadb/segment/impl/vector/local_hnsw.py index 215f04dcf7d..7358270c7d6 100644 --- a/chromadb/segment/impl/vector/local_hnsw.py +++ b/chromadb/segment/impl/vector/local_hnsw.py @@ -43,8 +43,6 @@ class LocalHnswSegment(VectorReader): _index: Optional[hnswlib.Index] _dimensionality: Optional[int] _total_elements_added: int - _total_elements_updated: int - _total_invalid_operations: int _max_seq_id: SeqId _lock: ReadWriteLock @@ -68,8 +66,6 @@ def __init__(self, system: System, segment: Segment): self._index = None self._dimensionality = None self._total_elements_added = 0 - self._total_elements_updated = 0 - self._total_invalid_operations = 0 self._max_seq_id = self._consumer.min_seqid() self._id_to_seq_id = {} @@ -279,7 +275,6 @@ def _apply_batch(self, batch: Batch) -> None: # If that succeeds, update the total count self._total_elements_added += batch.add_count - self._total_elements_updated += batch.update_count # If that succeeds, finally the seq ID self._max_seq_id = batch.max_seq_id @@ -305,7 +300,6 @@ def _write_records(self, records: Sequence[LogRecord]) -> None: batch.apply(record) else: logger.warning(f"Delete of nonexisting embedding ID: {id}") - self._total_invalid_operations += 1 elif op == Operation.UPDATE: if record["record"]["embedding"] is not None: @@ -315,13 +309,11 @@ def _write_records(self, records: Sequence[LogRecord]) -> None: logger.warning( f"Update of nonexisting embedding ID: {record['record']['id']}" ) - self._total_invalid_operations += 1 elif op == Operation.ADD: if not label: batch.apply(record, False) else: logger.warning(f"Add of existing embedding ID: {id}") - self._total_invalid_operations += 1 elif op == Operation.UPSERT: batch.apply(record, label is not None) diff --git a/chromadb/segment/impl/vector/local_persistent_hnsw.py b/chromadb/segment/impl/vector/local_persistent_hnsw.py index b8406c1180b..59d42c34c95 100644 --- a/chromadb/segment/impl/vector/local_persistent_hnsw.py +++ b/chromadb/segment/impl/vector/local_persistent_hnsw.py @@ -2,7 +2,7 @@ import shutil from overrides import override import pickle -from typing import Any, Dict, List, Optional, Sequence, Set, cast +from typing import Dict, List, Optional, Sequence, Set, cast from chromadb.config import System from chromadb.segment.impl.vector.batch import Batch from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams @@ -41,8 +41,6 @@ class PersistentData: dimensionality: Optional[int] total_elements_added: int - total_elements_updated: int - total_invalid_operations: int max_seq_id: SeqId id_to_label: Dict[str, int] @@ -53,8 +51,6 @@ def __init__( self, dimensionality: Optional[int], total_elements_added: int, - total_elements_updated: int, - total_invalid_operations: int, max_seq_id: int, id_to_label: Dict[str, int], label_to_id: Dict[int, str], @@ -62,19 +58,11 @@ def __init__( ): self.dimensionality = dimensionality self.total_elements_added = total_elements_added - self.total_elements_updated = total_elements_updated - self.total_invalid_operations = total_invalid_operations self.max_seq_id = max_seq_id self.id_to_label = id_to_label self.label_to_id = label_to_id self.id_to_seq_id = id_to_seq_id - def __setstate__(self, state: Any) -> None: - # Fields were added after the initial implementation - self.total_elements_updated = 0 - self.total_invalid_operations = 0 - self.__dict__.update(state) - @staticmethod def load_from_file(filename: str) -> "PersistentData": """Load persistent data from a file""" @@ -100,6 +88,9 @@ class PersistentLocalHnswSegment(LocalHnswSegment): _opentelemtry_client: OpenTelemetryClient + _num_log_records_since_last_batch: int = 0 + _num_log_records_since_last_persist: int = 0 + def __init__(self, system: System, segment: Segment): super().__init__(system, segment) @@ -133,8 +124,6 @@ def __init__(self, system: System, segment: Segment): self._persist_data = PersistentData( self._dimensionality, self._total_elements_added, - self._total_elements_updated, - self._total_invalid_operations, self._max_seq_id, self._id_to_label, self._label_to_id, @@ -209,7 +198,6 @@ def _persist(self) -> None: # Persist the metadata self._persist_data.dimensionality = self._dimensionality self._persist_data.total_elements_added = self._total_elements_added - self._persist_data.total_elements_updated = self._total_elements_updated self._persist_data.max_seq_id = self._max_seq_id # TODO: This should really be stored in sqlite, the index itself, or a better @@ -221,30 +209,19 @@ def _persist(self) -> None: with open(self._get_metadata_file(), "wb") as metadata_file: pickle.dump(self._persist_data, metadata_file, pickle.HIGHEST_PROTOCOL) + self._num_log_records_since_last_persist = 0 + @trace_method( "PersistentLocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL ) @override def _apply_batch(self, batch: Batch) -> None: super()._apply_batch(batch) - num_elements_added_since_last_persist = ( - self._total_elements_added - self._persist_data.total_elements_added - ) - num_elements_updated_since_last_persist = ( - self._total_elements_updated - self._persist_data.total_elements_updated - ) - num_invalid_operations_since_last_persist = ( - self._total_invalid_operations - self._persist_data.total_invalid_operations - ) - - if ( - num_elements_added_since_last_persist - + num_elements_updated_since_last_persist - + num_invalid_operations_since_last_persist - >= self._sync_threshold - ): + if self._num_log_records_since_last_persist >= self._sync_threshold: self._persist() + self._num_log_records_since_last_batch = 0 + @trace_method( "PersistentLocalHnswSegment._write_records", OpenTelemetryGranularity.ALL ) @@ -255,6 +232,9 @@ def _write_records(self, records: Sequence[LogRecord]) -> None: raise RuntimeError("Cannot add embeddings to stopped component") with WriteRWLock(self._lock): for record in records: + self._num_log_records_since_last_batch += 1 + self._num_log_records_since_last_persist += 1 + if record["record"]["embedding"] is not None: self._ensure_index(len(records), len(record["record"]["embedding"])) if not self._index_initialized: @@ -291,12 +271,10 @@ def _write_records(self, records: Sequence[LogRecord]) -> None: logger.warning( f"Update of nonexisting embedding ID: {record['record']['id']}" ) - self._total_invalid_operations += 1 elif op == Operation.ADD: if record["record"]["embedding"] is not None: if exists_in_index and not id_is_pending_delete: logger.warning(f"Add of existing embedding ID: {id}") - self._total_invalid_operations += 1 else: self._curr_batch.apply(record, not exists_in_index) self._brute_force_index.upsert([record]) @@ -305,15 +283,7 @@ def _write_records(self, records: Sequence[LogRecord]) -> None: self._curr_batch.apply(record, exists_in_index) self._brute_force_index.upsert([record]) - num_invalid_operations_since_last_persist = ( - self._total_invalid_operations - - self._persist_data.total_invalid_operations - ) - - if ( - len(self._curr_batch) + num_invalid_operations_since_last_persist - >= self._batch_size - ): + if self._num_log_records_since_last_batch >= self._batch_size: self._apply_batch(self._curr_batch) self._curr_batch = Batch() self._brute_force_index.clear() diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index 8f832747fab..b4d5d864105 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -559,7 +559,7 @@ def sqlite() -> Generator[System, None, None]: system.stop() -def sqlite_persistent() -> Generator[System, None, None]: +def sqlite_persistent_fixture() -> Generator[System, None, None]: """Fixture generator for segment-based API using persistent Sqlite""" save_path = tempfile.TemporaryDirectory() settings = Settings( @@ -590,8 +590,19 @@ def sqlite_persistent() -> Generator[System, None, None]: raise e +@pytest.fixture(scope="module") +def sqlite_persistent() -> Generator[System, None, None]: + yield from sqlite_persistent_fixture() + + def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]: - fixtures = [fastapi, async_fastapi, fastapi_persistent, sqlite, sqlite_persistent] + fixtures = [ + fastapi, + async_fastapi, + fastapi_persistent, + sqlite, + sqlite_persistent_fixture, + ] if "CHROMA_INTEGRATION_TEST" in os.environ: fixtures.append(integration) if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ: @@ -605,7 +616,7 @@ def system_http_server_fixtures() -> List[Callable[[], Generator[System, None, N fixtures = [ fixture for fixture in system_fixtures() - if fixture != sqlite and fixture != sqlite_persistent + if fixture != sqlite and fixture != sqlite_persistent_fixture ] return fixtures diff --git a/chromadb/test/property/invariants.py b/chromadb/test/property/invariants.py index a7bd2c6cade..d53fb0b5b4e 100644 --- a/chromadb/test/property/invariants.py +++ b/chromadb/test/property/invariants.py @@ -304,7 +304,7 @@ def ann_accuracy( try: note( - f"recall: {recall}, missing {missing} out of {size}, accuracy threshold {accuracy_threshold}" + f"# recall: {recall}, missing {missing} out of {size}, accuracy threshold {accuracy_threshold}" ) except InvalidArgument: pass # it's ok if we're running outside hypothesis diff --git a/chromadb/test/property/test_restart_persist.py b/chromadb/test/property/test_restart_persist.py new file mode 100644 index 00000000000..0397387d013 --- /dev/null +++ b/chromadb/test/property/test_restart_persist.py @@ -0,0 +1,83 @@ +from overrides import overrides +from chromadb.api.client import Client +from chromadb.config import System +import hypothesis.strategies as st +from hypothesis.stateful import ( + rule, + run_state_machine_as_test, + initialize, +) + +from chromadb.test.property.test_embeddings import ( + EmbeddingStateMachineBase, + EmbeddingStateMachineStates, + trace, +) +import chromadb.test.property.strategies as strategies + + +collection_persistent_st = st.shared( + strategies.collections( + with_hnsw_params=True, + with_persistent_hnsw_params=st.just(True), + # Makes it more likely to find persist-related bugs (by default these are set to 2000). + max_hnsw_batch_size=10, + max_hnsw_sync_threshold=10, + ), + key="coll_persistent", +) + + +# This machine shares a lot of similarity with the machine in chromadb/test/property/test_persist.py. +# However, test_persist.py tests correctness under complete process isolation and therefore can only check invariants on a new system--whereas this machine does not have full process isolation between systems/clients but after a restart continues to exercise the state machine with the newly-created system. +class RestartablePersistedEmbeddingStateMachine(EmbeddingStateMachineBase): + system: System + + def __init__(self, system: System) -> None: + self.system = system + client = Client.from_system(system) + super().__init__(client) + + @initialize(collection=collection_persistent_st) # type: ignore + @overrides + def initialize(self, collection: strategies.Collection): + self.client.reset() + + self.collection = self.client.create_collection( + name=collection.name, + metadata=collection.metadata, # type: ignore + embedding_function=collection.embedding_function, + ) + self.embedding_function = collection.embedding_function + trace("init") + self.on_state_change(EmbeddingStateMachineStates.initialize) + + self.record_set_state = strategies.StateMachineRecordSet( + ids=[], metadatas=[], documents=[], embeddings=[] + ) + + @rule() + def restart_system(self) -> None: + # Simulates restarting chromadb + self.system.stop() + self.system = System(self.system.settings) + self.system.start() + self.client.clear_system_cache() + self.client = Client.from_system(self.system) + self.collection = self.client.get_collection( + self.collection.name, embedding_function=self.embedding_function + ) + + @overrides + def teardown(self) -> None: + super().teardown() + # Need to manually stop the system to cleanup resources because we may have created a new system (above rule). + # Normally, we wouldn't have to worry about this as the system from the fixture is shared between state machine runs. + # (This helps avoid a "too many open files" error.) + self.system.stop() + + +def test_restart_persisted_client(sqlite_persistent: System) -> None: + run_state_machine_as_test( + lambda: RestartablePersistedEmbeddingStateMachine(sqlite_persistent), + ) # type: ignore