Skip to content

Commit

Permalink
[ENH] Handle metadata deletes + fix bugs related to Updates/deletes i…
Browse files Browse the repository at this point in the history
…n the metadata writer (#2344)

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
    - Handles metadata deletes
- Full text writer adds (token, freq) pair even if freq is 0. Fixes
this.
- Full text writer does not remove postings list of documents that have
been deleted. Fixes this.
    - Fix for test_query_without_add

## Test plan
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
sanketkedia authored Jun 18, 2024
1 parent f4d45f0 commit 7684d61
Show file tree
Hide file tree
Showing 11 changed files with 657 additions and 369 deletions.
1 change: 1 addition & 0 deletions chromadb/proto/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def to_proto_metadata_update_value(
return proto.UpdateMetadataValue(int_value=value)
elif isinstance(value, float):
return proto.UpdateMetadataValue(float_value=value)
# None is used to delete the metadata key.
elif value is None:
return proto.UpdateMetadataValue()
else:
Expand Down
17 changes: 9 additions & 8 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from collections import defaultdict
import chromadb.test.property.invariants as invariants
from chromadb.test.conftest import reset
import numpy as np


Expand Down Expand Up @@ -75,7 +76,7 @@ def __init__(self, api: ServerAPI):

@initialize(collection=collection_st) # type: ignore
def initialize(self, collection: strategies.Collection):
self.api.reset()
reset(self.api)
self.collection = self.api.create_collection(
name=collection.name,
metadata=collection.metadata,
Expand Down Expand Up @@ -306,7 +307,7 @@ def test_embeddings_state(caplog: pytest.LogCaptureFixture, api: ServerAPI) -> N


def test_multi_add(api: ServerAPI) -> None:
api.reset()
reset(api)
coll = api.create_collection(name="foo")
coll.add(ids=["a"], embeddings=[[0.0]])
assert coll.count() == 1
Expand All @@ -325,7 +326,7 @@ def test_multi_add(api: ServerAPI) -> None:


def test_dup_add(api: ServerAPI) -> None:
api.reset()
reset(api)
coll = api.create_collection(name="foo")
with pytest.raises(errors.DuplicateIDError):
coll.add(ids=["a", "a"], embeddings=[[0.0], [1.1]])
Expand All @@ -334,7 +335,7 @@ def test_dup_add(api: ServerAPI) -> None:


def test_query_without_add(api: ServerAPI) -> None:
api.reset()
reset(api)
coll = api.create_collection(name="foo")
fields: Include = ["documents", "metadatas", "embeddings", "distances"]
N = np.random.randint(1, 2000)
Expand All @@ -349,7 +350,7 @@ def test_query_without_add(api: ServerAPI) -> None:


def test_get_non_existent(api: ServerAPI) -> None:
api.reset()
reset(api)
coll = api.create_collection(name="foo")
result = coll.get(ids=["a"], include=["documents", "metadatas", "embeddings"])
assert len(result["ids"]) == 0
Expand All @@ -361,7 +362,7 @@ def test_get_non_existent(api: ServerAPI) -> None:
# TODO: Use SQL escaping correctly internally
@pytest.mark.xfail(reason="We don't properly escape SQL internally, causing problems")
def test_escape_chars_in_ids(api: ServerAPI) -> None:
api.reset()
reset(api)
id = "\x1f"
coll = api.create_collection(name="foo")
coll.add(ids=[id], embeddings=[[0.0]])
Expand All @@ -381,7 +382,7 @@ def test_escape_chars_in_ids(api: ServerAPI) -> None:
],
)
def test_delete_empty_fails(api: ServerAPI, kwargs: dict):
api.reset()
reset(api)
coll = api.create_collection(name="foo")
with pytest.raises(Exception) as e:
coll.delete(**kwargs)
Expand All @@ -404,7 +405,7 @@ def test_delete_empty_fails(api: ServerAPI, kwargs: dict):
],
)
def test_delete_success(api: ServerAPI, kwargs: dict):
api.reset()
reset(api)
coll = api.create_collection(name="foo")
# Should not raise
coll.delete(**kwargs)
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/src/blockstore/arrow/blockfile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ mod tests {
log::config::{self, GrpcLogConfig},
segment::DataRecord,
storage::{local::LocalStorage, Storage},
types::{update_metdata_to_metdata, MetadataValue},
types::MetadataValue,
};
use arrow::array::Int32Array;
use proptest::prelude::*;
Expand Down
2 changes: 2 additions & 0 deletions rust/worker/src/blockstore/positional_posting_list_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ impl PositionalPostingListBuilder {
return Err(PositionalPostingListBuilderError::DocIdDoesNotExist);
}

