Skip to content

Commit

Permalink
ivf
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu committed Oct 11, 2023
1 parent dc73857 commit 7b011d1
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 13 deletions.
3 changes: 2 additions & 1 deletion rust/lance-index/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ categories.workspace = true
rust-version.workspace = true

[dependencies]
arrow-schema.workspace = true
arrow-array.workspace = true
arrow-ord.workspace = true
arrow-schema.workspace = true
arrow-select.workspace = true
futures.workspace = true
lance-arrow.workspace = true
lance-core.workspace = true
Expand Down
59 changes: 47 additions & 12 deletions rust/lance-index/src/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

//! IVF - Inverted File Index

use std::ops::Range;
use std::sync::Arc;

use arrow_array::{
cast::AsArray, types::Float32Type, Array, ArrayRef, FixedSizeListArray, Float32Array,
RecordBatch, UInt32Array,
cast::AsArray,
types::{Float32Type, UInt32Type},
Array, ArrayRef, Float32Array, RecordBatch, UInt32Array,
};
use arrow_ord::sort::sort_to_indices;
use arrow_schema::Field;
use arrow_select::take::take;
use lance_arrow::RecordBatchExt;
use lance_core::{Error, Result};
use lance_linalg::{
Expand All @@ -43,30 +46,48 @@ pub struct Ivf {
///
/// It is a 2-D `(num_partitions * dimension)` of float32 array, 64-bit aligned via Arrow
/// memory allocator.
centroids: Arc<FixedSizeListArray>,
centroids: MatrixView<Float32Type>,

/// Transform applied to each partition.
transforms: Vec<Arc<dyn Transformer>>,

/// Metric type to compute pair-wise vector distance.
metric_type: MetricType,

/// Only covers a range of partitions.
partition_range: Option<Range<u32>>,
}

impl Ivf {
pub fn new(
centroids: Arc<FixedSizeListArray>,
centroids: MatrixView<Float32Type>,
metric_type: MetricType,
transforms: Vec<Arc<dyn Transformer>>,
) -> Self {
Self {
centroids,
metric_type,
transforms,
partition_range: None,
}
}

pub fn new_with_range(
centroids: MatrixView<Float32Type>,
metric_type: MetricType,
transforms: Vec<Arc<dyn Transformer>>,
range: Range<u32>,
) -> Self {
Self {
centroids,
metric_type,
transforms,
partition_range: Some(range),
}
}

fn dimension(&self) -> usize {
self.centroids.value_length() as usize
self.centroids.ndim()
}

/// Use the query vector to find `nprobes` closest partitions.
Expand All @@ -82,12 +103,9 @@ impl Ivf {
});
}
let dist_func = self.metric_type.batch_func();
let centroid_values = self.centroids.values();
let distances = dist_func(
query.values(),
centroid_values.as_primitive::<Float32Type>().values(),
self.dimension(),
) as ArrayRef;
let centroid_values = self.centroids.data();
let distances =
dist_func(query.values(), centroid_values.values(), self.dimension()) as ArrayRef;
let top_k_partitions = sort_to_indices(&distances, None, Some(nprobes))?;
Ok(top_k_partitions)
}
Expand All @@ -114,6 +132,23 @@ impl Ivf {
let matrix = MatrixView::<Float32Type>::try_from(data)?;
let part_ids = self.compute_partitions(&matrix);

let (part_ids, batch) = if let Some(part_range) = self.partition_range.as_ref() {
let idx_in_range: UInt32Array = part_ids
.values()
.iter()
.enumerate()
.filter(|(_, part_id)| part_range.contains(*part_id))
.map(|(idx, _)| idx as u32)
.collect();
let part_ids = take(&part_ids, &idx_in_range, None)?
.as_primitive::<UInt32Type>()
.clone();
let batch = batch.take(&idx_in_range)?;
(part_ids, batch)
} else {
(part_ids, batch.clone())
};

let field = Field::new(PART_ID_COLUMN, part_ids.data_type().clone(), false);
let mut batch = batch.try_with_column(field, Arc::new(part_ids))?;

Expand All @@ -130,7 +165,7 @@ impl Ivf {
#[instrument(skip(data))]
fn compute_partitions(&self, data: &MatrixView<Float32Type>) -> UInt32Array {
let ndim = data.ndim();
let centroids_arr: &Float32Array = self.centroids.values().as_primitive();
let centroids_arr = self.centroids.data();
let centroid_norms = centroids_arr
.values()
.chunks(ndim)
Expand Down

0 comments on commit 7b011d1

Please sign in to comment.