Skip to content

Commit

Permalink
feat: remove tombstones from the IVF remapping process by shrinking t…
Browse files Browse the repository at this point in the history
…he index (#1397)

I also simplify the remap tasks a bit by removing some unnecessary
traits.
  • Loading branch information
westonpace authored Oct 12, 2023
1 parent da14c0f commit f291c74
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 147 deletions.
1 change: 1 addition & 0 deletions rust/lance/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ async fn remap_index(
&field.name,
index_id,
&new_id,
matched,
row_id_map,
)
.await?;
Expand Down
8 changes: 8 additions & 0 deletions rust/lance/src/index/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ pub(crate) async fn remap_vector_index(
column: &str,
old_uuid: &Uuid,
new_uuid: &Uuid,
old_metadata: &crate::format::Index,
mapping: &HashMap<u64, Option<u64>>,
) -> Result<()> {
let old_index = open_index(dataset.clone(), column, &old_uuid.to_string()).await?;
Expand All @@ -290,8 +291,15 @@ pub(crate) async fn remap_vector_index(
dataset.as_ref(),
&old_uuid.to_string(),
&new_uuid.to_string(),
old_metadata.dataset_version,
ivf_index,
mapping,
old_metadata.name.clone(),
column.to_string(),
// We can safely assume there are no transforms today. We assert above that the
// top stage is IVF and IVF does not support transforms between IVF and PQ. This
// will be fixed in the future.
vec![],
)
.await?;
Ok(())
Expand Down
103 changes: 57 additions & 46 deletions rust/lance/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ use log::info;
use rand::{rngs::SmallRng, SeedableRng};
use serde::Serialize;
use snafu::{location, Location};
use tokio::io::AsyncWriteExt;
use tracing::{instrument, span, Level};

#[cfg(feature = "opq")]
Expand Down Expand Up @@ -737,21 +736,6 @@ pub async fn build_ivf_pq_index(
.await
}

#[async_trait]
trait RemapWriteTask: Send {
async fn write(self: Box<Self>, writer: &mut ObjectWriter) -> Result<()>;
}

#[async_trait]
trait RemapLoadTask: Send {
async fn load_and_remap(
self: Box<Self>,
reader: &dyn ObjectReader,
index: &IVFIndex,
mapping: &HashMap<u64, Option<u64>>,
) -> Result<Box<dyn RemapWriteTask>>;
}

struct RemapPageTask {
offset: usize,
length: u32,
Expand All @@ -768,14 +752,13 @@ impl RemapPageTask {
}
}

#[async_trait]
impl RemapLoadTask for RemapPageTask {
impl RemapPageTask {
async fn load_and_remap(
mut self: Box<Self>,
mut self,
reader: &dyn ObjectReader,
index: &IVFIndex,
mapping: &HashMap<u64, Option<u64>>,
) -> Result<Box<dyn RemapWriteTask>> {
) -> Result<Self> {
let mut page = index
.sub_index
.load(reader, self.offset, self.length as usize)
Expand All @@ -784,16 +767,16 @@ impl RemapLoadTask for RemapPageTask {
self.page = Some(page);
Ok(self)
}
}

#[async_trait]
impl RemapWriteTask for RemapPageTask {
async fn write(self: Box<Self>, writer: &mut ObjectWriter) -> Result<()> {
async fn write(self, writer: &mut ObjectWriter, ivf: &mut Ivf) -> Result<()> {
let page = self.page.as_ref().expect("Load was not called");
let page: &PQIndex = page
.as_any()
.downcast_ref()
.expect("Generic index writing not supported yet");
ivf.offsets.push(writer.tell());
ivf.lengths
.push(page.row_ids.as_ref().unwrap().len() as u32);
writer
.write_plain_encoded_array(&[page.code.as_ref().unwrap().as_ref()])
.await?;
Expand All @@ -804,25 +787,27 @@ impl RemapWriteTask for RemapPageTask {
}
}

fn generate_remap_tasks(
offsets: &Vec<usize>,
lengths: &[u32],
) -> Result<Vec<Box<dyn RemapLoadTask>>> {
let mut tasks: Vec<Box<dyn RemapLoadTask>> = Vec::with_capacity(offsets.len() * 2 + 1);
fn generate_remap_tasks(offsets: &Vec<usize>, lengths: &[u32]) -> Result<Vec<RemapPageTask>> {
let mut tasks: Vec<RemapPageTask> = Vec::with_capacity(offsets.len() * 2 + 1);

for (offset, length) in offsets.iter().zip(lengths.iter()) {
tasks.push(Box::new(RemapPageTask::new(*offset, *length)));
tasks.push(RemapPageTask::new(*offset, *length));
}

Ok(tasks)
}

#[allow(clippy::too_many_arguments)]
pub(crate) async fn remap_index_file(
dataset: &Dataset,
old_uuid: &str,
new_uuid: &str,
old_version: u64,
index: &IVFIndex,
mapping: &HashMap<u64, Option<u64>>,
name: String,
column: String,
transforms: Vec<pb::Transform>,
) -> Result<()> {
let object_store = dataset.object_store();
let old_path = dataset.indices_dir().child(old_uuid).child(INDEX_FILE_NAME);
Expand All @@ -834,26 +819,42 @@ pub(crate) async fn remap_index_file(
let tasks = generate_remap_tasks(&index.ivf.offsets, &index.ivf.lengths)?;

let mut task_stream = stream::iter(tasks.into_iter())
.then(|task| task.load_and_remap(reader.as_ref(), index, mapping));
// The below doesn't work today because of a bogus higher-rank lifetime error
// .map(|task| task.load_and_remap(reader.as_ref(), index, mapping))
// .buffered(num_cpus::get());
.map(|task| task.load_and_remap(reader.as_ref(), index, mapping))
.buffered(num_cpus::get());

let mut ivf = Ivf {
centroids: index.ivf.centroids.clone(),
offsets: Vec::with_capacity(index.ivf.offsets.len()),
lengths: Vec::with_capacity(index.ivf.lengths.len()),
};
while let Some(write_task) = task_stream.try_next().await? {
write_task.write(&mut writer).await?;
write_task.write(&mut writer, &mut ivf).await?;
}

// The index doesn't currently write down where the pages end. So we wait
// for the remapping to finish and then see how many bytes we copied. This
// will tell us since the remapped file should be the same size. Now we just
// copy the remaining bytes
let pos = writer.tell();
let file_size = reader.size().await?;
let remaining = reader.get_range(pos..file_size).await?;
let pq_sub_index = index
.sub_index
.as_any()
.downcast_ref::<PQIndex>()
.ok_or_else(|| Error::NotSupported {
source: "Remapping a non-pq sub-index".into(),
})?;

writer.write_all(&remaining).await?;
let metadata = IvfPQIndexMetadata {
name,
column,
dimension: index.ivf.dimension() as u32,
dataset_version: old_version,
ivf,
metric_type: index.metric_type,
pq: pq_sub_index.pq.clone(),
transforms,
};

let metadata = pb::Index::try_from(&metadata)?;
let pos = writer.write_protobuf(&metadata).await?;
writer.write_magics(pos).await?;
writer.shutdown().await?;

Ok(())
}

Expand Down Expand Up @@ -1360,9 +1361,19 @@ mod tests {
let new_uuid = Uuid::new_v4();
let new_uuid_str = new_uuid.to_string();

remap_index_file(&dataset, &uuid_str, &new_uuid_str, ivf_index, &mapping)
.await
.unwrap();
remap_index_file(
&dataset,
&uuid_str,
&new_uuid_str,
dataset.version().version,
ivf_index,
&mapping,
INDEX_NAME.to_string(),
WellKnownIvfPqData::COLUMN.to_string(),
vec![],
)
.await
.unwrap();

let remapped = open_index(
dataset.clone(),
Expand Down
130 changes: 29 additions & 101 deletions rust/lance/src/index/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use std::{any::Any, collections::BinaryHeap};

use arrow::datatypes::Float32Type;
use arrow_arith::aggregate::min;
use arrow_array::{
builder::Float32Builder, cast::as_primitive_array, Array, ArrayRef, FixedSizeListArray,
Float32Array, RecordBatch, UInt64Array, UInt8Array,
};
use arrow_ord::sort::sort_to_indices;
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use arrow_select::take::take;
use async_trait::async_trait;
Expand All @@ -35,12 +36,12 @@ use rand::SeedableRng;
use serde::Serialize;

use super::{MetricType, Query, VectorIndex};
use crate::arrow::*;
use crate::dataset::ROW_ID;
use crate::index::prefilter::PreFilter;
use crate::index::Index;
use crate::index::{pb, vector::kmeans::train_kmeans, vector::DIST_COL};
use crate::io::object_reader::{read_fixed_stride_array, ObjectReader};
use crate::{arrow::*, format::RowAddress};
use crate::{Error, Result};

/// Product Quantization Index.
Expand Down Expand Up @@ -250,69 +251,6 @@ impl Index for PQIndex {
}
}

// Helper struct for zipped distance + row id that sorts by distance
struct DistanceRowId {
distance: f32,
row_id: u64,
}

impl DistanceRowId {
fn new(distance: f32, row_id: u64) -> Self {
Self { distance, row_id }
}
}

impl DistanceRowId {
fn distance(&self) -> f32 {
self.distance
}
fn row_id(&self) -> u64 {
self.row_id
}
}

impl PartialEq for DistanceRowId {
fn eq(&self, other: &Self) -> bool {
// Note: we don't use == here, which is only PartialEq, since Ord
// requires Eq and so we use total_cmp which does satisfy Eq. That
// being said, I don't know if this matters. So feel free to just go
// back to == if this method is affecting perf
matches!(
self.distance().total_cmp(&other.distance()),
std::cmp::Ordering::Equal
)
}
}

impl Eq for DistanceRowId {}

impl PartialOrd for DistanceRowId {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl Ord for DistanceRowId {
// Note: this implementation of cmp is reversed. This is because BinaryHeap gives us
// the max items and we want the min items
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.distance().total_cmp(&self.distance())
}
}

// Helper function to take the top K items from an iterator in O(N) + K*log(N) time
fn top_k<I: Iterator>(iter: I, k: usize) -> Vec<I::Item>
where
I::Item: Ord,
{
let mut heap = BinaryHeap::from_iter(iter);
let mut items = Vec::with_capacity(k);
while !heap.is_empty() && items.len() < k {
items.push(heap.pop().unwrap());
}
items
}

#[async_trait]
impl VectorIndex for PQIndex {
/// Search top-k nearest neighbors for `key` within one PQ partition.
Expand Down Expand Up @@ -341,28 +279,10 @@ impl VectorIndex for PQIndex {

debug_assert_eq!(distances.len(), row_ids.len());

// Remove any tombstone rows from consideration when sorting
let distance_ids = distances
.as_any()
.downcast_ref::<Float32Array>()
.unwrap()
.values()
.iter()
.copied()
.zip(row_ids.values().iter().copied())
.filter(|(_, row_id)| *row_id != RowAddress::TOMBSTONE_ROW)
.map(|(distance, row_id)| DistanceRowId::new(distance, row_id));

let limit = query.k * query.refine_factor.unwrap_or(1) as usize;

let top_distance_ids = top_k(distance_ids, limit);

let distances = Arc::new(Float32Array::from_iter_values(
top_distance_ids.iter().map(|dist_id| dist_id.distance()),
));
let row_ids = Arc::new(UInt64Array::from_iter_values(
top_distance_ids.iter().map(|dist_id| dist_id.row_id()),
));
let indices = sort_to_indices(&distances, None, Some(limit))?;
let distances = take(&distances, &indices, None)?;
let row_ids = take(row_ids.as_ref(), &indices, None)?;

let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new(DIST_COL, DataType::Float32, false),
Expand Down Expand Up @@ -406,21 +326,29 @@ impl VectorIndex for PQIndex {
}

fn remap(&mut self, mapping: &HashMap<u64, Option<u64>>) -> Result<()> {
let remapped_ids =
UInt64Array::from_iter_values(self.row_ids.as_ref().unwrap().values().iter().map(
|old_row_id| {
mapping
.get(old_row_id)
.cloned()
// If the row id is not in the mapping then this row is not remapped and we keep as is
.unwrap_or(Some(*old_row_id))
// If the row is in the mapping, but maps to None, then it is deleted, and we insert
// a tombstone in its place
.unwrap_or(RowAddress::TOMBSTONE_ROW)
},
));

self.row_ids = Some(Arc::new(remapped_ids));
let code = self
.code
.as_ref()
.unwrap()
.values()
.chunks_exact(self.num_sub_vectors);
let row_ids = self.row_ids.as_ref().unwrap().values().iter();
let remapped = row_ids
.zip(code)
.filter_map(|(old_row_id, code)| {
let new_row_id = mapping.get(old_row_id).cloned();
// If the row id is not in the mapping then this row is not remapped and we keep as is
let new_row_id = new_row_id.unwrap_or(Some(*old_row_id));
new_row_id.map(|new_row_id| (new_row_id, code))
})
.collect::<Vec<_>>();

self.row_ids = Some(Arc::new(UInt64Array::from_iter_values(
remapped.iter().map(|(row_id, _)| *row_id),
)));
self.code = Some(Arc::new(UInt8Array::from_iter_values(
remapped.into_iter().flat_map(|(_, code)| code).copied(),
)));
Ok(())
}
}
Expand Down

0 comments on commit f291c74

Please sign in to comment.