Skip to content

Commit

Permalink
Merge pull request #213 from a16z/refactor-spartan
Browse files Browse the repository at this point in the history
Refactor Spartan commitments
  • Loading branch information
sragss authored Mar 26, 2024
2 parents 4f74b5c + e823be0 commit 891206b
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 173 deletions.
2 changes: 1 addition & 1 deletion jolt-core/src/jolt/vm/instruction_lookups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ where
pub fn verify(
preprocessing: &InstructionLookupsPreprocessing<F>,
proof: InstructionLookupsProof<C, M, F, G, InstructionSet, Subtables>,
commitment: InstructionCommitment<G>,
commitment: &InstructionCommitment<G>,
transcript: &mut Transcript,
) -> Result<(), ProofVerifyError> {
<Transcript as ProofTranscript<G>>::append_protocol_name(transcript, Self::protocol_name());
Expand Down
137 changes: 61 additions & 76 deletions jolt-core/src/jolt/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::lasso::memory_checking::{MemoryCheckingProver, MemoryCheckingVerifier
use crate::poly::hyrax::{HyraxCommitment, HyraxGenerators};
use crate::poly::pedersen::PedersenGenerators;
use crate::poly::structured_poly::BatchablePolynomials;
use crate::r1cs::snark::{R1CSCommitments, R1CSInputs, R1CSProof};
use crate::r1cs::snark::{R1CSUniqueCommitments, R1CSInputs, R1CSProof};
use crate::utils::errors::ProofVerifyError;
use crate::utils::thread::drop_in_background_thread;
use common::{
Expand Down Expand Up @@ -76,6 +76,7 @@ pub struct JoltCommitments<G: CurveGroup> {
pub bytecode: BytecodeCommitment<G>,
pub read_write_memory: MemoryCommitment<G>,
pub instruction_lookups: InstructionCommitment<G>,
pub r1cs: R1CSUniqueCommitments<G>
}

pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, const M: usize>
Expand Down Expand Up @@ -196,14 +197,8 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
read_write_memory: memory_polynomials,
instruction_lookups: instruction_lookups_polynomials,
};
let jolt_commitments = JoltCommitments {
bytecode: bytecode_commitment,
read_write_memory: memory_commitment,
instruction_lookups: instruction_lookups_commitment,
};

// Note: Some of the commitments in r1cs_commitments are duplicates of elsewhere.
let r1cs_proof = Self::prove_r1cs(
let (r1cs_proof, r1cs_commitment) = Self::prove_r1cs(
preprocessing,
instructions,
bytecode_rows,
Expand All @@ -212,12 +207,18 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
memory_trace.into_iter().flatten().collect(),
circuit_flags,
&jolt_polynomials,
&jolt_commitments,
&mut transcript,
);

drop_in_background_thread(jolt_polynomials);

let jolt_commitments = JoltCommitments {
bytecode: bytecode_commitment,
read_write_memory: memory_commitment,
instruction_lookups: instruction_lookups_commitment,
r1cs: r1cs_commitment
};

let jolt_proof = JoltProof {
bytecode: bytecode_proof,
read_write_memory: memory_proof,
Expand All @@ -237,25 +238,22 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
Self::verify_bytecode(
&preprocessing.bytecode,
proof.bytecode,
commitments.bytecode,
&commitments.bytecode,
&mut transcript,
)?;
Self::verify_memory(
&mut preprocessing.read_write_memory,
proof.read_write_memory,
commitments.read_write_memory,
&commitments.read_write_memory,
&mut transcript,
)?;
Self::verify_instruction_lookups(
&preprocessing.instruction_lookups,
proof.instruction_lookups,
commitments.instruction_lookups,
&commitments.instruction_lookups,
&mut transcript,
)?;
proof
.r1cs
.verify(&mut transcript)
.map_err(|e| ProofVerifyError::SpartanError(e.to_string()))?;
Self::verify_r1cs(proof.r1cs, commitments, &mut transcript)?;
Ok(())
}

Expand All @@ -276,7 +274,7 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
fn verify_instruction_lookups(
preprocessing: &InstructionLookupsPreprocessing<F>,
proof: InstructionLookupsProof<C, M, F, G, Self::InstructionSet, Self::Subtables>,
commitment: InstructionCommitment<G>,
commitment: &InstructionCommitment<G>,
transcript: &mut Transcript,
) -> Result<(), ProofVerifyError> {
InstructionLookupsProof::verify(preprocessing, proof, commitment, transcript)
Expand All @@ -303,7 +301,7 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
fn verify_bytecode(
preprocessing: &BytecodePreprocessing<F>,
proof: BytecodeProof<F, G>,
commitment: BytecodeCommitment<G>,
commitment: &BytecodeCommitment<G>,
transcript: &mut Transcript,
) -> Result<(), ProofVerifyError> {
BytecodeProof::verify_memory_checking(preprocessing, proof, &commitment, transcript)
Expand Down Expand Up @@ -356,7 +354,7 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
fn verify_memory(
preprocessing: &mut ReadWriteMemoryPreprocessing,
mut proof: ReadWriteMemoryProof<F, G>,
commitment: MemoryCommitment<G>,
commitment: &MemoryCommitment<G>,
transcript: &mut Transcript,
) -> Result<(), ProofVerifyError> {
assert!(proof.program_io.inputs.len() <= MAX_INPUT_SIZE as usize);
Expand All @@ -383,9 +381,8 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
memory_trace: Vec<MemoryOp>,
circuit_flags: Vec<F>,
jolt_polynomials: &JoltPolynomials<F, G>,
jolt_commitments: &JoltCommitments<G>,
transcript: &mut Transcript,
) -> R1CSProof<F, G> {
) -> (R1CSProof<F, G>, R1CSUniqueCommitments<G>) {
let N_FLAGS = 17;
let trace_len = trace.len();
let padded_trace_len = trace_len.next_power_of_two();
Expand All @@ -406,7 +403,7 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
- circuit_flags_bits
*/

// OBtain circuit_flags_packed to prog_v_rw. Pack them in little-endian order.
// Obtain circuit_flags_packed to prog_v_rw. Pack them in little-endian order.
let span = tracing::span!(tracing::Level::INFO, "pack_flags");
let _enter = span.enter();
let precomputed_powers: Vec<F> = (0..N_FLAGS)
Expand All @@ -421,7 +418,7 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
})
})
.collect();
packed_flags.extend(vec![F::zero(); padded_trace_len - packed_flags.len()]);
packed_flags.resize(padded_trace_len, F::zero());
drop(_enter);
drop(span);

Expand Down Expand Up @@ -469,10 +466,7 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
drop(_enter);
drop(span);

// Assemble the polynomials
let (bytecode_a, mut bytecode_v) = jolt_polynomials.bytecode.get_polys_r1cs();
bytecode_v.par_extend(packed_flags.par_iter());

let (bytecode_a, bytecode_v) = jolt_polynomials.bytecode.get_polys_r1cs();
let (memreg_a_rw, memreg_v_reads, memreg_v_writes) =
jolt_polynomials.read_write_memory.get_polys_r1cs();

Expand All @@ -489,20 +483,7 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
}
drop(_guard);

// Assemble the commitments
let span = tracing::span!(tracing::Level::INFO, "bytecode_commitment_conversions");
let _guard = span.enter();
// TODO(sragss): JoltCommitment::convert_to_pre_r1cs();
let bytecode_comms: Vec<HyraxCommitment<NUM_R1CS_POLYS, G>> = vec![
jolt_commitments.bytecode.read_write_commitments[0].clone(), // a
jolt_commitments.bytecode.read_write_commitments[2].clone(), // opcode,
jolt_commitments.bytecode.read_write_commitments[3].clone(), // rd
jolt_commitments.bytecode.read_write_commitments[4].clone(), // rs1
jolt_commitments.bytecode.read_write_commitments[5].clone(), // rs2
jolt_commitments.bytecode.read_write_commitments[6].clone(), // imm
];
drop(_guard);

// Commit to R1CS specific items
let commit_to_chunks = |data: &Vec<F>| -> Vec<HyraxCommitment<NUM_R1CS_POLYS, G>> {
data.par_chunks(padded_trace_len)
.map(|chunk| HyraxCommitment::commit_slice(chunk, &hyrax_generators))
Expand All @@ -514,41 +495,18 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
let chunks_x_comms = commit_to_chunks(&chunks_x);
let chunks_y_comms = commit_to_chunks(&chunks_y);
let lookup_outputs_comms = commit_to_chunks(&lookup_outputs);
let packed_flags_comm = vec![HyraxCommitment::commit_slice(
&packed_flags,
&hyrax_generators,
)];
let packed_flags_comm = HyraxCommitment::commit_slice(&packed_flags, &hyrax_generators);
let circuit_flags_comm = commit_to_chunks(&circuit_flags_bits);
drop(_guard);

let span = tracing::span!(tracing::Level::INFO, "conversions");
let _guard = span.enter();
let memory_comms = jolt_commitments
.read_write_memory
.a_v_read_write_commitments
.clone();
let dim_read_comms =
jolt_commitments.instruction_lookups.dim_read_commitment[0..C].to_vec();
drop(_guard);

let jolt_commitments_spartan = [
bytecode_comms,
packed_flags_comm,
memory_comms,
chunks_x_comms,
chunks_y_comms,
dim_read_comms,
lookup_outputs_comms,
circuit_flags_comm,
]
.concat();

// Flattening this out into a Vec<F> and chunking into PADDED_TRACE_LEN-sized chunks
// Flattening this out into a Vec<F> and chunking into PADDED_TRACE_LEN-sized chunks
// will be the exact witness vector to feed into the R1CS
// after pre-pending IO and appending the AUX
let inputs: R1CSInputs<F> = R1CSInputs::new(
padded_trace_len,
bytecode_a,
bytecode_v,
packed_flags,
memreg_a_rw,
memreg_v_reads,
memreg_v_writes,
Expand All @@ -559,18 +517,45 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
circuit_flags_bits,
);

let proof = R1CSProof::prove::<F>(
32,
C,
padded_trace_len,
inputs,
hyrax_generators.clone(),
&jolt_commitments_spartan,
transcript,
let (key, witness_segments, io_aux_commitments) = R1CSProof::<F,G>::compute_witness_commit(
32,
C,
padded_trace_len,
inputs,
&hyrax_generators)
.expect("R1CSProof setup failed");

let r1cs_commitments = R1CSUniqueCommitments::new(
io_aux_commitments,
chunks_x_comms,
chunks_y_comms,
lookup_outputs_comms,
packed_flags_comm,
circuit_flags_comm,
hyrax_generators
);

r1cs_commitments.append_to_transcript(transcript);

let proof = R1CSProof::prove(
key,
witness_segments,
transcript
)
.expect("proof failed");

(proof, r1cs_commitments)
}

fn verify_r1cs(
proof: R1CSProof<F, G>,
commitments: JoltCommitments<G>,
transcript: &mut Transcript,
) -> Result<(), ProofVerifyError> {
commitments.r1cs.append_to_transcript(transcript);
proof
.verify(commitments, C, transcript)
.map_err(|e| ProofVerifyError::SpartanError(e.to_string()))
}

#[tracing::instrument(skip_all, name = "Jolt::compute_lookup_outputs")]
Expand Down
Loading

0 comments on commit 891206b

Please sign in to comment.