Skip to content

Commit

Permalink
[ENH] simplify logic for when to persist index changes (re-apply with…
Browse files Browse the repository at this point in the history
… fix) (#2545)

History:

- #2539 originally introduced
some of the changes here
- it introduced a bug not caught by our test suite due to limitations of
our suite at the time
- changes were reverted with
#2544
- and now this PR adds a test that [was previously re-producing the
bug](https://github.com/chroma-core/chroma/actions/runs/10012532835/job/27678256604?pr=2545)
before it was fixed here

Should not be merged until hnswlib version is bumped (bug around
persisting a single item was recently fixed).
  • Loading branch information
codetheweb committed Jul 29, 2024
1 parent 48da264 commit 8907e53
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 56 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/_python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ jobs:
"chromadb/test/property/test_cross_version_persist.py",
"chromadb/test/property/test_embeddings.py",
"chromadb/test/property/test_filtering.py",
"chromadb/test/property/test_persist.py"]
"chromadb/test/property/test_persist.py",
"chromadb/test/property/test_restart_persist.py"]
include:
- test-globs: "chromadb/test/property/test_embeddings.py"
parallelized: true
Expand Down
8 changes: 0 additions & 8 deletions chromadb/segment/impl/vector/local_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ class LocalHnswSegment(VectorReader):
_index: Optional[hnswlib.Index]
_dimensionality: Optional[int]
_total_elements_added: int
_total_elements_updated: int
_total_invalid_operations: int
_max_seq_id: SeqId

_lock: ReadWriteLock
Expand All @@ -68,8 +66,6 @@ def __init__(self, system: System, segment: Segment):
self._index = None
self._dimensionality = None
self._total_elements_added = 0
self._total_elements_updated = 0
self._total_invalid_operations = 0
self._max_seq_id = self._consumer.min_seqid()

self._id_to_seq_id = {}
Expand Down Expand Up @@ -279,7 +275,6 @@ def _apply_batch(self, batch: Batch) -> None:

# If that succeeds, update the total count
self._total_elements_added += batch.add_count
self._total_elements_updated += batch.update_count

# If that succeeds, finally the seq ID
self._max_seq_id = batch.max_seq_id
Expand All @@ -305,7 +300,6 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
batch.apply(record)
else:
logger.warning(f"Delete of nonexisting embedding ID: {id}")
self._total_invalid_operations += 1

elif op == Operation.UPDATE:
if record["record"]["embedding"] is not None:
Expand All @@ -315,13 +309,11 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
logger.warning(
f"Update of nonexisting embedding ID: {record['record']['id']}"
)
self._total_invalid_operations += 1
elif op == Operation.ADD:
if not label:
batch.apply(record, False)
else:
logger.warning(f"Add of existing embedding ID: {id}")
self._total_invalid_operations += 1
elif op == Operation.UPSERT:
batch.apply(record, label is not None)

Expand Down
56 changes: 13 additions & 43 deletions chromadb/segment/impl/vector/local_persistent_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
from overrides import override
import pickle
from typing import Any, Dict, List, Optional, Sequence, Set, cast
from typing import Dict, List, Optional, Sequence, Set, cast
from chromadb.config import System
from chromadb.segment.impl.vector.batch import Batch
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
Expand Down Expand Up @@ -41,8 +41,6 @@ class PersistentData:

dimensionality: Optional[int]
total_elements_added: int
total_elements_updated: int
total_invalid_operations: int
max_seq_id: SeqId

id_to_label: Dict[str, int]
Expand All @@ -53,28 +51,18 @@ def __init__(
self,
dimensionality: Optional[int],
total_elements_added: int,
total_elements_updated: int,
total_invalid_operations: int,
max_seq_id: int,
id_to_label: Dict[str, int],
label_to_id: Dict[int, str],
id_to_seq_id: Dict[str, SeqId],
):
self.dimensionality = dimensionality
self.total_elements_added = total_elements_added
self.total_elements_updated = total_elements_updated
self.total_invalid_operations = total_invalid_operations
self.max_seq_id = max_seq_id
self.id_to_label = id_to_label
self.label_to_id = label_to_id
self.id_to_seq_id = id_to_seq_id

