From 77761c670f2d516ab486de0f7bde036ff00ebd99 Mon Sep 17 00:00:00 2001 From: ledwards2225 <98505400+ledwards2225@users.noreply.github.com> Date: Wed, 19 Jun 2024 14:09:53 -0700 Subject: [PATCH] chore: Ultra flavor cleanup (#7070) Removes a ton of duplication from ultra and ultra recursive flavors. This had the side effect of slightly changing the order of one of the get_all methods which required updates to the combiner test suite. The suite has been updated to include one test that depends on the "hand-computable" python values as originally intended, and one that checks the combiner optimization consistency without dependence on python generated values. --------- Co-authored-by: Rumata888 --- .../protogalaxy/combiner.test.cpp | 214 +++++++++++++++-- .../protogalaxy/combiner_example_gen.py | 192 +++++++-------- .../stdlib_circuit_builders/ultra_flavor.hpp | 166 +++++-------- .../ultra_recursive_flavor.hpp | 224 +----------------- 4 files changed, 345 insertions(+), 451 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/protogalaxy/combiner.test.cpp b/barretenberg/cpp/src/barretenberg/protogalaxy/combiner.test.cpp index 4ff7f81cb51..4c47ac5086b 100644 --- a/barretenberg/cpp/src/barretenberg/protogalaxy/combiner.test.cpp +++ b/barretenberg/cpp/src/barretenberg/protogalaxy/combiner.test.cpp @@ -2,6 +2,7 @@ #include "barretenberg/polynomials/pow.hpp" #include "barretenberg/protogalaxy/protogalaxy_prover.hpp" #include "barretenberg/relations/relation_parameters.hpp" +#include "barretenberg/relations/ultra_arithmetic_relation.hpp" #include "barretenberg/stdlib_circuit_builders/ultra_flavor.hpp" #include "barretenberg/sumcheck/instance/instances.hpp" #include @@ -44,12 +45,6 @@ TEST(Protogalaxy, CombinerOn2Instances) auto prover_polynomials = get_sequential_prover_polynomials( /*log_circuit_size=*/1, idx * 128); restrict_to_standard_arithmetic_relation(prover_polynomials); - // This ensures that the combiner accumulator for second instance = 0 - // The value is computed by generating the python script values, computing the resulting accumulator and - // taking the value at index 1 - if (idx == NUM_INSTANCES - 1) { - prover_polynomials.q_c[0] -= 13644570; - } instance->proving_key.polynomials = std::move(prover_polynomials); instance->proving_key.circuit_size = 2; instance_data[idx] = instance; @@ -57,21 +52,204 @@ TEST(Protogalaxy, CombinerOn2Instances) ProverInstances instances{ instance_data }; instances.alphas.fill(bb::Univariate(FF(0))); // focus on the arithmetic relation only + auto pow_polynomial = PowPolynomial(std::vector{ 2 }); + auto result = prover.compute_combiner(instances, pow_polynomial); + // The expected_result values are computed by running the python script combiner_example_gen.py + auto expected_result = Univariate(std::array{ 8600UL, + 12679448UL, + 73617560UL, + 220571672UL, + 491290520UL, + 923522840UL, + 1555017368UL, + 2423522840UL, + 3566787992UL, + 5022561560UL, + 6828592280UL, + 9022628888UL }); + EXPECT_EQ(result, expected_result); + } else { + std::vector> instance_data(NUM_INSTANCES); + ProtoGalaxyProver prover; + + for (size_t idx = 0; idx < NUM_INSTANCES; idx++) { + auto instance = std::make_shared(); + auto prover_polynomials = get_zero_prover_polynomials( + /*log_circuit_size=*/1); + restrict_to_standard_arithmetic_relation(prover_polynomials); + instance->proving_key.polynomials = std::move(prover_polynomials); + instance->proving_key.circuit_size = 2; + instance_data[idx] = instance; + } + + ProverInstances instances{ instance_data }; + instances.alphas.fill(bb::Univariate(FF(0))); // focus on the arithmetic relation only + + const auto create_add_gate = [](auto& polys, const size_t idx, FF w_l, FF w_r) { + polys.w_l[idx] = w_l; + polys.w_r[idx] = w_r; + polys.w_o[idx] = w_l + w_r; + polys.q_l[idx] = 1; + polys.q_r[idx] = 1; + polys.q_o[idx] = -1; + }; + + const auto create_mul_gate = [](auto& polys, const size_t idx, FF w_l, FF w_r) { + polys.w_l[idx] = w_l; + polys.w_r[idx] = w_r; + polys.w_o[idx] = w_l * w_r; + polys.q_m[idx] = 1; + polys.q_o[idx] = -1; + }; + + create_add_gate(instances[0]->proving_key.polynomials, 0, 1, 2); + create_add_gate(instances[0]->proving_key.polynomials, 1, 0, 4); + create_add_gate(instances[1]->proving_key.polynomials, 0, 3, 4); + create_mul_gate(instances[1]->proving_key.polynomials, 1, 1, 4); + + restrict_to_standard_arithmetic_relation(instances[0]->proving_key.polynomials); + restrict_to_standard_arithmetic_relation(instances[1]->proving_key.polynomials); + + /* Instance 0 Instance 1 + w_l w_r w_o q_m q_l q_r q_o q_c w_l w_r w_o q_m q_l q_r q_o q_c + 1 2 3 0 1 1 -1 0 3 4 7 0 1 1 -1 0 + 0 4 4 0 1 1 -1 0 1 4 4 1 0 0 -1 0 */ + + /* Lagrange-combined values, row index 0 Lagrange-combined values, row index 1 + in 0 1 2 3 4 5 6 in 0 1 2 3 4 5 6 + w_l 1 3 5 7 9 11 13 w_l 0 1 2 3 4 5 6 + w_r 2 4 6 8 10 12 14 w_r 4 4 4 4 4 4 4 + w_o 3 7 11 15 19 23 27 w_o 4 4 4 4 4 4 0 + q_m 0 0 0 0 0 0 0 q_m 0 1 2 3 4 5 6 + q_l 1 1 1 1 1 1 1 q_l 1 0 -1 -2 -3 -4 -5 + q_r 1 1 1 1 1 1 1 q_r 1 0 -1 -2 -3 -4 -5 + q_o -1 -1 -1 -1 -1 -1 -1 q_o -1 -1 -1 -1 -1 -1 -1 + q_c 0 0 0 0 0 0 0 q_c 0 0 0 0 0 0 0 + + relation value: + 0 0 0 0 0 0 0 0 0 6 18 36 60 90 */ + auto pow_polynomial = PowPolynomial(std::vector{ 2 }); auto result = prover.compute_combiner(instances, pow_polynomial); auto optimised_result = prover.compute_combiner(instances, pow_polynomial); - auto expected_result = Univariate(std::array{ 87706, - 0, - 0x02ee2966, - 0x0b0bd2cc, - 0x00001a98fc32, - 0x000033d5a598, - 0x00005901cefe, - 0x00008c5d7864, - 0x0000d028a1ca, - 0x000126a34b30UL, - 0x0001920d7496UL, - 0x000214a71dfcUL }); + auto expected_result = + Univariate(std::array{ 0, 0, 12, 36, 72, 120, 180, 252, 336, 432, 540, 660 }); + + EXPECT_EQ(result, expected_result); + EXPECT_EQ(optimised_result, expected_result); + } + }; + run_test(true); + run_test(false); +}; + +// Check that the optimized combiner computation yields a result consistent with the unoptimized version +TEST(Protogalaxy, CombinerOptimizationConsistency) +{ + constexpr size_t NUM_INSTANCES = 2; + using ProverInstance = ProverInstance_; + using ProverInstances = ProverInstances_; + using ProtoGalaxyProver = ProtoGalaxyProver_; + using UltraArithmeticRelation = UltraArithmeticRelation; + + constexpr size_t UNIVARIATE_LENGTH = 12; + const auto restrict_to_standard_arithmetic_relation = [](auto& polys) { + std::fill(polys.q_arith.begin(), polys.q_arith.end(), 1); + std::fill(polys.q_delta_range.begin(), polys.q_delta_range.end(), 0); + std::fill(polys.q_elliptic.begin(), polys.q_elliptic.end(), 0); + std::fill(polys.q_aux.begin(), polys.q_aux.end(), 0); + std::fill(polys.q_lookup.begin(), polys.q_lookup.end(), 0); + std::fill(polys.q_4.begin(), polys.q_4.end(), 0); + std::fill(polys.w_4.begin(), polys.w_4.end(), 0); + std::fill(polys.w_4_shift.begin(), polys.w_4_shift.end(), 0); + }; + + auto run_test = [&](bool is_random_input) { + // Combiner test on prover polynomisls containing random values, restricted to only the standard arithmetic + // relation. + if (is_random_input) { + std::vector> instance_data(NUM_INSTANCES); + ASSERT(NUM_INSTANCES == 2); // Don't want to handle more here + ProtoGalaxyProver prover; + + for (size_t idx = 0; idx < NUM_INSTANCES; idx++) { + auto instance = std::make_shared(); + auto prover_polynomials = get_sequential_prover_polynomials( + /*log_circuit_size=*/1, idx * 128); + restrict_to_standard_arithmetic_relation(prover_polynomials); + instance->proving_key.polynomials = std::move(prover_polynomials); + instance->proving_key.circuit_size = 2; + instance_data[idx] = instance; + } + + ProverInstances instances{ instance_data }; + instances.alphas.fill( + bb::Univariate(FF(0))); // focus on the arithmetic relation only + auto pow_polynomial = PowPolynomial(std::vector{ 2 }); + pow_polynomial.compute_values(); + + // Relation parameters are all zeroes + RelationParameters relation_parameters; + // Temporary accumulator to compute the sumcheck on the second instance + typename Flavor::TupleOfArraysOfValues temporary_accumulator; + + // Accumulate arithmetic relation over 2 rows on the second instance + for (size_t i = 0; i < 2; i++) { + UltraArithmeticRelation::accumulate( + std::get<0>(temporary_accumulator), + instance_data[NUM_INSTANCES - 1]->proving_key.polynomials.get_row(i), + relation_parameters, + pow_polynomial[i]); + } + // Get the result of the 0th subrelation of the arithmetic relation + FF instance_offset = std::get<0>(temporary_accumulator)[0]; + // Subtract it from q_c[0] (it directly affect the target sum, making it zero and enabling the optimisation) + instance_data[1]->proving_key.polynomials.q_c[0] -= instance_offset; + std::vector + extended_polynomials; // These hold the extensions of prover polynomials + + // Manually extend all polynomials. Create new ProverPolynomials from extended values + for (size_t idx = NUM_INSTANCES; idx < UNIVARIATE_LENGTH; idx++) { + + auto instance = std::make_shared(); + auto prover_polynomials = get_zero_prover_polynomials(1); + for (auto [instance_0_polynomial, instance_1_polynomial, new_polynomial] : + zip_view(instance_data[0]->proving_key.polynomials.get_all(), + instance_data[1]->proving_key.polynomials.get_all(), + prover_polynomials.get_all())) { + for (size_t i = 0; i < /*circuit_size*/ 2; i++) { + new_polynomial[i] = + instance_0_polynomial[i] + ((instance_1_polynomial[i] - instance_0_polynomial[i]) * idx); + } + } + extended_polynomials.push_back(std::move(prover_polynomials)); + } + std::array precomputed_result{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + // Compute the sum for each index separately, treating each extended instance independently + for (size_t idx = 0; idx < UNIVARIATE_LENGTH; idx++) { + + typename Flavor::TupleOfArraysOfValues accumulator; + if (idx < NUM_INSTANCES) { + for (size_t i = 0; i < 2; i++) { + UltraArithmeticRelation::accumulate(std::get<0>(accumulator), + instance_data[idx]->proving_key.polynomials.get_row(i), + relation_parameters, + pow_polynomial[i]); + } + } else { + for (size_t i = 0; i < 2; i++) { + UltraArithmeticRelation::accumulate(std::get<0>(accumulator), + extended_polynomials[idx - NUM_INSTANCES].get_row(i), + relation_parameters, + pow_polynomial[i]); + } + } + precomputed_result[idx] = std::get<0>(accumulator)[0]; + } + auto expected_result = Univariate(precomputed_result); + auto result = prover.compute_combiner(instances, pow_polynomial); + auto optimised_result = prover.compute_combiner(instances, pow_polynomial); + EXPECT_EQ(result, expected_result); EXPECT_EQ(optimised_result, expected_result); } else { diff --git a/barretenberg/cpp/src/barretenberg/protogalaxy/combiner_example_gen.py b/barretenberg/cpp/src/barretenberg/protogalaxy/combiner_example_gen.py index 0ace782160a..bc0fb19128b 100644 --- a/barretenberg/cpp/src/barretenberg/protogalaxy/combiner_example_gen.py +++ b/barretenberg/cpp/src/barretenberg/protogalaxy/combiner_example_gen.py @@ -6,71 +6,55 @@ EXTENDED_RELATION_LENGTH = 13 class Row: - def __init__(self, start): - - self.q_c = start + 2 * 0 - self.q_l = start + 2 * 1 - self.q_r = start + 2 * 2 - self.q_o = start + 2 * 3 - self.q_4 = start + 2 * 4 - self.q_m = start + 2 * 5 - self.q_arith = start + 2 * 6 - self.q_delta_range = start + 2 * 7 - self.q_elliptic = start + 2 * 8 - self.q_aux = start + 2 * 9 - self.q_lookup = start + 2 * 10 - self.sigma_1 = start + 2 * 11 - self.sigma_2 = start + 2 * 12 - self.sigma_3 = start + 2 * 13 - self.sigma_4 = start + 2 * 14 - self.id_1 = start + 2 * 15 - self.id_2 = start + 2 * 16 - self.id_3 = start + 2 * 17 - self.id_4 = start + 2 * 18 - self.table_1 = start + 2 * 19 - self.table_2 = start + 2 * 20 - self.table_3 = start + 2 * 21 - self.table_4 = start + 2 * 22 - self.lagrange_first = start + 2 * 23 - self.lagrange_last = start + 2 * 24 - self.w_l = start + 2 * 25 - self.w_r = start + 2 * 26 - self.w_o = start + 2 * 27 - self.w_4 = start + 2 * 28 - self.sorted_accum = start + 2 * 29 - self.z_perm = start + 2 * 30 - self.z_lookup = start + 2 * 31 - self.table_1_shift = start + 2 * 32 - self.table_2_shift = start + 2 * 33 - self.table_3_shift = start + 2 * 34 - self.table_4_shift = start + 2 * 35 - self.w_l_shift = start + 2 * 36 - self.w_r_shift = start + 2 * 37 - self.w_o_shift = start + 2 * 38 - self.w_4_shift = start + 2 * 39 - self.sorted_accum_shift = start + 2 * 40 - self.z_perm_shift = start + 2 * 41 - self.z_lookup_shift = start + 2 * 42 - - self.entities = [self.q_c, self.q_l, self.q_r, self.q_o, self.q_m, self.sigma_1, self.sigma_2, self.sigma_3, self.id_1, - self.id_2, self.id_3, self.lagrange_first, self.lagrange_last, self.w_l, self.w_r, self.w_o, self.z_perm, self.z_perm_shift] - - -class Instance: - def __init__(self, rows): - self.num_entities = len(rows[0].entities) - self.rows = rows - - -class Instances: - def __init__(self, instances): - self.num_entities = instances[0].num_entities - self.data = instances - - -def rel(w_l, w_r, w_o, q_m, q_l, q_r, q_o, q_c): - return q_m * w_l * w_r + q_l * w_l + q_r * w_r + q_o * w_o + q_c - + # Construct a set of 'all' polynomials with a very simple structure + def __init__(self, base_poly): + # Constuct polys by adding increasing factors of 2 to an input poly + self.q_m = base_poly + 2 * 0 + self.q_c = base_poly + 2 * 1 + self.q_l = base_poly + 2 * 2 + self.q_r = base_poly + 2 * 3 + self.q_o = base_poly + 2 * 4 + self.q_4 = base_poly + 2 * 5 + self.q_arith = base_poly + 2 * 6 + self.q_delta_range = base_poly + 2 * 7 + self.q_elliptic = base_poly + 2 * 8 + self.q_aux = base_poly + 2 * 9 + self.q_lookup = base_poly + 2 * 10 + self.sigma_1 = base_poly + 2 * 11 + self.sigma_2 = base_poly + 2 * 12 + self.sigma_3 = base_poly + 2 * 13 + self.sigma_4 = base_poly + 2 * 14 + self.id_1 = base_poly + 2 * 15 + self.id_2 = base_poly + 2 * 16 + self.id_3 = base_poly + 2 * 17 + self.id_4 = base_poly + 2 * 18 + self.table_1 = base_poly + 2 * 19 + self.table_2 = base_poly + 2 * 20 + self.table_3 = base_poly + 2 * 21 + self.table_4 = base_poly + 2 * 22 + self.lagrange_first = base_poly + 2 * 23 + self.lagrange_last = base_poly + 2 * 24 + self.w_l = base_poly + 2 * 25 + self.w_r = base_poly + 2 * 26 + self.w_o = base_poly + 2 * 27 + self.w_4 = base_poly + 2 * 28 + self.sorted_accum = base_poly + 2 * 29 + self.z_perm = base_poly + 2 * 30 + self.z_lookup = base_poly + 2 * 31 + self.table_1_shift = base_poly + 2 * 32 + self.table_2_shift = base_poly + 2 * 33 + self.table_3_shift = base_poly + 2 * 34 + self.table_4_shift = base_poly + 2 * 35 + self.w_l_shift = base_poly + 2 * 36 + self.w_r_shift = base_poly + 2 * 37 + self.w_o_shift = base_poly + 2 * 38 + self.w_4_shift = base_poly + 2 * 39 + self.sorted_accum_shift = base_poly + 2 * 40 + self.z_perm_shift = base_poly + 2 * 41 + self.z_lookup_shift = base_poly + 2 * 42 + + def arith_relation(self): + return self.q_m * self.w_l * self.w_r + self.q_l * self.w_l + self.q_r * self.w_r + self.q_o * self.w_o + self.q_c def extend_one_entity(input): result = input @@ -79,60 +63,54 @@ def extend_one_entity(input): result.append(delta + result[-1]) return result - -def get_extended_univariates(instances, row_idx): - rows = [instance.rows[row_idx] for instance in instances.data] - for entity_idx in range(instances.num_entities): - result = [row.entities[entity_idx] for row in rows] - result = np.array(extend_one_entity(result)) - return result - def compute_first_example(): - i0 = Instance([Row(0), Row(1)]) - i1 = Instance([Row(128), Row(129)]) - instances = Instances([i0, i1]) + # Construct baseline extensions for the two rows; extentions for all polys will be computed via the Row constructor + baseline_extension_0 = np.array(extend_one_entity([0, 128])) + baseline_extension_1 = baseline_extension_0 + 1 - row_0_extended = Row(get_extended_univariates(instances, 0)) - row_1_extended = Row(get_extended_univariates(instances, 1)) + # Construct extensions for all polys for the two rows in consideration + row_0_extended = Row(baseline_extension_0) + row_1_extended = Row(baseline_extension_1) accumulator = np.array([0 for _ in range(EXTENDED_RELATION_LENGTH)]) zeta_pow = 1 zeta = 2 for row in [row_0_extended, row_1_extended]: - relation_value = rel(row.w_l, row.w_r, row.w_o, row.q_m, - row.q_l, row.q_r, row.q_o, row.q_c) - accumulator += zeta_pow * relation_value + accumulator += zeta_pow * row.arith_relation() zeta_pow *= zeta return accumulator def compute_second_example(): - result = 0 - # 0 1 2 3 4 5 6 7 8 9 10 11 12 - w_l = np.array([ 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25]) - w_r = np.array([ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]) - w_o = np.array([ 3, 7, 11, 15, 19, 23, 27, 31, 35, 39, 43, 47, 51]) - q_m = np.array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) - q_l = np.array([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) - q_r = np.array([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) - q_o = np.array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - q_c = np.array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) - # contribution is zero, but why not? - result += rel(w_l, w_r, w_o, q_m, q_l, q_r, q_o, q_c) - - w_l = np.array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) - w_r = np.array([ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) - w_o = np.array([ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) - q_m = np.array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) - q_l = np.array([ 1, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9,-10,-11]) - q_r = np.array([ 1, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9,-10,-11]) - q_o = np.array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - q_c = np.array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) - - result += rel(w_l, w_r, w_o, q_m, q_l, q_r, q_o, q_c) - result *= 2 - - return result + def arith_relation(w_l, w_r, w_o, q_m, q_l, q_r, q_o, q_c): + return q_m * w_l * w_r + q_l * w_l + q_r * w_r + q_o * w_o + q_c + + result = 0 + # 0 1 2 3 4 5 6 7 8 9 10 11 12 + w_l = np.array([ 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25]) + w_r = np.array([ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]) + w_o = np.array([ 3, 7, 11, 15, 19, 23, 27, 31, 35, 39, 43, 47, 51]) + q_m = np.array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + q_l = np.array([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + q_r = np.array([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + q_o = np.array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]) + q_c = np.array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + # contribution is zero, but why not? + result += arith_relation(w_l, w_r, w_o, q_m, q_l, q_r, q_o, q_c) + + w_l = np.array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + w_r = np.array([ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) + w_o = np.array([ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) + q_m = np.array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + q_l = np.array([ 1, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9,-10,-11]) + q_r = np.array([ 1, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9,-10,-11]) + q_o = np.array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]) + q_c = np.array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + result += arith_relation(w_l, w_r, w_o, q_m, q_l, q_r, q_o, q_c) + result *= 2 + + return result if __name__ == "__main__": print(f"First example: \n {compute_first_example()}") diff --git a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/ultra_flavor.hpp b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/ultra_flavor.hpp index 5cd7326d445..909aa29d0d2 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/ultra_flavor.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/ultra_flavor.hpp @@ -48,12 +48,15 @@ class UltraFlavor { using GrandProductRelations = std::tuple, bb::LookupRelation>; // define the tuple of Relations that comprise the Sumcheck relation - using Relations = std::tuple, - bb::UltraPermutationRelation, - bb::LookupRelation, - bb::DeltaRangeConstraintRelation, - bb::EllipticRelation, - bb::AuxiliaryRelation>; + // Note: made generic for use in MegaRecursive. + template + using Relations_ = std::tuple, + bb::UltraPermutationRelation, + bb::LookupRelation, + bb::DeltaRangeConstraintRelation, + bb::EllipticRelation, + bb::AuxiliaryRelation>; + using Relations = Relations_; static constexpr size_t MAX_PARTIAL_RELATION_LENGTH = compute_max_partial_relation_length(); static_assert(MAX_PARTIAL_RELATION_LENGTH == 6); @@ -88,7 +91,6 @@ class UltraFlavor { static constexpr bool is_decider = true; - private: /** * @brief A base class labelling precomputed entities and (ordered) subsets of interest. * @details Used to build the proving key and verification key. @@ -189,92 +191,33 @@ class UltraFlavor { * Symbolically we have: AllEntities = PrecomputedEntities + WitnessEntities + "ShiftedEntities". It could be * implemented as such, but we have this now. */ - template class AllEntities { + template + class AllEntities : public PrecomputedEntities, + public WitnessEntities, + public ShiftedEntities { public: - DEFINE_FLAVOR_MEMBERS(DataType, - q_c, // column 0 - q_l, // column 1 - q_r, // column 2 - q_o, // column 3 - q_4, // column 4 - q_m, // column 5 - q_arith, // column 6 - q_delta_range, // column 7 - q_elliptic, // column 8 - q_aux, // column 9 - q_lookup, // column 10 - sigma_1, // column 11 - sigma_2, // column 12 - sigma_3, // column 13 - sigma_4, // column 14 - id_1, // column 15 - id_2, // column 16 - id_3, // column 17 - id_4, // column 18 - table_1, // column 19 - table_2, // column 20 - table_3, // column 21 - table_4, // column 22 - lagrange_first, // column 23 - lagrange_last, // column 24 - w_l, // column 25 - w_r, // column 26 - w_o, // column 27 - w_4, // column 28 - sorted_accum, // column 29 - z_perm, // column 30 - z_lookup, // column 31 - table_1_shift, // column 32 - table_2_shift, // column 33 - table_3_shift, // column 34 - table_4_shift, // column 35 - w_l_shift, // column 36 - w_r_shift, // column 37 - w_o_shift, // column 38 - w_4_shift, // column 39 - sorted_accum_shift, // column 40 - z_perm_shift, // column 41 - z_lookup_shift) // column 42 + DEFINE_COMPOUND_GET_ALL(PrecomputedEntities, WitnessEntities, ShiftedEntities) - auto get_wires() { return RefArray{ w_l, w_r, w_o, w_4 }; }; - auto get_selectors() - { - return RefArray{ q_m, q_c, q_l, q_r, q_o, q_4, q_arith, q_delta_range, q_elliptic, q_aux, q_lookup }; - } - auto get_sigmas() { return RefArray{ sigma_1, sigma_2, sigma_3, sigma_4 }; }; - auto get_ids() { return RefArray{ id_1, id_2, id_3, id_4 }; }; - auto get_tables() { return RefArray{ table_1, table_2, table_3, table_4 }; }; + auto get_wires() { return RefArray{ this->w_l, this->w_r, this->w_o, this->w_4 }; }; + auto get_selectors() { return PrecomputedEntities::get_selectors(); } + auto get_sigmas() { return RefArray{ this->sigma_1, this->sigma_2, this->sigma_3, this->sigma_4 }; }; + auto get_ids() { return RefArray{ this->id_1, this->id_2, this->id_3, this->id_4 }; }; + auto get_tables() { return RefArray{ this->table_1, this->table_2, this->table_3, this->table_4 }; }; // Gemini-specific getters. auto get_unshifted() { - return RefArray{ q_m, q_c, q_l, q_r, q_o, q_4, q_arith, q_delta_range, - q_elliptic, q_aux, q_lookup, sigma_1, sigma_2, sigma_3, sigma_4, id_1, - id_2, id_3, id_4, table_1, table_2, table_3, table_4, lagrange_first, - lagrange_last, w_l, w_r, w_o, w_4, sorted_accum, z_perm, z_lookup - - }; + return concatenate(PrecomputedEntities::get_all(), WitnessEntities::get_all()); }; - auto get_precomputed() - { - return RefArray{ q_m, q_c, q_l, q_r, q_o, q_4, q_arith, q_delta_range, - q_elliptic, q_aux, q_lookup, sigma_1, sigma_2, sigma_3, sigma_4, id_1, - id_2, id_3, id_4, table_1, table_2, table_3, table_4, lagrange_first, - lagrange_last + auto get_precomputed() { return PrecomputedEntities::get_all(); } - }; - } - - auto get_witness() { return RefArray{ w_l, w_r, w_o, w_4, sorted_accum, z_perm, z_lookup }; }; + auto get_witness() { return WitnessEntities::get_all(); }; auto get_to_be_shifted() { - return RefArray{ table_1, table_2, table_3, table_4, w_l, w_r, w_o, w_4, sorted_accum, z_perm, z_lookup }; - }; - auto get_shifted() - { - return RefArray{ table_1_shift, table_2_shift, table_3_shift, table_4_shift, w_l_shift, w_r_shift, - w_o_shift, w_4_shift, sorted_accum_shift, z_perm_shift, z_lookup_shift }; + return RefArray{ this->table_1, this->table_2, this->table_3, this->table_4, this->w_l, this->w_r, + this->w_o, this->w_4, this->sorted_accum, this->z_perm, this->z_lookup }; }; + auto get_shifted() { return ShiftedEntities::get_all(); }; }; public: @@ -667,36 +610,37 @@ class UltraFlavor { * witness polynomials). * */ - class VerifierCommitments : public AllEntities { + template + class VerifierCommitments_ : public AllEntities { public: - VerifierCommitments(const std::shared_ptr& verification_key, - const std::optional& witness_commitments = std::nullopt) + VerifierCommitments_(const std::shared_ptr& verification_key, + const std::optional& witness_commitments = std::nullopt) { - q_m = verification_key->q_m; - q_c = verification_key->q_c; - q_l = verification_key->q_l; - q_r = verification_key->q_r; - q_o = verification_key->q_o; - q_4 = verification_key->q_4; - q_arith = verification_key->q_arith; - q_delta_range = verification_key->q_delta_range; - q_elliptic = verification_key->q_elliptic; - q_aux = verification_key->q_aux; - q_lookup = verification_key->q_lookup; - sigma_1 = verification_key->sigma_1; - sigma_2 = verification_key->sigma_2; - sigma_3 = verification_key->sigma_3; - sigma_4 = verification_key->sigma_4; - id_1 = verification_key->id_1; - id_2 = verification_key->id_2; - id_3 = verification_key->id_3; - id_4 = verification_key->id_4; - table_1 = verification_key->table_1; - table_2 = verification_key->table_2; - table_3 = verification_key->table_3; - table_4 = verification_key->table_4; - lagrange_first = verification_key->lagrange_first; - lagrange_last = verification_key->lagrange_last; + this->q_m = verification_key->q_m; + this->q_c = verification_key->q_c; + this->q_l = verification_key->q_l; + this->q_r = verification_key->q_r; + this->q_o = verification_key->q_o; + this->q_4 = verification_key->q_4; + this->q_arith = verification_key->q_arith; + this->q_delta_range = verification_key->q_delta_range; + this->q_elliptic = verification_key->q_elliptic; + this->q_aux = verification_key->q_aux; + this->q_lookup = verification_key->q_lookup; + this->sigma_1 = verification_key->sigma_1; + this->sigma_2 = verification_key->sigma_2; + this->sigma_3 = verification_key->sigma_3; + this->sigma_4 = verification_key->sigma_4; + this->id_1 = verification_key->id_1; + this->id_2 = verification_key->id_2; + this->id_3 = verification_key->id_3; + this->id_4 = verification_key->id_4; + this->table_1 = verification_key->table_1; + this->table_2 = verification_key->table_2; + this->table_3 = verification_key->table_3; + this->table_4 = verification_key->table_4; + this->lagrange_first = verification_key->lagrange_first; + this->lagrange_last = verification_key->lagrange_last; if (witness_commitments.has_value()) { auto commitments = witness_commitments.value(); @@ -710,6 +654,8 @@ class UltraFlavor { } } }; + // Specialize for Ultra (general case used in UltraRecursive). + using VerifierCommitments = VerifierCommitments_; /** * @brief Derived class that defines proof structure for Ultra proofs, as well as supporting functions. diff --git a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/ultra_recursive_flavor.hpp b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/ultra_recursive_flavor.hpp index 2979b9ac0d0..a73509fe018 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/ultra_recursive_flavor.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/ultra_recursive_flavor.hpp @@ -71,12 +71,7 @@ template class UltraRecursiveFlavor_ { static constexpr size_t NUM_WITNESS_ENTITIES = 7; // define the tuple of Relations that comprise the Sumcheck relation - using Relations = std::tuple, - bb::UltraPermutationRelation, - bb::LookupRelation, - bb::DeltaRangeConstraintRelation, - bb::EllipticRelation, - bb::AuxiliaryRelation>; + using Relations = UltraFlavor::Relations_; static constexpr size_t MAX_PARTIAL_RELATION_LENGTH = compute_max_partial_relation_length(); static_assert(MAX_PARTIAL_RELATION_LENGTH == 6); @@ -99,167 +94,6 @@ template class UltraRecursiveFlavor_ { // define the container for storing the univariate contribution from each relation in Sumcheck using TupleOfArraysOfValues = decltype(create_tuple_of_arrays_of_values()); - private: - template - /** - * @brief A base class labelling precomputed entities and (ordered) subsets of interest. - * @details Used to build the proving key and verification key. - */ - class PrecomputedEntities : public PrecomputedEntitiesBase { - public: - DEFINE_FLAVOR_MEMBERS(DataType, - q_m, // column 0 - q_c, // column 1 - q_l, // column 2 - q_r, // column 3 - q_o, // column 4 - q_4, // column 5 - q_arith, // column 6 - q_delta_range, // column 7 - q_elliptic, // column 8 - q_aux, // column 9 - q_lookup, // column 10 - sigma_1, // column 11 - sigma_2, // column 12 - sigma_3, // column 13 - sigma_4, // column 14 - id_1, // column 15 - id_2, // column 16 - id_3, // column 17 - id_4, // column 18 - table_1, // column 19 - table_2, // column 20 - table_3, // column 21 - table_4, // column 22 - lagrange_first, // column 23 - lagrange_last); // column 24 - - auto get_selectors() - { - return RefArray{ q_m, q_c, q_l, q_r, q_o, q_4, q_arith, q_delta_range, q_elliptic, q_aux, q_lookup }; - }; - auto get_sigma_polynomials() { return RefArray{ sigma_1, sigma_2, sigma_3, sigma_4 }; }; - auto get_id_polynomials() { return RefArray{ id_1, id_2, id_3, id_4 }; }; - - auto get_table_polynomials() { return RefArray{ table_1, table_2, table_3, table_4 }; }; - }; - - /** - * @brief Container for all witness polynomials used/constructed by the prover. - * @details Shifts are not included here since they do not occupy their own memory. - */ - template class WitnessEntities { - public: - DEFINE_FLAVOR_MEMBERS(DataType, - w_l, // column 0 - w_r, // column 1 - w_o, // column 2 - w_4, // column 3 - sorted_accum, // column 4 - z_perm, // column 5 - z_lookup // column 6 - - ); - - auto get_wires() { return RefArray{ w_l, w_r, w_o, w_4 }; }; - }; - - public: - /** - * @brief A container for the witness commitments. - */ - using WitnessCommitments = WitnessEntities; - - /** - * @brief A base class labelling all entities (for instance, all of the polynomials used by the prover during - * sumcheck) in this Honk variant along with particular subsets of interest - * @details Used to build containers for: the prover's polynomial during sumcheck; the sumcheck's folded - * polynomials; the univariates consturcted during during sumcheck; the evaluations produced by sumcheck. - * - * Symbolically we have: AllEntities = PrecomputedEntities + WitnessEntities + "ShiftedEntities". It could be - * implemented as such, but we have this now. - */ - template class AllEntities { - public: - DEFINE_FLAVOR_MEMBERS(DataType, - q_c, // column 0 - q_l, // column 1 - q_r, // column 2 - q_o, // column 3 - q_4, // column 4 - q_m, // column 5 - q_arith, // column 6 - q_delta_range, // column 7 - q_elliptic, // column 8 - q_aux, // column 9 - q_lookup, // column 10 - sigma_1, // column 11 - sigma_2, // column 12 - sigma_3, // column 13 - sigma_4, // column 14 - id_1, // column 15 - id_2, // column 16 - id_3, // column 17 - id_4, // column 18 - table_1, // column 19 - table_2, // column 20 - table_3, // column 21 - table_4, // column 22 - lagrange_first, // column 23 - lagrange_last, // column 24 - w_l, // column 25 - w_r, // column 26 - w_o, // column 27 - w_4, // column 28 - sorted_accum, // column 29 - z_perm, // column 30 - z_lookup, // column 31 - table_1_shift, // column 32 - table_2_shift, // column 33 - table_3_shift, // column 34 - table_4_shift, // column 35 - w_l_shift, // column 36 - w_r_shift, // column 37 - w_o_shift, // column 38 - w_4_shift, // column 39 - sorted_accum_shift, // column 40 - z_perm_shift, // column 41 - z_lookup_shift // column 42 - ); - - auto get_wires() { return RefArray{ w_l, w_r, w_o, w_4 }; }; - // Gemini-specific getters. - auto get_unshifted() - { - return RefArray{ q_m, q_c, q_l, q_r, q_o, q_4, q_arith, q_delta_range, - q_elliptic, q_aux, q_lookup, sigma_1, sigma_2, sigma_3, sigma_4, id_1, - id_2, id_3, id_4, table_1, table_2, table_3, table_4, lagrange_first, - lagrange_last, w_l, w_r, w_o, w_4, sorted_accum, z_perm, z_lookup - - }; - }; - auto get_precomputed() - { - return RefArray{ q_m, q_c, q_l, q_r, q_o, q_4, q_arith, q_delta_range, - q_elliptic, q_aux, q_lookup, sigma_1, sigma_2, sigma_3, sigma_4, id_1, - id_2, id_3, id_4, table_1, table_2, table_3, table_4, lagrange_first, - lagrange_last - - }; - } - - auto get_witness() { return RefArray{ w_l, w_r, w_o, w_4, sorted_accum, z_perm, z_lookup }; }; - auto get_to_be_shifted() - { - return RefArray{ table_1, table_2, table_3, table_4, w_l, w_r, w_o, w_4, sorted_accum, z_perm, z_lookup }; - }; - auto get_shifted() - { - return RefArray{ table_1_shift, table_2_shift, table_3_shift, table_4_shift, w_l_shift, w_r_shift, - w_o_shift, w_4_shift, sorted_accum_shift, z_perm_shift, z_lookup_shift }; - }; - }; - public: /** * @brief The verification key is responsible for storing the the commitments to the precomputed (non-witnessk) @@ -269,7 +103,8 @@ template class UltraRecursiveFlavor_ { * that, and split out separate PrecomputedPolynomials/Commitments data for clarity but also for portability of our * circuits. */ - class VerificationKey : public VerificationKey_, VerifierCommitmentKey> { + class VerificationKey + : public VerificationKey_, VerifierCommitmentKey> { public: VerificationKey(const size_t circuit_size, const size_t num_public_inputs) { @@ -357,60 +192,17 @@ template class UltraRecursiveFlavor_ { * @brief A field element for each entity of the flavor. These entities represent the prover polynomials * evaluated at one point. */ - class AllValues : public AllEntities { + class AllValues : public UltraFlavor::AllEntities { public: - using Base = AllEntities; + using Base = UltraFlavor::AllEntities; using Base::Base; - AllValues(std::array _data_in) { this->_data = _data_in; } }; - /** - * @brief A container for commitment labels. - * @note It's debatable whether this should inherit from AllEntities. since most entries are not strictly - * needed. It has, however, been useful during debugging to have these labels available. - * - */ - class CommitmentLabels : public AllEntities { - public: - CommitmentLabels() - { - this->w_l = "W_L"; - this->w_r = "W_R"; - this->w_o = "W_O"; - this->w_4 = "W_4"; - this->z_perm = "Z_PERM"; - this->z_lookup = "Z_LOOKUP"; - this->sorted_accum = "SORTED_ACCUM"; + using CommitmentLabels = UltraFlavor::CommitmentLabels; - this->q_c = "Q_C"; - this->q_l = "Q_L"; - this->q_r = "Q_R"; - this->q_o = "Q_O"; - this->q_4 = "Q_4"; - this->q_m = "Q_M"; - this->q_arith = "Q_ARITH"; - this->q_delta_range = "Q_SORT"; - this->q_elliptic = "Q_ELLIPTIC"; - this->q_aux = "Q_AUX"; - this->q_lookup = "Q_LOOKUP"; - this->sigma_1 = "SIGMA_1"; - this->sigma_2 = "SIGMA_2"; - this->sigma_3 = "SIGMA_3"; - this->sigma_4 = "SIGMA_4"; - this->id_1 = "ID_1"; - this->id_2 = "ID_2"; - this->id_3 = "ID_3"; - this->id_4 = "ID_4"; - this->table_1 = "TABLE_1"; - this->table_2 = "TABLE_2"; - this->table_3 = "TABLE_3"; - this->table_4 = "TABLE_4"; - this->lagrange_first = "LAGRANGE_FIRST"; - this->lagrange_last = "LAGRANGE_LAST"; - }; - }; + using WitnessCommitments = UltraFlavor::WitnessEntities; - class VerifierCommitments : public AllEntities { + class VerifierCommitments : public UltraFlavor::AllEntities { public: VerifierCommitments(const std::shared_ptr& verification_key, const std::optional& witness_commitments = std::nullopt)