Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Handle metadata deletes + fix bugs related to Updates/deletes in the metadata writer #2344

Merged
merged 4 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chromadb/proto/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def to_proto_metadata_update_value(
return proto.UpdateMetadataValue(int_value=value)
elif isinstance(value, float):
return proto.UpdateMetadataValue(float_value=value)
# None is used to delete the metadata key.
elif value is None:
return proto.UpdateMetadataValue()
else:
Expand Down
17 changes: 9 additions & 8 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from collections import defaultdict
import chromadb.test.property.invariants as invariants
from chromadb.test.conftest import reset
import numpy as np


Expand Down Expand Up @@ -75,7 +76,7 @@ def __init__(self, api: ServerAPI):

@initialize(collection=collection_st) # type: ignore
def initialize(self, collection: strategies.Collection):
self.api.reset()
reset(self.api)
self.collection = self.api.create_collection(
name=collection.name,
metadata=collection.metadata,
Expand Down Expand Up @@ -306,7 +307,7 @@ def test_embeddings_state(caplog: pytest.LogCaptureFixture, api: ServerAPI) -> N


def test_multi_add(api: ServerAPI) -> None:
api.reset()
reset(api)
coll = api.create_collection(name="foo")
coll.add(ids=["a"], embeddings=[[0.0]])
assert coll.count() == 1
Expand All @@ -325,7 +326,7 @@ def test_multi_add(api: ServerAPI) -> None:


def test_dup_add(api: ServerAPI) -> None:
api.reset()
reset(api)
coll = api.create_collection(name="foo")
with pytest.raises(errors.DuplicateIDError):
coll.add(ids=["a", "a"], embeddings=[[0.0], [1.1]])
Expand All @@ -334,7 +335,7 @@ def test_dup_add(api: ServerAPI) -> None:


def test_query_without_add(api: ServerAPI) -> None:
api.reset()
reset(api)
coll = api.create_collection(name="foo")
fields: Include = ["documents", "metadatas", "embeddings", "distances"]
N = np.random.randint(1, 2000)
Expand All @@ -349,7 +350,7 @@ def test_query_without_add(api: ServerAPI) -> None:


def test_get_non_existent(api: ServerAPI) -> None:
api.reset()
reset(api)
coll = api.create_collection(name="foo")
result = coll.get(ids=["a"], include=["documents", "metadatas", "embeddings"])
assert len(result["ids"]) == 0
Expand All @@ -361,7 +362,7 @@ def test_get_non_existent(api: ServerAPI) -> None:
# TODO: Use SQL escaping correctly internally
@pytest.mark.xfail(reason="We don't properly escape SQL internally, causing problems")
def test_escape_chars_in_ids(api: ServerAPI) -> None:
api.reset()
reset(api)
id = "\x1f"
coll = api.create_collection(name="foo")
coll.add(ids=[id], embeddings=[[0.0]])
Expand All @@ -381,7 +382,7 @@ def test_escape_chars_in_ids(api: ServerAPI) -> None:
],
)
def test_delete_empty_fails(api: ServerAPI, kwargs: dict):
api.reset()
reset(api)
coll = api.create_collection(name="foo")
with pytest.raises(Exception) as e:
coll.delete(**kwargs)
Expand All @@ -404,7 +405,7 @@ def test_delete_empty_fails(api: ServerAPI, kwargs: dict):
],
)
def test_delete_success(api: ServerAPI, kwargs: dict):
api.reset()
reset(api)
coll = api.create_collection(name="foo")
# Should not raise
coll.delete(**kwargs)
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/src/blockstore/arrow/blockfile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ mod tests {
log::config::{self, GrpcLogConfig},
segment::DataRecord,
storage::{local::LocalStorage, Storage},
types::{update_metdata_to_metdata, MetadataValue},
types::MetadataValue,
};
use arrow::array::Int32Array;
use proptest::prelude::*;
Expand Down
2 changes: 2 additions & 0 deletions rust/worker/src/blockstore/positional_posting_list_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ impl PositionalPostingListBuilder {
return Err(PositionalPostingListBuilderError::DocIdDoesNotExist);
}