// Safe to unwrap here since this is called for >= 2nd time a token
// exists in the document.
self.positions.get_mut(&doc_id).unwrap().extend(positions);
Ok(())
}
Expand Down
7 changes: 2 additions & 5 deletions rust/worker/src/execution/operators/merge_metadata_results.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@ use crate::{
record_segment::{RecordSegmentReader, RecordSegmentReaderCreationError},
LogMaterializer, LogMaterializerError,
},
types::{
update_metdata_to_metdata, LogRecord, Metadata, MetadataValueConversionError, Operation,
Segment,
},
utils::{merge_sorted_vecs_conjunction, merge_sorted_vecs_disjunction},
types::{LogRecord, Metadata, MetadataValueConversionError, Operation, Segment},
utils::merge_sorted_vecs_conjunction,
};
use async_trait::async_trait;
use std::{
Expand Down
29 changes: 29 additions & 0 deletions rust/worker/src/execution/orchestration/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,28 @@ impl HnswQueryOrchestrator {
}
}

fn terminate_with_empty_response(&mut self, ctx: &ComponentContext<Self>) {
let result_channel = self
.result_channel
.take()
.expect("Invariant violation. Result channel is not set.");
let mut empty_resp = vec![];
for _ in 0..self.query_vectors.len() {
empty_resp.push(vec![]);
}
match result_channel.send(Ok(empty_resp)) {
Ok(_) => (),
Err(e) => {
// Log an error - this implied the listener was dropped
tracing::error!(
"[HnswQueryOrchestrator] Result channel dropped before sending empty response"
);
}
}
// Cancel the orchestrator so it stops processing
ctx.cancellation_token.cancel();
}

fn terminate_with_error(&mut self, error: Box<dyn ChromaError>, ctx: &ComponentContext<Self>) {
let result_channel = self
.result_channel
Expand Down Expand Up @@ -501,6 +523,13 @@ impl Component for HnswQueryOrchestrator {
}
};

// If segment is uninitialized and dimension is not set then we assume
// that this is a query before any add so return empty response.
if hnsw_segment.file_path.len() <= 0 && collection.dimension.is_none() {
self.terminate_with_empty_response(ctx);
return;
}