def __setstate__(self, state: Any) -> None:
# Fields were added after the initial implementation
self.total_elements_updated = 0
self.total_invalid_operations = 0
self.__dict__.update(state)

@staticmethod
def load_from_file(filename: str) -> "PersistentData":
"""Load persistent data from a file"""
Expand All @@ -100,6 +88,9 @@ class PersistentLocalHnswSegment(LocalHnswSegment):

_opentelemtry_client: OpenTelemetryClient

_num_log_records_since_last_batch: int = 0
_num_log_records_since_last_persist: int = 0

def __init__(self, system: System, segment: Segment):
super().__init__(system, segment)

Expand Down Expand Up @@ -133,8 +124,6 @@ def __init__(self, system: System, segment: Segment):
self._persist_data = PersistentData(
self._dimensionality,
self._total_elements_added,
self._total_elements_updated,
self._total_invalid_operations,
self._max_seq_id,
self._id_to_label,
self._label_to_id,
Expand Down Expand Up @@ -209,7 +198,6 @@ def _persist(self) -> None:
# Persist the metadata
self._persist_data.dimensionality = self._dimensionality
self._persist_data.total_elements_added = self._total_elements_added
self._persist_data.total_elements_updated = self._total_elements_updated
self._persist_data.max_seq_id = self._max_seq_id

# TODO: This should really be stored in sqlite, the index itself, or a better
Expand All @@ -221,30 +209,19 @@ def _persist(self) -> None:
with open(self._get_metadata_file(), "wb") as metadata_file:
pickle.dump(self._persist_data, metadata_file, pickle.HIGHEST_PROTOCOL)

self._num_log_records_since_last_persist = 0

@trace_method(
"PersistentLocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL
)
@override
def _apply_batch(self, batch: Batch) -> None:
super()._apply_batch(batch)
num_elements_added_since_last_persist = (
self._total_elements_added - self._persist_data.total_elements_added
)
num_elements_updated_since_last_persist = (
self._total_elements_updated - self._persist_data.total_elements_updated
)
num_invalid_operations_since_last_persist = (
self._total_invalid_operations - self._persist_data.total_invalid_operations
)

if (
num_elements_added_since_last_persist
+ num_elements_updated_since_last_persist
+ num_invalid_operations_since_last_persist
>= self._sync_threshold
):
if self._num_log_records_since_last_persist >= self._sync_threshold:
self._persist()

self._num_log_records_since_last_batch = 0

@trace_method(
"PersistentLocalHnswSegment._write_records", OpenTelemetryGranularity.ALL
)
Expand All @@ -255,6 +232,9 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
raise RuntimeError("Cannot add embeddings to stopped component")
with WriteRWLock(self._lock):
for record in records:
self._num_log_records_since_last_batch += 1
self._num_log_records_since_last_persist += 1

if record["record"]["embedding"] is not None:
self._ensure_index(len(records), len(record["record"]["embedding"]))
if not self._index_initialized:
Expand Down Expand Up @@ -291,12 +271,10 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
logger.warning(
f"Update of nonexisting embedding ID: {record['record']['id']}"
)
self._total_invalid_operations += 1
elif op == Operation.ADD:
if record["record"]["embedding"] is not None:
if exists_in_index and not id_is_pending_delete:
logger.warning(f"Add of existing embedding ID: {id}")
self._total_invalid_operations += 1
else:
self._curr_batch.apply(record, not exists_in_index)
self._brute_force_index.upsert([record])
Expand All @@ -305,15 +283,7 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
self._curr_batch.apply(record, exists_in_index)
self._brute_force_index.upsert([record])

num_invalid_operations_since_last_persist = (
self._total_invalid_operations
- self._persist_data.total_invalid_operations
)

if (
len(self._curr_batch) + num_invalid_operations_since_last_persist
>= self._batch_size
):
if self._num_log_records_since_last_batch >= self._batch_size:
self._apply_batch(self._curr_batch)
self._curr_batch = Batch()
self._brute_force_index.clear()
Expand Down
17 changes: 14 additions & 3 deletions chromadb/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def sqlite() -> Generator[System, None, None]:
system.stop()


def sqlite_persistent() -> Generator[System, None, None]:
def sqlite_persistent_fixture() -> Generator[System, None, None]:
"""Fixture generator for segment-based API using persistent Sqlite"""
save_path = tempfile.TemporaryDirectory()
settings = Settings(
Expand Down Expand Up @@ -590,8 +590,19 @@ def sqlite_persistent() -> Generator[System, None, None]:
raise e


@pytest.fixture(scope="module")
def sqlite_persistent() -> Generator[System, None, None]:
yield from sqlite_persistent_fixture()


def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]:
fixtures = [fastapi, async_fastapi, fastapi_persistent, sqlite, sqlite_persistent]
fixtures = [
fastapi,
async_fastapi,
fastapi_persistent,
sqlite,
sqlite_persistent_fixture,
]
if "CHROMA_INTEGRATION_TEST" in os.environ:
fixtures.append(integration)
if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ:
Expand All @@ -605,7 +616,7 @@ def system_http_server_fixtures() -> List[Callable[[], Generator[System, None, N
fixtures = [
fixture
for fixture in system_fixtures()
if fixture != sqlite and fixture != sqlite_persistent
if fixture != sqlite and fixture != sqlite_persistent_fixture
]
return fixtures

Expand Down
2 changes: 1 addition & 1 deletion chromadb/test/property/invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def ann_accuracy(

try:
note(
f"recall: {recall}, missing {missing} out of {size}, accuracy threshold {accuracy_threshold}"
f"# recall: {recall}, missing {missing} out of {size}, accuracy threshold {accuracy_threshold}"
)
except InvalidArgument:
pass # it's ok if we're running outside hypothesis
Expand Down
83 changes: 83 additions & 0 deletions chromadb/test/property/test_restart_persist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from overrides import overrides
from chromadb.api.client import Client
from chromadb.config import System
import hypothesis.strategies as st
from hypothesis.stateful import (
rule,
run_state_machine_as_test,
initialize,
)

from chromadb.test.property.test_embeddings import (
EmbeddingStateMachineBase,
EmbeddingStateMachineStates,
trace,
)
import chromadb.test.property.strategies as strategies


collection_persistent_st = st.shared(
strategies.collections(
with_hnsw_params=True,
with_persistent_hnsw_params=st.just(True),
# Makes it more likely to find persist-related bugs (by default these are set to 2000).
max_hnsw_batch_size=10,
max_hnsw_sync_threshold=10,
),
key="coll_persistent",
)


# This machine shares a lot of similarity with the machine in chromadb/test/property/test_persist.py.
# However, test_persist.py tests correctness under complete process isolation and therefore can only check invariants on a new system--whereas this machine does not have full process isolation between systems/clients but after a restart continues to exercise the state machine with the newly-created system.
class RestartablePersistedEmbeddingStateMachine(EmbeddingStateMachineBase):
system: System

def __init__(self, system: System) -> None:
self.system = system
client = Client.from_system(system)
super().__init__(client)

@initialize(collection=collection_persistent_st) # type: ignore
@overrides
def initialize(self, collection: strategies.Collection):
self.client.reset()

self.collection = self.client.create_collection(
name=collection.name,
metadata=collection.metadata, # type: ignore
embedding_function=collection.embedding_function,
)
self.embedding_function = collection.embedding_function
trace("init")
self.on_state_change(EmbeddingStateMachineStates.initialize)

self.record_set_state = strategies.StateMachineRecordSet(
ids=[], metadatas=[], documents=[], embeddings=[]
)

@rule()
def restart_system(self) -> None:
# Simulates restarting chromadb
self.system.stop()
self.system = System(self.system.settings)
self.system.start()
self.client.clear_system_cache()
self.client = Client.from_system(self.system)
self.collection = self.client.get_collection(
self.collection.name, embedding_function=self.embedding_function
)

@overrides
def teardown(self) -> None:
super().teardown()
# Need to manually stop the system to cleanup resources because we may have created a new system (above rule).
# Normally, we wouldn't have to worry about this as the system from the fixture is shared between state machine runs.
# (This helps avoid a "too many open files" error.)
self.system.stop()


def test_restart_persisted_client(sqlite_persistent: System) -> None:
run_state_machine_as_test(
lambda: RestartablePersistedEmbeddingStateMachine(sqlite_persistent),
) # type: ignore

0 comments on commit 8907e53

Please sign in to comment.