Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: benchmark and improve L2 partition compute #1453

Merged
merged 14 commits into from
Oct 23, 2023
Merged

Conversation

eddyxu
Copy link
Contributor

@eddyxu eddyxu commented Oct 22, 2023

Improve compute_partitions(centroids, vectors, L2) by 2x (6.9s -> 3.5s)

sums = _mm256_fmadd_ps(sub, sub, sums);
sums = _mm256_fmadd_ps(s2, s2, sums);
sums = _mm256_fmadd_ps(s3, s3, sums);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use 4 accumulators for this sums1..4 and only add the four together at the end. Might be better throughput

Copy link
Contributor Author

@eddyxu eddyxu Oct 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Delivered another 10%

}

/// Fast partition computation for L2 distance.
fn compute_partitions_l2(centroids: &[f32], data: &[f32], dim: usize) -> Vec<u32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: compute_partitions_l2_f32 or compute_partitions_l2<f32>

Copy link
Contributor Author

@eddyxu eddyxu Oct 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, lemme make it _f32 first, get things going. Because f16/bf16 can prob increase STRIPE / TILE by 2x

Comment on lines +550 to +551
const STRIPE_SIZE: usize = 128;
const TILE_SIZE: usize = 16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we may want to benchmark these and use different numbers on different CPUs,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one idea is to use something like L1 cache size * factor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is single CPU, so the cache is target to the 32KB / 64KB L1 per core.

Expect to have higher level (i.e., kmeans ) to handle distribution of multiple batches to centroids computation.

@eddyxu eddyxu changed the title feat: benchmark and improve l2 partition compute feat: benchmark and improve L2 partition compute Oct 23, 2023
Copy link
Contributor

@westonpace westonpace left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. I won't promise to have stepped through the math in detail but the idea of a tiling approach seems very sound to me.

rust/lance-linalg/benches/compute_partition.rs Outdated Show resolved Hide resolved
let len = from.len() / 8 * 8;
let mut sums = _mm256_setzero_ps();
for i in (0..len).step_by(8) {
let len = from.len() / 32 * 32;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any unit tests comparing this with a naive approach so we know we are calculating the right thing here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Get a slice of `data[di][s..s+STRIP_SIZE]`.
let cent_slice = get_slice(centroids, ci, s, dim, slice_len);
let dist = data_slice.l2(cent_slice);
dists[di * TILE_SIZE + (ci - centroid_start)] += dist;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's very likely I'm reading this incorrectly but it looks like you are calculating:

sqrt(diff_sq(0) + ... + diff_sq(N))

by breaking it into pieces:

sqrt(diff_sq(0) + ... + diff_sq(a)) + sqrt(diff_sq(a+1) + ... + diff_sq(N))

But doesn't this require sqrt(a + b) = sqrt(a) + sqrt(b) which isn't true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

results[0] += l2_scalar(&from[len..], &to[len..]);

Our L2 implementation does not calculate sqrt()

@eddyxu eddyxu merged commit 4df9d33 into main Oct 23, 2023
16 checks passed
@eddyxu eddyxu deleted the lei/bench_partition branch October 23, 2023 15:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants