diff --git a/.gitignore b/.gitignore index 829691c6..ec2971fb 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,7 @@ /target testdata +Cargo.lock +params +agg.pk +break_points.json \ No newline at end of file diff --git a/snark-verifier-sdk/Cargo.toml b/snark-verifier-sdk/Cargo.toml index adf30bec..fa423daf 100644 --- a/snark-verifier-sdk/Cargo.toml +++ b/snark-verifier-sdk/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "snark-verifier-sdk" -version = "0.1.2" +version = "0.1.3" edition = "2021" [dependencies] diff --git a/snark-verifier-sdk/benches/read_pk.rs b/snark-verifier-sdk/benches/read_pk.rs index 55154a2e..f87f52c8 100644 --- a/snark-verifier-sdk/benches/read_pk.rs +++ b/snark-verifier-sdk/benches/read_pk.rs @@ -9,7 +9,7 @@ use halo2_proofs::{halo2curves::bn256::Bn256, poly::kzg::commitment::ParamsKZG}; use pprof::criterion::{Output, PProfProfiler}; use rand::rngs::OsRng; -use snark_verifier_sdk::halo2::aggregation::AggregationConfigParams; +use snark_verifier_sdk::halo2::aggregation::{AggregationConfigParams, VerifierUniversality}; use snark_verifier_sdk::{ gen_pk, halo2::{aggregation::AggregationCircuit, gen_snark_shplonk}, @@ -190,6 +190,7 @@ fn bench(c: &mut Criterion) { None, ¶ms, snarks, + VerifierUniversality::None, ); std::fs::remove_file("examples/agg.pk").ok(); diff --git a/snark-verifier-sdk/benches/standard_plonk.rs b/snark-verifier-sdk/benches/standard_plonk.rs index eecb7140..873f77cb 100644 --- a/snark-verifier-sdk/benches/standard_plonk.rs +++ b/snark-verifier-sdk/benches/standard_plonk.rs @@ -12,7 +12,7 @@ use halo2_proofs::{ use pprof::criterion::{Output, PProfProfiler}; use rand::rngs::OsRng; use snark_verifier_sdk::evm::{evm_verify, gen_evm_proof_shplonk, gen_evm_verifier_shplonk}; -use snark_verifier_sdk::halo2::aggregation::AggregationConfigParams; +use snark_verifier_sdk::halo2::aggregation::{AggregationConfigParams, VerifierUniversality}; use snark_verifier_sdk::{ gen_pk, halo2::{aggregation::AggregationCircuit, gen_proof_shplonk, gen_snark_shplonk}, @@ -193,6 +193,7 @@ fn bench(c: &mut Criterion) { None, ¶ms, snarks.clone(), + VerifierUniversality::None, ); let start0 = start_timer!(|| "gen vk & pk"); @@ -213,6 +214,7 @@ fn bench(c: &mut Criterion) { Some(break_points.clone()), params, snarks.clone(), + VerifierUniversality::None, ); let instances = agg_circuit.instances(); gen_proof_shplonk(params, pk, agg_circuit, instances, None) @@ -230,6 +232,7 @@ fn bench(c: &mut Criterion) { Some(break_points), ¶ms, snarks.clone(), + VerifierUniversality::None, ); let num_instances = agg_circuit.num_instance(); let instances = agg_circuit.instances(); diff --git a/snark-verifier-sdk/examples/n_as_witness.rs b/snark-verifier-sdk/examples/n_as_witness.rs new file mode 100644 index 00000000..3ee25f23 --- /dev/null +++ b/snark-verifier-sdk/examples/n_as_witness.rs @@ -0,0 +1,188 @@ +use halo2_base::gates::builder::CircuitBuilderStage; +use halo2_base::halo2_proofs; +use halo2_base::halo2_proofs::arithmetic::Field; +use halo2_base::halo2_proofs::halo2curves::bn256::Fr; +use halo2_base::halo2_proofs::poly::commitment::Params; +use halo2_base::utils::fs::gen_srs; +use halo2_proofs::halo2curves as halo2_curves; + +use rand::rngs::StdRng; +use rand::SeedableRng; +use snark_verifier_sdk::halo2::aggregation::{AggregationConfigParams, VerifierUniversality}; +use snark_verifier_sdk::SHPLONK; +use snark_verifier_sdk::{ + gen_pk, + halo2::{aggregation::AggregationCircuit, gen_snark_shplonk}, + Snark, +}; + +mod application { + use super::halo2_curves::bn256::Fr; + use super::halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, + poly::Rotation, + }; + + use snark_verifier_sdk::CircuitExt; + + #[derive(Clone, Copy)] + pub struct StandardPlonkConfig { + a: Column, + b: Column, + c: Column, + q_a: Column, + q_b: Column, + q_c: Column, + q_ab: Column, + constant: Column, + #[allow(dead_code)] + instance: Column, + } + + impl StandardPlonkConfig { + fn configure(meta: &mut ConstraintSystem) -> Self { + let [a, b, c] = [(); 3].map(|_| meta.advice_column()); + let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); + let instance = meta.instance_column(); + + [a, b, c].map(|column| meta.enable_equality(column)); + + meta.create_gate( + "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", + |meta| { + let [a, b, c] = + [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); + let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] + .map(|column| meta.query_fixed(column, Rotation::cur())); + let instance = meta.query_instance(instance, Rotation::cur()); + Some( + q_a * a.clone() + + q_b * b.clone() + + q_c * c + + q_ab * a * b + + constant + + instance, + ) + }, + ); + + StandardPlonkConfig { a, b, c, q_a, q_b, q_c, q_ab, constant, instance } + } + } + + #[derive(Clone)] + pub struct StandardPlonk(pub Fr, pub usize); + + impl CircuitExt for StandardPlonk { + fn num_instance(&self) -> Vec { + vec![1] + } + + fn instances(&self) -> Vec> { + vec![vec![self.0]] + } + } + + impl Circuit for StandardPlonk { + type Config = StandardPlonkConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self(Fr::zero(), self.1) + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + meta.set_minimum_degree(4); + StandardPlonkConfig::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "", + |mut region| { + region.assign_advice(config.a, 0, Value::known(self.0)); + region.assign_fixed(config.q_a, 0, -Fr::one()); + region.assign_advice(config.a, 1, Value::known(-Fr::from(5u64))); + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed(column, 1, Fr::from(idx as u64)); + } + let a = region.assign_advice(config.a, 2, Value::known(Fr::one())); + a.copy_advice(&mut region, config.b, 3); + a.copy_advice(&mut region, config.c, 4); + + // assuming <= 10 blinding factors + // fill in most of circuit with a computation + let n = self.1; + for offset in 5..n - 10 { + region.assign_advice(config.a, offset, Value::known(-Fr::from(5u64))); + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed(column, offset, Fr::from(idx as u64)); + } + } + + Ok(()) + }, + ) + } + } +} + +fn gen_application_snark(k: u32) -> Snark { + let rng = StdRng::seed_from_u64(0); + let params = gen_srs(k); + let circuit = application::StandardPlonk(Fr::random(rng), params.n() as usize); + + let pk = gen_pk(¶ms, &circuit, None); + gen_snark_shplonk(¶ms, &pk, circuit, None::<&str>) +} + +fn main() { + let dummy_snark = gen_application_snark(8); + + let k = 15u32; + let params = gen_srs(k); + let lookup_bits = k as usize - 1; + let mut agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Keygen, + AggregationConfigParams { degree: k, lookup_bits, ..Default::default() }, + None, + ¶ms, + vec![dummy_snark], + VerifierUniversality::Full, + ); + let agg_config = agg_circuit.config(Some(10)); + + let pk = gen_pk(¶ms, &agg_circuit, None); + let break_points = agg_circuit.break_points(); + + let snarks = [8, 12, 15, 20].map(|k| (k, gen_application_snark(k))); + for (k, snark) in snarks { + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Prover, + agg_config, + Some(break_points.clone()), + ¶ms, + vec![snark], + VerifierUniversality::Full, + ); + let _snark = gen_snark_shplonk(¶ms, &pk, agg_circuit, None::<&str>); + println!("snark with k = {k} success"); + } +} diff --git a/snark-verifier-sdk/examples/range_check.rs b/snark-verifier-sdk/examples/range_check.rs new file mode 100644 index 00000000..c9200df7 --- /dev/null +++ b/snark-verifier-sdk/examples/range_check.rs @@ -0,0 +1,90 @@ +use ark_std::{end_timer, start_timer}; +use halo2_base::gates::builder::{ + BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, RangeWithInstanceCircuitBuilder, +}; +use halo2_base::gates::flex_gate::GateStrategy; +use halo2_base::halo2_proofs::halo2curves::bn256::Fr; +use halo2_base::halo2_proofs::plonk::Circuit; +use halo2_base::safe_types::{GateInstructions, RangeChip, RangeInstructions}; +use halo2_base::utils::fs::gen_srs; + +use itertools::Itertools; +use snark_verifier_sdk::halo2::aggregation::{AggregationConfigParams, VerifierUniversality}; +use snark_verifier_sdk::SHPLONK; +use snark_verifier_sdk::{ + gen_pk, + halo2::{aggregation::AggregationCircuit, gen_snark_shplonk}, + Snark, +}; + +fn generate_circuit(k: u32) -> Snark { + let mut builder = GateThreadBuilder::new(false); + let ctx = builder.main(0); + let lookup_bits = k as usize - 1; + let range = RangeChip::::default(lookup_bits); + + let x = ctx.load_witness(Fr::from(14)); + range.range_check(ctx, x, 2 * lookup_bits + 1); + range.gate().add(ctx, x, x); + + let circuit = RangeWithInstanceCircuitBuilder::::keygen( + builder.clone(), + BaseConfigParams { + strategy: GateStrategy::Vertical, + k: k as usize, + num_advice_per_phase: vec![1], + num_lookup_advice_per_phase: vec![1], + num_fixed: 1, + lookup_bits: Some(lookup_bits), + }, + vec![], + ); + let params = gen_srs(k); + + let pk = gen_pk(¶ms, &circuit, None); + let breakpoints = circuit.break_points(); + + let circuit = RangeWithInstanceCircuitBuilder::::prover( + builder.clone(), + circuit.params(), + breakpoints, + vec![], + ); + gen_snark_shplonk(¶ms, &pk, circuit, None::<&str>) +} + +fn main() { + let dummy_snark = generate_circuit(13); + + let k = 14u32; + let lookup_bits = k as usize - 1; + let params = gen_srs(k); + let mut agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Keygen, + AggregationConfigParams { degree: k, lookup_bits, ..Default::default() }, + None, + ¶ms, + vec![dummy_snark], + VerifierUniversality::Full, + ); + let agg_config = agg_circuit.config(Some(10)); + + let start0 = start_timer!(|| "gen vk & pk"); + let pk = gen_pk(¶ms, &agg_circuit, None); + end_timer!(start0); + let break_points = agg_circuit.break_points(); + + let snarks = (14..17).map(generate_circuit).collect_vec(); + for (i, snark) in snarks.into_iter().enumerate() { + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Prover, + agg_config, + Some(break_points.clone()), + ¶ms, + vec![snark], + VerifierUniversality::Full, + ); + let _snark = gen_snark_shplonk(¶ms, &pk, agg_circuit, None::<&str>); + println!("snark {i} success"); + } +} diff --git a/snark-verifier-sdk/examples/vkey_as_witness.rs b/snark-verifier-sdk/examples/vkey_as_witness.rs new file mode 100644 index 00000000..a0eb30d2 --- /dev/null +++ b/snark-verifier-sdk/examples/vkey_as_witness.rs @@ -0,0 +1,184 @@ +use application::ComputeFlag; + +use halo2_base::gates::builder::CircuitBuilderStage; +use halo2_base::halo2_proofs; +use halo2_base::halo2_proofs::arithmetic::Field; +use halo2_base::halo2_proofs::halo2curves::bn256::Fr; +use halo2_base::utils::fs::gen_srs; +use halo2_proofs::halo2curves as halo2_curves; + +use halo2_proofs::{halo2curves::bn256::Bn256, poly::kzg::commitment::ParamsKZG}; +use rand::rngs::OsRng; +use snark_verifier_sdk::halo2::aggregation::{AggregationConfigParams, VerifierUniversality}; +use snark_verifier_sdk::SHPLONK; +use snark_verifier_sdk::{ + gen_pk, + halo2::{aggregation::AggregationCircuit, gen_snark_shplonk}, + Snark, +}; + +mod application { + use super::halo2_curves::bn256::Fr; + use super::halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, + poly::Rotation, + }; + + use snark_verifier_sdk::CircuitExt; + + #[derive(Clone, Copy)] + pub struct StandardPlonkConfig { + a: Column, + b: Column, + c: Column, + q_a: Column, + q_b: Column, + q_c: Column, + q_ab: Column, + constant: Column, + #[allow(dead_code)] + instance: Column, + } + + impl StandardPlonkConfig { + fn configure(meta: &mut ConstraintSystem) -> Self { + let [a, b, c] = [(); 3].map(|_| meta.advice_column()); + let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); + let instance = meta.instance_column(); + + [a, b, c].map(|column| meta.enable_equality(column)); + + meta.create_gate( + "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", + |meta| { + let [a, b, c] = + [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); + let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] + .map(|column| meta.query_fixed(column, Rotation::cur())); + let instance = meta.query_instance(instance, Rotation::cur()); + Some( + q_a * a.clone() + + q_b * b.clone() + + q_c * c + + q_ab * a * b + + constant + + instance, + ) + }, + ); + + StandardPlonkConfig { a, b, c, q_a, q_b, q_c, q_ab, constant, instance } + } + } + + #[derive(Clone, Copy, PartialEq, Eq)] + pub enum ComputeFlag { + All, + SkipFixed, + SkipCopy, + } + + #[derive(Clone)] + pub struct StandardPlonk(pub Fr, pub ComputeFlag); + + impl CircuitExt for StandardPlonk { + fn num_instance(&self) -> Vec { + vec![1] + } + + fn instances(&self) -> Vec> { + vec![vec![self.0]] + } + } + + impl Circuit for StandardPlonk { + type Config = StandardPlonkConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self(Fr::zero(), self.1) + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + meta.set_minimum_degree(4); + StandardPlonkConfig::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "", + |mut region| { + region.assign_advice(config.a, 0, Value::known(self.0)); + region.assign_fixed(config.q_a, 0, -Fr::one()); + region.assign_advice(config.a, 1, Value::known(-Fr::from(5u64))); + if self.1 != ComputeFlag::SkipFixed { + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed(column, 1, Fr::from(idx as u64)); + } + } + let a = region.assign_advice(config.a, 2, Value::known(Fr::one())); + if self.1 != ComputeFlag::SkipCopy { + a.copy_advice(&mut region, config.b, 3); + a.copy_advice(&mut region, config.c, 4); + } + + Ok(()) + }, + ) + } + } +} + +fn gen_application_snark(params: &ParamsKZG, flag: ComputeFlag) -> Snark { + let circuit = application::StandardPlonk(Fr::random(OsRng), flag); + + let pk = gen_pk(params, &circuit, None); + gen_snark_shplonk(params, &pk, circuit, None::<&str>) +} + +fn main() { + let params_app = gen_srs(8); + let dummy_snark = gen_application_snark(¶ms_app, ComputeFlag::All); + + let k = 15u32; + let params = gen_srs(k); + let lookup_bits = k as usize - 1; + let mut agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Keygen, + AggregationConfigParams { degree: k, lookup_bits, ..Default::default() }, + None, + ¶ms, + vec![dummy_snark], + VerifierUniversality::PreprocessedAsWitness, + ); + let agg_config = agg_circuit.config(Some(10)); + + let pk = gen_pk(¶ms, &agg_circuit, None); + let break_points = agg_circuit.break_points(); + + let snarks = [ComputeFlag::All, ComputeFlag::SkipFixed, ComputeFlag::SkipCopy] + .map(|flag| gen_application_snark(¶ms_app, flag)); + for (i, snark) in snarks.into_iter().enumerate() { + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Prover, + agg_config, + Some(break_points.clone()), + ¶ms, + vec![snark], + VerifierUniversality::PreprocessedAsWitness, + ); + let _snark = gen_snark_shplonk(¶ms, &pk, agg_circuit, None::<&str>); + println!("snark {i} success"); + } +} diff --git a/snark-verifier-sdk/src/halo2.rs b/snark-verifier-sdk/src/halo2.rs index b0710230..06d5e622 100644 --- a/snark-verifier-sdk/src/halo2.rs +++ b/snark-verifier-sdk/src/halo2.rs @@ -35,6 +35,7 @@ use snark_verifier::{ AccumulationScheme, PolynomialCommitmentScheme, Query, }, system::halo2::{compile, Config}, + util::arithmetic::Rotation, util::transcript::TranscriptWrite, verifier::plonk::PlonkProof, }; @@ -122,20 +123,25 @@ where end_timer!(proof_time); // validate proof before caching - assert!({ - let mut transcript_read = - PoseidonTranscript::::from_spec(&proof[..], POSEIDON_SPEC.clone()); - VerificationStrategy::<_, V>::finalize( - verify_proof::<_, V, _, _, _>( - params.verifier_params(), - pk.get_vk(), - AccumulatorStrategy::new(params.verifier_params()), - &[instances.as_slice()], - &mut transcript_read, + assert!( + { + let mut transcript_read = PoseidonTranscript::::from_spec( + &proof[..], + POSEIDON_SPEC.clone(), + ); + VerificationStrategy::<_, V>::finalize( + verify_proof::<_, V, _, _, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript_read, + ) + .unwrap(), ) - .unwrap(), - ) - }); + }, + "SNARK proof failed to verify" + ); if let Some((instance_path, proof_path)) = path { write_instances(&instances, instance_path); @@ -286,7 +292,7 @@ where NativeLoader, Accumulator = KzgAccumulator, VerifyingKey = KzgAsVerifyingKey, - > + CostEstimation>>, + > + CostEstimation>>, { struct CsProxy(PhantomData<(F, C)>); diff --git a/snark-verifier-sdk/src/halo2/aggregation.rs b/snark-verifier-sdk/src/halo2/aggregation.rs index 2d093f7b..0063a599 100644 --- a/snark-verifier-sdk/src/halo2/aggregation.rs +++ b/snark-verifier-sdk/src/halo2/aggregation.rs @@ -43,6 +43,37 @@ pub type Svk = KzgSuccinctVerifyingKey; pub type BaseFieldEccChip<'chip> = halo2_ecc::ecc::BaseFieldEccChip<'chip, G1Affine>; pub type Halo2Loader<'chip> = loader::halo2::Halo2Loader>; +pub struct SnarkAggregationWitness<'a> { + pub previous_instances: Vec>>, + pub accumulator: KzgAccumulator>>, + /// This returns the assigned `preprocessed` and `transcript_initial_state` values as a vector of assigned values, one for each aggregated snark. + /// These can then be exposed as public instances. + /// + /// This is only useful if preprocessed digest is loaded as witness (i.e., `preprocessed_as_witness` is true in `aggregate`), so we set it to `None` otherwise. + pub preprocessed_digests: Option>>>, +} + +/// Different possible stages of universality the aggregation circuit can support +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub enum VerifierUniversality { + /// Default: verifier is specific to a single circuit + None, + /// Preprocessed digest (commitments to fixed columns) is loaded as witness + PreprocessedAsWitness, + /// Preprocessed as witness and number of rows in the circuit `n` loaded as witness + Full, +} + +impl VerifierUniversality { + pub fn preprocessed_as_witness(&self) -> bool { + self != &VerifierUniversality::None + } + + pub fn n_as_witness(&self) -> bool { + self == &VerifierUniversality::Full + } +} + #[allow(clippy::type_complexity)] /// Core function used in `synthesize` to aggregate multiple `snarks`. /// @@ -50,6 +81,9 @@ pub type Halo2Loader<'chip> = loader::halo2::Halo2Loader( @@ -57,7 +91,8 @@ pub fn aggregate<'a, AS>( loader: &Rc>, snarks: &[Snark], as_proof: &[u8], -) -> (Vec>>, KzgAccumulator>>) + universality: VerifierUniversality, +) -> SnarkAggregationWitness<'a> where AS: PolynomialCommitmentScheme< G1Affine, @@ -82,6 +117,7 @@ where }; let mut previous_instances = Vec::with_capacity(snarks.len()); + let mut preprocessed_digests = Vec::with_capacity(snarks.len()); // to avoid re-loading the spec each time, we create one transcript and clear the stream let mut transcript = PoseidonTranscript::>, &[u8]>::from_spec( loader, @@ -89,10 +125,42 @@ where POSEIDON_SPEC.clone(), ); + let preprocessed_as_witness = universality.preprocessed_as_witness(); let mut accumulators = snarks .iter() - .flat_map(|snark| { - let protocol = snark.protocol.loaded(loader); + .flat_map(|snark: &Snark| { + let protocol = if preprocessed_as_witness { + // always load `domain.n` as witness if vkey is witness + snark.protocol.loaded_preprocessed_as_witness(loader, universality.n_as_witness()) + } else { + snark.protocol.loaded(loader) + }; + let inputs = protocol + .preprocessed + .iter() + .flat_map(|preprocessed| { + let assigned = preprocessed.assigned(); + [assigned.x(), assigned.y()] + .into_iter() + .flat_map(|coordinate| coordinate.limbs().to_vec()) + .collect_vec() + }) + .chain( + protocol.transcript_initial_state.clone().map(|scalar| scalar.into_assigned()), + ) + .chain( + protocol + .domain_as_witness + .as_ref() + .map(|domain| domain.n.clone().into_assigned()), + ) // If `n` is witness, add it as part of input + .chain( + protocol + .domain_as_witness + .as_ref() + .map(|domain| domain.gen.clone().into_assigned()), + ) // If `n` is witness, add the generator of the order `n` subgroup as part of input + .collect_vec(); let instances = assign_instances(&snark.instances); // read the transcript and perform Fiat-Shamir @@ -111,6 +179,7 @@ where previous_instances.push( instances.into_iter().flatten().map(|scalar| scalar.into_assigned()).collect(), ); + preprocessed_digests.push(inputs); accumulator }) @@ -129,8 +198,9 @@ where } else { accumulators.pop().unwrap() }; + let preprocessed_digests = preprocessed_as_witness.then_some(preprocessed_digests); - (previous_instances, accumulator) + SnarkAggregationWitness { previous_instances, accumulator, preprocessed_digests } } /// Same as `FlexGateConfigParams` except we assume a single Phase and default 'Vertical' strategy. @@ -187,6 +257,14 @@ impl TryFrom<&BaseConfigParams> for AggregationConfigParams { } } +impl TryFrom for AggregationConfigParams { + type Error = &'static str; + + fn try_from(value: BaseConfigParams) -> Result { + Self::try_from(&value) + } +} + /// Holds virtual contexts for the cells used to verify a collection of snarks #[derive(Clone, Debug)] pub struct AggregationCtxBuilder { @@ -196,6 +274,11 @@ pub struct AggregationCtxBuilder { pub accumulator: Vec>, // the public instances from previous snarks that were aggregated pub previous_instances: Vec>>, + /// This returns the assigned `preprocessed_digest` (vkey), optional `transcript_initial_state`, `domain.n` (optional), and `omega` (optional) values as a vector of assigned values, one for each aggregated snark. + /// These can then be exposed as public instances. + /// + /// This is only useful if preprocessed digest is loaded as witness (i.e., `universality != None`), so we set it to `None` if `universality == None`. + pub preprocessed_digests: Option>>>, } #[derive(Clone, Debug)] @@ -204,6 +287,11 @@ pub struct AggregationCircuit { // the public instances from previous snarks that were aggregated, now collected as PRIVATE assigned values // the user can optionally append these to `inner.assigned_instances` to expose them pub previous_instances: Vec>>, + /// This returns the assigned `preprocessed_digest` (vkey), optional `transcript_initial_state`, `domain.n` (optional), and `omega` (optional) values as a vector of assigned values, one for each aggregated snark. + /// These can then be exposed as public instances. + /// + /// This is only useful if preprocessed digest is loaded as witness (i.e., `universality != None`), so we set it to `None` if `universality == None`. + pub preprocessed_digests: Option>>>, // accumulation scheme proof, private input // pub as_proof: Vec, } @@ -236,13 +324,21 @@ impl AggregationCtxBuilder { /// /// Also returns the limbs of the pair of elliptic curve points, referred to as the `accumulator`, that need to be verified in a final pairing check. /// - /// Warning: will fail silently if `snarks` were created using a different multi-open scheme than `AS` + /// # Universality + /// - If `universality` is not `None`, then the verifying keys of each snark in `snarks` is loaded as a witness in the circuit. + /// - Moreover, if `universality` is `Full`, then the number of rows `n` of each snark in `snarks` is also loaded as a witness. In this case the generator `omega` of the order `n` multiplicative subgroup of `F` is also loaded as a witness. + /// - By default, these witnesses are _private_ and returned in `self.preprocessed_ + /// - The user can optionally modify the circuit after calling this function to add more instances to `assigned_instances` to expose. + /// + /// # Warning + /// Will fail silently if `snarks` were created using a different multi-open scheme than `AS` /// where `AS` can be either [`crate::SHPLONK`] or [`crate::GWC`] (for original PLONK multi-open scheme) pub fn new( witness_gen_only: bool, lookup_bits: usize, params: &ParamsKZG, snarks: impl IntoIterator, + universality: VerifierUniversality, ) -> Self where AS: for<'a> Halo2KzgAccumulationScheme<'a>, @@ -289,8 +385,8 @@ impl AggregationCtxBuilder { let ecc_chip = BaseFieldEccChip::new(&fp_chip); let loader = Halo2Loader::new(ecc_chip, builder); - let (previous_instances, accumulator) = - aggregate::(&svk, &loader, &snarks, as_proof.as_slice()); + let SnarkAggregationWitness { previous_instances, accumulator, preprocessed_digests } = + aggregate::(&svk, &loader, &snarks, as_proof.as_slice(), universality); let lhs = accumulator.lhs.assigned(); let rhs = accumulator.rhs.assigned(); let accumulator = lhs @@ -314,7 +410,7 @@ impl AggregationCtxBuilder { } let builder = loader.take_ctx(); - Self { builder, accumulator, previous_instances } + Self { builder, accumulator, previous_instances, preprocessed_digests } } } @@ -322,6 +418,8 @@ impl AggregationCircuit { /// Given snarks, this creates a circuit and runs the `GateThreadBuilder` to verify all the snarks. /// By default, the returned circuit has public instances equal to the limbs of the pair of elliptic curve points, referred to as the `accumulator`, that need to be verified in a final pairing check. /// + /// See [`AggregationCtxBuilder`] for more details. + /// /// The user can optionally modify the circuit after calling this function to add more instances to `assigned_instances` to expose. /// /// Warning: will fail silently if `snarks` were created using a different multi-open scheme than `AS` @@ -332,17 +430,23 @@ impl AggregationCircuit { break_points: Option, params: &ParamsKZG, snarks: impl IntoIterator, + universality: VerifierUniversality, ) -> Self where AS: for<'a> Halo2KzgAccumulationScheme<'a>, { - let AggregationCtxBuilder { builder, accumulator, previous_instances } = - AggregationCtxBuilder::new::( - stage == CircuitBuilderStage::Prover, - agg_config.lookup_bits, - params, - snarks, - ); + let AggregationCtxBuilder { + builder, + accumulator, + previous_instances, + preprocessed_digests, + } = AggregationCtxBuilder::new::( + stage == CircuitBuilderStage::Prover, + agg_config.lookup_bits, + params, + snarks, + universality, + ); let inner = RangeWithInstanceCircuitBuilder::from_stage( stage, builder, @@ -350,7 +454,7 @@ impl AggregationCircuit { break_points, accumulator, ); - Self { inner, previous_instances } + Self { inner, previous_instances, preprocessed_digests } } pub fn public( @@ -364,7 +468,14 @@ impl AggregationCircuit { where AS: for<'a> Halo2KzgAccumulationScheme<'a>, { - let mut private = Self::new::(stage, agg_config, break_points, params, snarks); + let mut private = Self::new::( + stage, + agg_config, + break_points, + params, + snarks, + VerifierUniversality::None, + ); private.expose_previous_instances(has_prev_accumulator); private } @@ -380,11 +491,8 @@ impl AggregationCircuit { } /// Auto-configure the circuit and change the circuit's internal configuration parameters. - pub fn config(&mut self, k: u32, minimum_rows: Option) -> BaseConfigParams { - let mut new_config = self.inner.circuit.0.builder.borrow().config(k as usize, minimum_rows); - new_config.lookup_bits = self.inner.circuit.0.config_params.lookup_bits; - self.inner.circuit.0.config_params = new_config.clone(); - new_config + pub fn config(&mut self, minimum_rows: Option) -> AggregationConfigParams { + self.inner.config(minimum_rows).try_into().unwrap() } pub fn break_points(&self) -> MultiPhaseThreadBreakPoints { diff --git a/snark-verifier/Cargo.toml b/snark-verifier/Cargo.toml index 5606df37..ebbf283b 100644 --- a/snark-verifier/Cargo.toml +++ b/snark-verifier/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "snark-verifier" -version = "0.1.2" +version = "0.1.3" edition = "2021" [dependencies] diff --git a/snark-verifier/examples/recursion.rs b/snark-verifier/examples/recursion.rs index b469a51a..9c9b169a 100644 --- a/snark-verifier/examples/recursion.rs +++ b/snark-verifier/examples/recursion.rs @@ -373,7 +373,7 @@ mod recursion { ) -> (Vec>>, Vec>>>) { let protocol = if let Some(preprocessed_digest) = preprocessed_digest { let preprocessed_digest = loader.scalar_from_assigned(preprocessed_digest); - let protocol = snark.protocol.loaded_preprocessed_as_witness(loader); + let protocol = snark.protocol.loaded_preprocessed_as_witness(loader, false); let inputs = protocol .preprocessed .iter() diff --git a/snark-verifier/src/loader.rs b/snark-verifier/src/loader.rs index a3637f08..26450e1a 100644 --- a/snark-verifier/src/loader.rs +++ b/snark-verifier/src/loader.rs @@ -67,6 +67,12 @@ pub trait LoadedScalar: Clone + Debug + PartialEq + FieldOps { acc } + /// Returns power to exponent, where exponent is also [`LoadedScalar`]. + /// If `Loader` is for Halo2, then `exp` must have at most `exp_max_bits` bits (otherwise constraints will fail). + /// + /// Currently **unimplemented** for EvmLoader + fn pow_var(&self, exp: &Self, exp_max_bits: usize) -> Self; + /// Returns powers up to exponent `n-1`. fn powers(&self, n: usize) -> Vec { iter::once(self.loader().load_one()) diff --git a/snark-verifier/src/loader/evm/loader.rs b/snark-verifier/src/loader/evm/loader.rs index bfb37c8c..ba304f2b 100644 --- a/snark-verifier/src/loader/evm/loader.rs +++ b/snark-verifier/src/loader/evm/loader.rs @@ -632,6 +632,10 @@ impl> LoadedScalar for Scalar { fn loader(&self) -> &Self::Loader { &self.loader } + + fn pow_var(&self, _exp: &Self, _exp_max_bits: usize) -> Self { + todo!() + } } impl EcPointLoader for Rc diff --git a/snark-verifier/src/loader/halo2/loader.rs b/snark-verifier/src/loader/halo2/loader.rs index 105972c0..f8e1da7d 100644 --- a/snark-verifier/src/loader/halo2/loader.rs +++ b/snark-verifier/src/loader/halo2/loader.rs @@ -306,6 +306,14 @@ impl> LoadedScalar for Sc fn loader(&self) -> &Self::Loader { &self.loader } + + fn pow_var(&self, exp: &Self, max_bits: usize) -> Self { + let loader = self.loader(); + let base = self.clone().into_assigned(); + let exp = exp.clone().into_assigned(); + let res = loader.scalar_chip().pow_var(&mut loader.ctx_mut(), &base, &exp, max_bits); + loader.scalar_from_assigned(res) + } } impl> Debug for Scalar { diff --git a/snark-verifier/src/loader/halo2/shim.rs b/snark-verifier/src/loader/halo2/shim.rs index 9d010d2b..49bbad41 100644 --- a/snark-verifier/src/loader/halo2/shim.rs +++ b/snark-verifier/src/loader/halo2/shim.rs @@ -65,6 +65,15 @@ pub trait IntegerInstructions: Clone + Debug { lhs: &Self::AssignedInteger, rhs: &Self::AssignedInteger, ); + + /// Returns `base^exponent` and constrains that `exponent` has at most `max_bits` bits. + fn pow_var( + &self, + ctx: &mut Self::Context, + base: &Self::AssignedInteger, + exponent: &Self::AssignedInteger, + max_bits: usize, + ) -> Self::AssignedInteger; } /// Instructions to handle elliptic curve point operations. @@ -233,6 +242,16 @@ mod halo2_lib { ) { ctx.main(0).constrain_equal(a, b); } + + fn pow_var( + &self, + ctx: &mut Self::Context, + base: &Self::AssignedInteger, + exponent: &Self::AssignedInteger, + max_bits: usize, + ) -> Self::AssignedInteger { + GateInstructions::pow_var(self, ctx.main(0), *base, *exponent, max_bits) + } } impl<'chip, C: CurveAffineExt> EccInstructions for BaseFieldEccChip<'chip, C> diff --git a/snark-verifier/src/loader/native.rs b/snark-verifier/src/loader/native.rs index 783aaa89..a9aa86ff 100644 --- a/snark-verifier/src/loader/native.rs +++ b/snark-verifier/src/loader/native.rs @@ -2,7 +2,7 @@ use crate::{ loader::{EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, - util::arithmetic::{Curve, CurveAffine, FieldOps, PrimeField}, + util::arithmetic::{fe_to_big, Curve, CurveAffine, FieldOps, PrimeField}, Error, }; use lazy_static::lazy_static; @@ -38,6 +38,11 @@ impl LoadedScalar for F { fn loader(&self) -> &NativeLoader { &LOADER } + + fn pow_var(&self, exp: &Self, _: usize) -> Self { + let exp = fe_to_big(*exp).to_u64_digits(); + self.pow_vartime(exp) + } } impl EcPointLoader for NativeLoader { diff --git a/snark-verifier/src/pcs.rs b/snark-verifier/src/pcs.rs index 65b1325b..1ca9eedc 100644 --- a/snark-verifier/src/pcs.rs +++ b/snark-verifier/src/pcs.rs @@ -3,7 +3,7 @@ use crate::{ loader::{native::NativeLoader, Loader}, util::{ - arithmetic::{CurveAffine, PrimeField}, + arithmetic::{CurveAffine, Rotation}, msm::Msm, transcript::{TranscriptRead, TranscriptWrite}, }, @@ -18,24 +18,26 @@ pub mod kzg; /// Query to an oracle. /// It assumes all queries are based on the same point, but with some `shift`. #[derive(Clone, Debug)] -pub struct Query { +pub struct Query { /// Index of polynomial to query pub poly: usize, /// Shift of the query point. - pub shift: F, + pub shift: S, + /// Shift loaded as either constant or witness. It is user's job to ensure this is correctly constrained to have value equal to `shift` + pub loaded_shift: T, /// Evaluation read from transcript. pub eval: T, } -impl Query { +impl Query { /// Initialize [`Query`] without evaluation. - pub fn new(poly: usize, shift: F) -> Self { - Self { poly, shift, eval: () } + pub fn new(poly: usize, shift: S) -> Self { + Self { poly, shift, loaded_shift: (), eval: () } } - /// Returns [`Query`] with evaluation. - pub fn with_evaluation(self, eval: T) -> Query { - Query { poly: self.poly, shift: self.shift, eval } + /// Returns [`Query`] with evaluation and optionally the shift are loaded as. + pub fn with_evaluation(self, loaded_shift: T, eval: T) -> Query { + Query { poly: self.poly, shift: self.shift, loaded_shift, eval } } } @@ -55,7 +57,7 @@ where /// Read [`PolynomialCommitmentScheme::Proof`] from transcript. fn read_proof( vk: &Self::VerifyingKey, - queries: &[Query], + queries: &[Query], transcript: &mut T, ) -> Result where @@ -66,7 +68,7 @@ where vk: &Self::VerifyingKey, commitments: &[Msm], point: &L::LoadedScalar, - queries: &[Query], + queries: &[Query], proof: &Self::Proof, ) -> Result; } diff --git a/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs b/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs index 538aa4fb..e2cb87ab 100644 --- a/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs +++ b/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs @@ -5,7 +5,7 @@ use crate::{ PolynomialCommitmentScheme, Query, }, util::{ - arithmetic::{CurveAffine, Fraction, PrimeField}, + arithmetic::{CurveAffine, Fraction, PrimeField, Rotation}, msm::Msm, transcript::TranscriptRead, Itertools, @@ -35,7 +35,7 @@ where fn read_proof( svk: &Self::VerifyingKey, - queries: &[Query], + queries: &[Query], transcript: &mut T, ) -> Result where @@ -48,7 +48,7 @@ where svk: &Self::VerifyingKey, commitments: &[Msm], x: &L::LoadedScalar, - queries: &[Query], + queries: &[Query], proof: &Self::Proof, ) -> Result { let loader = x.loader(); @@ -119,7 +119,7 @@ where { fn read>( svk: &IpaSuccinctVerifyingKey, - queries: &[Query], + queries: &[Query], transcript: &mut T, ) -> Result { // Multiopen @@ -157,28 +157,33 @@ where } } -fn query_sets(queries: &[Query]) -> Vec> +fn query_sets(queries: &[Query]) -> Vec> where - F: PrimeField + Ord, + S: PartialEq + Ord + Copy, T: Clone, { let poly_shifts = - queries.iter().fold(Vec::<(usize, Vec, Vec<&T>)>::new(), |mut poly_shifts, query| { + queries.iter().fold(Vec::<(usize, Vec<_>, Vec<&T>)>::new(), |mut poly_shifts, query| { if let Some(pos) = poly_shifts.iter().position(|(poly, _, _)| *poly == query.poly) { let (_, shifts, evals) = &mut poly_shifts[pos]; - if !shifts.contains(&query.shift) { - shifts.push(query.shift); + if !shifts.iter().map(|(shift, _)| shift).contains(&query.shift) { + shifts.push((query.shift, query.loaded_shift.clone())); evals.push(&query.eval); } } else { - poly_shifts.push((query.poly, vec![query.shift], vec![&query.eval])); + poly_shifts.push(( + query.poly, + vec![(query.shift, query.loaded_shift.clone())], + vec![&query.eval], + )); } poly_shifts }); - poly_shifts.into_iter().fold(Vec::>::new(), |mut sets, (poly, shifts, evals)| { + poly_shifts.into_iter().fold(Vec::>::new(), |mut sets, (poly, shifts, evals)| { if let Some(pos) = sets.iter().position(|set| { - BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) + BTreeSet::from_iter(set.shifts.iter().map(|(shift, _)| shift)) + == BTreeSet::from_iter(shifts.iter().map(|(shift, _)| shift)) }) { let set = &mut sets[pos]; if !set.polys.contains(&poly) { @@ -187,7 +192,7 @@ where set.shifts .iter() .map(|lhs| { - let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); + let idx = shifts.iter().position(|rhs| lhs.0 == rhs.0).unwrap(); evals[idx] }) .collect(), @@ -201,18 +206,23 @@ where }) } -fn query_set_coeffs(sets: &[QuerySet], x: &T, x_3: &T) -> Vec> +fn query_set_coeffs( + sets: &[QuerySet], + x: &T, + x_3: &T, +) -> Vec> where F: PrimeField + Ord, T: LoadedScalar, { - let loader = x.loader(); - let superset = sets.iter().flat_map(|set| set.shifts.clone()).sorted().dedup(); + let superset = BTreeMap::from_iter(sets.iter().flat_map(|set| set.shifts.clone())); let size = sets.iter().map(|set| set.shifts.len()).chain(Some(2)).max().unwrap(); let powers_of_x = x.powers(size); let x_3_minus_x_shift_i = BTreeMap::from_iter( - superset.map(|shift| (shift, x_3.clone() - x.clone() * loader.load_const(&shift))), + superset + .into_iter() + .map(|(shift, loaded_shift)| (shift, x_3.clone() - x.clone() * loaded_shift)), ); let mut coeffs = sets @@ -228,23 +238,22 @@ where } #[derive(Clone, Debug)] -struct QuerySet<'a, F, T> { - shifts: Vec, +struct QuerySet<'a, S, T> { + shifts: Vec<(S, T)>, polys: Vec, evals: Vec>, } -impl<'a, F, T> QuerySet<'a, F, T> -where - F: PrimeField, - T: LoadedScalar, -{ +impl<'a, S, T> QuerySet<'a, S, T> { fn msm>( &self, commitments: &[Msm<'a, C, L>], q_eval: &T, powers_of_x_1: &[T], - ) -> Msm { + ) -> Msm + where + T: LoadedScalar, + { self.polys .iter() .rev() @@ -254,7 +263,15 @@ where - Msm::constant(q_eval.clone()) } - fn f_eval(&self, coeff: &QuerySetCoeff, q_eval: &T, powers_of_x_1: &[T]) -> T { + fn f_eval( + &self, + coeff: &QuerySetCoeff, + q_eval: &T, + powers_of_x_1: &[T], + ) -> T + where + T: LoadedScalar, + { let loader = q_eval.loader(); let r_eval = { let r_evals = self @@ -291,7 +308,12 @@ where F: PrimeField + Ord, T: LoadedScalar, { - fn new(shifts: &[F], powers_of_x: &[T], x_3: &T, x_3_minus_x_shift_i: &BTreeMap) -> Self { + fn new( + shifts: &[(Rotation, T)], + powers_of_x: &[T], + x_3: &T, + x_3_minus_x_shift_i: &BTreeMap, + ) -> Self { let loader = x_3.loader(); let normalized_ell_primes = shifts .iter() @@ -301,9 +323,9 @@ where .iter() .enumerate() .filter(|&(i, _)| i != j) - .map(|(_, shift_i)| (*shift_j - shift_i)) + .map(|(_, shift_i)| (shift_j.1.clone() - &shift_i.1)) .reduce(|acc, value| acc * value) - .unwrap_or(F::ONE) + .unwrap_or_else(|| loader.load_const(&F::ONE)) }) .collect_vec(); @@ -313,17 +335,15 @@ where let barycentric_weights = shifts .iter() .zip(normalized_ell_primes.iter()) - .map(|(shift, normalized_ell_prime)| { - loader.sum_products_with_coeff(&[ - (*normalized_ell_prime, x_pow_k_minus_one, x_3), - (-(*normalized_ell_prime * shift), x_pow_k_minus_one, x), - ]) + .map(|((_, loaded_shift), normalized_ell_prime)| { + let tmp = normalized_ell_prime.clone() * x_pow_k_minus_one; + loader.sum_products(&[(&tmp, x_3), (&-(tmp.clone() * loaded_shift), x)]) }) .map(Fraction::one_over) .collect_vec(); let f_eval_coeff = Fraction::one_over(loader.product( - &shifts.iter().map(|shift| x_3_minus_x_shift_i.get(shift).unwrap()).collect_vec(), + &shifts.iter().map(|(shift, _)| x_3_minus_x_shift_i.get(shift).unwrap()).collect_vec(), )); Self { diff --git a/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs b/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs index d1398ebb..e13e09cc 100644 --- a/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs +++ b/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs @@ -6,7 +6,7 @@ use crate::{ PolynomialCommitmentScheme, Query, }, util::{ - arithmetic::{CurveAffine, Fraction, MultiMillerLoop, PrimeField}, + arithmetic::{CurveAffine, Fraction, MultiMillerLoop, PrimeField, Rotation}, msm::Msm, transcript::TranscriptRead, Itertools, @@ -37,7 +37,7 @@ where fn read_proof( _: &KzgSuccinctVerifyingKey, - _: &[Query], + _: &[Query], transcript: &mut T, ) -> Result, Error> where @@ -50,7 +50,7 @@ where svk: &KzgSuccinctVerifyingKey, commitments: &[Msm], z: &L::LoadedScalar, - queries: &[Query], + queries: &[Query], proof: &Bdfg21Proof, ) -> Result { let sets = query_sets(queries); @@ -106,24 +106,29 @@ where } } -fn query_sets(queries: &[Query]) -> Vec> { +fn query_sets(queries: &[Query]) -> Vec> { let poly_shifts = - queries.iter().fold(Vec::<(usize, Vec, Vec<&T>)>::new(), |mut poly_shifts, query| { + queries.iter().fold(Vec::<(usize, Vec<_>, Vec<&T>)>::new(), |mut poly_shifts, query| { if let Some(pos) = poly_shifts.iter().position(|(poly, _, _)| *poly == query.poly) { let (_, shifts, evals) = &mut poly_shifts[pos]; - if !shifts.contains(&query.shift) { - shifts.push(query.shift); + if !shifts.iter().map(|(shift, _)| shift).contains(&query.shift) { + shifts.push((query.shift, query.loaded_shift.clone())); evals.push(&query.eval); } } else { - poly_shifts.push((query.poly, vec![query.shift], vec![&query.eval])); + poly_shifts.push(( + query.poly, + vec![(query.shift, query.loaded_shift.clone())], + vec![&query.eval], + )); } poly_shifts }); - poly_shifts.into_iter().fold(Vec::>::new(), |mut sets, (poly, shifts, evals)| { + poly_shifts.into_iter().fold(Vec::>::new(), |mut sets, (poly, shifts, evals)| { if let Some(pos) = sets.iter().position(|set| { - BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) + BTreeSet::from_iter(set.shifts.iter().map(|(shift, _)| shift)) + == BTreeSet::from_iter(shifts.iter().map(|(shift, _)| shift)) }) { let set = &mut sets[pos]; if !set.polys.contains(&poly) { @@ -132,7 +137,7 @@ fn query_sets(queries: &[Query]) -> Vec(queries: &[Query]) -> Vec>( - sets: &[QuerySet], + sets: &[QuerySet], z: &T, z_prime: &T, ) -> Vec> { - let loader = z.loader(); - - let superset = sets.iter().flat_map(|set| set.shifts.clone()).sorted().dedup(); + // map of shift => loaded_shift, removing duplicate `shift` values + // shift is the rotation, not omega^rotation, to ensure BTreeMap does not depend on omega (otherwise ordering can change) + let superset = BTreeMap::from_iter(sets.iter().flat_map(|set| set.shifts.clone())); let size = sets.iter().map(|set| set.shifts.len()).chain(Some(2)).max().unwrap(); let powers_of_z = z.powers(size); let z_prime_minus_z_shift_i = BTreeMap::from_iter( - superset.map(|shift| (shift, z_prime.clone() - z.clone() * loader.load_const(&shift))), + superset + .into_iter() + .map(|(shift, loaded_shift)| (shift, z_prime.clone() - z.clone() * loaded_shift)), ); let mut z_s_1 = None; @@ -187,19 +194,22 @@ fn query_set_coeffs>( } #[derive(Clone, Debug)] -struct QuerySet<'a, F, T> { - shifts: Vec, +struct QuerySet<'a, S, T> { + shifts: Vec<(S, T)>, // vec of (shift, loaded_shift) polys: Vec, evals: Vec>, } -impl<'a, F: PrimeField, T: LoadedScalar> QuerySet<'a, F, T> { +impl<'a, S, T> QuerySet<'a, S, T> { fn msm>( &self, - coeff: &QuerySetCoeff, + coeff: &QuerySetCoeff, commitments: &[Msm<'a, C, L>], powers_of_mu: &[T], - ) -> Msm { + ) -> Msm + where + T: LoadedScalar, + { self.polys .iter() .zip(self.evals.iter()) @@ -242,10 +252,10 @@ where T: LoadedScalar, { fn new( - shifts: &[F], + shifts: &[(Rotation, T)], powers_of_z: &[T], z_prime: &T, - z_prime_minus_z_shift_i: &BTreeMap, + z_prime_minus_z_shift_i: &BTreeMap, z_s_1: &Option, ) -> Self { let loader = z_prime.loader(); @@ -258,9 +268,9 @@ where .iter() .enumerate() .filter(|&(i, _)| i != j) - .map(|(_, shift_i)| (*shift_j - shift_i)) + .map(|(_, shift_i)| (shift_j.1.clone() - &shift_i.1)) .reduce(|acc, value| acc * value) - .unwrap_or(F::ONE) + .unwrap_or_else(|| loader.load_const(&F::ONE)) }) .collect_vec(); @@ -270,17 +280,18 @@ where let barycentric_weights = shifts .iter() .zip(normalized_ell_primes.iter()) - .map(|(shift, normalized_ell_prime)| { - loader.sum_products_with_coeff(&[ - (*normalized_ell_prime, z_pow_k_minus_one, z_prime), - (-(*normalized_ell_prime * shift), z_pow_k_minus_one, z), - ]) + .map(|((_, loaded_shift), normalized_ell_prime)| { + let tmp = normalized_ell_prime.clone() * z_pow_k_minus_one; + loader.sum_products(&[(&tmp, z_prime), (&-(tmp.clone() * loaded_shift), z)]) }) .map(Fraction::one_over) .collect_vec(); let z_s = loader.product( - &shifts.iter().map(|shift| z_prime_minus_z_shift_i.get(shift).unwrap()).collect_vec(), + &shifts + .iter() + .map(|(shift, _)| z_prime_minus_z_shift_i.get(shift).unwrap()) + .collect_vec(), ); let z_s_1_over_z_s = z_s_1.clone().map(|z_s_1| Fraction::new(z_s_1, z_s.clone())); @@ -330,9 +341,9 @@ impl CostEstimation for KzgAs where M: MultiMillerLoop, { - type Input = Vec>; + type Input = Vec>; - fn estimate_cost(_: &Vec>) -> Cost { + fn estimate_cost(_: &Vec>) -> Cost { Cost { num_commitment: 2, num_msm: 2, ..Default::default() } } } diff --git a/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs b/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs index da5d51a6..e8114d09 100644 --- a/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs +++ b/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs @@ -6,7 +6,7 @@ use crate::{ PolynomialCommitmentScheme, Query, }, util::{ - arithmetic::{CurveAffine, MultiMillerLoop, PrimeField}, + arithmetic::{CurveAffine, MultiMillerLoop, Rotation}, msm::Msm, transcript::TranscriptRead, Itertools, @@ -32,7 +32,7 @@ where fn read_proof( _: &Self::VerifyingKey, - queries: &[Query], + queries: &[Query], transcript: &mut T, ) -> Result where @@ -45,7 +45,7 @@ where svk: &Self::VerifyingKey, commitments: &[Msm], z: &L::LoadedScalar, - queries: &[Query], + queries: &[Query], proof: &Self::Proof, ) -> Result { let sets = query_sets(queries); @@ -58,7 +58,10 @@ where .map(|(msm, power_of_u)| msm * power_of_u) .sum::>() }; - let z_omegas = sets.iter().map(|set| z.loader().load_const(&set.shift) * z); + let z_omegas = sets.iter().map(|set| { + let loaded_shift = set.loaded_shift.clone(); + loaded_shift * z + }); let rhs = proof .ws @@ -92,7 +95,7 @@ where C: CurveAffine, L: Loader, { - fn read(queries: &[Query], transcript: &mut T) -> Result + fn read(queries: &[Query], transcript: &mut T) -> Result where T: TranscriptRead, { @@ -103,22 +106,25 @@ where } } -struct QuerySet<'a, F, T> { - shift: F, +struct QuerySet<'a, S, T> { + shift: S, + loaded_shift: T, polys: Vec, evals: Vec<&'a T>, } -impl<'a, F, T> QuerySet<'a, F, T> +impl<'a, S, T> QuerySet<'a, S, T> where - F: PrimeField, T: Clone, { fn msm>( &self, commitments: &[Msm<'a, C, L>], powers_of_v: &[L::LoadedScalar], - ) -> Msm { + ) -> Msm + where + T: LoadedScalar, + { self.polys .iter() .zip(self.evals.iter().cloned()) @@ -132,9 +138,9 @@ where } } -fn query_sets(queries: &[Query]) -> Vec> +fn query_sets(queries: &[Query]) -> Vec> where - F: PrimeField, + S: PartialEq + Copy, T: Clone + PartialEq, { queries.iter().fold(Vec::new(), |mut sets, query| { @@ -144,6 +150,7 @@ where } else { sets.push(QuerySet { shift: query.shift, + loaded_shift: query.loaded_shift.clone(), polys: vec![query.poly], evals: vec![&query.eval], }); @@ -156,9 +163,9 @@ impl CostEstimation for KzgAs where M: MultiMillerLoop, { - type Input = Vec>; + type Input = Vec>; - fn estimate_cost(queries: &Vec>) -> Cost { + fn estimate_cost(queries: &Vec>) -> Cost { let num_w = query_sets(queries).len(); Cost { num_commitment: num_w, num_msm: num_w, ..Default::default() } } diff --git a/snark-verifier/src/system/halo2.rs b/snark-verifier/src/system/halo2.rs index 2dc5751d..74824361 100644 --- a/snark-verifier/src/system/halo2.rs +++ b/snark-verifier/src/system/halo2.rs @@ -141,6 +141,7 @@ pub fn compile<'a, C: CurveAffine, P: Params<'a, C>>( PlonkProtocol { domain, + domain_as_witness: None, preprocessed, num_instance: polynomials.num_instance(), num_witness: polynomials.num_witness(), diff --git a/snark-verifier/src/verifier/plonk.rs b/snark-verifier/src/verifier/plonk.rs index d5937ab8..9c910400 100644 --- a/snark-verifier/src/verifier/plonk.rs +++ b/snark-verifier/src/verifier/plonk.rs @@ -15,7 +15,10 @@ use crate::{ AccumulationDecider, AccumulationScheme, AccumulatorEncoding, PolynomialCommitmentScheme, Query, }, - util::{arithmetic::CurveAffine, transcript::TranscriptRead}, + util::{ + arithmetic::{CurveAffine, Rotation}, + transcript::TranscriptRead, + }, verifier::{plonk::protocol::CommonPolynomialEvaluation, SnarkVerifier}, Error, }; @@ -62,8 +65,12 @@ where proof: &Self::Proof, ) -> Result { let common_poly_eval = { - let mut common_poly_eval = - CommonPolynomialEvaluation::new(&protocol.domain, protocol.langranges(), &proof.z); + let mut common_poly_eval = CommonPolynomialEvaluation::new( + &protocol.domain, + protocol.langranges(), + &proof.z, + &protocol.domain_as_witness, + ); L::batch_invert(common_poly_eval.denoms()); common_poly_eval.evaluate(); @@ -140,7 +147,7 @@ where L: Loader, AS: AccumulationScheme + PolynomialCommitmentScheme - + CostEstimation>>, + + CostEstimation>>, { type Input = PlonkProtocol; @@ -168,7 +175,7 @@ where L: Loader, AS: AccumulationScheme + PolynomialCommitmentScheme - + CostEstimation>>, + + CostEstimation>>, { type Input = PlonkProtocol; diff --git a/snark-verifier/src/verifier/plonk/proof.rs b/snark-verifier/src/verifier/plonk/proof.rs index a42a56ae..459027ba 100644 --- a/snark-verifier/src/verifier/plonk/proof.rs +++ b/snark-verifier/src/verifier/plonk/proof.rs @@ -13,7 +13,10 @@ use crate::{ }, Error, }; -use std::{collections::HashMap, iter}; +use std::{ + collections::{BTreeMap, HashMap}, + iter, +}; /// Proof of PLONK with [`PolynomialCommitmentScheme`] that has /// [`AccumulationScheme`]. @@ -153,26 +156,42 @@ where } /// Empty queries - pub fn empty_queries(protocol: &PlonkProtocol) -> Vec> { - protocol - .queries - .iter() - .map(|query| { - let shift = protocol.domain.rotate_scalar(C::Scalar::ONE, query.rotation); - pcs::Query::new(query.poly, shift) - }) - .collect() + pub fn empty_queries(protocol: &PlonkProtocol) -> Vec> { + // `preprocessed` should always be non-empty, unless the circuit has no constraints or constants + protocol.queries.iter().map(|query| pcs::Query::new(query.poly, query.rotation)).collect() } pub(super) fn queries( &self, protocol: &PlonkProtocol, mut evaluations: HashMap, - ) -> Vec> { + ) -> Vec> { + if protocol.queries.is_empty() { + return vec![]; + } + let loader = evaluations[&protocol.queries[0]].loader(); + let rotations = + protocol.queries.iter().map(|query| query.rotation).sorted().dedup().collect_vec(); + let loaded_shifts = if let Some(domain) = protocol.domain_as_witness.as_ref() { + // the `rotation`s are still constants, it is only generator `omega` that might be witness + BTreeMap::from_iter( + rotations.into_iter().map(|rotation| (rotation, domain.rotate_one(rotation))), + ) + } else { + BTreeMap::from_iter(rotations.into_iter().map(|rotation| { + ( + rotation, + loader.load_const(&protocol.domain.rotate_scalar(C::Scalar::ONE, rotation)), + ) + })) + }; Self::empty_queries(protocol) .into_iter() .zip(protocol.queries.iter().map(|query| evaluations.remove(query).unwrap())) - .map(|(query, eval)| query.with_evaluation(eval)) + .map(|(query, eval)| { + let shift = loaded_shifts[&query.shift].clone(); + query.with_evaluation(shift, eval) + }) .collect() } diff --git a/snark-verifier/src/verifier/plonk/protocol.rs b/snark-verifier/src/verifier/plonk/protocol.rs index 97f6f336..098e7de9 100644 --- a/snark-verifier/src/verifier/plonk/protocol.rs +++ b/snark-verifier/src/verifier/plonk/protocol.rs @@ -1,7 +1,7 @@ use crate::{ loader::{native::NativeLoader, LoadedScalar, Loader}, util::{ - arithmetic::{CurveAffine, Domain, Field, Fraction, Rotation}, + arithmetic::{CurveAffine, Domain, Field, Fraction, PrimeField, Rotation}, Itertools, }, }; @@ -9,13 +9,44 @@ use num_integer::Integer; use num_traits::One; use serde::{Deserialize, Serialize}; use std::{ - cmp::max, + cmp::{max, Ordering}, collections::{BTreeMap, BTreeSet}, fmt::Debug, iter::{self, Sum}, ops::{Add, Mul, Neg, Sub}, }; +/// Domain parameters to be optionally loaded as witnesses +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct DomainAsWitness +where + C: CurveAffine, + L: Loader, +{ + /// Number of rows in the domain + pub n: L::LoadedScalar, + /// Generator of the domain + pub gen: L::LoadedScalar, + /// Inverse generator of the domain + pub gen_inv: L::LoadedScalar, +} + +impl DomainAsWitness +where + C: CurveAffine, + L: Loader, +{ + /// Rotate `F::one()` to given `rotation`. + pub fn rotate_one(&self, rotation: Rotation) -> L::LoadedScalar { + let loader = self.gen.loader(); + match rotation.0.cmp(&0) { + Ordering::Equal => loader.load_one(), + Ordering::Greater => self.gen.pow_const(rotation.0 as u64), + Ordering::Less => self.gen_inv.pow_const(-rotation.0 as u64), + } + } +} + /// Protocol specifying configuration of a PLONK. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct PlonkProtocol @@ -29,6 +60,15 @@ where ))] /// Working domain. pub domain: Domain, + + #[serde(bound( + serialize = "L::LoadedScalar: Serialize", + deserialize = "L::LoadedScalar: Deserialize<'de>" + ))] + #[serde(skip_serializing_if = "Option::is_none")] + /// Optional: load `domain.n` and `domain.gen` as a witness + pub domain_as_witness: Option>, + #[serde(bound( serialize = "L::LoadedEcPoint: Serialize", deserialize = "L::LoadedEcPoint: Deserialize<'de>" @@ -115,6 +155,7 @@ where .map(|transcript_initial_state| loader.load_const(transcript_initial_state)); PlonkProtocol { domain: self.domain.clone(), + domain_as_witness: None, preprocessed, num_instance: self.num_instance.clone(), num_witness: self.num_witness.clone(), @@ -133,12 +174,17 @@ where #[cfg(feature = "loader_halo2")] mod halo2 { use crate::{ - loader::halo2::{EccInstructions, Halo2Loader}, + loader::{ + halo2::{EccInstructions, Halo2Loader}, + LoadedScalar, + }, util::arithmetic::CurveAffine, verifier::plonk::PlonkProtocol, }; use std::rc::Rc; + use super::DomainAsWitness; + impl PlonkProtocol where C: CurveAffine, @@ -149,7 +195,15 @@ mod halo2 { pub fn loaded_preprocessed_as_witness>( &self, loader: &Rc>, + load_n_as_witness: bool, ) -> PlonkProtocol>> { + let domain_as_witness = load_n_as_witness.then(|| { + let n = loader.assign_scalar(C::Scalar::from(self.domain.n as u64)); + let gen = loader.assign_scalar(self.domain.gen); + let gen_inv = gen.invert().expect("subgroup generation is invertible"); + DomainAsWitness { n, gen, gen_inv } + }); + let preprocessed = self .preprocessed .iter() @@ -161,6 +215,7 @@ mod halo2 { .map(|transcript_initial_state| loader.assign_scalar(*transcript_initial_state)); PlonkProtocol { domain: self.domain.clone(), + domain_as_witness, preprocessed, num_instance: self.num_instance.clone(), num_witness: self.num_witness.clone(), @@ -201,26 +256,39 @@ where C: CurveAffine, L: Loader, { + // if `n_as_witness` is Some, then we assume `n_as_witness` has value equal to `domain.n` (i.e., number of rows in the circuit) + // and is loaded as a witness instead of a constant. + // The generator of `domain` also depends on `n`. pub fn new( domain: &Domain, - langranges: impl IntoIterator, + lagranges: impl IntoIterator, z: &L::LoadedScalar, + domain_as_witness: &Option>, ) -> Self { let loader = z.loader(); - let zn = z.pow_const(domain.n as u64); - let langranges = langranges.into_iter().sorted().dedup().collect_vec(); - + let lagranges = lagranges.into_iter().sorted().dedup().collect_vec(); let one = loader.load_one(); + + let (zn, n_inv, omegas) = if let Some(domain) = domain_as_witness.as_ref() { + let zn = z.pow_var(&domain.n, C::Scalar::S as usize + 1); + let n_inv = domain.n.invert().expect("n is not zero"); + let omegas = lagranges.iter().map(|&i| domain.rotate_one(Rotation(i))).collect_vec(); + (zn, n_inv, omegas) + } else { + let zn = z.pow_const(domain.n as u64); + let n_inv = loader.load_const(&domain.n_inv); + let omegas = lagranges + .iter() + .map(|&i| loader.load_const(&domain.rotate_scalar(C::Scalar::ONE, Rotation(i)))) + .collect_vec(); + (zn, n_inv, omegas) + }; + let zn_minus_one = zn.clone() - &one; let zn_minus_one_inv = Fraction::one_over(zn_minus_one.clone()); - let n_inv = loader.load_const(&domain.n_inv); let numer = zn_minus_one.clone() * &n_inv; - let omegas = langranges - .iter() - .map(|&i| loader.load_const(&domain.rotate_scalar(C::Scalar::ONE, Rotation(i)))) - .collect_vec(); let lagrange_evals = omegas .iter() .map(|omega| Fraction::new(numer.clone() * omega, z.clone() - omega)) @@ -231,7 +299,7 @@ where zn_minus_one, zn_minus_one_inv, identity: z.clone(), - lagrange: langranges.into_iter().zip(lagrange_evals).collect(), + lagrange: lagranges.into_iter().zip(lagrange_evals).collect(), } }