// Safe to unwrap here since this is called for >= 2nd time a token
// exists in the document.
self.positions.get_mut(&doc_id).unwrap().extend(positions);
Ok(())
}
Expand Down
7 changes: 2 additions & 5 deletions rust/worker/src/execution/operators/merge_metadata_results.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@ use crate::{
record_segment::{RecordSegmentReader, RecordSegmentReaderCreationError},
LogMaterializer, LogMaterializerError,
},
types::{
update_metdata_to_metdata, LogRecord, Metadata, MetadataValueConversionError, Operation,
Segment,
},
utils::{merge_sorted_vecs_conjunction, merge_sorted_vecs_disjunction},
types::{LogRecord, Metadata, MetadataValueConversionError, Operation, Segment},
utils::merge_sorted_vecs_conjunction,
};
use async_trait::async_trait;
use std::{
Expand Down
29 changes: 29 additions & 0 deletions rust/worker/src/execution/orchestration/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,28 @@ impl HnswQueryOrchestrator {
}
}

fn terminate_with_empty_response(&mut self, ctx: &ComponentContext<Self>) {
let result_channel = self
.result_channel
.take()
.expect("Invariant violation. Result channel is not set.");
let mut empty_resp = vec![];
for _ in 0..self.query_vectors.len() {
empty_resp.push(vec![]);
}
match result_channel.send(Ok(empty_resp)) {
Ok(_) => (),
Err(e) => {
// Log an error - this implied the listener was dropped
tracing::error!(
"[HnswQueryOrchestrator] Result channel dropped before sending empty response"
);
}
}
// Cancel the orchestrator so it stops processing
ctx.cancellation_token.cancel();
}

fn terminate_with_error(&mut self, error: Box<dyn ChromaError>, ctx: &ComponentContext<Self>) {
let result_channel = self
.result_channel
Expand Down Expand Up @@ -501,6 +523,13 @@ impl Component for HnswQueryOrchestrator {
}
};

// If segment is uninitialized and dimension is not set then we assume
// that this is a query before any add so return empty response.
if hnsw_segment.file_path.len() <= 0 && collection.dimension.is_none() {
self.terminate_with_empty_response(ctx);
return;
}

