Skip to content

Commit

Permalink
Ppsnark refactorings (#208)
Browse files Browse the repository at this point in the history
* refactor: Refactor row/col vector construction for efficiency

- Optimized the creation of `row` and `col` in `R1CSShapeSparkRepr::new` using map and unzip methods.
- Updated `R1CSShapeSparkRepr::evaluation_oracles` to create `E_row` and `E_col` using the same logic for consistency.

* refactor: Refactor and optimize `R1CSShapeSparkRepr` initialization

- Updated method for zero padding in `val_B` and `val_C` using `std::iter::repeat`, to need one vector allocation instead of two
- Functionality and outputs remain unchanged.

* refactor: Refactor polynomial struct in SumCheck to use generic Scalar type

- Updated `CompressedUniPoly` and `UniPoly` structs in `sumcheck.rs` to use the generic `Scalar` type.
- Adapted all methods within these structs to accommodate the `Scalar` type instead of `G: Group` type.
- Modified the type of `cubic_polys` in `ppsnark.rs` to `CompressedUniPoly<G::Scalar>`.

* refactor: Eliminate most instances of resize

resize in Rust may cause reallocation of the memory, which is an expensive operation. This is particularly true when the vector is resized to a larger size.
  • Loading branch information
huitseeker authored Jul 28, 2023
1 parent cdab403 commit eeb3e47
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 84 deletions.
103 changes: 46 additions & 57 deletions src/spartan/ppsnark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,51 +119,34 @@ impl<G: Group> R1CSShapeSparkRepr<G> {
max(total_nz, max(2 * S.num_vars, S.num_cons)).next_power_of_two()
};

let row = {
let mut r = S
.A
.iter()
.chain(S.B.iter())
.chain(S.C.iter())
.map(|(r, _, _)| *r)
.collect::<Vec<usize>>();
r.resize(N, 0usize);
r
};
let (mut row, mut col) = (vec![0usize; N], vec![0usize; N]);

let col = {
let mut c = S
.A
.iter()
.chain(S.B.iter())
.chain(S.C.iter())
.map(|(_, c, _)| *c)
.collect::<Vec<usize>>();
c.resize(N, 0usize);
c
};
for (i, (r, c, _)) in S.A.iter().chain(S.B.iter()).chain(S.C.iter()).enumerate() {
row[i] = *r;
col[i] = *c;
}

let val_A = {
let mut val = S.A.iter().map(|(_, _, v)| *v).collect::<Vec<G::Scalar>>();
val.resize(N, G::Scalar::ZERO);
let mut val = vec![G::Scalar::ZERO; N];
for (i, (_, _, v)) in S.A.iter().enumerate() {
val[i] = *v;
}
val
};

let val_B = {
// prepend zeros
let mut val = vec![G::Scalar::ZERO; S.A.len()];
val.extend(S.B.iter().map(|(_, _, v)| *v).collect::<Vec<G::Scalar>>());
// append zeros
val.resize(N, G::Scalar::ZERO);
let mut val = vec![G::Scalar::ZERO; N];
for (i, (_, _, v)) in S.B.iter().enumerate() {
val[S.A.len() + i] = *v;
}
val
};

let val_C = {
// prepend zeros
let mut val = vec![G::Scalar::ZERO; S.A.len() + S.B.len()];
val.extend(S.C.iter().map(|(_, _, v)| *v).collect::<Vec<G::Scalar>>());
// append zeros
val.resize(N, G::Scalar::ZERO);
let mut val = vec![G::Scalar::ZERO; N];
for (i, (_, _, v)) in S.C.iter().enumerate() {
val[S.A.len() + S.B.len() + i] = *v;
}
val
};

Expand Down Expand Up @@ -265,29 +248,30 @@ impl<G: Group> R1CSShapeSparkRepr<G> {

let mem_row = EqPolynomial::new(r_x_padded).evals();
let mem_col = {
let mut z = z.to_vec();
z.resize(self.N, G::Scalar::ZERO);
z
let mut val = vec![G::Scalar::ZERO; self.N];
for (i, v) in z.iter().enumerate() {
val[i] = *v;
}
val
};

let mut E_row = S
.A
.iter()
.chain(S.B.iter())
.chain(S.C.iter())
.map(|(r, _, _)| mem_row[*r])
.collect::<Vec<G::Scalar>>();

let mut E_col = S
.A
.iter()
.chain(S.B.iter())
.chain(S.C.iter())
.map(|(_, c, _)| mem_col[*c])
.collect::<Vec<G::Scalar>>();
let (E_row, E_col) = {
let mut E_row = vec![mem_row[0]; self.N]; // we place mem_row[0] since resized row is appended with 0s
let mut E_col = vec![mem_col[0]; self.N];

E_row.resize(self.N, mem_row[0]); // we place mem_row[0] since resized row is appended with 0s
E_col.resize(self.N, mem_col[0]);
for (i, (val_r, val_c)) in S
.A
.iter()
.chain(S.B.iter())
.chain(S.C.iter())
.map(|(r, c, _)| (mem_row[*r], mem_col[*c]))
.enumerate()
{
E_row[i] = val_r;
E_col[i] = val_c;
}
(E_row, E_col)
};

(mem_row, mem_col, E_row, E_col)
}
Expand Down Expand Up @@ -862,7 +846,7 @@ impl<G: Group, EE: EvaluationEngineTrait<G, CE = G::CE>> RelaxedR1CSSNARK<G, EE>

let mut e = claim;
let mut r: Vec<G::Scalar> = Vec::new();
let mut cubic_polys: Vec<CompressedUniPoly<G>> = Vec::new();
let mut cubic_polys: Vec<CompressedUniPoly<G::Scalar>> = Vec::new();
let num_rounds = mem.size().log_2();
for _i in 0..num_rounds {
let mut evals: Vec<Vec<G::Scalar>> = Vec::new();
Expand Down Expand Up @@ -999,8 +983,13 @@ impl<G: Group, EE: EvaluationEngineTrait<G, CE = G::CE>> RelaxedR1CSSNARKTrait<G
Bz.resize(pk.S_repr.N, G::Scalar::ZERO);
Cz.resize(pk.S_repr.N, G::Scalar::ZERO);

let mut E = W.E.clone();
E.resize(pk.S_repr.N, G::Scalar::ZERO);
let E = {
let mut val = vec![G::Scalar::ZERO; pk.S_repr.N];
for (i, w_e) in W.E.iter().enumerate() {
val[i] = *w_e;
}
val
};

(Az, Bz, Cz, E)
};
Expand Down
51 changes: 24 additions & 27 deletions src/spartan/sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@
use super::polynomial::MultilinearPolynomial;
use crate::errors::NovaError;
use crate::traits::{Group, TranscriptEngineTrait, TranscriptReprTrait};
use core::marker::PhantomData;
use ff::Field;
use ff::{Field, PrimeField};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "")]
pub(crate) struct SumcheckProof<G: Group> {
compressed_polys: Vec<CompressedUniPoly<G>>,
compressed_polys: Vec<CompressedUniPoly<G::Scalar>>,
}

