Skip to content

Commit

Permalink
Cleanup memory pipeline for SuperNova (#227)
Browse files Browse the repository at this point in the history
* NIFS::prove -> NIFS::prove_mut

* preallocate witness

* simplify resourcebuffer init
  • Loading branch information
winston-h-zhang authored Jan 4, 2024
1 parent ae521be commit 52a57d6
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 62 deletions.
12 changes: 6 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,17 +415,17 @@ where
let buffer_primary = ResourceBuffer {
l_w: None,
l_u: None,
ABC_Z_1: R1CSResult::default(r1cs_primary),
ABC_Z_2: R1CSResult::default(r1cs_primary),
T: r1cs::default_T(r1cs_primary),
ABC_Z_1: R1CSResult::default(r1cs_primary.num_cons),
ABC_Z_2: R1CSResult::default(r1cs_primary.num_cons),
T: r1cs::default_T::<E1>(r1cs_primary.num_cons),
};

let buffer_secondary = ResourceBuffer {
l_w: None,
l_u: None,
ABC_Z_1: R1CSResult::default(r1cs_secondary),
ABC_Z_2: R1CSResult::default(r1cs_secondary),
T: r1cs::default_T(r1cs_secondary),
ABC_Z_1: R1CSResult::default(r1cs_secondary.num_cons),
ABC_Z_2: R1CSResult::default(r1cs_secondary.num_cons),
T: r1cs::default_T::<E2>(r1cs_secondary.num_cons),
};

Ok(Self {
Expand Down
12 changes: 6 additions & 6 deletions src/r1cs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,11 +577,11 @@ impl<E: Engine> R1CSShape<E> {

impl<E: Engine> R1CSResult<E> {
/// Produces a default `R1CSResult` given an `R1CSShape`
pub fn default(S: &R1CSShape<E>) -> Self {
pub fn default(num_cons: usize) -> Self {
Self {
AZ: vec![E::Scalar::ZERO; S.num_cons],
BZ: vec![E::Scalar::ZERO; S.num_cons],
CZ: vec![E::Scalar::ZERO; S.num_cons],
AZ: vec![E::Scalar::ZERO; num_cons],
BZ: vec![E::Scalar::ZERO; num_cons],
CZ: vec![E::Scalar::ZERO; num_cons],
}
}
}
Expand Down Expand Up @@ -817,8 +817,8 @@ impl<E: Engine> AbsorbInROTrait<E> for RelaxedR1CSInstance<E> {
}

/// Empty buffer for `commit_T_into`
pub fn default_T<E: Engine>(shape: &R1CSShape<E>) -> Vec<E::Scalar> {
Vec::with_capacity(shape.num_cons)
pub fn default_T<E: Engine>(num_cons: usize) -> Vec<E::Scalar> {
Vec::with_capacity(num_cons)
}

#[cfg(test)]
Expand Down
155 changes: 105 additions & 50 deletions src/supernova/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
digest::{DigestComputer, SimpleDigestible},
errors::NovaError,
r1cs::{
commitment_key_size, CommitmentKeyHint, R1CSInstance, R1CSShape, R1CSWitness,
self, commitment_key_size, CommitmentKeyHint, R1CSInstance, R1CSResult, R1CSShape, R1CSWitness,
RelaxedR1CSInstance, RelaxedR1CSWitness,
},
scalar_as_base,
Expand Down Expand Up @@ -393,6 +393,21 @@ where
}
}

/// A resource buffer for SuperNova's [`RecursiveSNARK`] for storing scratch values that are computed by `prove_step`,
/// which allows the reuse of memory allocations and avoids unnecessary new allocations in the critical section.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct ResourceBuffer<E: Engine> {
l_w: Option<R1CSWitness<E>>,
l_u: Option<R1CSInstance<E>>,

ABC_Z_1: R1CSResult<E>,
ABC_Z_2: R1CSResult<E>,

/// buffer for `commit_T`
T: Vec<E::Scalar>,
}

/// A SNARK that proves the correct execution of an non-uniform incremental computation
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "")]
Expand All @@ -416,6 +431,11 @@ where
proven_circuit_index: usize,
program_counter: E1::Scalar,

/// Buffer for memory needed by the primary fold-step
buffer_primary: ResourceBuffer<E1>,
/// Buffer for memory needed by the secondary fold-step
buffer_secondary: ResourceBuffer<E2>,

// Relaxed instances for the primary circuits
// Entries are `None` if the circuit has not been executed yet
r_W_primary: Vec<Option<RelaxedR1CSWitness<E1>>>,
Expand Down Expand Up @@ -454,6 +474,8 @@ where
let num_augmented_circuits = non_uniform_circuit.num_circuits();
let circuit_index = non_uniform_circuit.initial_circuit_index();

let r1cs_secondary = &pp.circuit_shape_secondary.r1cs_shape;

// check the length of the secondary initial input
if z0_secondary.len() != pp.circuit_shape_secondary.F_arity {
return Err(SuperNovaError::NovaError(
Expand Down Expand Up @@ -541,7 +563,7 @@ where
return Err(NovaError::InvalidStepOutputLength.into());
}
let (u_secondary, w_secondary) = cs_secondary
.r1cs_instance_and_witness(&pp.circuit_shape_secondary.r1cs_shape, &pp.ck_secondary)
.r1cs_instance_and_witness(r1cs_secondary, &pp.ck_secondary)
.map_err(|_| SuperNovaError::NovaError(NovaError::UnSat))?;

// IVC proof for the primary circuit
Expand All @@ -561,9 +583,8 @@ where
let l_u_secondary = u_secondary;

// Initialize relaxed instance/witness pair for the secondary circuit proofs
let r_W_secondary = RelaxedR1CSWitness::<E2>::default(&pp.circuit_shape_secondary.r1cs_shape);
let r_U_secondary =
RelaxedR1CSInstance::default(&pp.ck_secondary, &pp.circuit_shape_secondary.r1cs_shape);
let r_W_secondary: RelaxedR1CSWitness<E2> = RelaxedR1CSWitness::<E2>::default(r1cs_secondary);
let r_U_secondary = RelaxedR1CSInstance::default(&pp.ck_secondary, r1cs_secondary);

// Outputs of the two circuits and next program counter thus far.
let zi_primary = zi_primary
Expand Down Expand Up @@ -593,6 +614,31 @@ where
let r_U_primary_initial_list = (0..num_augmented_circuits)
.map(|i| (i == circuit_index).then(|| r_U_primary.clone()))
.collect::<Vec<Option<RelaxedR1CSInstance<E1>>>>();

// find the largest length r1cs shape for the buffer size
let max_num_cons = pp
.circuit_shapes
.iter()
.map(|circuit| circuit.r1cs_shape.num_cons)
.max()
.unwrap();

let buffer_primary = ResourceBuffer {
l_w: None,
l_u: None,
ABC_Z_1: R1CSResult::default(max_num_cons),
ABC_Z_2: R1CSResult::default(max_num_cons),
T: r1cs::default_T::<E1>(max_num_cons),
};

let buffer_secondary = ResourceBuffer {
l_w: None,
l_u: None,
ABC_Z_1: R1CSResult::default(r1cs_secondary.num_cons),
ABC_Z_2: R1CSResult::default(r1cs_secondary.num_cons),
T: r1cs::default_T::<E2>(r1cs_secondary.num_cons),
};

Ok(Self {
pp_digest: pp.digest(),
num_augmented_circuits,
Expand All @@ -603,6 +649,9 @@ where
proven_circuit_index: circuit_index,
program_counter: zi_primary_pc_next,

buffer_primary,
buffer_secondary,

r_W_primary: r_W_primary_initial_list,
r_U_primary: r_U_primary_initial_list,
z0_secondary: z0_secondary.to_vec(),
Expand All @@ -629,40 +678,45 @@ where
return Ok(());
}

// save the inputs before proceeding to the `i+1`th step
let r_U_primary_i = self.r_U_primary.clone();
// Create single-entry accumulator list for the secondary circuit to hand to SuperNovaAugmentedCircuitInputs
let r_U_secondary_i = vec![Some(self.r_U_secondary.clone())];
let l_u_secondary_i = self.l_u_secondary.clone();

let circuit_index = c_primary.circuit_index();
assert_eq!(self.program_counter, E1::Scalar::from(circuit_index as u64));

// fold the secondary circuit's instance
let (nifs_secondary, (r_U_secondary_folded, r_W_secondary_folded)) = NIFS::prove(
let nifs_secondary = NIFS::prove_mut(
&pp.ck_secondary,
&pp.ro_consts_secondary,
&scalar_as_base::<E1>(self.pp_digest),
&pp.circuit_shape_secondary.r1cs_shape,
&self.r_U_secondary,
&self.r_W_secondary,
&mut self.r_U_secondary,
&mut self.r_W_secondary,
&self.l_u_secondary,
&self.l_w_secondary,
&mut self.buffer_secondary.T,
&mut self.buffer_secondary.ABC_Z_1,
&mut self.buffer_secondary.ABC_Z_2,
)
.map_err(SuperNovaError::NovaError)?;

// clone and updated running instance on respective circuit_index
let r_U_secondary_next = r_U_secondary_folded;
let r_W_secondary_next = r_W_secondary_folded;

// Create single-entry accumulator list for the secondary circuit to hand to SuperNovaAugmentedCircuitInputs
let r_U_secondary = vec![Some(self.r_U_secondary.clone())];

let mut cs_primary = SatisfyingAssignment::<E1>::new();
let T =
let mut cs_primary = SatisfyingAssignment::<E1>::with_capacity(
pp[circuit_index].r1cs_shape.num_io + 1,
pp[circuit_index].r1cs_shape.num_vars,
);
let T: <<E2 as Engine>::CE as CommitmentEngineTrait<E2>>::Commitment =
Commitment::<E2>::decompress(&nifs_secondary.comm_T).map_err(SuperNovaError::NovaError)?;
let inputs_primary: SuperNovaAugmentedCircuitInputs<'_, E2> =
SuperNovaAugmentedCircuitInputs::new(
scalar_as_base::<E1>(self.pp_digest),
E1::Scalar::from(self.i as u64),
&self.z0_primary,
Some(&self.zi_primary),
Some(&r_U_secondary),
Some(&self.l_u_secondary),
Some(&r_U_secondary_i),
Some(&l_u_secondary_i),
Some(&T),
Some(self.program_counter),
E1::Scalar::ZERO,
Expand All @@ -689,37 +743,43 @@ where
.r1cs_instance_and_witness(&pp[circuit_index].r1cs_shape, &pp.ck_primary)
.map_err(SuperNovaError::NovaError)?;

// Split into `if let`/`else` statement
// to avoid `returns a value referencing data owned by closure` error on `&RelaxedR1CSInstance::default` and `RelaxedR1CSWitness::default`
let (nifs_primary, (r_U_primary_folded, r_W_primary_folded)) = match (
self.r_U_primary.get(circuit_index),
self.r_W_primary.get(circuit_index),
let (r_U_primary, r_W_primary) = if let (Some(Some(r_U_primary)), Some(Some(r_W_primary))) = (
self.r_U_primary.get_mut(circuit_index),
self.r_W_primary.get_mut(circuit_index),
) {
(Some(Some(r_U_primary)), Some(Some(r_W_primary))) => NIFS::prove(
&pp.ck_primary,
&pp.ro_consts_primary,
&self.pp_digest,
&pp[circuit_index].r1cs_shape,
r_U_primary,
r_W_primary,
&l_u_primary,
&l_w_primary,
)
.map_err(SuperNovaError::NovaError)?,
_ => NIFS::prove(
(r_U_primary, r_W_primary)
} else {
self.r_U_primary[circuit_index] = Some(RelaxedR1CSInstance::default(
&pp.ck_primary,
&pp.ro_consts_primary,
&self.pp_digest,
&pp[circuit_index].r1cs_shape,
&RelaxedR1CSInstance::default(&pp.ck_primary, &pp[circuit_index].r1cs_shape),
&RelaxedR1CSWitness::default(&pp[circuit_index].r1cs_shape),
&l_u_primary,
&l_w_primary,
));
self.r_W_primary[circuit_index] =
Some(RelaxedR1CSWitness::default(&pp[circuit_index].r1cs_shape));
(
self.r_U_primary[circuit_index].as_mut().unwrap(),
self.r_W_primary[circuit_index].as_mut().unwrap(),
)
.map_err(SuperNovaError::NovaError)?,
};

let mut cs_secondary = SatisfyingAssignment::<E2>::new();
let nifs_primary = NIFS::prove_mut(
&pp.ck_primary,
&pp.ro_consts_primary,
&self.pp_digest,
&pp[circuit_index].r1cs_shape,
r_U_primary,
r_W_primary,
&l_u_primary,
&l_w_primary,
&mut self.buffer_primary.T,
&mut self.buffer_primary.ABC_Z_1,
&mut self.buffer_primary.ABC_Z_2,
)
.map_err(SuperNovaError::NovaError)?;

let mut cs_secondary = SatisfyingAssignment::<E2>::with_capacity(
pp.circuit_shape_secondary.r1cs_shape.num_io + 1,
pp.circuit_shape_secondary.r1cs_shape.num_vars,
);
let binding =
Commitment::<E1>::decompress(&nifs_primary.comm_T).map_err(SuperNovaError::NovaError)?;
let inputs_secondary: SuperNovaAugmentedCircuitInputs<'_, E1> =
Expand All @@ -728,7 +788,7 @@ where
E2::Scalar::from(self.i as u64),
&self.z0_secondary,
Some(&self.zi_secondary),
Some(&self.r_U_primary),
Some(&r_U_primary_i),
Some(&l_u_primary),
Some(&binding),
None, // pc is always None for secondary circuit
Expand Down Expand Up @@ -782,11 +842,6 @@ where
));
}

// clone and updated running instance on respective circuit_index
self.r_U_primary[circuit_index] = Some(r_U_primary_folded);
self.r_W_primary[circuit_index] = Some(r_W_primary_folded);
self.r_W_secondary = r_W_secondary_next;
self.r_U_secondary = r_U_secondary_next;
self.l_w_secondary = l_w_secondary_next;
self.l_u_secondary = l_u_secondary_next;
self.i += 1;
Expand Down

0 comments on commit 52a57d6

Please sign in to comment.