Skip to content

Commit

Permalink
just use _rowid validity as selection vector
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Oct 15, 2023
1 parent a5d4813 commit 1fd2203
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 70 deletions.
4 changes: 2 additions & 2 deletions rust/lance/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2700,7 +2700,7 @@ mod tests {
);
assert_eq!(
schema.field_with_name("_distance").unwrap(),
&Field::new("_distance", DataType::Float32, false)
&Field::new("_distance", DataType::Float32, true)
);
}
}
Expand Down Expand Up @@ -2759,7 +2759,7 @@ mod tests {
);
assert_eq!(
schema.field_with_name("_distance").unwrap(),
&Field::new("_distance", DataType::Float32, false)
&Field::new("_distance", DataType::Float32, true)
);
}
}
Expand Down
18 changes: 8 additions & 10 deletions rust/lance/src/dataset/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -752,32 +752,30 @@ mod tests {
let test_dir = tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();
let mut dataset = create_dataset(test_uri).await;
dataset.delete("i >= 80 and i < 95").await.unwrap();
dataset.delete("i >= 0 and i < 15").await.unwrap();

let fragment = &dataset.get_fragments()[2];
let fragment = &dataset.get_fragments()[0];
let mut reader = fragment.open(dataset.schema()).await.unwrap();
reader.with_make_deletions_null();
reader.with_row_id();

// Since the first batch is all deleted, it will return an empty batch.
let batch1 = reader.read_batch(0, ..).await.unwrap();
assert_eq!(batch1.num_rows(), 0);

// The second batch is partially deleted, so the deleted rows will be
// marked null across all columns.
// marked null with null row ids.
let batch2 = reader.read_batch(1, ..).await.unwrap();
for i in 0..batch2.num_columns() {
assert_eq!(batch2.column(i).null_count(), 5);
}
assert_eq!(
batch2.column_by_name("i").unwrap().as_ref(),
&Int32Array::from_iter((90..100).map(|v| if v < 95 { None } else { Some(v) }))
batch2.column_by_name(ROW_ID).unwrap().as_ref(),
&UInt64Array::from_iter((10..20).map(|v| if v < 15 { None } else { Some(v) }))
);

// The final batch is not deleted, so it will be returned as-is.
let batch3 = reader.read_batch(2, ..).await.unwrap();
assert_eq!(
batch3.column_by_name("i").unwrap().as_ref(),
&Int32Array::from_iter_values(100..110)
batch3.column_by_name(ROW_ID).unwrap().as_ref(),
&UInt64Array::from_iter_values(20..30)
);
}

Expand Down
12 changes: 9 additions & 3 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ use crate::{Error, Result};
use snafu::{location, Location};
/// Column name for the meta row ID.
pub const ROW_ID: &str = "_rowid";

lazy_static::lazy_static! {
/// Row ID field. This is nullable because its validity bitmap is sometimes used
/// as a selection vector.
pub static ref ROW_ID_FIELD: ArrowField = ArrowField::new(ROW_ID, DataType::UInt64, true);
}
pub const DEFAULT_BATCH_SIZE: usize = 8192;

// Same as pyarrow Dataset::scanner()
Expand Down Expand Up @@ -376,7 +382,7 @@ impl Scanner {
extra_columns.push(ArrowField::new(DIST_COL, DataType::Float32, true));
};
if self.with_row_id {
extra_columns.push(ArrowField::new(ROW_ID, DataType::UInt64, false));
extra_columns.push(ROW_ID_FIELD.clone());
}

let schema = if !extra_columns.is_empty() {
Expand Down Expand Up @@ -1274,7 +1280,7 @@ mod test {
),
true,
),
ArrowField::new("_distance", DataType::Float32, false),
ArrowField::new("_distance", DataType::Float32, true),
])
);

Expand Down Expand Up @@ -1378,7 +1384,7 @@ mod test {
),
true,
),
ArrowField::new("_distance", DataType::Float32, false),
ArrowField::new("_distance", DataType::Float32, true),
])
);

Expand Down
4 changes: 2 additions & 2 deletions rust/lance/src/index/vector/diskann/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use tracing::instrument;

