Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: use selection vector strategy to improve exact knn performance with deletions #1418

Merged
merged 3 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,40 @@ def check_index(has_knn_combined, delete_has_happened):
# Optimize the index, combined KNN should no longer be needed
dataset.optimize.optimize_indices()
check_index(has_knn_combined=False, delete_has_happened=True)


def test_knn_with_deletions(tmp_path):
dims = 5
values = pa.array(
[x for val in range(50) for x in [float(val)] * 5], type=pa.float32()
)
tbl = pa.Table.from_pydict(
{
"vector": pa.FixedSizeListArray.from_arrays(values, dims),
"filterable": pa.array(range(50)),
}
)
dataset = lance.write_dataset(tbl, tmp_path, max_rows_per_group=10)

dataset.delete("not (filterable % 5 == 0)")

# Do KNN with k=100, should return 10 vectors
expected = [
[0.0] * 5,
[5.0] * 5,
[10.0] * 5,
[15.0] * 5,
[20.0] * 5,
[25.0] * 5,
[30.0] * 5,
[35.0] * 5,
[40.0] * 5,
[45.0] * 5,
]

results = dataset.to_table(
nearest={"column": "vector", "q": [0.0] * 5, "k": 100}
).column("vector")
assert len(results) == 10

assert expected == [r.as_py() for r in results]
1 change: 1 addition & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ datafusion-common = "32.0"
datafusion-sql = "32.0"
either = "1.0"
futures = "0.3"
lazy_static = "1"
log = "0.4"
num_cpus = "1.0"
num-traits = "0.2"
Expand Down
1 change: 1 addition & 0 deletions rust/lance-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ bytes.workspace = true
datafusion-common.workspace = true
datafusion-sql.workspace = true
futures.workspace = true
lazy_static.workspace = true
num_cpus.workspace = true
object_store.workspace = true
pin-project.workspace = true
Expand Down
11 changes: 11 additions & 0 deletions rust/lance-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use arrow_schema::{DataType, Field as ArrowField};

pub mod datatypes;
pub mod encodings;
pub mod error;
pub mod format;
pub mod io;

pub use error::{Error, Result};

/// Column name for the meta row ID.
pub const ROW_ID: &str = "_rowid";

lazy_static::lazy_static! {
/// Row ID field. This is nullable because its validity bitmap is sometimes used
/// as a selection vector.
pub static ref ROW_ID_FIELD: ArrowField = ArrowField::new(ROW_ID, DataType::UInt64, true);
}
47 changes: 33 additions & 14 deletions rust/lance-index/src/vector/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,25 @@
use std::sync::Arc;

use arrow_array::{
cast::AsArray, types::Float32Type, Array, ArrayRef, FixedSizeListArray, RecordBatch,
StructArray,
cast::AsArray, make_array, Array, ArrayRef, FixedSizeListArray, RecordBatch, StructArray,
};
use arrow_ord::sort::sort_to_indices;
use arrow_schema::{DataType, Field as ArrowField, SchemaRef};
use arrow_schema::{DataType, Field as ArrowField, SchemaRef, SortOptions};
use arrow_select::{concat::concat, take::take};
use futures::{
future,
stream::{repeat_with, StreamExt, TryStreamExt},
};
use lance_arrow::*;
use lance_core::{io::RecordBatchStream, Error, Result};
use lance_core::{io::RecordBatchStream, Error, Result, ROW_ID};
use lance_linalg::distance::DistanceType;
use snafu::{location, Location};
use tracing::instrument;

use super::{Query, DIST_COL};

fn distance_field() -> ArrowField {
ArrowField::new(DIST_COL, DataType::Float32, false)
ArrowField::new(DIST_COL, DataType::Float32, true)
}

#[instrument(level = "debug", skip_all)]
Expand Down Expand Up @@ -87,17 +86,37 @@ async fn flat_search_batch(
.ok_or_else(|| Error::Schema {
message: format!("column {} does not exist in dataset", query.column),
location: location!(),
})?
.clone();
let flatten_vectors = as_fixed_size_list_array(vectors.as_ref()).values().clone();
})?;

// A selection vector may have been applied to _rowid column, so we need to
// push that onto vectors if possible.
let vectors = as_fixed_size_list_array(vectors.as_ref()).clone();
let validity_buffer = if let Some(rowids) = batch.column_by_name(ROW_ID) {
rowids.nulls().map(|nulls| nulls.buffer().clone())
} else {
None
};

let vectors = vectors
.into_data()
.into_builder()
.null_bit_buffer(validity_buffer)
.build()
.map(make_array)?;
let vectors = as_fixed_size_list_array(vectors.as_ref()).clone();