// Validate that the collection has a dimension set. Downstream steps will rely on this
// so that they can unwrap the dimension without checking for None
if collection.dimension.is_none() {
Expand Down
111 changes: 84 additions & 27 deletions rust/worker/src/index/fulltext/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ use super::tokenizer::ChromaTokenStream;

#[derive(Error, Debug)]
pub enum FullTextIndexError {
#[error("Multiple tokens found in frequencies blockfile")]
MultipleTokenFrequencies,
#[error("Empty value in positional posting list")]
EmptyValueInPositionalPostingList,
#[error("Invariant violation")]
Expand All @@ -37,9 +35,24 @@ impl ChromaError for FullTextIndexError {
}
}

pub(crate) struct UncommittedPostings {
// token -> {doc -> [start positions]}
positional_postings: HashMap<String, PositionalPostingListBuilder>,
// (token, doc) pairs that should be deleted from storage.
deleted_token_doc_pairs: Vec<(String, i32)>,
}

impl UncommittedPostings {
pub(crate) fn new() -> Self {
Self {
positional_postings: HashMap::new(),
deleted_token_doc_pairs: Vec::new(),
}
}
}

#[derive(Clone)]
pub(crate) struct FullTextIndexWriter<'me> {
// We use this to implement updates which require read-then-write semantics.
full_text_index_reader: Option<FullTextIndexReader<'me>>,
posting_lists_blockfile_writer: BlockfileWriter,
frequencies_blockfile_writer: BlockfileWriter,
Expand All @@ -49,11 +62,12 @@ pub(crate) struct FullTextIndexWriter<'me> {
// a lightweight lock instead. This is needed currently to
// keep holding the lock across an await point.
// term -> positional posting list builder for that term
uncommitted: Arc<tokio::sync::Mutex<HashMap<String, PositionalPostingListBuilder>>>,
uncommitted_postings: Arc<tokio::sync::Mutex<UncommittedPostings>>,
// TODO(Sanket): Move off this tokio::sync::mutex and use
// a lightweight lock instead. This is needed currently to
// keep holding the lock across an await point.
// Value of this map is a tuple because we also need to keep the old frequency
// Value of this map is a tuple (old freq and new freq)
// because we also need to keep the old frequency
// around. The reason is (token, freq) is the key in the blockfile hence
// when freq changes, we need to delete the old (token, freq) key.
uncommitted_frequencies: Arc<tokio::sync::Mutex<HashMap<String, (i32, i32)>>>,
Expand All @@ -71,7 +85,7 @@ impl<'me> FullTextIndexWriter<'me> {
posting_lists_blockfile_writer,
frequencies_blockfile_writer,
tokenizer: Arc::new(Mutex::new(tokenizer)),
uncommitted: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
uncommitted_postings: Arc::new(tokio::sync::Mutex::new(UncommittedPostings::new())),
uncommitted_frequencies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
}
}
Expand All @@ -85,18 +99,21 @@ impl<'me> FullTextIndexWriter<'me> {
Some(_) => return Ok(()),
None => {
let frequency = match &self.full_text_index_reader {
// Readers are uninitialized until the first compaction finishes
// so there is a case when this is none hence not an error.
None => 0,
Some(reader) => match reader.get_frequencies_for_token(token).await {
Ok(frequency) => frequency,
// New token so start with frequency of 0.
Err(_) => 0,
},
};
uncommitted_frequencies
.insert(token.to_string(), (frequency as i32, frequency as i32));
}
}
let mut uncommitted = self.uncommitted.lock().await;
match uncommitted.get(token) {
let mut uncommitted_postings = self.uncommitted_postings.lock().await;
match uncommitted_postings.positional_postings.get(token) {
Some(_) => {
// This should never happen -- if uncommitted has the token, then
// uncommitted_frequencies should have had it as well.
Expand All @@ -108,9 +125,12 @@ impl<'me> FullTextIndexWriter<'me> {
None => {
let mut builder = PositionalPostingListBuilder::new();
let results = match &self.full_text_index_reader {
// Readers are uninitialized until the first compaction finishes
// so there is a case when this is none hence not an error.
None => vec![],
Some(reader) => match reader.get_all_results_for_token(token).await {
Ok(results) => results,
// New token so start with empty postings list.
Err(_) => vec![],
},
};
Expand All @@ -123,7 +143,9 @@ impl<'me> FullTextIndexWriter<'me> {
}
}
}
uncommitted.insert(token.to_string(), builder);
uncommitted_postings
.positional_postings
.insert(token.to_string(), builder);
}
}
Ok(())
Expand All @@ -145,11 +167,16 @@ impl<'me> FullTextIndexWriter<'me> {
self.populate_frequencies_and_posting_lists_from_previous_version(token.text.as_str())
.await?;
let mut uncommitted_frequencies = self.uncommitted_frequencies.lock().await;
// The entry should always exist because self.populate_frequencies_and_posting_lists_from_previous_version
// will have created it if this token is new to the system.
uncommitted_frequencies
.entry(token.text.to_string())
.and_modify(|e| (*e).0 += 1);
let mut uncommitted = self.uncommitted.lock().await;
let builder = uncommitted
let mut uncommitted_postings = self.uncommitted_postings.lock().await;
// For a new token, the uncommitted list will not contain any entry so insert
// an empty builder in that case.
let builder = uncommitted_postings
.positional_postings
.entry(token.text.to_string())
.or_insert(PositionalPostingListBuilder::new());

Expand Down Expand Up @@ -198,10 +225,19 @@ impl<'me> FullTextIndexWriter<'me> {
return Err(FullTextIndexError::InvariantViolation);
}
}
let mut uncommitted = self.uncommitted.lock().await;
match uncommitted.get_mut(token.text.as_str()) {
let mut uncommitted_postings = self.uncommitted_postings.lock().await;
match uncommitted_postings
.positional_postings
.get_mut(token.text.as_str())
{
Some(builder) => match builder.delete_doc_id(offset_id as i32) {
Ok(_) => {}
Ok(_) => {
// Track all the deleted (token, doc) pairs. This is needed
// to remove the old postings list for this pair from storage.
uncommitted_postings
.deleted_token_doc_pairs
.push((token.text.clone(), offset_id as i32));
}
Err(e) => {
// This is a fatal invariant violation: we've been asked to
// delete a document which doesn't appear in the positional posting list.
Expand Down Expand Up @@ -234,10 +270,24 @@ impl<'me> FullTextIndexWriter<'me> {
Ok(())
}

// TODO(Sanket): Handle document and metadata deletes.
pub async fn write_to_blockfiles(&mut self) -> Result<(), FullTextIndexError> {
let mut uncommitted = self.uncommitted.lock().await;
for (key, mut value) in uncommitted.drain() {
let mut uncommitted_postings = self.uncommitted_postings.lock().await;
// Delete (token, doc) pairs from blockfile first. Note that the ordering is
// important here i.e. we need to delete before inserting the new postings
// list otherwise we could incorrectly delete posting lists that shouldn't be deleted.
for (token, offset_id) in uncommitted_postings.deleted_token_doc_pairs.drain(..) {
match self
.posting_lists_blockfile_writer
.delete::<u32, &Int32Array>(token.as_str(), offset_id as u32)
.await
{
Ok(_) => {}
Err(e) => {
return Err(FullTextIndexError::BlockfileWriteError(e));
}
}
}
for (key, mut value) in uncommitted_postings.positional_postings.drain() {
let built_list = value.build();
for doc_id in built_list.doc_ids.iter() {
match doc_id {
Expand Down Expand Up @@ -275,15 +325,19 @@ impl<'me> FullTextIndexWriter<'me> {
}
}
// Insert the new frequency.
// Add only if the frequency is not zero. This can happen in case of document
// deletes.
// TODO we just have token -> frequency here. Should frequency be the key or should we use an empty key and make it the value?
match self
.frequencies_blockfile_writer
.set(key.as_str(), value.0 as u32, 0)
.await
{
Ok(_) => {}
Err(e) => {
return Err(FullTextIndexError::BlockfileWriteError(e));
if value.0 > 0 {
match self
.frequencies_blockfile_writer
.set(key.as_str(), value.0 as u32, 0)
.await
{
Ok(_) => {}
Err(e) => {
return Err(FullTextIndexError::BlockfileWriteError(e));
}
}
}
}
Expand Down Expand Up @@ -385,9 +439,12 @@ impl<'me> FullTextIndexReader<'me> {
return Ok(vec![]);
}
if res.len() > 1 {
return Err(FullTextIndexError::MultipleTokenFrequencies);
panic!("Invariant violation. Multiple frequency values found for a token.");
}
let res = res[0];
if res.1 <= 0 {
panic!("Invariant violation. Zero frequency token found.");
}
// Throw away the "value" since we store frequencies in the keys.
token_frequencies.push((token.text.to_string(), res.1));
}
Expand Down Expand Up @@ -509,7 +566,7 @@ impl<'me> FullTextIndexReader<'me> {
return Ok(0);
}
if res.len() > 1 {
return Err(FullTextIndexError::MultipleTokenFrequencies);
panic!("Invariant violation. Multiple frequency values found for a token.");
}
Ok(res[0].1)
}
Expand Down
Loading

0 comments on commit 7684d61

Please sign in to comment.