use super::row_vertex::{RowVertex, RowVertexSerDe};
use crate::{
dataset::{Dataset, ROW_ID},
dataset::{scanner::ROW_ID_FIELD, Dataset},
index::{
prefilter::PreFilter,
vector::{
Expand Down Expand Up @@ -227,7 +227,7 @@ impl VectorIndex for DiskANNIndex {
async fn search(&self, query: &Query, pre_filter: &PreFilter) -> Result<RecordBatch> {
let state = greedy_search(&self.graph, 0, query.key.values(), query.k, query.k * 2).await?;
let schema = Arc::new(Schema::new(vec![
Field::new(ROW_ID, DataType::UInt64, true),
ROW_ID_FIELD.clone(),
Field::new(DIST_COL, DataType::Float32, true),
]));

Expand Down
22 changes: 20 additions & 2 deletions rust/lance/src/index/vector/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use std::sync::Arc;

use arrow_array::make_array;
use arrow_array::{
cast::as_struct_array, Array, ArrayRef, FixedSizeListArray, RecordBatch, StructArray,
};
Expand All @@ -31,6 +32,7 @@ use tracing::instrument;

use super::{Query, DIST_COL};
use crate::arrow::*;
use crate::dataset::ROW_ID;
use crate::io::RecordBatchStream;
use crate::{Error, Result};

Expand Down Expand Up @@ -61,8 +63,7 @@ pub async fn flat_search(
// do this in a streaming fashion. See also: https://github.com/lancedb/lance/issues/1324
let batch = concat_batches(&batches[0].schema(), &batches)?;
let distances = batch.column_by_name(DIST_COL).unwrap();
let k = std::cmp::min(query.k, distances.len() - distances.null_count());
let indices = sort_to_indices(distances, None, Some(k))?;
let indices = sort_to_indices(distances, None, Some(query.k))?;

let struct_arr = StructArray::from(batch);
let selected_arr = take(&struct_arr, &indices, None)?;
Expand All @@ -87,7 +88,24 @@ async fn flat_search_batch(
message: format!("column {} does not exist in dataset", query.column),
location: location!(),
})?;

// A selection vector may have been applied to _rowid column, so we need to
// push that onto vectors if possible.
let vectors = as_fixed_size_list_array(vectors.as_ref()).clone();
let validity_buffer = if let Some(rowids) = batch.column_by_name(ROW_ID) {
rowids.nulls().map(|nulls| nulls.buffer().clone())
} else {
None
};

let vectors = vectors
.into_data()
.into_builder()
.null_bit_buffer(validity_buffer)
.build()
.map(make_array)?;
let vectors = as_fixed_size_list_array(vectors.as_ref()).clone();

tokio::task::spawn_blocking(move || {
let distances = mt.arrow_batch_func()(key.values(), &vectors) as ArrayRef;

Expand Down
3 changes: 2 additions & 1 deletion rust/lance/src/index/vector/ivf/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use lance_linalg::{distance::MetricType, MatrixView};
use snafu::{location, Location};
use tracing::instrument;

use crate::dataset::scanner::ROW_ID_FIELD;
use crate::dataset::ROW_ID;
use crate::index::vector::ivf::{
io::write_index_partitions,
Expand Down Expand Up @@ -95,7 +96,7 @@ pub async fn shuffle_dataset(

// TODO: dynamically detect schema from the transforms.
let schema = Schema::new(vec![
Field::new(ROW_ID, DataType::UInt64, false),
ROW_ID_FIELD.clone(),
Field::new(
PQ_CODE_COLUMN,
DataType::FixedSizeList(
Expand Down
4 changes: 2 additions & 2 deletions rust/lance/src/index/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use serde::Serialize;
use tracing::instrument;

use super::{MetricType, Query, VectorIndex};
use crate::dataset::ROW_ID;
use crate::dataset::scanner::ROW_ID_FIELD;
use crate::index::{pb, prefilter::PreFilter, vector::DIST_COL, Index};
use crate::io::object_reader::{read_fixed_stride_array, ObjectReader};
use crate::{arrow::*, utils::tokio::spawn_cpu};
Expand Down Expand Up @@ -279,7 +279,7 @@ impl VectorIndex for PQIndex {

let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new(DIST_COL, DataType::Float32, true),
ArrowField::new(ROW_ID, DataType::UInt64, true),
ROW_ID_FIELD.clone(),
]));
Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?)
})
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/index/vector/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub(crate) trait VectorIndex: Send + Sync + std::fmt::Debug + Index {
/// use arrow_schema::{Schema, Field, DataType};
///
/// Schema::new(vec![
/// Field::new("_rowid", DataType::UInt64, false),
/// Field::new("_rowid", DataType::UInt64, true),
/// Field::new("_distance", DataType::Float32, false),
/// ]);
/// ```
Expand Down
10 changes: 5 additions & 5 deletions rust/lance/src/io/exec/knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ use tokio::sync::mpsc::Receiver;
use tokio::task::JoinHandle;
use tracing::{instrument, Instrument};

use crate::dataset::scanner::DatasetRecordBatchStream;
use crate::dataset::{Dataset, ROW_ID};
use crate::dataset::scanner::{DatasetRecordBatchStream, ROW_ID_FIELD};
use crate::dataset::Dataset;
use crate::index::prefilter::PreFilter;
use crate::index::vector::flat::flat_search;
use crate::index::vector::{open_index, Query, DIST_COL};
Expand Down Expand Up @@ -121,7 +121,7 @@ impl DFRecordBatchStream for KNNFlatStream {
fn schema(&self) -> arrow_schema::SchemaRef {
Arc::new(Schema::new(vec![
Field::new(DIST_COL, DataType::Float32, true),
Field::new(ROW_ID, DataType::UInt64, true),
ROW_ID_FIELD.clone(),
]))
}
}
Expand Down Expand Up @@ -304,7 +304,7 @@ impl DFRecordBatchStream for KNNIndexStream {
fn schema(&self) -> arrow_schema::SchemaRef {
Arc::new(Schema::new(vec![
Field::new(DIST_COL, DataType::Float32, true),
Field::new(ROW_ID, DataType::UInt64, true),
ROW_ID_FIELD.clone(),
]))
}
}
Expand Down Expand Up @@ -389,7 +389,7 @@ impl ExecutionPlan for KNNIndexExec {
fn schema(&self) -> arrow_schema::SchemaRef {
Arc::new(Schema::new(vec![
Field::new(DIST_COL, DataType::Float32, true),
Field::new(ROW_ID, DataType::UInt64, true),
ROW_ID_FIELD.clone(),
]))
}

Expand Down
9 changes: 5 additions & 4 deletions rust/lance/src/io/exec/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::sync::Arc;
use std::task::{Context, Poll};

use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
use arrow_schema::{Field, Schema as ArrowSchema, SchemaRef};
use datafusion::error::{DataFusionError, Result};
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream,
Expand All @@ -30,7 +30,8 @@ use futures::{stream, Future};
use futures::{StreamExt, TryStreamExt};

use crate::dataset::fragment::{FileFragment, FragmentReader};
use crate::dataset::{Dataset, ROW_ID};
use crate::dataset::scanner::ROW_ID_FIELD;
use crate::dataset::Dataset;
use crate::datatypes::Schema;
use crate::format::Fragment;

Expand Down Expand Up @@ -184,7 +185,7 @@ impl RecordBatchStream for LanceStream {
let schema: ArrowSchema = self.projection.as_ref().into();
if self.with_row_id {
let mut fields: Vec<Arc<Field>> = schema.fields.to_vec();
fields.push(Arc::new(Field::new(ROW_ID, DataType::UInt64, false)));
fields.push(Arc::new(ROW_ID_FIELD.clone()));
Arc::new(ArrowSchema::new(fields))
} else {
Arc::new(schema)
Expand Down Expand Up @@ -274,7 +275,7 @@ impl ExecutionPlan for LanceScanExec {
let schema: ArrowSchema = self.projection.as_ref().into();
if self.with_row_id {
let mut fields: Vec<Arc<Field>> = schema.fields.to_vec();
fields.push(Arc::new(Field::new(ROW_ID, DataType::UInt64, false)));
fields.push(Arc::new(ROW_ID_FIELD.clone()));
Arc::new(ArrowSchema::new(fields))
} else {
Arc::new(schema)
Expand Down
Loading

0 comments on commit 1fd2203

Please sign in to comment.