tokio::task::spawn_blocking(move || {
let distances = mt.batch_func()(
key.values(),
flatten_vectors.as_primitive::<Float32Type>().values(),
key.len(),
) as ArrayRef;
let distances = mt.arrow_batch_func()(key.values(), &vectors) as ArrayRef;

// We don't want any nulls in result, so limit to k or the number of valid values.
let k = std::cmp::min(k, distances.len() - distances.null_count());

let sort_options = SortOptions {
nulls_first: false,
..Default::default()
};
let indices = sort_to_indices(&distances, Some(sort_options), Some(k))?;

let indices = sort_to_indices(&distances, None, Some(k))?;
let batch_with_distance = batch.try_with_column(distance_field(), distances)?;
let struct_arr = StructArray::from(batch_with_distance);
let selected_arr = take(&struct_arr, &indices, None)?;
Expand Down
14 changes: 13 additions & 1 deletion rust/lance-linalg/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

use std::sync::Arc;

use arrow_array::Float32Array;
use arrow_array::{FixedSizeListArray, Float32Array};

pub mod cosine;
pub mod dot;
Expand Down Expand Up @@ -50,6 +50,7 @@ pub type MetricType = DistanceType;

pub type DistanceFunc = fn(&[f32], &[f32]) -> f32;
pub type BatchDistanceFunc = fn(&[f32], &[f32], usize) -> Arc<Float32Array>;
pub type ArrowBatchDistanceFunc = fn(&[f32], &FixedSizeListArray) -> Arc<Float32Array>;

impl DistanceType {
/// Compute the distance from one vector to a batch of vectors.
Expand All @@ -61,6 +62,17 @@ impl DistanceType {
}
}

/// Compute the distance from one vector to a batch of vectors.
///
/// This propagates nulls to the output.
pub fn arrow_batch_func(&self) -> ArrowBatchDistanceFunc {
match self {
Self::L2 => l2_distance_arrow_batch,
Self::Cosine => cosine_distance_arrow_batch,
Self::Dot => dot_distance_arrow_batch,
}
}

/// Returns the distance function between two vectors.
pub fn func(&self) -> DistanceFunc {
match self {
Expand Down
42 changes: 34 additions & 8 deletions rust/lance-linalg/src/distance/cosine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
use std::iter::Sum;
use std::sync::Arc;

use arrow_array::Float32Array;
use arrow_array::cast::AsArray;
use arrow_array::types::Float32Type;
use arrow_array::{Array, FixedSizeListArray, Float32Array};
use half::{bf16, f16};
use num_traits::real::Real;
use num_traits::{AsPrimitive, FromPrimitive};
Expand Down Expand Up @@ -219,13 +221,37 @@ pub fn cosine_distance<T: Cosine + ?Sized>(from: &T, to: &T) -> T::Output {
pub fn cosine_distance_batch(from: &[f32], to: &[f32], dimension: usize) -> Arc<Float32Array> {
let x_norm = norm_l2(from);

let dists = unsafe {
Float32Array::from_trusted_len_iter(
to.chunks_exact(dimension)
.map(|y| Some(from.cosine_fast(x_norm, y))),
)
};
Arc::new(dists)
let dists = to
.chunks_exact(dimension)
.map(|y| from.cosine_fast(x_norm, y));
Arc::new(Float32Array::new(dists.collect(), None))
}

/// Compute Cosine distance between a vector and a batch of vectors.
///
/// Null buffer of `to` is propagated to the returned array.
///
/// Parameters
///
/// - `from`: the vector to compute distance from.
/// - `to`: a list of vectors to compute distance to.
///
/// # Panics
///
/// Panics if the length of `from` is not equal to the dimension (value length) of `to`.
pub fn cosine_distance_arrow_batch(from: &[f32], to: &FixedSizeListArray) -> Arc<Float32Array> {
let dimension = to.value_length() as usize;
debug_assert_eq!(from.len(), dimension);

let x_norm = norm_l2(from);

// TODO: if we detect there is a run of nulls, should we skip those?
let to_values = to.values().as_primitive::<Float32Type>().values();
let dists = to_values
.chunks_exact(dimension)
.map(|v| from.cosine_fast(x_norm, v));

Arc::new(Float32Array::new(dists.collect(), to.nulls().cloned()))
}

#[cfg(target_arch = "x86_64")]
Expand Down
37 changes: 29 additions & 8 deletions rust/lance-linalg/src/distance/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
use std::iter::Sum;
use std::sync::Arc;

use arrow_array::Float32Array;
use arrow_array::{cast::AsArray, types::Float32Type, Array, FixedSizeListArray, Float32Array};
use half::{bf16, f16};
use num_traits::real::Real;

Expand Down Expand Up @@ -79,13 +79,34 @@ pub fn dot_distance_batch(from: &[f32], to: &[f32], dimension: usize) -> Arc<Flo
debug_assert_eq!(from.len(), dimension);
debug_assert_eq!(to.len() % dimension, 0);

let dists = unsafe {
Float32Array::from_trusted_len_iter(
to.chunks_exact(dimension)
.map(|v| Some(dot_distance(from, v))),
)
};
Arc::new(dists)
let dists = to.chunks_exact(dimension).map(|v| dot_distance(from, v));

Arc::new(Float32Array::new(dists.collect(), None))
}

/// Compute negative dot product distance between a vector and a batch of vectors.
///
/// Null buffer of `to` is propagated to the returned array.
///
/// Parameters
///
/// - `from`: the vector to compute distance from.
/// - `to`: a list of vectors to compute distance to.
///
/// # Panics
///
/// Panics if the length of `from` is not equal to the dimension (value length) of `to`.
pub fn dot_distance_arrow_batch(from: &[f32], to: &FixedSizeListArray) -> Arc<Float32Array> {
let dimension = to.value_length() as usize;
debug_assert_eq!(from.len(), dimension);

// TODO: if we detect there is a run of nulls, should we skip those?
let to_values = to.values().as_primitive::<Float32Type>().values();
let dists = to_values
.chunks_exact(dimension)
.map(|v| dot_distance(from, v));

Arc::new(Float32Array::new(dists.collect(), to.nulls().cloned()))
}

/// Negative dot distance.
Expand Down
31 changes: 26 additions & 5 deletions rust/lance-linalg/src/distance/l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use std::iter::Sum;
use std::sync::Arc;

use arrow_array::Float32Array;
use arrow_array::{cast::AsArray, types::Float32Type, Array, FixedSizeListArray, Float32Array};
use half::{bf16, f16};
use num_traits::real::Real;

Expand Down Expand Up @@ -127,10 +127,31 @@ pub fn l2_distance_batch(from: &[f32], to: &[f32], dimension: usize) -> Arc<Floa
assert_eq!(from.len(), dimension);
assert_eq!(to.len() % dimension, 0);

let dists = unsafe {
Float32Array::from_trusted_len_iter(to.chunks_exact(dimension).map(|v| Some(from.l2(v))))
};
Arc::new(dists)
let dists = to.chunks_exact(dimension).map(|v| from.l2(v));
Arc::new(Float32Array::new(dists.collect(), None))
}

/// Compute L2 distance between a vector and a batch of vectors.
///
/// Null buffer of `to` is propagated to the returned array.
///
/// Parameters
///
/// - `from`: the vector to compute distance from.
/// - `to`: a list of vectors to compute distance to.
///
/// # Panics
///
/// Panics if the length of `from` is not equal to the dimension (value length) of `to`.
pub fn l2_distance_arrow_batch(from: &[f32], to: &FixedSizeListArray) -> Arc<Float32Array> {
let dimension = to.value_length() as usize;
debug_assert_eq!(from.len(), dimension);

// TODO: if we detect there is a run of nulls, should we skip those?
let to_values = to.values().as_primitive::<Float32Type>().values();
let dists = to_values.chunks_exact(dimension).map(|v| from.l2(v));

Arc::new(Float32Array::new(dists.collect(), to.nulls().cloned()))
}

#[cfg(target_arch = "x86_64")]
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ tfrecord = { version = "0.14.0", optional = true, features = ["async"] }
aws-sdk-dynamodb = { version = "0.30.0", optional = true }
tempfile = { workspace = true }
tracing = { workspace = true }
lazy_static = "1"
lazy_static = { workspace = true }
base64 = "0.21.4"
async_cell = "0.2.2"

Expand Down
11 changes: 6 additions & 5 deletions rust/lance/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ use crate::session::Session;
use crate::utils::temporal::{utc_now, SystemTime};
use crate::{Error, Result};
use hash_joiner::HashJoiner;
pub use scanner::ROW_ID;
pub use lance_core::ROW_ID;
pub use write::{WriteMode, WriteParams};

const INDICES_DIR: &str = "_indices";
Expand Down Expand Up @@ -1381,6 +1381,7 @@ mod tests {
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use arrow_select::take::take;
use futures::stream::TryStreamExt;
use lance_index::vector::DIST_COL;
use lance_linalg::distance::MetricType;
use lance_testing::datagen::generate_random_array;
use tempfile::tempdir;
Expand Down Expand Up @@ -2718,8 +2719,8 @@ mod tests {
)
);
assert_eq!(
schema.field_with_name("_distance").unwrap(),
&Field::new("_distance", DataType::Float32, false)
schema.field_with_name(DIST_COL).unwrap(),
&Field::new(DIST_COL, DataType::Float32, true)
);
}
}
Expand Down Expand Up @@ -2777,8 +2778,8 @@ mod tests {
)
);
assert_eq!(
schema.field_with_name("_distance").unwrap(),
&Field::new("_distance", DataType::Float32, false)
schema.field_with_name(DIST_COL).unwrap(),
&Field::new(DIST_COL, DataType::Float32, true)
);
}
}
Expand Down
Loading