Skip to content

Commit

Permalink
feat: compare msm with proper benches
Browse files Browse the repository at this point in the history
  • Loading branch information
ed255 committed Jun 4, 2024
1 parent db46631 commit d1f79a5
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 18 deletions.
97 changes: 95 additions & 2 deletions benches/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ extern crate criterion;

use criterion::{BenchmarkId, Criterion};
use ff::Field;
use ff::PrimeField;
use group::prime::PrimeCurveAffine;
use halo2curves::bn256::{Fr as Scalar, G1Affine as Point};
use halo2curves::msm::{best_multiexp, multiexp_serial};
use halo2curves::msm::{
best_multiexp, best_multiexp_jonathan, best_multiexp_skip_zeros, multiexp_serial,
};
use rand_core::RngCore;
use rand_core::SeedableRng;
use rand_xorshift::XorShiftRng;
use rayon::current_thread_index;
Expand Down Expand Up @@ -112,5 +116,94 @@ fn msm(c: &mut Criterion) {
group.finish();
}

criterion_group!(benches, msm);
fn gen_scalars_points(k: u8, small: bool) -> (Vec<Scalar>, Vec<Point>) {
let points = (0..1 << k)
.into_par_iter()
.map_init(
|| {
let uniq = current_thread_index().unwrap();
assert!(std::mem::size_of::<usize>() == 8);
XorShiftRng::seed_from_u64(uniq as u64)
},
|rng, _| Point::random(rng),
)
.collect();

// 1 byte upper bound
let max_val = 2u64.pow((8) as u32);

let scalars = (0..1 << k)
.into_par_iter()
.map_init(
|| {
let uniq = current_thread_index().unwrap();
assert!(std::mem::size_of::<usize>() == 8);
XorShiftRng::seed_from_u64(uniq as u64)
},
|rng, _| {
if small {
let v = rng.next_u64() % max_val;
Scalar::from_u128(v as u128)
} else {
Scalar::random(rng)
}
},
)
.collect();

(scalars, points)
}

fn msm_cmp(c: &mut Criterion) {
let mut group = c.benchmark_group("msm_cmp");
let min_k = 18;
let max_k = 22;
let (scalars_small, points_small) = gen_scalars_points(max_k, true);
let (scalars_big, points_big) = gen_scalars_points(max_k, false);

for small in [false, true] {
let (scalars, points) = if small {
(&scalars_small, &points_small)
} else {
(&scalars_big, &points_big)
};
for k in min_k..=max_k {
let name = format!("msm func={}, k={}, small={}", "original", k, small);
group
.bench_function(BenchmarkId::new(name, k), |b| {
let n: usize = 1 << k;
b.iter(|| {
best_multiexp(&scalars[..n], &points[..n]);
})
})
.sample_size(10);

let name = format!("msm func={}, k={}, small={}", "skip_zeros_edu", k, small);
group
.bench_function(BenchmarkId::new(name, k), |b| {
let n: usize = 1 << k;
b.iter(|| {
best_multiexp_skip_zeros(&scalars[..n], &points[..n]);
})
})
.sample_size(10);

let name = format!(
"msm func={}, k={}, small={}",
"skip_zeros_jonathan", k, small
);
group
.bench_function(BenchmarkId::new(name, k), |b| {
let n: usize = 1 << k;
b.iter(|| {
best_multiexp_jonathan(&scalars[..n], &points[..n]);
})
})
.sample_size(10);
}
}
group.finish();
}

criterion_group!(benches, msm_cmp);
criterion_main!(benches);
168 changes: 152 additions & 16 deletions src/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,139 @@ pub fn best_multiexp_independent_points_small<C: CurveAffine, const N: usize>(
acc.into_iter().sum::<_>()
}

