Skip to content

Commit

Permalink
[CLN] LogRecord TypedDict attr name incorrect (#2486)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Fixes a typed dict related mypy error where we propagated an incorrect
name
	 - Fixes other or ignores other type errors
 - New functionality
	 - none

## Test plan
*How are these changes tested?*
Existing tests
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB committed Jul 10, 2024
1 parent b601a3b commit 27ba303
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 66 deletions.
4 changes: 2 additions & 2 deletions chromadb/db/mixins/embeddings_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def submit_embeddings(
# submit_embeddings so we do not reorder the records before submitting them
embedding_record = LogRecord(
log_offset=seq_id,
operation_record=OperationRecord(
record=OperationRecord(
id=id,
embedding=submit_embedding_record["embedding"],
encoding=submit_embedding_record["encoding"],
Expand Down Expand Up @@ -321,7 +321,7 @@ def _backfill(self, subscription: Subscription) -> None:
[
LogRecord(
log_offset=row[0],
operation_record=OperationRecord(
record=OperationRecord(
operation=_operation_codes_inv[row[1]],
id=row[2],
embedding=vector,
Expand Down
39 changes: 17 additions & 22 deletions chromadb/segment/impl/metadata/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,10 @@ def _insert_record(self, cur: Cursor, record: LogRecord, upsert: bool) -> None:
.into(t)
.columns(t.segment_id, t.embedding_id, t.seq_id)
.where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)))
.where(t.embedding_id == ParameterValue(record["operation_record"]["id"]))
.where(t.embedding_id == ParameterValue(record["record"]["id"]))
).insert(
ParameterValue(self._db.uuid_to_db(self._id)),
ParameterValue(record["operation_record"]["id"]),
ParameterValue(record["record"]["id"]),
ParameterValue(_encode_seq_id(record["log_offset"])),
)
sql, params = get_sql(q)
Expand All @@ -278,18 +278,17 @@ def _insert_record(self, cur: Cursor, record: LogRecord, upsert: bool) -> None:
except sqlite3.IntegrityError:
# Can't use INSERT OR REPLACE here because it changes the primary key.
if upsert:
# Cast here because the OpenTel decorators obfuscate the type
return cast(None, self._update_record(cur, record))
return self._update_record(cur, record)
else:
logger.warning(
f"Insert of existing embedding ID: {record['operation_record']['id']}"
f"Insert of existing embedding ID: {record['record']['id']}"
)
# We are trying to add for a record that already exists. Fail the call.
# We don't throw an exception since this is in principal an async path
return

if record["operation_record"]["metadata"]:
self._update_metadata(cur, id, record["operation_record"]["metadata"])
if record["record"]["metadata"]:
self._update_metadata(cur, id, record["record"]["metadata"])

@trace_method(
"SqliteMetadataSegment._update_metadata", OpenTelemetryGranularity.ALL
Expand Down Expand Up @@ -411,7 +410,7 @@ def _delete_record(self, cur: Cursor, record: LogRecord) -> None:
self._db.querybuilder()
.from_(t)
.where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)))
.where(t.embedding_id == ParameterValue(record["operation_record"]["id"]))
.where(t.embedding_id == ParameterValue(record["record"]["id"]))
.delete()
)
q_fts = (
Expand All @@ -426,10 +425,7 @@ def _delete_record(self, cur: Cursor, record: LogRecord) -> None:
.where(
t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
)
.where(
t.embedding_id
== ParameterValue(record["operation_record"]["id"])
)
.where(t.embedding_id == ParameterValue(record["record"]["id"]))
)
)
)
Expand All @@ -439,7 +435,7 @@ def _delete_record(self, cur: Cursor, record: LogRecord) -> None:
result = cur.execute(sql, params).fetchone()
if result is None:
logger.warning(
f"Delete of nonexisting embedding ID: {record['operation_record']['id']}"
f"Delete of nonexisting embedding ID: {record['record']['id']}"
)
else:
id = result[0]
Expand All @@ -466,19 +462,19 @@ def _update_record(self, cur: Cursor, record: LogRecord) -> None:
.update(t)
.set(t.seq_id, ParameterValue(_encode_seq_id(record["log_offset"])))
.where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)))
.where(t.embedding_id == ParameterValue(record["operation_record"]["id"]))
.where(t.embedding_id == ParameterValue(record["record"]["id"]))
)
sql, params = get_sql(q)
sql = sql + " RETURNING id"
result = cur.execute(sql, params).fetchone()
if result is None:
logger.warning(
f"Update of nonexisting embedding ID: {record['operation_record']['id']}"
f"Update of nonexisting embedding ID: {record['record']['id']}"
)
else:
id = result[0]
if record["operation_record"]["metadata"]:
self._update_metadata(cur, id, record["operation_record"]["metadata"])
if record["record"]["metadata"]:
self._update_metadata(cur, id, record["record"]["metadata"])

@trace_method("SqliteMetadataSegment._write_metadata", OpenTelemetryGranularity.ALL)
def _write_metadata(self, records: Sequence[LogRecord]) -> None:
Expand All @@ -498,14 +494,13 @@ def _write_metadata(self, records: Sequence[LogRecord]) -> None:
sql, params = get_sql(q)
sql = sql.replace("INSERT", "INSERT OR REPLACE")
cur.execute(sql, params)

if record["operation_record"]["operation"] == Operation.ADD:
if record["record"]["operation"] == Operation.ADD:
self._insert_record(cur, record, False)
elif record["operation_record"]["operation"] == Operation.UPSERT:
elif record["record"]["operation"] == Operation.UPSERT:
self._insert_record(cur, record, True)
elif record["operation_record"]["operation"] == Operation.DELETE:
elif record["record"]["operation"] == Operation.DELETE:
self._delete_record(cur, record)
elif record["operation_record"]["operation"] == Operation.UPDATE:
elif record["record"]["operation"] == Operation.UPDATE:
self._update_record(cur, record)

@trace_method(
Expand Down
24 changes: 9 additions & 15 deletions chromadb/segment/impl/vector/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def get_written_ids(self) -> List[str]:
def get_written_vectors(self, ids: List[str]) -> List[Vector]:
"""Get the list of vectors to write in this batch"""
return [
cast(Vector, self._ids_to_records[id]["operation_record"]["embedding"])
for id in ids
cast(Vector, self._ids_to_records[id]["record"]["embedding"]) for id in ids
]

def get_record(self, id: str) -> LogRecord:
Expand All @@ -60,26 +59,21 @@ def apply(self, record: LogRecord, exists_already: bool = False) -> None:
The exists_already flag should be set to True if the ID does exist in the index, and False otherwise.
"""

id = record["operation_record"]["id"]
if record["operation_record"]["operation"] == Operation.DELETE:
id = record["record"]["id"]
if record["record"]["operation"] == Operation.DELETE:
# If the ID was previously written, remove it from the written set
# And update the add/update/delete counts
if id in self._written_ids:
self._written_ids.remove(id)
if (
self._ids_to_records[id]["operation_record"]["operation"]
== Operation.ADD
):
if self._ids_to_records[id]["record"]["operation"] == Operation.ADD:
self.add_count -= 1
elif (
self._ids_to_records[id]["operation_record"]["operation"]
== Operation.UPDATE
self._ids_to_records[id]["record"]["operation"] == Operation.UPDATE
):
self.update_count -= 1
self._deleted_ids.add(id)
elif (
self._ids_to_records[id]["operation_record"]["operation"]
== Operation.UPSERT
self._ids_to_records[id]["record"]["operation"] == Operation.UPSERT
):
if id in self._upsert_add_ids:
self.add_count -= 1
Expand All @@ -104,15 +98,15 @@ def apply(self, record: LogRecord, exists_already: bool = False) -> None:
self._deleted_ids.remove(id)

# Update the add/update counts
if record["operation_record"]["operation"] == Operation.UPSERT:
if record["record"]["operation"] == Operation.UPSERT:
if not exists_already:
self.add_count += 1
self._upsert_add_ids.add(id)
else:
self.update_count += 1
elif record["operation_record"]["operation"] == Operation.ADD:
elif record["record"]["operation"] == Operation.ADD:
self.add_count += 1
elif record["operation_record"]["operation"] == Operation.UPDATE:
elif record["record"]["operation"] == Operation.UPDATE:
self.update_count += 1

self.max_seq_id = max(self.max_seq_id, record["log_offset"])
8 changes: 4 additions & 4 deletions chromadb/segment/impl/vector/brute_force_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def upsert(self, records: List[LogRecord]) -> None:
)

for i, record in enumerate(records):
id = record["operation_record"]["id"]
vector = record["operation_record"]["embedding"]
id = record["record"]["id"]
vector = record["record"]["embedding"]
self.id_to_seq_id[id] = record["log_offset"]
if id in self.deleted_ids:
self.deleted_ids.remove(id)
Expand All @@ -88,14 +88,14 @@ def upsert(self, records: List[LogRecord]) -> None:

def delete(self, records: List[LogRecord]) -> None:
for record in records:
id = record["operation_record"]["id"]
id = record["record"]["id"]
if id in self.id_to_index:
index = self.id_to_index[id]
self.deleted_ids.add(id)
del self.id_to_index[id]
del self.index_to_id[index]
del self.id_to_seq_id[id]
self.vectors[index].fill(np.NaN)
self.vectors[index].fill(np.nan)
self.free_indices.append(index)
else:
logger.warning(f"Delete of nonexisting embedding ID: {id}")
Expand Down
9 changes: 4 additions & 5 deletions chromadb/segment/impl/vector/local_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def __init__(self, system: System, segment: Segment):

self._lock = ReadWriteLock()
self._opentelemtry_client = system.require(OpenTelemetryClient)
super().__init__(system, segment)

@staticmethod
@override
Expand Down Expand Up @@ -292,8 +291,8 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:

for record in records:
self._max_seq_id = max(self._max_seq_id, record["log_offset"])
id = record["operation_record"]["id"]
op = record["operation_record"]["operation"]
id = record["record"]["id"]
op = record["record"]["operation"]
label = self._id_to_label.get(id, None)

if op == Operation.DELETE:
Expand All @@ -303,12 +302,12 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
logger.warning(f"Delete of nonexisting embedding ID: {id}")

elif op == Operation.UPDATE:
if record["operation_record"]["embedding"] is not None:
if record["record"]["embedding"] is not None:
if label is not None:
batch.apply(record)
else:
logger.warning(
f"Update of nonexisting embedding ID: {record['operation_record']['id']}"
f"Update of nonexisting embedding ID: {record['record']['id']}"
)
elif op == Operation.ADD:
if not label:
Expand Down
18 changes: 8 additions & 10 deletions chromadb/segment/impl/vector/local_persistent_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,8 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
raise RuntimeError("Cannot add embeddings to stopped component")
with WriteRWLock(self._lock):
for record in records:
if record["operation_record"]["embedding"] is not None:
self._ensure_index(
len(records), len(record["operation_record"]["embedding"])
)
if record["record"]["embedding"] is not None:
self._ensure_index(len(records), len(record["record"]["embedding"]))
if not self._index_initialized:
# If the index is not initialized here, it means that we have
# not yet added any records to the index. So we can just
Expand All @@ -240,8 +238,8 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
self._brute_force_index = cast(BruteForceIndex, self._brute_force_index)

self._max_seq_id = max(self._max_seq_id, record["log_offset"])
id = record["operation_record"]["id"]
op = record["operation_record"]["operation"]
id = record["record"]["id"]
op = record["record"]["operation"]
exists_in_index = self._id_to_label.get(
id, None
) is not None or self._brute_force_index.has_id(id)
Expand All @@ -256,23 +254,23 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
logger.warning(f"Delete of nonexisting embedding ID: {id}")

elif op == Operation.UPDATE:
if record["operation_record"]["embedding"] is not None:
if record["record"]["embedding"] is not None:
if exists_in_index:
self._curr_batch.apply(record)
self._brute_force_index.upsert([record])
else:
logger.warning(
f"Update of nonexisting embedding ID: {record['operation_record']['id']}"
f"Update of nonexisting embedding ID: {record['record']['id']}"
)
elif op == Operation.ADD:
if record["operation_record"]["embedding"] is not None:
if record["record"]["embedding"] is not None:
if not exists_in_index:
self._curr_batch.apply(record, not exists_in_index)
self._brute_force_index.upsert([record])
else:
logger.warning(f"Add of existing embedding ID: {id}")
elif op == Operation.UPSERT:
if record["operation_record"]["embedding"] is not None:
if record["record"]["embedding"] is not None:
self._curr_batch.apply(record, exists_in_index)
self._brute_force_index.upsert([record])
if len(self._curr_batch) >= self._batch_size:
Expand Down
14 changes: 6 additions & 8 deletions chromadb/test/ingest/test_producer_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,14 @@ def assert_records_match(
"""Given a list of inserted and consumed records, make sure they match"""
assert len(consumed_records) == len(inserted_records)
for inserted, consumed in zip(inserted_records, consumed_records):
assert inserted["id"] == consumed["operation_record"]["id"]
assert inserted["operation"] == consumed["operation_record"]["operation"]
assert inserted["encoding"] == consumed["operation_record"]["encoding"]
assert inserted["metadata"] == consumed["operation_record"]["metadata"]
assert inserted["id"] == consumed["record"]["id"]
assert inserted["operation"] == consumed["record"]["operation"]
assert inserted["encoding"] == consumed["record"]["encoding"]
assert inserted["metadata"] == consumed["record"]["metadata"]

if inserted["embedding"] is not None:
assert consumed["operation_record"]["embedding"] is not None
assert_approx_equal(
inserted["embedding"], consumed["operation_record"]["embedding"]
)
assert consumed["record"]["embedding"] is not None
assert_approx_equal(inserted["embedding"], consumed["record"]["embedding"])


@pytest.mark.asyncio
Expand Down

0 comments on commit 27ba303

Please sign in to comment.