Skip to content

Commit

Permalink
Merge pull request #151 from HungryCatsStudio/cleanup-dense-poly
Browse files Browse the repository at this point in the history
Cleanup dense poly
  • Loading branch information
sragss committed Jan 25, 2024
2 parents 889e2ec + 2307aa6 commit 72e9d36
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 111 deletions.
2 changes: 1 addition & 1 deletion jolt-core/src/jolt/vm/instruction_lookups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ where
fn batch(&self) -> Self::BatchedPolynomials {
use rayon::prelude::*;
let (batched_dim_read, (batched_final, batched_E, batched_flag)) = rayon::join(
|| DensePolynomial::merge_dual(self.dim.as_ref(), self.read_cts.as_ref()),
|| DensePolynomial::merge(self.dim.iter().chain(&self.read_cts)),
|| {
let batched_final = DensePolynomial::merge(&self.final_cts);
let (batched_E, batched_flag) = rayon::join(
Expand Down
2 changes: 1 addition & 1 deletion jolt-core/src/lasso/surge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ where
#[tracing::instrument(skip_all, name = "SurgePolys::batch")]
fn batch(&self) -> Self::BatchedPolynomials {
let (batched_dim_read, (batched_final, batched_E)) = rayon::join(
|| DensePolynomial::merge_dual(self.dim.as_ref(), self.read_cts.as_ref()),
|| DensePolynomial::merge(self.dim.iter().chain(&self.read_cts)),
|| {
rayon::join(
|| DensePolynomial::merge(&self.final_cts),
Expand Down
83 changes: 83 additions & 0 deletions jolt-core/src/msm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,86 @@ fn ln_without_floats(a: usize) -> usize {
// log2(a) * ln(2)
(ark_std::log2(a) * 69 / 100) as usize
}

/// Special MSM where all scalar values are 0 / 1 – does not verify.
pub(crate) fn flags_msm<G: CurveGroup>(scalars: &[G::ScalarField], bases: &[G::Affine]) -> G {
assert_eq!(scalars.len(), bases.len());
let result = scalars
.into_iter()
.enumerate()
.filter(|(_index, scalar)| !scalar.is_zero())
.map(|(index, scalar)| bases[index])
.sum();

result
}

pub(crate) fn sm_msm<V: VariableBaseMSM>(
scalars: &[<V::ScalarField as PrimeField>::BigInt],
bases: &[V::MulBase],
) -> V {
assert_eq!(scalars.len(), bases.len());
let num_buckets: usize = 1 << 16; // TODO(sragss): This should be passed in / dependent on M = N^{1/C}

// #[cfg(test)]
// scalars.for_each(|scalar| {
// assert!(scalar < V::ScalarField::from(num_buckets as u64).into_bigint())
// });

// Assign things to buckets based on the scalar
let mut buckets: Vec<V> = vec![V::zero(); num_buckets];
scalars.into_iter().enumerate().for_each(|(index, scalar)| {
let bucket_index: u64 = scalar.as_ref()[0];
buckets[bucket_index as usize] += bases[index];
});

let mut result = V::zero();
let mut running_sum = V::zero();
buckets
.into_iter()
.skip(1)
.enumerate()
.rev()
.for_each(|(index, bucket)| {
running_sum += bucket;
result += running_sum;
});
result
}

#[cfg(test)]
mod tests {

use ark_std::test_rng;

use crate::poly::dense_mlpoly::DensePolynomial;

use super::*;

#[test]
fn sm_msm_parity() {
use ark_curve25519::{EdwardsAffine as G1Affine, EdwardsProjective as G1Projective, Fr};
let mut rng = test_rng();
let bases = vec![
G1Affine::rand(&mut rng),
G1Affine::rand(&mut rng),
G1Affine::rand(&mut rng),
];
let scalars = vec![Fr::from(3), Fr::from(2), Fr::from(1)];
let expected_result = bases[0] + bases[0] + bases[0] + bases[1] + bases[1] + bases[2];
assert_eq!(bases[0] + bases[0] + bases[0], bases[0] * scalars[0]);
let expected_result_b =
bases[0] * scalars[0] + bases[1] * scalars[1] + bases[2] * scalars[2];
assert_eq!(expected_result, expected_result_b);

let calc_result_a: G1Projective = VariableBaseMSM::msm(&bases, &scalars).unwrap();
assert_eq!(calc_result_a, expected_result);

let scalars_bigint: Vec<_> = scalars
.into_iter()
.map(|scalar| scalar.into_bigint())
.collect();
let calc_result_b: G1Projective = sm_msm(&scalars_bigint, &bases);
assert_eq!(calc_result_b, expected_result);
}
}
119 changes: 10 additions & 109 deletions jolt-core/src/poly/dense_mlpoly.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![allow(clippy::too_many_arguments)]
use crate::msm::{flags_msm, sm_msm};
use crate::poly::eq_poly::EqPolynomial;
use crate::utils::{self, compute_dotproduct, compute_dotproduct_low_optimized, mul_0_1_optimized};

Expand Down Expand Up @@ -219,63 +220,17 @@ impl<F: PrimeField> DensePolynomial<F> {
let scalars = self.Z[R_size * i..R_size * (i + 1)].as_ref();
match hint {
CommitHint::Normal => Commitments::batch_commit(scalars, &gens),
CommitHint::Flags => Self::flags_msm(scalars, &gens),
CommitHint::Flags => flags_msm(scalars, &gens),
CommitHint::Small => {
let bigints: Vec<_> = scalars.iter().map(|s| s.into_bigint()).collect();
Self::sm_msm(&bigints, &gens)
sm_msm(&bigints, &gens)
}
}
})
.collect();
PolyCommitment { C }
}

/// Special MSM where all scalar values are 0 / 1 – does not verify.
fn flags_msm<G: CurveGroup>(scalars: &[G::ScalarField], bases: &[G::Affine]) -> G {
assert_eq!(scalars.len(), bases.len());
let result = scalars
.into_iter()
.enumerate()
.filter(|(_index, scalar)| !scalar.is_zero())
.map(|(index, scalar)| bases[index])
.sum();

result
}

pub fn sm_msm<V: VariableBaseMSM>(
scalars: &[<V::ScalarField as PrimeField>::BigInt],
bases: &[V::MulBase],
) -> V {
assert_eq!(scalars.len(), bases.len());
let num_buckets: usize = 1 << 16; // TODO(sragss): This should be passed in / dependent on M = N^{1/C}

// #[cfg(test)]
// scalars.for_each(|scalar| {
// assert!(scalar < V::ScalarField::from(num_buckets as u64).into_bigint())
// });

// Assign things to buckets based on the scalar
let mut buckets: Vec<V> = vec![V::zero(); num_buckets];
scalars.into_iter().enumerate().for_each(|(index, scalar)| {
let bucket_index: u64 = scalar.as_ref()[0];
buckets[bucket_index as usize] += bases[index];
});

let mut result = V::zero();
let mut running_sum = V::zero();
buckets
.into_iter()
.skip(1)
.enumerate()
.rev()
.for_each(|(index, bucket)| {
running_sum += bucket;
result += running_sum;
});
result
}

#[tracing::instrument(skip_all, name = "DensePolynomial.bound")]
pub fn bound(&self, L: &[F]) -> Vec<F> {
let (left_num_vars, right_num_vars) =
Expand Down Expand Up @@ -442,24 +397,15 @@ impl<F: PrimeField> DensePolynomial<F> {
self.Z.as_ref()
}

pub fn extend(&mut self, other: &DensePolynomial<F>) {
assert_eq!(self.Z.len(), self.len);
let other_vec = other.vec();
assert_eq!(other_vec.len(), self.len);
self.Z.extend(other_vec);
self.num_vars += 1;
self.len *= 2;
assert_eq!(self.Z.len(), self.len);
}

#[tracing::instrument(skip_all, name = "DensePoly.merge")]
pub fn merge<T>(polys: &[T]) -> DensePolynomial<F>
where
T: AsRef<DensePolynomial<F>>,
{
let total_len: usize = polys.iter().map(|poly| poly.as_ref().vec().len()).sum();
pub fn merge(polys: impl IntoIterator<Item = impl AsRef<Self>> + Clone) -> DensePolynomial<F> {
let polys_iter_cloned = polys.clone().into_iter();
let total_len: usize = polys
.into_iter()
.map(|poly| poly.as_ref().vec().len())
.sum();
let mut Z: Vec<F> = Vec::with_capacity(total_len.next_power_of_two());
for poly in polys {
for poly in polys_iter_cloned {
Z.extend_from_slice(poly.as_ref().vec());
}

Expand All @@ -469,25 +415,6 @@ impl<F: PrimeField> DensePolynomial<F> {
DensePolynomial::new(Z)
}

#[tracing::instrument(skip_all, name = "DensePoly.merge_dual")]
pub fn merge_dual<T>(polys_a: &[T], polys_b: &[T]) -> DensePolynomial<F>
where
T: AsRef<DensePolynomial<F>>,
{
let total_len_a: usize = polys_a.iter().map(|poly| poly.as_ref().len()).sum();
let total_len_b: usize = polys_b.iter().map(|poly| poly.as_ref().len()).sum();
let total_len = total_len_a + total_len_b;

let mut Z: Vec<F> = Vec::with_capacity(total_len.next_power_of_two());
polys_a.iter().for_each(|poly| Z.extend_from_slice(poly.as_ref().vec()));
polys_b.iter().for_each(|poly| Z.extend_from_slice(poly.as_ref().vec()));

// pad the polynomial with zero polynomial at the end
Z.resize(Z.capacity(), F::zero());

DensePolynomial::new(Z)
}

pub fn combined_commit<G>(
&self,
label: &'static [u8],
Expand Down Expand Up @@ -1024,32 +951,6 @@ mod tests {
);
}

#[test]
fn sm_msm_parity() {
use ark_curve25519::{EdwardsAffine as G1Affine, EdwardsProjective as G1Projective, Fr};
let mut rng = test_rng();
let bases = vec![
G1Affine::rand(&mut rng),
G1Affine::rand(&mut rng),
G1Affine::rand(&mut rng),
];
let scalars = vec![Fr::from(3), Fr::from(2), Fr::from(1)];
let expected_result = bases[0] + bases[0] + bases[0] + bases[1] + bases[1] + bases[2];
assert_eq!(bases[0] + bases[0] + bases[0], bases[0] * scalars[0]);
let expected_result_b =
bases[0] * scalars[0] + bases[1] * scalars[1] + bases[2] * scalars[2];
assert_eq!(expected_result, expected_result_b);

let calc_result_a: G1Projective = VariableBaseMSM::msm(&bases, &scalars).unwrap();
assert_eq!(calc_result_a, expected_result);

let scalars_bigint: Vec<_> = scalars
.into_iter()
.map(|scalar| scalar.into_bigint())
.collect();
let calc_result_b: G1Projective = DensePolynomial::<Fr>::sm_msm(&scalars_bigint, &bases);
assert_eq!(calc_result_b, expected_result);
}

#[test]
fn commit_with_hint_parity() {
Expand Down

0 comments on commit 72e9d36

Please sign in to comment.