Skip to content

Commit

Permalink
perf: dimension-based kernel for L2 and Cosine (#1503)
Browse files Browse the repository at this point in the history
Provide faster routine for large PQ training, where each subvector only
has 8 floats.

L2:

```
L2(simd,f32x8)          time:   [1.9428 ms 1.9509 ms 1.9561 ms]
                        change: [-26.987% -25.863% -24.769%] (p = 0.00 < 0.10)
                        Performance has improved.
```

Cosine:

```
Cosine(simd,8-f32) rng seed
                        time:   [2.8588 ms 2.8693 ms 2.8960 ms]
                        change: [-62.852% -62.407% -61.952%] (p = 0.00 < 0.10)
                        Performance has improved.
```
  • Loading branch information
eddyxu authored Nov 2, 2023
1 parent f427bd2 commit c82edbc
Show file tree
Hide file tree
Showing 11 changed files with 187 additions and 216 deletions.
9 changes: 9 additions & 0 deletions rust/lance-linalg/benches/cosine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ fn bench_distance(c: &mut Criterion) {
)
})
});

let key: Float32Array = generate_random_array_with_seed(8, [0; 32]);
let target: Float32Array = generate_random_array_with_seed(TOTAL * 8, [42; 32]);

c.bench_function("Cosine(simd,f32x8) rng seed", |b| {
b.iter(|| {
black_box(cosine_distance_batch(key.values(), target.values(), 8).collect::<Vec<_>>())
})
});
}

#[cfg(target_os = "linux")]
Expand Down
9 changes: 9 additions & 0 deletions rust/lance-linalg/benches/l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ fn bench_distance(c: &mut Criterion) {
}))
});
});

let key: Float32Array = generate_random_array_with_seed(8, [5; 32]);
// 1M of 1024 D vectors. 4GB in memory.
let target: Float32Array = generate_random_array_with_seed(TOTAL * 8, [7; 32]);
c.bench_function("L2(simd,f32x8)", |b| {
b.iter(|| {
black_box(l2_distance_batch(key.values(), target.values(), 8).count());
})
});
}

#[cfg(target_os = "linux")]
Expand Down
2 changes: 1 addition & 1 deletion rust/lance-linalg/benches/norm_l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use arrow_arith::{aggregate::sum, numeric::mul};
use arrow_array::{cast::AsArray, types::Float32Type, Float32Array};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use criterion::{criterion_group, criterion_main, Criterion};

#[cfg(target_os = "linux")]
use pprof::criterion::{Output, PProfProfiler};
Expand Down
3 changes: 0 additions & 3 deletions rust/lance-linalg/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ pub mod dot;
pub mod l2;
pub mod norm_l2;

#[cfg(target_arch = "x86_64")]
mod x86_64;

use arrow_schema::ArrowError;
pub use cosine::*;
pub use dot::*;
Expand Down
255 changes: 96 additions & 159 deletions rust/lance-linalg/src/distance/cosine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ use num_traits::{AsPrimitive, FromPrimitive};

use super::dot::dot;
use super::norm_l2::{norm_l2, Normalize};
use crate::simd::{
f32::{f32x16, f32x8},
SIMD,
};