impl<G: Group> SumcheckProof<G> {
pub fn new(compressed_polys: Vec<CompressedUniPoly<G>>) -> Self {
pub fn new(compressed_polys: Vec<CompressedUniPoly<G::Scalar>>) -> Self {
Self { compressed_polys }
}

Expand Down Expand Up @@ -101,7 +100,7 @@ impl<G: Group> SumcheckProof<G> {
F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync,
{
let mut r: Vec<G::Scalar> = Vec::new();
let mut polys: Vec<CompressedUniPoly<G>> = Vec::new();
let mut polys: Vec<CompressedUniPoly<G::Scalar>> = Vec::new();
let mut claim_per_round = *claim;
for _ in 0..num_rounds {
let poly = {
Expand Down Expand Up @@ -150,7 +149,7 @@ impl<G: Group> SumcheckProof<G> {
{
let mut e = *claim;
let mut r: Vec<G::Scalar> = Vec::new();
let mut quad_polys: Vec<CompressedUniPoly<G>> = Vec::new();
let mut quad_polys: Vec<CompressedUniPoly<G::Scalar>> = Vec::new();

for _j in 0..num_rounds {
let mut evals: Vec<(G::Scalar, G::Scalar)> = Vec::new();
Expand Down Expand Up @@ -204,7 +203,7 @@ impl<G: Group> SumcheckProof<G> {
F: Fn(&G::Scalar, &G::Scalar, &G::Scalar, &G::Scalar) -> G::Scalar + Sync,
{
let mut r: Vec<G::Scalar> = Vec::new();
let mut polys: Vec<CompressedUniPoly<G>> = Vec::new();
let mut polys: Vec<CompressedUniPoly<G::Scalar>> = Vec::new();
let mut claim_per_round = *claim;

for _ in 0..num_rounds {
Expand Down Expand Up @@ -288,34 +287,33 @@ impl<G: Group> SumcheckProof<G> {
// ax^2 + bx + c stored as vec![a,b,c]
// ax^3 + bx^2 + cx + d stored as vec![a,b,c,d]
#[derive(Debug)]
pub struct UniPoly<G: Group> {
coeffs: Vec<G::Scalar>,
pub struct UniPoly<Scalar: PrimeField> {
coeffs: Vec<Scalar>,
}

// ax^2 + bx + c stored as vec![a,c]
// ax^3 + bx^2 + cx + d stored as vec![a,c,d]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CompressedUniPoly<G: Group> {
coeffs_except_linear_term: Vec<G::Scalar>,
_p: PhantomData<G>,
pub struct CompressedUniPoly<Scalar: PrimeField> {
coeffs_except_linear_term: Vec<Scalar>,
}

impl<G: Group> UniPoly<G> {
pub fn from_evals(evals: &[G::Scalar]) -> Self {
impl<Scalar: PrimeField> UniPoly<Scalar> {
pub fn from_evals(evals: &[Scalar]) -> Self {
// we only support degree-2 or degree-3 univariate polynomials
assert!(evals.len() == 3 || evals.len() == 4);
let coeffs = if evals.len() == 3 {
// ax^2 + bx + c
let two_inv = G::Scalar::from(2).invert().unwrap();
let two_inv = Scalar::from(2).invert().unwrap();

let c = evals[0];
let a = two_inv * (evals[2] - evals[1] - evals[1] + c);
let b = evals[1] - c - a;
vec![c, b, a]
} else {
// ax^3 + bx^2 + cx + d
let two_inv = G::Scalar::from(2).invert().unwrap();
let six_inv = G::Scalar::from(6).invert().unwrap();
let two_inv = Scalar::from(2).invert().unwrap();
let six_inv = Scalar::from(6).invert().unwrap();

let d = evals[0];
let a = six_inv
Expand All @@ -338,18 +336,18 @@ impl<G: Group> UniPoly<G> {
self.coeffs.len() - 1
}

pub fn eval_at_zero(&self) -> G::Scalar {
pub fn eval_at_zero(&self) -> Scalar {
self.coeffs[0]
}

pub fn eval_at_one(&self) -> G::Scalar {
pub fn eval_at_one(&self) -> Scalar {
(0..self.coeffs.len())
.into_par_iter()
.map(|i| self.coeffs[i])
.reduce(|| G::Scalar::ZERO, |a, b| a + b)
.reduce(|| Scalar::ZERO, |a, b| a + b)
}

pub fn evaluate(&self, r: &G::Scalar) -> G::Scalar {
pub fn evaluate(&self, r: &Scalar) -> Scalar {
let mut eval = self.coeffs[0];
let mut power = *r;
for coeff in self.coeffs.iter().skip(1) {
Expand All @@ -359,27 +357,26 @@ impl<G: Group> UniPoly<G> {
eval
}

pub fn compress(&self) -> CompressedUniPoly<G> {
pub fn compress(&self) -> CompressedUniPoly<Scalar> {
let coeffs_except_linear_term = [&self.coeffs[0..1], &self.coeffs[2..]].concat();
assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len());
CompressedUniPoly {
coeffs_except_linear_term,
_p: Default::default(),
}
}
}

impl<G: Group> CompressedUniPoly<G> {
impl<Scalar: PrimeField> CompressedUniPoly<Scalar> {
// we require eval(0) + eval(1) = hint, so we can solve for the linear term as:
// linear_term = hint - 2 * constant_term - deg2 term - deg3 term
pub fn decompress(&self, hint: &G::Scalar) -> UniPoly<G> {
pub fn decompress(&self, hint: &Scalar) -> UniPoly<Scalar> {
let mut linear_term =
*hint - self.coeffs_except_linear_term[0] - self.coeffs_except_linear_term[0];
for i in 1..self.coeffs_except_linear_term.len() {
linear_term -= self.coeffs_except_linear_term[i];
}

let mut coeffs: Vec<G::Scalar> = Vec::new();
let mut coeffs: Vec<Scalar> = Vec::new();
coeffs.push(self.coeffs_except_linear_term[0]);
coeffs.push(linear_term);
coeffs.extend(&self.coeffs_except_linear_term[1..]);
Expand All @@ -388,7 +385,7 @@ impl<G: Group> CompressedUniPoly<G> {
}
}

impl<G: Group> TranscriptReprTrait<G> for UniPoly<G> {
impl<G: Group> TranscriptReprTrait<G> for UniPoly<G::Scalar> {
fn to_transcript_bytes(&self) -> Vec<u8> {
let coeffs = self.compress().coeffs_except_linear_term;
coeffs.as_slice().to_transcript_bytes()
Expand Down

0 comments on commit eeb3e47

Please sign in to comment.