// Validate that the collection has a dimension set. Downstream steps will rely on this
// so that they can unwrap the dimension without checking for None
if collection.dimension.is_none() {
Expand Down
111 changes: 84 additions & 27 deletions rust/worker/src/index/fulltext/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ use super::tokenizer::ChromaTokenStream;

#[derive(Error, Debug)]
pub enum FullTextIndexError {
#[error("Multiple tokens found in frequencies blockfile")]
MultipleTokenFrequencies,
#[error("Empty value in positional posting list")]
EmptyValueInPositionalPostingList,
#[error("Invariant violation")]
Expand All @@ -37,9 +35,24 @@ impl ChromaError for FullTextIndexError {
}
}

pub(crate) struct UncommittedPostings {
// token -> {doc -> [start positions]}
positional_postings: HashMap<String, PositionalPostingListBuilder>,
// (token, doc) pairs that should be deleted from storage.
deleted_token_doc_pairs: Vec<(String, i32)>,
}

impl UncommittedPostings {
pub(crate) fn new() -> Self {
Self {
positional_postings: HashMap::new(),
deleted_token_doc_pairs: Vec::new(),
}
}
}

#[derive(Clone)]
pub(crate) struct FullTextIndexWriter<'me> {
// We use this to implement updates which require read-then-write semantics.
full_text_index_reader: Option<FullTextIndexReader<'me>>,
posting_lists_blockfile_writer: BlockfileWriter,
frequencies_blockfile_writer: BlockfileWriter,
Expand All @@ -49,11 +62,12 @@ pub(crate) struct FullTextIndexWriter<'me> {
// a lightweight lock instead. This is needed currently to
// keep holding the lock across an await point.
// term -> positional posting list builder for that term
uncommitted: Arc<tokio::sync::Mutex<HashMap<String, PositionalPostingListBuilder>>>,
uncommitted_postings: Arc<tokio::sync::Mutex<UncommittedPostings>>,
// TODO(Sanket): Move off this tokio::sync::mutex and use
// a lightweight lock instead. This is needed currently to
// keep holding the lock across an await point.
// Value of this map is a tuple because we also need to keep the old frequency
// Value of this map is a tuple (old freq and new freq)
// because we also need to keep the old frequency
// around. The reason is (token, freq) is the key in the blockfile hence
// when freq changes, we need to delete the old (token, freq) key.
uncommitted_frequencies: Arc<tokio::sync::Mutex<HashMap<String, (i32, i32)>>>,
Expand All @@ -71,7 +85,7 @@ impl<'me> FullTextIndexWriter<'me> {
posting_lists_blockfile_writer,
frequencies_blockfile_writer,
tokenizer: Arc::new(Mutex::new(tokenizer)),
uncommitted: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
uncommitted_postings: Arc::new(tokio::sync::Mutex::new(UncommittedPostings::new())),
uncommitted_frequencies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
}
}
Expand All @@ -85,18 +99,21 @@ impl<'me> FullTextIndexWriter<'me> {
Some(_) => return Ok(()),
None => {
let frequency = match &self.full_text_index_reader {
// Readers are uninitialized until the first compaction finishes
// so there is a case when this is none hence not an error.
None => 0,
Some(reader) => match reader.get_frequencies_for_token(token).await {
Ok(frequency) => frequency,
// New token so start with frequency of 0.
Err(_) => 0,
},
};
uncommitted_frequencies
.insert(token.to_string(), (frequency as i32, frequency as i32));
}
}
let mut uncommitted = self.uncommitted.lock().await;
match uncommitted.get(token) {
let mut uncommitted_postings = self.uncommitted_postings.lock().await;
match uncommitted_postings.positional_postings.get(token) {
Some(_) => {
// This should never happen -- if uncommitted has the token, then
// uncommitted_frequencies should have had it as well.
Expand All @@ -108,9 +125,12 @@ impl<'me> FullTextIndexWriter<'me> {
None => {
let mut builder = PositionalPostingListBuilder::new();
let results = match &self.full_text_index_reader {
// Readers are uninitialized until the first compaction finishes
// so there is a case when this is none hence not an error.
None => vec![],
Some(reader) => match reader.get_all_results_for_token(token).await {
Ok(results) => results,
// New token so start with empty postings list.
Err(_) => vec![],
},
};
Expand All @@ -123,7 +143,9 @@ impl<'me> FullTextIndexWriter<'me> {
}
}
}
uncommitted.insert(token.to_string(), builder);
uncommitted_postings
.positional_postings
.insert(token.to_string(), builder);
}
}
Ok(())
Expand All @@ -145,11 +167,16 @@ impl<'me> FullTextIndexWriter<'me> {
self.populate_frequencies_and_posting_lists_from_previous_version(token.text.as_str())
.await?;
let mut uncommitted_frequencies = self.uncommitted_frequencies.lock().await;
// The entry should always exist because self.populate_frequencies_and_posting_lists_from_previous_version
// will have created it if this token is new to the system.
uncommitted_frequencies
.entry(token.text.to_string())
.and_modify(|e| (*e).0 += 1);
let mut uncommitted = self.uncommitted.lock().await;
let builder = uncommitted
let mut uncommitted_postings = self.uncommitted_postings.lock().await;
// For a new token, the uncommitted list will not contain any entry so insert
// an empty builder in that case.
let builder = uncommitted_postings
.positional_postings
.entry(token.text.to_string())
.or_insert(PositionalPostingListBuilder::new());

Expand Down Expand Up @@ -198,10 +225,19 @@ impl<'me> FullTextIndexWriter<'me> {
return Err(FullTextIndexError::InvariantViolation);
}
}
let mut uncommitted = self.uncommitted.lock().await;
match uncommitted.get_mut(token.text.as_str()) {
let mut uncommitted_postings = self.uncommitted_postings.lock().await;
match uncommitted_postings
.positional_postings
.get_mut(token.text.as_str())
{
Some(builder) => match builder.delete_doc_id(offset_id as i32) {
Ok(_) => {}
Ok(_) => {
// Track all the deleted (token, doc) pairs. This is needed
// to remove the old postings list for this pair from storage.
uncommitted_postings
.deleted_token_doc_pairs
.push((token.text.clone(), offset_id as i32));
}
Err(e) => {
// This is a fatal invariant violation: we've been asked to
// delete a document which doesn't appear in the positional posting list.
Expand Down Expand Up @@ -234,10 +270,24 @@ impl<'me> FullTextIndexWriter<'me> {
Ok(())
}

// TODO(Sanket): Handle document and metadata deletes.
pub async fn write_to_blockfiles(&mut self) -> Result<(), FullTextIndexError> {
let mut uncommitted = self.uncommitted.lock().await;
for (key, mut value) in uncommitted.drain() {
let mut uncommitted_postings = self.uncommitted_postings.lock().await;
// Delete (token, doc) pairs from blockfile first. Note that the ordering is
// important here i.e. we need to delete before inserting the new postings
// list otherwise we could incorrectly delete posting lists that shouldn't be deleted.
for (token, offset_id) in uncommitted_postings.deleted_token_doc_pairs.drain(..) {
match self
.posting_lists_blockfile_writer
.delete::<u32, &Int32Array>(token.as_str(), offset_id as u32)
.await
{
Ok(_) => {}
Err(e) => {
return Err(FullTextIndexError::BlockfileWriteError(e));
}
}
}
for (key, mut value) in uncommitted_postings.positional_postings.drain() {
let built_list = value.build();
for doc_id in built_list.doc_ids.iter() {
match doc_id {
Expand Down Expand Up @@ -275,15 +325,19 @@ impl<'me> FullTextIndexWriter<'me> {
}
}
// Insert the new frequency.
// Add only if the frequency is not zero. This can happen in case of document
// deletes.
// TODO we just have token -> frequency here. Should frequency be the key or should we use an empty key and make it the value?
match self
.frequencies_blockfile_writer
.set(key.as_str(), value.0 as u32, 0)
.await
{
Ok(_) => {}
Err(e) => {
return Err(FullTextIndexError::BlockfileWriteError(e));
if value.0 > 0 {
match self
.frequencies_blockfile_writer
.set(key.as_str(), value.0 as u32, 0)
.await
{
Ok(_) => {}
Err(e) => {
return Err(FullTextIndexError::BlockfileWriteError(e));
}
}
}
}
Expand Down Expand Up @@ -385,9 +439,12 @@ impl<'me> FullTextIndexReader<'me> {
return Ok(vec![]);
}
if res.len() > 1 {
return Err(FullTextIndexError::MultipleTokenFrequencies);
panic!("Invariant violation. Multiple frequency values found for a token.");
}
let res = res[0];
if res.1 <= 0 {
panic!("Invariant violation. Zero frequency token found.");
}
// Throw away the "value" since we store frequencies in the keys.
token_frequencies.push((token.text.to_string(), res.1));
}
Expand Down Expand Up @@ -509,7 +566,7 @@ impl<'me> FullTextIndexReader<'me> {
return Ok(0);
}
if res.len() > 1 {
return Err(FullTextIndexError::MultipleTokenFrequencies);
panic!("Invariant violation. Multiple frequency values found for a token.");
}
Ok(res[0].1)
}
Expand Down
Loading
Loading