/// Cosine Distance
pub trait Cosine {
Expand Down Expand Up @@ -113,20 +117,34 @@ impl Cosine for [f32] {

#[inline]
fn cosine_fast(&self, x_norm: Self::Output, other: &Self) -> Self::Output {
#[cfg(target_arch = "aarch64")]
{
aarch64::neon::cosine_f32(self, other, x_norm)
let dim = self.len();
let unrolled_len = dim / 16 * 16;
let mut y_norm16 = f32x16::zeros();
let mut xy16 = f32x16::zeros();
for i in (0..unrolled_len).step_by(16) {
unsafe {
let x = f32x16::load_unaligned(self.as_ptr().add(i));
let y = f32x16::load_unaligned(other.as_ptr().add(i));
xy16.multiply_add(x, y);
y_norm16.multiply_add(y, y);
}
}

#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("fma") {
return x86_64::avx::cosine_f32(self, other, x_norm);
let aligned_len = dim / 8 * 8;
let mut y_norm8 = f32x8::zeros();
let mut xy8 = f32x8::zeros();
for i in (unrolled_len..aligned_len).step_by(8) {
unsafe {
let x = f32x8::load_unaligned(self.as_ptr().add(i));
let y = f32x8::load_unaligned(other.as_ptr().add(i));
xy8.multiply_add(x, y);
y_norm8.multiply_add(y, y);
}
}

#[cfg(not(target_arch = "aarch64"))]
cosine_scalar(self, x_norm, other)
let y_norm =
y_norm16.reduce_sum() + y_norm8.reduce_sum() + norm_l2(&other[aligned_len..]).powi(2);
let xy =
xy16.reduce_sum() + xy8.reduce_sum() + dot(&self[aligned_len..], &other[aligned_len..]);
1.0 - xy / x_norm / y_norm.sqrt()
}

#[inline]
Expand All @@ -136,23 +154,28 @@ impl Cosine for [f32] {
y_norm: Self::Output,
y: &Self,
) -> Self::Output {
#[cfg(target_arch = "aarch64")]
{
// TODO: SIMD with normalized X and Y.
let _ = y_norm; // Make compiler happy.
aarch64::neon::cosine_f32(self, y, x_norm)
let dim = self.len();
let unrolled_len = dim / 16 * 16;
let mut xy16 = f32x16::zeros();
for i in (0..unrolled_len).step_by(16) {
unsafe {
let x = f32x16::load_unaligned(self.as_ptr().add(i));
let y = f32x16::load_unaligned(y.as_ptr().add(i));
xy16.multiply_add(x, y);
}
}

#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("fma") {
return x86_64::avx::cosine_f32_norms(self, y, x_norm, y_norm);
let aligned_len = dim / 8 * 8;
let mut xy8 = f32x8::zeros();
for i in (unrolled_len..aligned_len).step_by(8) {
unsafe {
let x = f32x8::load_unaligned(self.as_ptr().add(i));
let y = f32x8::load_unaligned(y.as_ptr().add(i));
xy8.multiply_add(x, y);
}
}

// Slow path
#[cfg(not(target_arch = "aarch64"))]
cosine_scalar_fast(self, x_norm, y, y_norm)
let xy =
xy16.reduce_sum() + xy8.reduce_sum() + dot(&self[aligned_len..], &y[aligned_len..]);
1.0 - xy / x_norm / y_norm
}
}

Expand Down Expand Up @@ -215,20 +238,63 @@ pub fn cosine_distance<T: Cosine + ?Sized>(from: &T, to: &T) -> T::Output {
from.cosine(to)
}

mod f32 {
use super::*;

// TODO: how can we explicity infer N?
#[inline]
pub(super) fn cosine_once<S: SIMD<f32, N>, const N: usize>(
x: &[f32],
x_norm: f32,
y: &[f32],
) -> f32 {
let x = unsafe { S::load_unaligned(x.as_ptr()) };
let y = unsafe { S::load_unaligned(y.as_ptr()) };
let y2 = y * y;
let xy = x * y;
1.0 - xy.reduce_sum() / x_norm / y2.reduce_sum().sqrt()
}
}

/// Cosine Distance
///
/// <https://en.wikipedia.org/wiki/Cosine_similarity>
///
/// Parameters
/// -----------
///
/// - *from*: the vector to compute distance from.
/// - *to*: the batch of vectors to compute distance to.
/// - *dimension*: the dimension of the vector.
///
/// Returns
/// -------
/// An iterator of pair-wise cosine distance between from vector to each vector in the batch.
///
pub fn cosine_distance_batch<'a>(
from: &'a [f32],
to: &'a [f32],
batch: &'a [f32],
dimension: usize,
) -> Box<dyn Iterator<Item = f32> + 'a> {
let x_norm = norm_l2(from);

Box::new(
to.chunks_exact(dimension)
.map(move |y| from.cosine_fast(x_norm, y)),
)
match dimension {
8 => Box::new(
batch
.chunks_exact(dimension)
.map(move |y| f32::cosine_once::<f32x8, 8>(from, x_norm, y)),
),
16 => Box::new(
batch
.chunks_exact(dimension)
.map(move |y| f32::cosine_once::<f32x16, 16>(from, x_norm, y)),
),
_ => Box::new(
batch
.chunks_exact(dimension)
.map(move |y| from.cosine_fast(x_norm, y)),
),
}
}

/// Compute Cosine distance between a vector and a batch of vectors.
Expand Down Expand Up @@ -258,135 +324,6 @@ pub fn cosine_distance_arrow_batch(from: &[f32], to: &FixedSizeListArray) -> Arc
Arc::new(Float32Array::new(dists.collect(), to.nulls().cloned()))
}

