Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: MSM skip doubling when window has all zeros #152

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "halo2curves"
version = "0.6.1"
version = "0.6.2"
authors = ["Privacy Scaling Explorations team"]
license = "MIT/Apache-2.0"
edition = "2021"
Expand Down
144 changes: 101 additions & 43 deletions src/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::CurveAffine;
use ff::Field;
use ff::PrimeField;
use group::Group;
use rayon::iter::IntoParallelIterator;
use rayon::iter::{
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
};
Expand Down Expand Up @@ -287,6 +288,7 @@ impl<C: CurveAffine> Schedule<C> {
}

pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
// Do conversion to bytes once
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();

let c = if bases.len() < 4 {
Expand All @@ -299,7 +301,34 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &

let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1;

for current_window in (0..number_of_windows).rev() {
// In each window, get the booth index of each coefficient
let mut coeffs_in_windows = Vec::with_capacity(number_of_windows);
// Track what is the last window where we actually have nonzero booth index, so we completely skip buckets where the scalar bits for all coeffs are 0
let mut max_nonzero_window = None;
for current_window in 0..number_of_windows {
let coeffs_in_window: Vec<i32> = coeffs
.iter()
.map(|coeff| {
let coeff = get_booth_index(current_window, c, coeff.as_ref());
if coeff != 0 {
max_nonzero_window = Some(current_window);
}
coeff
})
.collect();
coeffs_in_windows.push(coeffs_in_window);
}
// Save memory and drop coeffs as bytes since it's not needed anymore
drop(coeffs);

if max_nonzero_window.is_none() {
return;
}
for coeffs_in_window in coeffs_in_windows
.into_iter()
.take(max_nonzero_window.unwrap() + 1)
.rev()
{
for _ in 0..c {
*acc = acc.double();
}
Expand Down Expand Up @@ -337,8 +366,7 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &

let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; 1 << (c - 1)];

for (coeff, base) in coeffs.iter().zip(bases.iter()) {
let coeff = get_booth_index(current_window, c, coeff.as_ref());
for (coeff, base) in coeffs_in_window.into_iter().zip(bases.iter()) {
if coeff.is_positive() {
buckets[coeff as usize - 1].add_assign(base);
}
Expand Down Expand Up @@ -422,52 +450,82 @@ pub fn best_multiexp_independent_points<C: CurveAffine>(

// number of windows
let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1;
// accumumator for each window
let mut acc = vec![C::Curve::identity(); number_of_windows];
acc.par_iter_mut().enumerate().rev().for_each(|(w, acc)| {
// jacobian buckets for already scheduled points
let mut j_bucks = vec![Bucket::<C>::None; 1 << (c - 1)];

// schedular for affine addition
let mut sched = Schedule::new(c);

for (base_idx, coeff) in coeffs.iter().enumerate() {
let buck_idx = get_booth_index(w, c, coeff.as_ref());

if buck_idx != 0 {
// parse bucket index
let sign = buck_idx.is_positive();
let buck_idx = buck_idx.unsigned_abs() as usize - 1;
// In each window, get the booth index of each coefficient
let mut coeffs_in_windows = Vec::with_capacity(number_of_windows);
// Track what is the last window where we actually have nonzero booth index, so we completely skip buckets where the scalar bits for all coeffs are 0
let mut max_nonzero_window = None;
for current_window in 0..number_of_windows {
let coeffs_in_window: Vec<i32> = coeffs
.iter()
.map(|coeff| {
let coeff = get_booth_index(current_window, c, coeff.as_ref());
if coeff != 0 {
max_nonzero_window = Some(current_window);
}
coeff
})
.collect();
coeffs_in_windows.push(coeffs_in_window);
}
// Save memory and drop coeffs as bytes since it's not needed anymore
drop(coeffs);

if sched.contains(buck_idx) {
// greedy accumulation
// we use original bases here
j_bucks[buck_idx].add_assign(&bases[base_idx], sign);
} else {
// also flushes the schedule if full
sched.add(&bases_local, base_idx, buck_idx, sign);
if max_nonzero_window.is_none() {
// Everything is zero
return C::Curve::identity();
}
let number_of_windows = max_nonzero_window.unwrap() + 1;
// accumumator for each window
let mut acc = vec![C::Curve::identity(); number_of_windows];
coeffs_in_windows
.into_par_iter()
.take(number_of_windows)
.zip(acc.par_iter_mut())
.enumerate()
.rev()
.for_each(|(w, (coeffs_in_window, acc))| {
// jacobian buckets for already scheduled points
let mut j_bucks = vec![Bucket::<C>::None; 1 << (c - 1)];

// schedular for affine addition
let mut sched = Schedule::new(c);

for (base_idx, buck_idx) in coeffs_in_window.into_iter().enumerate() {
if buck_idx != 0 {
// parse bucket index
let sign = buck_idx.is_positive();
let buck_idx = buck_idx.unsigned_abs() as usize - 1;

if sched.contains(buck_idx) {
// greedy accumulation
// we use original bases here
j_bucks[buck_idx].add_assign(&bases[base_idx], sign);
} else {
// also flushes the schedule if full
sched.add(&bases_local, base_idx, buck_idx, sign);
}
}
}
}

// flush the schedule
sched.execute(&bases_local);

// summation by parts
// e.g. 3a + 2b + 1c = a +
// (a) + b +
// ((a) + b) + c
let mut running_sum = C::Curve::identity();
for (j_buck, a_buck) in j_bucks.iter().zip(sched.buckets.iter()).rev() {
running_sum += j_buck.add(a_buck);
*acc += running_sum;
}
// flush the schedule
sched.execute(&bases_local);

// summation by parts
// e.g. 3a + 2b + 1c = a +
// (a) + b +
// ((a) + b) + c
let mut running_sum = C::Curve::identity();
for (j_buck, a_buck) in j_bucks.iter().zip(sched.buckets.iter()).rev() {
running_sum += j_buck.add(a_buck);
*acc += running_sum;
}

// shift accumulator to the window position
for _ in 0..c * w {
*acc = acc.double();
}
});
// shift accumulator to the window position
for _ in 0..c * w {
*acc = acc.double();
}
});
acc.into_iter().sum::<_>()
}

Expand Down
Loading