diff --git a/Cargo.toml b/Cargo.toml index 233bf95ad0..1e75bcb87e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "halo2_frontend", "halo2_middleware", "halo2_backend", + "halo2_debug", "p3_frontend", ] resolver = "2" diff --git a/halo2_backend/src/plonk/circuit.rs b/halo2_backend/src/plonk/circuit.rs index c36bd7b9b4..716284b65a 100644 --- a/halo2_backend/src/plonk/circuit.rs +++ b/halo2_backend/src/plonk/circuit.rs @@ -25,6 +25,12 @@ pub enum VarBack { Challenge(ChallengeMid), } +impl std::fmt::Display for VarBack { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + impl Variable for VarBack { fn degree(&self) -> usize { match self { @@ -40,8 +46,8 @@ impl Variable for VarBack { } } - fn write_identifier(&self, _writer: &mut W) -> std::io::Result<()> { - unimplemented!("unused method") + fn write_identifier(&self, writer: &mut W) -> std::io::Result<()> { + write!(writer, "{}", self) } } diff --git a/halo2_debug/Cargo.toml b/halo2_debug/Cargo.toml new file mode 100644 index 0000000000..d86ea3755b --- /dev/null +++ b/halo2_debug/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "halo2_debug" +version = "0.3.0" +authors = [ + "Privacy Scaling Explorations team", +] +edition = "2021" +rust-version = "1.66.0" +description = """ +Halo2 Debug. This package contains utilities for debugging and testing within +the halo2 ecosystem. +""" +license = "MIT OR Apache-2.0" +repository = "https://github.com/privacy-scaling-explorations/halo2" +documentation = "https://privacy-scaling-explorations.github.io/halo2/" +categories = ["cryptography"] +keywords = ["halo", "proofs", "zkp", "zkSNARKs"] + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"] + +[dependencies] +ff = "0.13" +halo2curves = { version = "0.6.1", default-features = false } +num-bigint = "0.4.5" +halo2_middleware = { path = "../halo2_middleware" } diff --git a/halo2_debug/src/display.rs b/halo2_debug/src/display.rs new file mode 100644 index 0000000000..f4feba48f6 --- /dev/null +++ b/halo2_debug/src/display.rs @@ -0,0 +1,351 @@ +use ff::PrimeField; +use halo2_middleware::circuit::{ColumnMid, VarMid}; +use halo2_middleware::expression::{Expression, Variable}; +use halo2_middleware::{lookup, shuffle}; +use num_bigint::BigUint; +use std::collections::HashMap; +use std::fmt; + +/// Wrapper type over `PrimeField` that implements Display with nice output. +/// - If the value is a power of two, format it as `2^k` +/// - If the value is smaller than 2^16, format it in decimal +/// - If the value is bigger than congruent -2^16, format it in decimal as the negative congruent +/// (between -2^16 and 0). +/// - Else format it in hex without leading zeros. +pub struct FDisp<'a, F: PrimeField>(pub &'a F); + +impl fmt::Display for FDisp<'_, F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let v = (*self.0).to_repr(); + let v = v.as_ref(); + let v = BigUint::from_bytes_le(v); + let v_bits = v.bits(); + if v_bits >= 8 && v.count_ones() == 1 { + write!(f, "2^{}", v.trailing_zeros().unwrap_or_default()) + } else if v_bits < 16 { + write!(f, "{}", v) + } else { + let v_neg = (F::ZERO - self.0).to_repr(); + let v_neg = v_neg.as_ref(); + let v_neg = BigUint::from_bytes_le(v_neg); + let v_neg_bits = v_neg.bits(); + if v_neg_bits < 16 { + write!(f, "-{}", v_neg) + } else { + write!(f, "0x{:x}", v) + } + } + } +} + +/// Wrapper type over `Expression` that implements Display with nice output. +/// The formatting of the `Expression::Variable` case is parametrized with the second field, which +/// take as auxiliary value the third field. +/// Use the constructor `expr_disp` to format variables using their `Display` implementation. +/// Use the constructor `expr_disp_names` for an `Expression` with `V: VarMid` to format column +/// queries according to their string names. +pub struct ExprDisp<'a, F: PrimeField, V: Variable, A>( + /// Expression to display + pub &'a Expression, + /// `V: Variable` formatter method that uses auxiliary value + pub fn(&V, &mut fmt::Formatter<'_>, a: &A) -> fmt::Result, + /// Auxiliary value to be passed to the `V: Variable` formatter + pub &'a A, +); + +fn var_fmt_default(v: &V, f: &mut fmt::Formatter<'_>, _: &()) -> fmt::Result { + write!(f, "{}", v) +} + +fn var_fmt_names( + v: &VarMid, + f: &mut fmt::Formatter<'_>, + names: &HashMap, +) -> fmt::Result { + if let VarMid::Query(q) = v { + if let Some(name) = names.get(&ColumnMid::new(q.column_type, q.column_index)) { + return write!(f, "{}", name); + } + } + write!(f, "{}", v) +} + +/// ExprDisp constructor that formats viariables using their `Display` implementation. +pub fn expr_disp(e: &Expression) -> ExprDisp { + ExprDisp(e, var_fmt_default, &()) +} + +/// ExprDisp constructor for an `Expression` with `V: VarMid` that formats column queries according +/// to their string names. +pub fn expr_disp_names<'a, F: PrimeField>( + e: &'a Expression, + names: &'a HashMap, +) -> ExprDisp<'a, F, VarMid, HashMap> { + ExprDisp(e, var_fmt_names, names) +} + +impl fmt::Display for ExprDisp<'_, F, V, A> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let is_sum = |e: &Expression| -> bool { matches!(e, Expression::Sum(_, _)) }; + let fmt_expr = + |e: &Expression, f: &mut fmt::Formatter<'_>, parens: bool| -> fmt::Result { + if parens { + write!(f, "(")?; + } + write!(f, "{}", ExprDisp(e, self.1, self.2))?; + if parens { + write!(f, ")")?; + } + Ok(()) + }; + + match self.0 { + Expression::Constant(c) => write!(f, "{}", FDisp(c)), + Expression::Var(v) => self.1(v, f, self.2), + Expression::Negated(a) => { + write!(f, "-")?; + fmt_expr(a, f, is_sum(a)) + } + Expression::Sum(a, b) => { + fmt_expr(a, f, false)?; + if let Expression::Negated(neg) = &**b { + write!(f, " - ")?; + fmt_expr(neg, f, is_sum(neg)) + } else { + write!(f, " + ")?; + fmt_expr(b, f, false) + } + } + Expression::Product(a, b) => { + fmt_expr(a, f, is_sum(a))?; + write!(f, " * ")?; + fmt_expr(b, f, is_sum(b)) + } + } + } +} + +/// Wrapper type over `lookup::Argument` that implements Display with nice output. +/// The formatting of the `Expression::Variable` case is parametrized with the second field, which +/// take as auxiliary value the third field. +/// Use the constructor `lookup_arg_disp` to format variables using their `Display` implementation. +/// Use the constructor `lookup_arg_disp_names` for a lookup of `Expression` with `V: VarMid` that +/// formats column queries according to their string names. +pub struct LookupArgDisp<'a, F: PrimeField, V: Variable, A>( + /// Lookup argument to display + pub &'a lookup::Argument, + /// `V: Variable` formatter method that uses auxiliary value + pub fn(&V, &mut fmt::Formatter<'_>, a: &A) -> fmt::Result, + /// Auxiliary value to be passed to the `V: Variable` formatter + pub &'a A, +); + +/// LookupArgDisp constructor that formats viariables using their `Display` implementation. +pub fn lookup_arg_disp( + a: &lookup::Argument, +) -> LookupArgDisp { + LookupArgDisp(a, var_fmt_default, &()) +} + +/// LookupArgDisp constructor for a lookup of `Expression` with `V: VarMid` that formats column +/// queries according to their string names. +pub fn lookup_arg_disp_names<'a, F: PrimeField>( + a: &'a lookup::Argument, + names: &'a HashMap, +) -> LookupArgDisp<'a, F, VarMid, HashMap> { + LookupArgDisp(a, var_fmt_names, names) +} + +impl fmt::Display for LookupArgDisp<'_, F, V, A> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[")?; + for (i, expr) in self.0.input_expressions.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", ExprDisp(expr, self.1, self.2))?; + } + write!(f, "] in [")?; + for (i, expr) in self.0.table_expressions.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", ExprDisp(expr, self.1, self.2))?; + } + write!(f, "]")?; + Ok(()) + } +} + +/// Wrapper type over `shuffle::Argument` that implements Display with nice output. +/// The formatting of the `Expression::Variable` case is parametrized with the second field, which +/// take as auxiliary value the third field. +/// Use the constructor `shuffle_arg_disp` to format variables using their `Display` +/// implementation. +/// Use the constructor `shuffle_arg_disp_names` for a shuffle of `Expression` with `V: VarMid` +/// that formats column queries according to their string names. +pub struct ShuffleArgDisp<'a, F: PrimeField, V: Variable, A>( + /// Shuffle argument to display + pub &'a shuffle::Argument, + /// `V: Variable` formatter method that uses auxiliary value + pub fn(&V, &mut fmt::Formatter<'_>, a: &A) -> fmt::Result, + /// Auxiliary value to be passed to the `V: Variable` formatter + pub &'a A, +); + +/// ShuffleArgDisp constructor that formats viariables using their `Display` implementation. +pub fn shuffle_arg_disp( + a: &shuffle::Argument, +) -> ShuffleArgDisp { + ShuffleArgDisp(a, var_fmt_default, &()) +} + +/// ShuffleArgDisp constructor for a shuffle of `Expression` with `V: VarMid` that formats column +/// queries according to their string names. +pub fn shuffle_arg_disp_names<'a, F: PrimeField>( + a: &'a shuffle::Argument, + names: &'a HashMap, +) -> ShuffleArgDisp<'a, F, VarMid, HashMap> { + ShuffleArgDisp(a, var_fmt_names, names) +} + +impl fmt::Display for ShuffleArgDisp<'_, F, V, A> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[")?; + for (i, expr) in self.0.input_expressions.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", ExprDisp(expr, self.1, self.2))?; + } + write!(f, "] shuff [")?; + for (i, expr) in self.0.shuffle_expressions.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", ExprDisp(expr, self.1, self.2))?; + } + write!(f, "]")?; + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use ff::Field; + use halo2_middleware::circuit::{Any, QueryMid, VarMid}; + use halo2_middleware::poly::Rotation; + use halo2curves::bn256::Fr; + + #[test] + fn test_lookup_shuffle_arg_disp() { + type E = Expression; + let a0 = VarMid::Query(QueryMid::new(Any::Advice, 0, Rotation(0))); + let a1 = VarMid::Query(QueryMid::new(Any::Advice, 1, Rotation(0))); + let f0 = VarMid::Query(QueryMid::new(Any::Fixed, 0, Rotation(0))); + let a0: E = Expression::Var(a0); + let a1: E = Expression::Var(a1); + let f0: E = Expression::Var(f0); + + let names = [ + (ColumnMid::new(Any::Advice, 0), "a".to_string()), + (ColumnMid::new(Any::Advice, 1), "b".to_string()), + (ColumnMid::new(Any::Fixed, 0), "s".to_string()), + ] + .into_iter() + .collect(); + + let arg = lookup::Argument { + name: "lookup".to_string(), + input_expressions: vec![f0.clone() * a0.clone(), f0.clone() * a1.clone()], + table_expressions: vec![f0.clone(), f0.clone() * (a0.clone() + a1.clone())], + }; + assert_eq!( + "[f0 * a0, f0 * a1] in [f0, f0 * (a0 + a1)]", + format!("{}", lookup_arg_disp(&arg)) + ); + assert_eq!( + "[s * a, s * b] in [s, s * (a + b)]", + format!("{}", lookup_arg_disp_names(&arg, &names)) + ); + + let arg = shuffle::Argument { + name: "shuffle".to_string(), + input_expressions: vec![f0.clone() * a0.clone(), f0.clone() * a1.clone()], + shuffle_expressions: vec![f0.clone(), f0.clone() * (a0.clone() + a1.clone())], + }; + assert_eq!( + "[f0 * a0, f0 * a1] shuff [f0, f0 * (a0 + a1)]", + format!("{}", shuffle_arg_disp(&arg)) + ); + assert_eq!( + "[s * a, s * b] shuff [s, s * (a + b)]", + format!("{}", shuffle_arg_disp_names(&arg, &names)) + ); + } + + #[test] + fn test_expr_disp() { + type E = Expression; + let a0 = VarMid::Query(QueryMid::new(Any::Advice, 0, Rotation(0))); + let a1 = VarMid::Query(QueryMid::new(Any::Advice, 1, Rotation(0))); + let a0: E = Expression::Var(a0); + let a1: E = Expression::Var(a1); + + let e = a0.clone() + a1.clone(); + assert_eq!("a0 + a1", format!("{}", expr_disp(&e))); + let e = a0.clone() + a1.clone() + a0.clone(); + assert_eq!("a0 + a1 + a0", format!("{}", expr_disp(&e))); + + let e = a0.clone() * a1.clone(); + assert_eq!("a0 * a1", format!("{}", expr_disp(&e))); + let e = a0.clone() * a1.clone() * a0.clone(); + assert_eq!("a0 * a1 * a0", format!("{}", expr_disp(&e))); + + let e = a0.clone() - a1.clone(); + assert_eq!("a0 - a1", format!("{}", expr_disp(&e))); + let e = (a0.clone() - a1.clone()) - a0.clone(); + assert_eq!("a0 - a1 - a0", format!("{}", expr_disp(&e))); + let e = a0.clone() - (a1.clone() - a0.clone()); + assert_eq!("a0 - (a1 - a0)", format!("{}", expr_disp(&e))); + + let e = a0.clone() * a1.clone() + a0.clone(); + assert_eq!("a0 * a1 + a0", format!("{}", expr_disp(&e))); + let e = a0.clone() * (a1.clone() + a0.clone()); + assert_eq!("a0 * (a1 + a0)", format!("{}", expr_disp(&e))); + + let e = a0.clone() + a1.clone(); + let names = [ + (ColumnMid::new(Any::Advice, 0), "a".to_string()), + (ColumnMid::new(Any::Advice, 1), "b".to_string()), + ] + .into_iter() + .collect(); + assert_eq!("a + b", format!("{}", expr_disp_names(&e, &names))); + } + + #[test] + fn test_f_disp() { + let v = Fr::ZERO; + assert_eq!("0", format!("{}", FDisp(&v))); + + let v = Fr::ONE; + assert_eq!("1", format!("{}", FDisp(&v))); + + let v = Fr::from(12345u64); + assert_eq!("12345", format!("{}", FDisp(&v))); + + let v = Fr::from(0x10000); + assert_eq!("2^16", format!("{}", FDisp(&v))); + + let v = Fr::from(0x12345); + assert_eq!("0x12345", format!("{}", FDisp(&v))); + + let v = -Fr::ONE; + assert_eq!("-1", format!("{}", FDisp(&v))); + + let v = -Fr::from(12345u64); + assert_eq!("-12345", format!("{}", FDisp(&v))); + } +} diff --git a/halo2_debug/src/lib.rs b/halo2_debug/src/lib.rs new file mode 100644 index 0000000000..8754563b71 --- /dev/null +++ b/halo2_debug/src/lib.rs @@ -0,0 +1 @@ +pub mod display; diff --git a/halo2_frontend/src/plonk/circuit/constraint_system.rs b/halo2_frontend/src/plonk/circuit/constraint_system.rs index 7a393a6ac0..cabe718042 100644 --- a/halo2_frontend/src/plonk/circuit/constraint_system.rs +++ b/halo2_frontend/src/plonk/circuit/constraint_system.rs @@ -790,6 +790,16 @@ impl ConstraintSystem { /// Annotate an Instance column. pub fn annotate_lookup_any_column(&mut self, column: T, annotation: A) + where + A: Fn() -> AR, + AR: Into, + T: Into>, + { + self.annotate_column(column, annotation) + } + + /// Annotate a column. + pub fn annotate_column(&mut self, column: T, annotation: A) where A: Fn() -> AR, AR: Into, diff --git a/halo2_middleware/Cargo.toml b/halo2_middleware/Cargo.toml index 1c24c77c69..63b0f7e0b2 100644 --- a/halo2_middleware/Cargo.toml +++ b/halo2_middleware/Cargo.toml @@ -35,8 +35,13 @@ rayon = "1.8" ark-std = { version = "0.3" } proptest = "1" group = "0.13" -rand_xorshift = "0.3.0" -rand_core = "0.6.4" +rand_core = { version = "0.6", default-features = false } +halo2_backend = { path = "../halo2_backend" } +halo2_frontend = { path = "../halo2_frontend" } +rand_xorshift = "0.3" + +[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dev-dependencies] +getrandom = { version = "0.2", features = ["js"] } [lib] bench = false diff --git a/halo2_middleware/src/circuit.rs b/halo2_middleware/src/circuit.rs index 98c2891b74..4c8d4ee0d9 100644 --- a/halo2_middleware/src/circuit.rs +++ b/halo2_middleware/src/circuit.rs @@ -34,6 +34,16 @@ pub struct QueryMid { pub rotation: Rotation, } +impl QueryMid { + pub fn new(column_type: Any, column_index: usize, rotation: Rotation) -> Self { + Self { + column_index, + column_type, + rotation, + } + } +} + #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum VarMid { /// This is a generic column query @@ -42,6 +52,28 @@ pub enum VarMid { Challenge(ChallengeMid), } +impl fmt::Display for VarMid { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + VarMid::Query(query) => { + match query.column_type { + Any::Fixed => write!(f, "f")?, + Any::Advice => write!(f, "a")?, + Any::Instance => write!(f, "i")?, + }; + write!(f, "{}", query.column_index)?; + if query.rotation.0 != 0 { + write!(f, "[{}]", query.rotation.0)?; + } + Ok(()) + } + VarMid::Challenge(challenge) => { + write!(f, "ch{}", challenge.index()) + } + } + } +} + impl Variable for VarMid { fn degree(&self) -> usize { match self { @@ -58,19 +90,7 @@ impl Variable for VarMid { } fn write_identifier(&self, writer: &mut W) -> std::io::Result<()> { - match self { - VarMid::Query(query) => { - match query.column_type { - Any::Fixed => write!(writer, "fixed")?, - Any::Advice => write!(writer, "advice")?, - Any::Instance => write!(writer, "instance")?, - }; - write!(writer, "[{}][{}]", query.column_index, query.rotation.0) - } - VarMid::Challenge(challenge) => { - write!(writer, "challenge[{}]", challenge.index()) - } - } + write!(writer, "{}", self) } } @@ -136,6 +156,19 @@ pub struct ConstraintSystemMid { pub minimum_degree: Option, } +impl ConstraintSystemMid { + /// Returns the number of phases + pub fn phases(&self) -> usize { + let max_phase = self + .advice_column_phase + .iter() + .copied() + .max() + .unwrap_or_default(); + max_phase as usize + 1 + } +} + /// Data that needs to be preprocessed from a circuit #[derive(Debug, Clone)] pub struct Preprocessing { diff --git a/halo2_middleware/src/expression.rs b/halo2_middleware/src/expression.rs index 217d26c321..1cbb557ed4 100644 --- a/halo2_middleware/src/expression.rs +++ b/halo2_middleware/src/expression.rs @@ -3,7 +3,7 @@ use core::ops::{Add, Mul, Neg, Sub}; use ff::Field; use std::iter::{Product, Sum}; -pub trait Variable: Clone + Copy + std::fmt::Debug + Eq + PartialEq { +pub trait Variable: Clone + Copy + std::fmt::Debug + std::fmt::Display + Eq + PartialEq { /// Degree that an expression would have if it was only this variable. fn degree(&self) -> usize; diff --git a/halo2_middleware/src/zal.rs b/halo2_middleware/src/zal.rs index 5d376e3e5b..28d9491ce6 100644 --- a/halo2_middleware/src/zal.rs +++ b/halo2_middleware/src/zal.rs @@ -257,117 +257,94 @@ mod test { use ark_std::{end_timer, start_timer}; use ff::Field; use group::{Curve, Group}; - use rand_core::SeedableRng; - use rand_xorshift::XorShiftRng; + use rand_core::OsRng; - fn gen_points_scalars(k: usize) -> (Vec, Vec) { - let mut rng = XorShiftRng::seed_from_u64(3141592u64); - - let points = (0..1 << k) - .map(|_| C::Curve::random(&mut rng)) + fn run_msm_zal_default(min_k: usize, max_k: usize) { + let points = (0..1 << max_k) + .map(|_| C::Curve::random(OsRng)) .collect::>(); - let mut affine_points = vec![C::identity(); 1 << k]; + let mut affine_points = vec![C::identity(); 1 << max_k]; C::Curve::batch_normalize(&points[..], &mut affine_points[..]); let points = affine_points; - let scalars = (0..1 << k) - .map(|_| C::Scalar::random(&mut rng)) + let scalars = (0..1 << max_k) + .map(|_| C::Scalar::random(OsRng)) .collect::>(); - (points, scalars) - } + for k in min_k..=max_k { + let points = &points[..1 << k]; + let scalars = &scalars[..1 << k]; - fn run_msm_zal_default(points: &[C], scalars: &[C::Scalar], k: usize) { - let points = &points[..1 << k]; - let scalars = &scalars[..1 << k]; + let t0 = start_timer!(|| format!("freestanding msm k={}", k)); + let e0 = best_multiexp(scalars, points); + end_timer!(t0); - let t0 = start_timer!(|| format!("freestanding msm k={}", k)); - let e0 = best_multiexp(scalars, points); - end_timer!(t0); + let engine = PlonkEngineConfig::build_default::(); + let t1 = start_timer!(|| format!("H2cEngine msm k={}", k)); + let e1 = engine.msm_backend.msm(scalars, points); + end_timer!(t1); - let engine = PlonkEngineConfig::build_default::(); - let t1 = start_timer!(|| format!("H2cEngine msm k={}", k)); - let e1 = engine.msm_backend.msm(scalars, points); - end_timer!(t1); + assert_eq!(e0, e1); - assert_eq!(e0, e1); + // Caching API + // ----------- + let t2 = start_timer!(|| format!("H2cEngine msm cached base k={}", k)); + let base_descriptor = engine.msm_backend.get_base_descriptor(points); + let e2 = engine + .msm_backend + .msm_with_cached_base(scalars, &base_descriptor); + end_timer!(t2); - // Caching API - // ----------- - let t2 = start_timer!(|| format!("H2cEngine msm cached base k={}", k)); - let base_descriptor = engine.msm_backend.get_base_descriptor(points); - let e2 = engine - .msm_backend - .msm_with_cached_base(scalars, &base_descriptor); - end_timer!(t2); - assert_eq!(e0, e2); - - let t3 = start_timer!(|| format!("H2cEngine msm cached coeffs k={}", k)); - let coeffs_descriptor = engine.msm_backend.get_coeffs_descriptor(scalars); - let e3 = engine - .msm_backend - .msm_with_cached_scalars(&coeffs_descriptor, points); - end_timer!(t3); - assert_eq!(e0, e3); - - let t4 = start_timer!(|| format!("H2cEngine msm cached inputs k={}", k)); - let e4 = engine - .msm_backend - .msm_with_cached_inputs(&coeffs_descriptor, &base_descriptor); - end_timer!(t4); - assert_eq!(e0, e4); + assert_eq!(e0, e2) + } } - fn run_msm_zal_custom(points: &[C], scalars: &[C::Scalar], k: usize) { - let points = &points[..1 << k]; - let scalars = &scalars[..1 << k]; - - let t0 = start_timer!(|| format!("freestanding msm k={}", k)); - let e0 = best_multiexp(scalars, points); - end_timer!(t0); - - let engine = PlonkEngineConfig::new() - .set_curve::() - .set_msm(H2cEngine::new()) - .build(); - let t1 = start_timer!(|| format!("H2cEngine msm k={}", k)); - let e1 = engine.msm_backend.msm(scalars, points); - end_timer!(t1); - - assert_eq!(e0, e1); - - // Caching API - // ----------- - let t2 = start_timer!(|| format!("H2cEngine msm cached base k={}", k)); - let base_descriptor = engine.msm_backend.get_base_descriptor(points); - let e2 = engine - .msm_backend - .msm_with_cached_base(scalars, &base_descriptor); - end_timer!(t2); - - assert_eq!(e0, e2) - } + fn run_msm_zal_custom(min_k: usize, max_k: usize) { + let points = (0..1 << max_k) + .map(|_| C::Curve::random(OsRng)) + .collect::>(); + let mut affine_points = vec![C::identity(); 1 << max_k]; + C::Curve::batch_normalize(&points[..], &mut affine_points[..]); + let points = affine_points; - #[test] - #[ignore] - fn test_performance_h2c_msm_zal() { - let (min_k, max_k) = (3, 14); - let (points, scalars) = gen_points_scalars::(max_k); + let scalars = (0..1 << max_k) + .map(|_| C::Scalar::random(OsRng)) + .collect::>(); for k in min_k..=max_k { let points = &points[..1 << k]; let scalars = &scalars[..1 << k]; - run_msm_zal_default(points, scalars, k); - run_msm_zal_custom(points, scalars, k); + let t0 = start_timer!(|| format!("freestanding msm k={}", k)); + let e0 = best_multiexp(scalars, points); + end_timer!(t0); + + let engine = PlonkEngineConfig::new() + .set_curve::() + .set_msm(H2cEngine::new()) + .build(); + let t1 = start_timer!(|| format!("H2cEngine msm k={}", k)); + let e1 = engine.msm_backend.msm(scalars, points); + end_timer!(t1); + + assert_eq!(e0, e1); + + // Caching API + // ----------- + let t2 = start_timer!(|| format!("H2cEngine msm cached base k={}", k)); + let base_descriptor = engine.msm_backend.get_base_descriptor(points); + let e2 = engine + .msm_backend + .msm_with_cached_base(scalars, &base_descriptor); + end_timer!(t2); + + assert_eq!(e0, e2) } } #[test] fn test_msm_zal() { - const MSM_SIZE: usize = 12; - let (points, scalars) = gen_points_scalars::(MSM_SIZE); - run_msm_zal_default(&points, &scalars, MSM_SIZE); - run_msm_zal_custom(&points, &scalars, MSM_SIZE); + run_msm_zal_default::(3, 14); + run_msm_zal_custom::(3, 14); } } diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index a68f422934..4302dd7276 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -61,6 +61,7 @@ gumdrop = "0.8" proptest = "1" dhat = "0.3.2" serde_json = "1" +halo2_debug = { path = "../halo2_debug" } [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dev-dependencies] getrandom = { version = "0.2", features = ["js"] } diff --git a/halo2_proofs/tests/compress_selectors.rs b/halo2_proofs/tests/compress_selectors.rs index ec87354fc2..b34a099151 100644 --- a/halo2_proofs/tests/compress_selectors.rs +++ b/halo2_proofs/tests/compress_selectors.rs @@ -3,6 +3,8 @@ use std::marker::PhantomData; use ff::PrimeField; +use halo2_debug::display::expr_disp_names; +use halo2_frontend::circuit::compile_circuit; use halo2_frontend::plonk::Error; use halo2_proofs::circuit::{Cell, Layouter, SimpleFloorPlanner, Value}; use halo2_proofs::poly::Rotation; @@ -10,6 +12,7 @@ use halo2_proofs::poly::Rotation; use halo2_backend::transcript::{ Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, }; +use halo2_middleware::circuit::{Any, ColumnMid}; use halo2_middleware::zal::impls::{H2cEngine, PlonkEngineConfig}; use halo2_proofs::arithmetic::Field; use halo2_proofs::plonk::{ @@ -101,12 +104,16 @@ impl MyCircuitChip { let l = meta.advice_column(); let r = meta.advice_column(); let o = meta.advice_column(); + meta.annotate_column(l, || "l"); + meta.annotate_column(r, || "r"); + meta.annotate_column(o, || "o"); let s_add = meta.selector(); let s_mul = meta.selector(); let s_cubed = meta.selector(); let PI = meta.instance_column(); + meta.annotate_column(PI, || "pi"); meta.enable_equality(l); meta.enable_equality(r); @@ -435,6 +442,63 @@ How the `compress_selectors` works in `MyCircuit` under the hood: */ +#[test] +fn test_compress_gates() { + let k = 4; + let circuit: MyCircuit = MyCircuit { + x: Value::known(Fr::one()), + y: Value::known(Fr::one()), + constant: Fr::one(), + }; + + // Without compression + + let (mut compress_false, _, _) = compile_circuit(k, &circuit, false).unwrap(); + + let names = &mut compress_false.cs.general_column_annotations; + names.insert(ColumnMid::new(Any::Fixed, 0), "s_add".to_string()); + names.insert(ColumnMid::new(Any::Fixed, 1), "s_mul".to_string()); + names.insert(ColumnMid::new(Any::Fixed, 2), "s_cubed".to_string()); + let cs = &compress_false.cs; + let names = &cs.general_column_annotations; + assert_eq!(3, cs.gates.len()); + assert_eq!( + "s_add * (l + r - o)", + format!("{}", expr_disp_names(&cs.gates[0].poly, names)) + ); + assert_eq!( + "s_mul * (l * r - o)", + format!("{}", expr_disp_names(&cs.gates[1].poly, names)) + ); + assert_eq!( + "s_cubed * (l * l * l - o)", + format!("{}", expr_disp_names(&cs.gates[2].poly, names)) + ); + + // With compression + + let (mut compress_true, _, _) = compile_circuit(k, &circuit, true).unwrap(); + + let names = &mut compress_true.cs.general_column_annotations; + names.insert(ColumnMid::new(Any::Fixed, 0), "s_add_mul".to_string()); + names.insert(ColumnMid::new(Any::Fixed, 1), "s_cubed".to_string()); + let cs = &compress_true.cs; + let names = &cs.general_column_annotations; + assert_eq!(3, cs.gates.len()); + assert_eq!( + "s_add_mul * (2 - s_add_mul) * (l + r - o)", + format!("{}", expr_disp_names(&cs.gates[0].poly, names)) + ); + assert_eq!( + "s_add_mul * (1 - s_add_mul) * (l * r - o)", + format!("{}", expr_disp_names(&cs.gates[1].poly, names)) + ); + assert_eq!( + "s_cubed * (l * l * l - o)", + format!("{}", expr_disp_names(&cs.gates[2].poly, names)) + ); +} + #[test] fn test_success() { // vk & pk keygen both WITH compress