#[cfg(target_arch = "x86_64")]
mod x86_64 {
use std::arch::x86_64::*;

use super::dot;
use super::norm_l2;

pub mod avx {
use super::*;

#[inline]
pub fn cosine_f32(x_vector: &[f32], y_vector: &[f32], x_norm: f32) -> f32 {
unsafe {
use crate::distance::x86_64::avx::add_f32_register;

let len = x_vector.len() / 8 * 8;
let mut xy = _mm256_setzero_ps();
let mut y_sq = _mm256_setzero_ps();
for i in (0..len).step_by(8) {
let x = _mm256_loadu_ps(x_vector.as_ptr().add(i));
let y = _mm256_loadu_ps(y_vector.as_ptr().add(i));
xy = _mm256_fmadd_ps(x, y, xy);
y_sq = _mm256_fmadd_ps(y, y, y_sq);
}
// handle remaining elements
let mut dotprod = add_f32_register(xy);
dotprod += dot(&x_vector[len..], &y_vector[len..]);
let mut y_sq_sum = add_f32_register(y_sq);
y_sq_sum += norm_l2(&y_vector[len..]).powi(2);
let div = x_norm * y_sq_sum.sqrt();
if div == 0.0 {
1.0
} else {
1.0 - dotprod / div
}
}
}

#[inline]
pub fn cosine_f32_norms(
x_vector: &[f32],
y_vector: &[f32],
x_norm: f32,
y_norm: f32,
) -> f32 {
unsafe {
use crate::distance::x86_64::avx::add_f32_register;

let len = x_vector.len() / 8 * 8;
let mut xy = _mm256_setzero_ps();
for i in (0..len).step_by(8) {
let x = _mm256_loadu_ps(x_vector.as_ptr().add(i));
let y = _mm256_loadu_ps(y_vector.as_ptr().add(i));
xy = _mm256_fmadd_ps(x, y, xy);
}
// handle remaining elements
let mut dotprod = add_f32_register(xy);
dotprod += dot(&x_vector[len..], &y_vector[len..]);
let div = x_norm * y_norm;
if div == 0.0 {
1.0
} else {
1.0 - dotprod / div
}
}
}
}
}

#[cfg(target_arch = "aarch64")]
mod aarch64 {
use std::arch::aarch64::*;

use super::dot;
use super::norm_l2;

pub mod neon {
use super::*;

#[inline]
pub fn cosine_f32(x: &[f32], y: &[f32], x_norm: f32) -> f32 {
unsafe {
let len = x.len() / 16 * 16;
let buf = [0.0_f32; 4];
let mut xy = vld1q_f32(buf.as_ptr());
let mut y_sq = xy;

let mut xy1 = vld1q_f32(buf.as_ptr());
let mut y_sq1 = xy1;

let mut xy2 = vld1q_f32(buf.as_ptr());
let mut y_sq2 = xy2;

let mut xy3 = vld1q_f32(buf.as_ptr());
let mut y_sq3 = xy3;
for i in (0..len).step_by(16) {
let left = vld1q_f32(x.as_ptr().add(i));
let right = vld1q_f32(y.as_ptr().add(i));
xy = vfmaq_f32(xy, left, right);
y_sq = vfmaq_f32(y_sq, right, right);

let left1 = vld1q_f32(x.as_ptr().add(i + 4));
let right1 = vld1q_f32(y.as_ptr().add(i + 4));
xy1 = vfmaq_f32(xy1, left1, right1);
y_sq1 = vfmaq_f32(y_sq1, right1, right1);

let left2 = vld1q_f32(x.as_ptr().add(i + 8));
let right2 = vld1q_f32(y.as_ptr().add(i + 8));
xy2 = vfmaq_f32(xy2, left2, right2);
y_sq2 = vfmaq_f32(y_sq2, right2, right2);

let left3 = vld1q_f32(x.as_ptr().add(i + 12));
let right3 = vld1q_f32(y.as_ptr().add(i + 12));
xy3 = vfmaq_f32(xy3, left3, right3);
y_sq3 = vfmaq_f32(y_sq3, right3, right3);
}
xy = vaddq_f32(vaddq_f32(xy, xy3), vaddq_f32(xy1, xy2));
y_sq = vaddq_f32(vaddq_f32(y_sq, y_sq3), vaddq_f32(y_sq1, y_sq2));
// handle remaining elements
let mut dotprod = vaddvq_f32(xy);
dotprod += dot(&x[len..], &y[len..]);
let mut y_sq_sum = vaddvq_f32(y_sq);
y_sq_sum += norm_l2(&y[len..]).powi(2);
1.0 - dotprod / (x_norm * y_sq_sum.sqrt())
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit c82edbc

Please sign in to comment.