-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: benchmark and improve L2 partition compute #1453
Conversation
rust/lance-linalg/src/distance/l2.rs
Outdated
sums = _mm256_fmadd_ps(sub, sub, sums); | ||
sums = _mm256_fmadd_ps(s2, s2, sums); | ||
sums = _mm256_fmadd_ps(s3, s3, sums); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use 4 accumulators for this sums1..4
and only add the four together at the end. Might be better throughput
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Delivered another 10%
rust/lance-linalg/src/kmeans.rs
Outdated
} | ||
|
||
/// Fast partition computation for L2 distance. | ||
fn compute_partitions_l2(centroids: &[f32], data: &[f32], dim: usize) -> Vec<u32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: compute_partitions_l2_f32
or compute_partitions_l2<f32>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, lemme make it _f32
first, get things going. Because f16/bf16 can prob increase STRIPE / TILE by 2x
const STRIPE_SIZE: usize = 128; | ||
const TILE_SIZE: usize = 16; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we may want to benchmark these and use different numbers on different CPUs,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one idea is to use something like L1 cache size * factor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is single CPU, so the cache is target to the 32KB / 64KB L1 per core.
Expect to have higher level (i.e., kmeans ) to handle distribution of multiple batches to centroids computation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good. I won't promise to have stepped through the math in detail but the idea of a tiling approach seems very sound to me.
let len = from.len() / 8 * 8; | ||
let mut sums = _mm256_setzero_ps(); | ||
for i in (0..len).step_by(8) { | ||
let len = from.len() / 32 * 32; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have any unit tests comparing this with a naive approach so we know we are calculating the right thing here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Get a slice of `data[di][s..s+STRIP_SIZE]`. | ||
let cent_slice = get_slice(centroids, ci, s, dim, slice_len); | ||
let dist = data_slice.l2(cent_slice); | ||
dists[di * TILE_SIZE + (ci - centroid_start)] += dist; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's very likely I'm reading this incorrectly but it looks like you are calculating:
sqrt(diff_sq(0) + ... + diff_sq(N))
by breaking it into pieces:
sqrt(diff_sq(0) + ... + diff_sq(a)) + sqrt(diff_sq(a+1) + ... + diff_sq(N))
But doesn't this require sqrt(a + b) = sqrt(a) + sqrt(b)
which isn't true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lance/rust/lance-linalg/src/distance/l2.rs
Line 190 in 2e4ce22
results[0] += l2_scalar(&from[len..], &to[len..]); |
Our L2 implementation does not calculate sqrt()
Co-authored-by: Weston Pace <[email protected]>
Improve
compute_partitions(centroids, vectors, L2)
by 2x (6.9s -> 3.5s)