Skip to content

Commit

Permalink
[BUG] fix multi collection log purge
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb committed Aug 2, 2024
1 parent 9472830 commit 26bfdee
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 14 deletions.
7 changes: 5 additions & 2 deletions chromadb/db/mixins/embeddings_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def delete_log(self, collection_id: UUID) -> None:

@trace_method("SqlEmbeddingsQueue.purge_log", OpenTelemetryGranularity.ALL)
@override
def purge_log(self) -> None:
def purge_log(self, collection_id: UUID) -> None:
segments_t = Table("segments")
segment_ids_q = (
self.querybuilder()
Expand All @@ -140,6 +140,9 @@ def purge_log(self) -> None:
# - > 1 has not never written to the max_seq_id table
# In that case, we should not delete any WAL entries as we can't be sure that the all segments are caught up.
.select(functions.Coalesce(Table("max_seq_id").seq_id, -1))
.where(
segments_t.collection == ParameterValue(self.uuid_to_db(collection_id))
)
.left_join(Table("max_seq_id"))
.on(segments_t.id == Table("max_seq_id").segment_id)
)
Expand Down Expand Up @@ -255,7 +258,7 @@ def submit_embeddings(
self._notify_all(topic_name, embedding_records)

if self.config.get_parameter("automatically_purge").value:
self.purge_log()
self.purge_log(collection_id)

return seq_ids

Expand Down
2 changes: 1 addition & 1 deletion chromadb/ingest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def delete_log(self, collection_id: UUID) -> None:
pass

@abstractmethod
def purge_log(self) -> None:
def purge_log(self, collection_id: UUID) -> None:
"""Truncates the log for the given collection, removing all seen records."""
pass

Expand Down
26 changes: 26 additions & 0 deletions chromadb/test/db/test_log_purge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from chromadb.api.client import Client
from chromadb.config import System
from chromadb.test.property import invariants


def test_log_purge(sqlite_persistent: System) -> None:
client = Client.from_system(sqlite_persistent)

first_collection = client.create_collection(
"first_collection", metadata={"hnsw:sync_threshold": 10, "hnsw:batch_size": 10}
)
second_collection = client.create_collection(
"second_collection", metadata={"hnsw:sync_threshold": 10, "hnsw:batch_size": 10}
)
collections = [first_collection, second_collection]

# (Does not trigger a purge)
for i in range(5):
first_collection.add(ids=str(i), embeddings=[i, i])

# (Should trigger a purge)
for i in range(100):
second_collection.add(ids=str(i), embeddings=[i, i])

# The purge of the second collection should not be blocked by the first
invariants.log_size_below_max(client._system, collections, True)
21 changes: 10 additions & 11 deletions chromadb/test/property/invariants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import gc
import math
from chromadb.api.configuration import HNSWConfigurationInternal
from chromadb.config import System
from chromadb.db.base import get_sql
from chromadb.db.impl.sqlite import SqliteDB
Expand Down Expand Up @@ -334,29 +333,29 @@ def _total_embedding_queue_log_size(sqlite: SqliteDB) -> int:


def log_size_below_max(
system: System, collection: Collection, has_collection_mutated: bool
system: System, collections: List[Collection], has_collection_mutated: bool
) -> None:
sqlite = system.instance(SqliteDB)

if has_collection_mutated:
# Must always keep one entry to avoid reusing seq_ids
assert _total_embedding_queue_log_size(sqlite) >= 1

hnsw_config = cast(
HNSWConfigurationInternal,
collection.get_model()
.get_configuration()
.get_parameter("hnsw_configuration")
.value,
# We purge per-collection as the sync_threshold is a per-collection setting
sync_threshold_sum = sum(
collection.metadata.get("hnsw:sync_threshold", 1000)
for collection in collections
)
batch_size_sum = sum(
collection.metadata.get("hnsw:batch_size", 1000)
for collection in collections
)
sync_threshold = cast(int, hnsw_config.get_parameter("sync_threshold").value)
batch_size = cast(int, hnsw_config.get_parameter("batch_size").value)

# -1 is used because the queue is always at least 1 entry long, so deletion stops before the max ack'ed sequence ID.
# And if the batch_size != sync_threshold, the queue can have up to batch_size - 1 more entries.
assert (
_total_embedding_queue_log_size(sqlite) - 1
<= sync_threshold + batch_size - 1
<= sync_threshold_sum + batch_size_sum - 1
)
else:
assert _total_embedding_queue_log_size(sqlite) == 0

0 comments on commit 26bfdee

Please sign in to comment.