-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] simplify logic for when to persist index changes #2539
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import shutil | ||
from overrides import override | ||
import pickle | ||
from typing import Any, Dict, List, Optional, Sequence, Set, cast | ||
from typing import Dict, List, Optional, Sequence, Set, cast | ||
from chromadb.config import System | ||
from chromadb.segment.impl.vector.batch import Batch | ||
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams | ||
|
@@ -40,9 +40,6 @@ class PersistentData: | |
"""Stores the data and metadata needed for a PersistentLocalHnswSegment""" | ||
|
||
dimensionality: Optional[int] | ||
total_elements_added: int | ||
total_elements_updated: int | ||
total_invalid_operations: int | ||
max_seq_id: SeqId | ||
|
||
id_to_label: Dict[str, int] | ||
|
@@ -52,29 +49,17 @@ class PersistentData: | |
def __init__( | ||
self, | ||
dimensionality: Optional[int], | ||
total_elements_added: int, | ||
total_elements_updated: int, | ||
total_invalid_operations: int, | ||
max_seq_id: int, | ||
id_to_label: Dict[str, int], | ||
label_to_id: Dict[int, str], | ||
id_to_seq_id: Dict[str, SeqId], | ||
): | ||
self.dimensionality = dimensionality | ||
self.total_elements_added = total_elements_added | ||
self.total_elements_updated = total_elements_updated | ||
self.total_invalid_operations = total_invalid_operations | ||
self.max_seq_id = max_seq_id | ||
self.id_to_label = id_to_label | ||
self.label_to_id = label_to_id | ||
self.id_to_seq_id = id_to_seq_id | ||
|
||
def __setstate__(self, state: Any) -> None: | ||
# Fields were added after the initial implementation | ||
self.total_elements_updated = 0 | ||
self.total_invalid_operations = 0 | ||
self.__dict__.update(state) | ||
|
||
@staticmethod | ||
def load_from_file(filename: str) -> "PersistentData": | ||
"""Load persistent data from a file""" | ||
|
@@ -100,6 +85,9 @@ class PersistentLocalHnswSegment(LocalHnswSegment): | |
|
||
_opentelemtry_client: OpenTelemetryClient | ||
|
||
_num_log_records_since_last_batch: int = 0 | ||
_num_log_records_since_last_persist: int = 0 | ||
|
||
def __init__(self, system: System, segment: Segment): | ||
super().__init__(system, segment) | ||
|
||
|
@@ -120,7 +108,6 @@ def __init__(self, system: System, segment: Segment): | |
self._get_metadata_file() | ||
) | ||
self._dimensionality = self._persist_data.dimensionality | ||
self._total_elements_added = self._persist_data.total_elements_added | ||
self._max_seq_id = self._persist_data.max_seq_id | ||
self._id_to_label = self._persist_data.id_to_label | ||
self._label_to_id = self._persist_data.label_to_id | ||
|
@@ -132,9 +119,6 @@ def __init__(self, system: System, segment: Segment): | |
else: | ||
self._persist_data = PersistentData( | ||
self._dimensionality, | ||
self._total_elements_added, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. confirming this was not used anywhere else There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. confirmed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. turns out this was the issue, I should have fully reasoned through it :/ |
||
self._total_elements_updated, | ||
self._total_invalid_operations, | ||
self._max_seq_id, | ||
self._id_to_label, | ||
self._label_to_id, | ||
|
@@ -208,8 +192,6 @@ def _persist(self) -> None: | |
|
||
# Persist the metadata | ||
self._persist_data.dimensionality = self._dimensionality | ||
self._persist_data.total_elements_added = self._total_elements_added | ||
self._persist_data.total_elements_updated = self._total_elements_updated | ||
self._persist_data.max_seq_id = self._max_seq_id | ||
|
||
# TODO: This should really be stored in sqlite, the index itself, or a better | ||
|
@@ -221,30 +203,19 @@ def _persist(self) -> None: | |
with open(self._get_metadata_file(), "wb") as metadata_file: | ||
pickle.dump(self._persist_data, metadata_file, pickle.HIGHEST_PROTOCOL) | ||
|
||
self._num_log_records_since_last_persist = 0 | ||
|
||
@trace_method( | ||
"PersistentLocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL | ||
) | ||
@override | ||
def _apply_batch(self, batch: Batch) -> None: | ||
super()._apply_batch(batch) | ||
num_elements_added_since_last_persist = ( | ||
self._total_elements_added - self._persist_data.total_elements_added | ||
) | ||
num_elements_updated_since_last_persist = ( | ||
self._total_elements_updated - self._persist_data.total_elements_updated | ||
) | ||
num_invalid_operations_since_last_persist = ( | ||
self._total_invalid_operations - self._persist_data.total_invalid_operations | ||
) | ||
|
||
if ( | ||
num_elements_added_since_last_persist | ||
+ num_elements_updated_since_last_persist | ||
+ num_invalid_operations_since_last_persist | ||
>= self._sync_threshold | ||
): | ||
if self._num_log_records_since_last_persist >= self._sync_threshold: | ||
self._persist() | ||
|
||
self._num_log_records_since_last_batch = 0 | ||
|
||
@trace_method( | ||
"PersistentLocalHnswSegment._write_records", OpenTelemetryGranularity.ALL | ||
) | ||
|
@@ -255,6 +226,9 @@ def _write_records(self, records: Sequence[LogRecord]) -> None: | |
raise RuntimeError("Cannot add embeddings to stopped component") | ||
with WriteRWLock(self._lock): | ||
for record in records: | ||
self._num_log_records_since_last_batch += 1 | ||
self._num_log_records_since_last_persist += 1 | ||
|
||
if record["record"]["embedding"] is not None: | ||
self._ensure_index(len(records), len(record["record"]["embedding"])) | ||
if not self._index_initialized: | ||
|
@@ -305,15 +279,7 @@ def _write_records(self, records: Sequence[LogRecord]) -> None: | |
self._curr_batch.apply(record, exists_in_index) | ||
self._brute_force_index.upsert([record]) | ||
|
||
num_invalid_operations_since_last_persist = ( | ||
self._total_invalid_operations | ||
- self._persist_data.total_invalid_operations | ||
) | ||
|
||
if ( | ||
len(self._curr_batch) + num_invalid_operations_since_last_persist | ||
>= self._batch_size | ||
): | ||
if self._num_log_records_since_last_batch >= self._batch_size: | ||
self._apply_batch(self._curr_batch) | ||
self._curr_batch = Batch() | ||
self._brute_force_index.clear() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested with pickle and it seems to be fine if you remove a field from a class def