diff --git a/chromadb/db/mixins/embeddings_queue.py b/chromadb/db/mixins/embeddings_queue.py index 913e6dc347d..8f84e4b5d6e 100644 --- a/chromadb/db/mixins/embeddings_queue.py +++ b/chromadb/db/mixins/embeddings_queue.py @@ -190,7 +190,7 @@ def submit_embeddings( # submit_embeddings so we do not reorder the records before submitting them embedding_record = LogRecord( log_offset=seq_id, - operation_record=OperationRecord( + record=OperationRecord( id=id, embedding=submit_embedding_record["embedding"], encoding=submit_embedding_record["encoding"], @@ -321,7 +321,7 @@ def _backfill(self, subscription: Subscription) -> None: [ LogRecord( log_offset=row[0], - operation_record=OperationRecord( + record=OperationRecord( operation=_operation_codes_inv[row[1]], id=row[2], embedding=vector, diff --git a/chromadb/segment/impl/metadata/sqlite.py b/chromadb/segment/impl/metadata/sqlite.py index 23751c8954c..aa895445f9b 100644 --- a/chromadb/segment/impl/metadata/sqlite.py +++ b/chromadb/segment/impl/metadata/sqlite.py @@ -265,10 +265,10 @@ def _insert_record(self, cur: Cursor, record: LogRecord, upsert: bool) -> None: .into(t) .columns(t.segment_id, t.embedding_id, t.seq_id) .where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))) - .where(t.embedding_id == ParameterValue(record["operation_record"]["id"])) + .where(t.embedding_id == ParameterValue(record["record"]["id"])) ).insert( ParameterValue(self._db.uuid_to_db(self._id)), - ParameterValue(record["operation_record"]["id"]), + ParameterValue(record["record"]["id"]), ParameterValue(_encode_seq_id(record["log_offset"])), ) sql, params = get_sql(q) @@ -278,18 +278,17 @@ def _insert_record(self, cur: Cursor, record: LogRecord, upsert: bool) -> None: except sqlite3.IntegrityError: # Can't use INSERT OR REPLACE here because it changes the primary key. if upsert: - # Cast here because the OpenTel decorators obfuscate the type - return cast(None, self._update_record(cur, record)) + return self._update_record(cur, record) else: logger.warning( - f"Insert of existing embedding ID: {record['operation_record']['id']}" + f"Insert of existing embedding ID: {record['record']['id']}" ) # We are trying to add for a record that already exists. Fail the call. # We don't throw an exception since this is in principal an async path return - if record["operation_record"]["metadata"]: - self._update_metadata(cur, id, record["operation_record"]["metadata"]) + if record["record"]["metadata"]: + self._update_metadata(cur, id, record["record"]["metadata"]) @trace_method( "SqliteMetadataSegment._update_metadata", OpenTelemetryGranularity.ALL @@ -411,7 +410,7 @@ def _delete_record(self, cur: Cursor, record: LogRecord) -> None: self._db.querybuilder() .from_(t) .where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))) - .where(t.embedding_id == ParameterValue(record["operation_record"]["id"])) + .where(t.embedding_id == ParameterValue(record["record"]["id"])) .delete() ) q_fts = ( @@ -426,10 +425,7 @@ def _delete_record(self, cur: Cursor, record: LogRecord) -> None: .where( t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)) ) - .where( - t.embedding_id - == ParameterValue(record["operation_record"]["id"]) - ) + .where(t.embedding_id == ParameterValue(record["record"]["id"])) ) ) ) @@ -439,7 +435,7 @@ def _delete_record(self, cur: Cursor, record: LogRecord) -> None: result = cur.execute(sql, params).fetchone() if result is None: logger.warning( - f"Delete of nonexisting embedding ID: {record['operation_record']['id']}" + f"Delete of nonexisting embedding ID: {record['record']['id']}" ) else: id = result[0] @@ -466,19 +462,19 @@ def _update_record(self, cur: Cursor, record: LogRecord) -> None: .update(t) .set(t.seq_id, ParameterValue(_encode_seq_id(record["log_offset"]))) .where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))) - .where(t.embedding_id == ParameterValue(record["operation_record"]["id"])) + .where(t.embedding_id == ParameterValue(record["record"]["id"])) ) sql, params = get_sql(q) sql = sql + " RETURNING id" result = cur.execute(sql, params).fetchone() if result is None: logger.warning( - f"Update of nonexisting embedding ID: {record['operation_record']['id']}" + f"Update of nonexisting embedding ID: {record['record']['id']}" ) else: id = result[0] - if record["operation_record"]["metadata"]: - self._update_metadata(cur, id, record["operation_record"]["metadata"]) + if record["record"]["metadata"]: + self._update_metadata(cur, id, record["record"]["metadata"]) @trace_method("SqliteMetadataSegment._write_metadata", OpenTelemetryGranularity.ALL) def _write_metadata(self, records: Sequence[LogRecord]) -> None: @@ -498,14 +494,13 @@ def _write_metadata(self, records: Sequence[LogRecord]) -> None: sql, params = get_sql(q) sql = sql.replace("INSERT", "INSERT OR REPLACE") cur.execute(sql, params) - - if record["operation_record"]["operation"] == Operation.ADD: + if record["record"]["operation"] == Operation.ADD: self._insert_record(cur, record, False) - elif record["operation_record"]["operation"] == Operation.UPSERT: + elif record["record"]["operation"] == Operation.UPSERT: self._insert_record(cur, record, True) - elif record["operation_record"]["operation"] == Operation.DELETE: + elif record["record"]["operation"] == Operation.DELETE: self._delete_record(cur, record) - elif record["operation_record"]["operation"] == Operation.UPDATE: + elif record["record"]["operation"] == Operation.UPDATE: self._update_record(cur, record) @trace_method( diff --git a/chromadb/segment/impl/vector/batch.py b/chromadb/segment/impl/vector/batch.py index 43cbe886005..d1f660d8f3a 100644 --- a/chromadb/segment/impl/vector/batch.py +++ b/chromadb/segment/impl/vector/batch.py @@ -38,8 +38,7 @@ def get_written_ids(self) -> List[str]: def get_written_vectors(self, ids: List[str]) -> List[Vector]: """Get the list of vectors to write in this batch""" return [ - cast(Vector, self._ids_to_records[id]["operation_record"]["embedding"]) - for id in ids + cast(Vector, self._ids_to_records[id]["record"]["embedding"]) for id in ids ] def get_record(self, id: str) -> LogRecord: @@ -60,26 +59,21 @@ def apply(self, record: LogRecord, exists_already: bool = False) -> None: The exists_already flag should be set to True if the ID does exist in the index, and False otherwise. """ - id = record["operation_record"]["id"] - if record["operation_record"]["operation"] == Operation.DELETE: + id = record["record"]["id"] + if record["record"]["operation"] == Operation.DELETE: # If the ID was previously written, remove it from the written set # And update the add/update/delete counts if id in self._written_ids: self._written_ids.remove(id) - if ( - self._ids_to_records[id]["operation_record"]["operation"] - == Operation.ADD - ): + if self._ids_to_records[id]["record"]["operation"] == Operation.ADD: self.add_count -= 1 elif ( - self._ids_to_records[id]["operation_record"]["operation"] - == Operation.UPDATE + self._ids_to_records[id]["record"]["operation"] == Operation.UPDATE ): self.update_count -= 1 self._deleted_ids.add(id) elif ( - self._ids_to_records[id]["operation_record"]["operation"] - == Operation.UPSERT + self._ids_to_records[id]["record"]["operation"] == Operation.UPSERT ): if id in self._upsert_add_ids: self.add_count -= 1 @@ -104,15 +98,15 @@ def apply(self, record: LogRecord, exists_already: bool = False) -> None: self._deleted_ids.remove(id) # Update the add/update counts - if record["operation_record"]["operation"] == Operation.UPSERT: + if record["record"]["operation"] == Operation.UPSERT: if not exists_already: self.add_count += 1 self._upsert_add_ids.add(id) else: self.update_count += 1 - elif record["operation_record"]["operation"] == Operation.ADD: + elif record["record"]["operation"] == Operation.ADD: self.add_count += 1 - elif record["operation_record"]["operation"] == Operation.UPDATE: + elif record["record"]["operation"] == Operation.UPDATE: self.update_count += 1 self.max_seq_id = max(self.max_seq_id, record["log_offset"]) diff --git a/chromadb/segment/impl/vector/brute_force_index.py b/chromadb/segment/impl/vector/brute_force_index.py index 3eef8e043d5..530123a2223 100644 --- a/chromadb/segment/impl/vector/brute_force_index.py +++ b/chromadb/segment/impl/vector/brute_force_index.py @@ -68,8 +68,8 @@ def upsert(self, records: List[LogRecord]) -> None: ) for i, record in enumerate(records): - id = record["operation_record"]["id"] - vector = record["operation_record"]["embedding"] + id = record["record"]["id"] + vector = record["record"]["embedding"] self.id_to_seq_id[id] = record["log_offset"] if id in self.deleted_ids: self.deleted_ids.remove(id) @@ -88,14 +88,14 @@ def upsert(self, records: List[LogRecord]) -> None: def delete(self, records: List[LogRecord]) -> None: for record in records: - id = record["operation_record"]["id"] + id = record["record"]["id"] if id in self.id_to_index: index = self.id_to_index[id] self.deleted_ids.add(id) del self.id_to_index[id] del self.index_to_id[index] del self.id_to_seq_id[id] - self.vectors[index].fill(np.NaN) + self.vectors[index].fill(np.nan) self.free_indices.append(index) else: logger.warning(f"Delete of nonexisting embedding ID: {id}") diff --git a/chromadb/segment/impl/vector/local_hnsw.py b/chromadb/segment/impl/vector/local_hnsw.py index b055762af4c..7358270c7d6 100644 --- a/chromadb/segment/impl/vector/local_hnsw.py +++ b/chromadb/segment/impl/vector/local_hnsw.py @@ -74,7 +74,6 @@ def __init__(self, system: System, segment: Segment): self._lock = ReadWriteLock() self._opentelemtry_client = system.require(OpenTelemetryClient) - super().__init__(system, segment) @staticmethod @override @@ -292,8 +291,8 @@ def _write_records(self, records: Sequence[LogRecord]) -> None: for record in records: self._max_seq_id = max(self._max_seq_id, record["log_offset"]) - id = record["operation_record"]["id"] - op = record["operation_record"]["operation"] + id = record["record"]["id"] + op = record["record"]["operation"] label = self._id_to_label.get(id, None) if op == Operation.DELETE: @@ -303,12 +302,12 @@ def _write_records(self, records: Sequence[LogRecord]) -> None: logger.warning(f"Delete of nonexisting embedding ID: {id}") elif op == Operation.UPDATE: - if record["operation_record"]["embedding"] is not None: + if record["record"]["embedding"] is not None: if label is not None: batch.apply(record) else: logger.warning( - f"Update of nonexisting embedding ID: {record['operation_record']['id']}" + f"Update of nonexisting embedding ID: {record['record']['id']}" ) elif op == Operation.ADD: if not label: diff --git a/chromadb/segment/impl/vector/local_persistent_hnsw.py b/chromadb/segment/impl/vector/local_persistent_hnsw.py index 110276395cd..9dadc906e98 100644 --- a/chromadb/segment/impl/vector/local_persistent_hnsw.py +++ b/chromadb/segment/impl/vector/local_persistent_hnsw.py @@ -228,10 +228,8 @@ def _write_records(self, records: Sequence[LogRecord]) -> None: raise RuntimeError("Cannot add embeddings to stopped component") with WriteRWLock(self._lock): for record in records: - if record["operation_record"]["embedding"] is not None: - self._ensure_index( - len(records), len(record["operation_record"]["embedding"]) - ) + if record["record"]["embedding"] is not None: + self._ensure_index(len(records), len(record["record"]["embedding"])) if not self._index_initialized: # If the index is not initialized here, it means that we have # not yet added any records to the index. So we can just @@ -240,8 +238,8 @@ def _write_records(self, records: Sequence[LogRecord]) -> None: self._brute_force_index = cast(BruteForceIndex, self._brute_force_index) self._max_seq_id = max(self._max_seq_id, record["log_offset"]) - id = record["operation_record"]["id"] - op = record["operation_record"]["operation"] + id = record["record"]["id"] + op = record["record"]["operation"] exists_in_index = self._id_to_label.get( id, None ) is not None or self._brute_force_index.has_id(id) @@ -256,23 +254,23 @@ def _write_records(self, records: Sequence[LogRecord]) -> None: logger.warning(f"Delete of nonexisting embedding ID: {id}") elif op == Operation.UPDATE: - if record["operation_record"]["embedding"] is not None: + if record["record"]["embedding"] is not None: if exists_in_index: self._curr_batch.apply(record) self._brute_force_index.upsert([record]) else: logger.warning( - f"Update of nonexisting embedding ID: {record['operation_record']['id']}" + f"Update of nonexisting embedding ID: {record['record']['id']}" ) elif op == Operation.ADD: - if record["operation_record"]["embedding"] is not None: + if record["record"]["embedding"] is not None: if not exists_in_index: self._curr_batch.apply(record, not exists_in_index) self._brute_force_index.upsert([record]) else: logger.warning(f"Add of existing embedding ID: {id}") elif op == Operation.UPSERT: - if record["operation_record"]["embedding"] is not None: + if record["record"]["embedding"] is not None: self._curr_batch.apply(record, exists_in_index) self._brute_force_index.upsert([record]) if len(self._curr_batch) >= self._batch_size: diff --git a/chromadb/test/ingest/test_producer_consumer.py b/chromadb/test/ingest/test_producer_consumer.py index d44faca5894..15554b1e708 100644 --- a/chromadb/test/ingest/test_producer_consumer.py +++ b/chromadb/test/ingest/test_producer_consumer.py @@ -135,16 +135,14 @@ def assert_records_match( """Given a list of inserted and consumed records, make sure they match""" assert len(consumed_records) == len(inserted_records) for inserted, consumed in zip(inserted_records, consumed_records): - assert inserted["id"] == consumed["operation_record"]["id"] - assert inserted["operation"] == consumed["operation_record"]["operation"] - assert inserted["encoding"] == consumed["operation_record"]["encoding"] - assert inserted["metadata"] == consumed["operation_record"]["metadata"] + assert inserted["id"] == consumed["record"]["id"] + assert inserted["operation"] == consumed["record"]["operation"] + assert inserted["encoding"] == consumed["record"]["encoding"] + assert inserted["metadata"] == consumed["record"]["metadata"] if inserted["embedding"] is not None: - assert consumed["operation_record"]["embedding"] is not None - assert_approx_equal( - inserted["embedding"], consumed["operation_record"]["embedding"] - ) + assert consumed["record"]["embedding"] is not None + assert_approx_equal(inserted["embedding"], consumed["record"]["embedding"]) @pytest.mark.asyncio