pub fn multiexp_serial_jonathan<C: CurveAffine>(
coeffs: &[C::Scalar],
bases: &[C],
acc: &mut C::Curve,
) {
// Do conversion to bytes once
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();

let c = if bases.len() < 4 {
1
} else if bases.len() < 32 {
3
} else {
(f64::from(bases.len() as u32)).ln().ceil() as usize
};

let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1;

// In each window, get the booth index of each coefficient
let mut coeffs_in_windows = Vec::with_capacity(number_of_windows);
// Track what is the last window where we actually have nonzero booth index, so we completely skip buckets where the scalar bits for all coeffs are 0
let mut max_nonzero_window = None;
for current_window in 0..number_of_windows {
let coeffs_in_window: Vec<i32> = coeffs
.iter()
.map(|coeff| {
let coeff = get_booth_index(current_window, c, coeff.as_ref());
if coeff != 0 {
max_nonzero_window = Some(current_window);
}
coeff
})
.collect();
coeffs_in_windows.push(coeffs_in_window);
}
// Save memory and drop coeffs as bytes since it's not needed anymore
drop(coeffs);

if max_nonzero_window.is_none() {
return;
}
for coeffs_in_window in coeffs_in_windows
.into_iter()
.take(max_nonzero_window.unwrap() + 1)
.rev()
{
for _ in 0..c {
*acc = acc.double();
}

#[derive(Clone, Copy)]
enum Bucket<C: CurveAffine> {
None,
Affine(C),
Projective(C::Curve),
}

impl<C: CurveAffine> Bucket<C> {
fn add_assign(&mut self, other: &C) {
*self = match *self {
Bucket::None => Bucket::Affine(*other),
Bucket::Affine(a) => Bucket::Projective(a + *other),
Bucket::Projective(mut a) => {
a += *other;
Bucket::Projective(a)
}
}
}

fn add(self, mut other: C::Curve) -> C::Curve {
match self {
Bucket::None => other,
Bucket::Affine(a) => {
other += a;
other
}
Bucket::Projective(a) => other + a,
}
}
}

let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; 1 << (c - 1)];

for (coeff, base) in coeffs_in_window.into_iter().zip(bases.iter()) {
if coeff.is_positive() {
buckets[coeff as usize - 1].add_assign(base);
}
if coeff.is_negative() {
buckets[coeff.unsigned_abs() as usize - 1].add_assign(&base.neg());
}
}

// Summation by parts
// e.g. 3a + 2b + 1c = a +
// (a) + b +
// ((a) + b) + c
let mut running_sum = C::Curve::identity();
for exp in buckets.into_iter().rev() {
running_sum = exp.add(running_sum);
*acc += &running_sum;
}
}
}

pub fn best_multiexp_jonathan<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
assert_eq!(coeffs.len(), bases.len());

let num_threads = rayon::current_num_threads();
if coeffs.len() > num_threads {
let chunk = coeffs.len() / num_threads;
let num_chunks = coeffs.chunks(chunk).len();
let mut results = vec![C::Curve::identity(); num_chunks];
rayon::scope(|scope| {
let chunk = coeffs.len() / num_threads;

for ((coeffs, bases), acc) in coeffs
.chunks(chunk)
.zip(bases.chunks(chunk))
.zip(results.iter_mut())
{
scope.spawn(move |_| {
multiexp_serial_jonathan(coeffs, bases, acc);
});
}
});
results.iter().fold(C::Curve::identity(), |a, b| a + b)
} else {
let mut acc = C::Curve::identity();
multiexp_serial_jonathan(coeffs, bases, &mut acc);
acc
}
}

#[cfg(test)]
mod test {

Expand Down Expand Up @@ -1046,21 +1179,23 @@ mod test {
C::Curve::batch_normalize(&points[..], &mut affine_points[..]);
let points = affine_points;

const BYTES: usize = 1;
println!("bits = {}", BYTES * 8);
assert!(BYTES <= 16);
let max_val = 2u128.pow((BYTES * 8) as u32);
// const BYTES: usize = 1;
// println!("bits = {}", BYTES * 8);
// assert!(BYTES <= 16);
// let max_val = 2u128.pow((BYTES * 8) as u32);
let mut scalars = vec![C::Scalar::ZERO; 1 << max_k];
let mut scalars_small = vec![[0; BYTES]; 1 << max_k];
// let mut scalars_small = vec![[0; BYTES]; 1 << max_k];
for i in 0..1 << max_k {
let v_lo = OsRng.next_u64() as u128;
let v_hi = OsRng.next_u64() as u128;
let mut v = v_lo + (v_hi << 64);
if BYTES < 16 {
v %= max_val;
}
scalars[i] = C::Scalar::from_u128(v);
scalars_small[i] = v.to_le_bytes()[..BYTES].try_into().unwrap();
// let v_lo = OsRng.next_u64() as u128;
// let v_hi = OsRng.next_u64() as u128;
// let mut v = v_lo + (v_hi << 64);
// if BYTES < 16 {
// v %= max_val;
// }
// scalars[i] = C::Scalar::from_u128(v);
// scalars_small[i] = v.to_le_bytes()[..BYTES].try_into().unwrap();

scalars[i] = C::Scalar::random(OsRng);
}

for k in min_k..=max_k {
Expand All @@ -1077,14 +1212,15 @@ mod test {
// end_timer!(t01);
// assert_eq!(e01, e0);

let t11 = start_timer!(|| format!("older_skip_zeros k={}", k));
let e11 = super::best_multiexp_skip_zeros(scalars, points);
end_timer!(t11);

let t1 = start_timer!(|| format!("older k={}", k));
let e1 = super::best_multiexp(scalars, points);
end_timer!(t1);
// assert_eq!(e0, e1);

let t11 = start_timer!(|| format!("older_skip_zeros k={}", k));
let e11 = super::best_multiexp_skip_zeros(scalars, points);
end_timer!(t11);
assert_eq!(e11, e1);

// let t11 = start_timer!(|| format!("older_small k={}", k));
Expand Down

0 comments on commit d1f79a5

Please sign in to comment.