Skip to content

Commit

Permalink
feat: zk sumcheck (#7517)
Browse files Browse the repository at this point in the history
Added ZK Sumcheck that ensures that neither round univariates nor
claimed evaluations leak witness information

ZK Sumcheck is "togglable": only Flavors with (HasZK = true) use it 

Refactored sumcheck tests: now they are typed by the Flavor (Ultra or
UltraWithZK)

Made sumcheck-outline.md consistent with the implementation, expanded
docs in sumcheck.hpp and sumcheck_round.hpp

Note: ultra/mega/... -provers and verifiers using ZK sumcheck will be
added later

Closes AztecProtocol/barretenberg#979
  • Loading branch information
iakovenkos authored Aug 19, 2024
1 parent 8f9dfd9 commit 0e9a530
Show file tree
Hide file tree
Showing 9 changed files with 932 additions and 308 deletions.
2 changes: 1 addition & 1 deletion barretenberg/cpp/docs/Doxyfile
Original file line number Diff line number Diff line change
Expand Up @@ -1355,7 +1355,7 @@ HTML_EXTRA_FILES =
# The default value is: AUTO_LIGHT.
# This tag requires that the tag GENERATE_HTML is set to YES.

HTML_COLORSTYLE = AUTO_LIGHT
HTML_COLORSTYLE = TOGGLE

# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen
# will adjust the colors in the style sheet and background images according to
Expand Down
229 changes: 125 additions & 104 deletions barretenberg/cpp/docs/src/sumcheck-outline.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions barretenberg/cpp/src/barretenberg/flavor/flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ template <typename T> concept IsFoldingFlavor = IsAnyOf<T, UltraFlavor,
UltraRecursiveFlavor_<CircuitSimulatorBN254>,
MegaRecursiveFlavor_<UltraCircuitBuilder>,
MegaRecursiveFlavor_<MegaCircuitBuilder>, MegaRecursiveFlavor_<CircuitSimulatorBN254>>;
template <typename T>
concept FlavorHasZK = T::HasZK;

template <typename Container, typename Element>
inline std::string flavor_get_label(Container&& container, const Element& element) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class UltraKeccakFlavor {
using CommitmentKey = bb::CommitmentKey<Curve>;
using VerifierCommitmentKey = bb::VerifierCommitmentKey<Curve>;

// Indicates that this flavor runs with non-ZK Sumcheck.
static constexpr bool HasZK = false;
static constexpr size_t NUM_WIRES = CircuitBuilder::NUM_WIRES;
// The number of multivariate polynomials on which a sumcheck prover sumcheck operates (including shifts). We often
// need containers of this size to hold related data, so we choose a name more agnostic than `NUM_POLYNOMIALS`.
Expand All @@ -42,6 +44,8 @@ class UltraKeccakFlavor {
static constexpr size_t NUM_PRECOMPUTED_ENTITIES = 25;
// The total number of witness entities not including shifts.
static constexpr size_t NUM_WITNESS_ENTITIES = 8;
// The total number of witnesses including shifts and derived entities.
static constexpr size_t NUM_ALL_WITNESS_ENTITIES = 13;
// Total number of folded polynomials, which is just all polynomials except the shifts
static constexpr size_t NUM_FOLDED_ENTITIES = NUM_PRECOMPUTED_ENTITIES + NUM_WITNESS_ENTITIES;

Expand Down
417 changes: 388 additions & 29 deletions barretenberg/cpp/src/barretenberg/sumcheck/sumcheck.hpp

Large diffs are not rendered by default.

400 changes: 250 additions & 150 deletions barretenberg/cpp/src/barretenberg/sumcheck/sumcheck.test.cpp

Large diffs are not rendered by default.

25 changes: 22 additions & 3 deletions barretenberg/cpp/src/barretenberg/sumcheck/sumcheck_output.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once

#include "barretenberg/flavor/flavor.hpp"
#include <array>
#include <optional>
#include <vector>
Expand All @@ -11,15 +11,34 @@ namespace bb {
* =(u_0,\ldots, u_{d-1})\f$. These are computed by \ref bb::SumcheckProver< Flavor > "Sumcheck Prover" and need to be
* checked using Zeromorph.
*/
template <typename Flavor> struct SumcheckOutput {
template <typename Flavor, typename = void> struct SumcheckOutput {
using FF = typename Flavor::FF;
using ClaimedEvaluations = typename Flavor::AllValues;
// \f$ \vec u = (u_0, ..., u_{d-1}) \f$
std::vector<FF> challenge;
// Evaluations in \f$ \vec u \f$ of the polynomials used in Sumcheck
// Evaluations at \f$ \vec u \f$ of the polynomials used in Sumcheck
ClaimedEvaluations claimed_evaluations;
// Whether or not the evaluations of multilinear polynomials \f$ P_1, \ldots, P_N \f$ and final Sumcheck evaluation
// have been confirmed
std::optional<bool> verified = false; // optional b/c this struct is shared by the Prover/Verifier
};
/**
* @brief A modification of SumcheckOutput required by ZK Flavors where a vector of evaluations of Libra univariates is
* included.
*
* @tparam Flavor
*/
template <typename Flavor> struct SumcheckOutput<Flavor, std::enable_if_t<FlavorHasZK<Flavor>>> {
using FF = typename Flavor::FF;
using ClaimedEvaluations = typename Flavor::AllValues;
// \f$ \vec u = (u_0, ..., u_{d-1}) \f$
std::vector<FF> challenge;
// Evaluations at \f$ \vec u \f$ of the polynomials used in Sumcheck
ClaimedEvaluations claimed_evaluations;
// Include ClaimedLibraEvaluations conditioned on FlavorHasZK concept
std::vector<FF> claimed_libra_evaluations;
// Whether or not the evaluations of multilinear polynomials \f$ P_1, \ldots, P_N \f$ and final Sumcheck evaluation
// have been confirmed
std::optional<bool> verified = false; // Optional b/c this struct is shared by the Prover/Verifier
};
} // namespace bb
108 changes: 87 additions & 21 deletions barretenberg/cpp/src/barretenberg/sumcheck/sumcheck_round.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "barretenberg/relations/relation_types.hpp"
#include "barretenberg/relations/utils.hpp"
#include "barretenberg/stdlib/primitives/bool/bool.hpp"
#include "zk_sumcheck_data.hpp"

namespace bb {

Expand Down Expand Up @@ -59,9 +60,8 @@ template <typename Flavor> class SumcheckProverRound {
* "MAX_PARTIAL_RELATION_LENGTH + 1".
*/
static constexpr size_t BATCHED_RELATION_PARTIAL_LENGTH = Flavor::BATCHED_RELATION_PARTIAL_LENGTH;

using SumcheckRoundUnivariate = bb::Univariate<FF, BATCHED_RELATION_PARTIAL_LENGTH>;
SumcheckTupleOfTuplesOfUnivariates univariate_accumulators;

// Prover constructor
SumcheckProverRound(size_t initial_round_size)
: round_size(initial_round_size)
Expand All @@ -87,7 +87,9 @@ template <typename Flavor> class SumcheckProverRound {
input in the first round, or from the \ref multivariates table. Using general method
\ref bb::Univariate::extend_to "extend_to", the evaluations of these polynomials are extended from the
domain \f$ \{0,1\} \f$ to the domain \f$ \{0,\ldots, D\} \f$ required for the computation of the round univariate.
* In the case when witness polynomials are masked (ZK Flavors), this method has to distinguish between witness and
* non-witness polynomials. The witness univariates obtained from witness multilinears are corrected by a masking
* quadratic term extended to the same length MAX_PARTIAL_RELATION_LENGTH.
* Should only be called externally with relation_idx equal to 0.
* In practice, #multivariates is either ProverPolynomials or PartiallyEvaluatedMultivariates.
*
Expand All @@ -98,13 +100,33 @@ template <typename Flavor> class SumcheckProverRound {
*/
template <typename ProverPolynomialsOrPartiallyEvaluatedMultivariates>
void extend_edges(ExtendedEdges& extended_edges,
const ProverPolynomialsOrPartiallyEvaluatedMultivariates& multivariates,
size_t edge_idx)
ProverPolynomialsOrPartiallyEvaluatedMultivariates& multivariates,
size_t edge_idx,
std::optional<ZKSumcheckData<Flavor>> zk_sumcheck_data = std::nullopt)
{
for (auto [extended_edge, multivariate] : zip_view(extended_edges.get_all(), multivariates.get_all())) {
bb::Univariate<FF, 2> edge({ multivariate[edge_idx], multivariate[edge_idx + 1] });
extended_edge = edge.template extend_to<MAX_PARTIAL_RELATION_LENGTH>();
}

if constexpr (!Flavor::HasZK) {
for (auto [extended_edge, multivariate] : zip_view(extended_edges.get_all(), multivariates.get_all())) {
bb::Univariate<FF, 2> edge({ multivariate[edge_idx], multivariate[edge_idx + 1] });
extended_edge = edge.template extend_to<MAX_PARTIAL_RELATION_LENGTH>();
}
} else {
// extend edges of witness polynomials and add correcting terms
for (auto [extended_edge, multivariate, masking_univariate] :
zip_view(extended_edges.get_all_witnesses(),
multivariates.get_all_witnesses(),
zk_sumcheck_data.value().masking_terms_evaluations)) {
bb::Univariate<FF, 2> edge({ multivariate[edge_idx], multivariate[edge_idx + 1] });
extended_edge = edge.template extend_to<MAX_PARTIAL_RELATION_LENGTH>();
extended_edge += masking_univariate;
};
// extend edges of public polynomials
for (auto [extended_edge, multivariate] :
zip_view(extended_edges.get_non_witnesses(), multivariates.get_non_witnesses())) {
bb::Univariate<FF, 2> edge({ multivariate[edge_idx], multivariate[edge_idx + 1] });
extended_edge = edge.template extend_to<MAX_PARTIAL_RELATION_LENGTH>();
};
};
}

/**
Expand All @@ -130,11 +152,13 @@ template <typename Flavor> class SumcheckProverRound {
method \ref extend_and_batch_univariates "extend and batch univariates".
*/
template <typename ProverPolynomialsOrPartiallyEvaluatedMultivariates>
bb::Univariate<FF, BATCHED_RELATION_PARTIAL_LENGTH> compute_univariate(
SumcheckRoundUnivariate compute_univariate(
const size_t round_idx,
ProverPolynomialsOrPartiallyEvaluatedMultivariates& polynomials,
const bb::RelationParameters<FF>& relation_parameters,
const bb::PowPolynomial<FF>& pow_polynomial,
const RelationSeparator alpha)
const RelationSeparator alpha,
std::optional<ZKSumcheckData<Flavor>> zk_sumcheck_data = std::nullopt) // only submitted when Flavor HasZK
{
BB_OP_COUNT_TIME();

Expand Down Expand Up @@ -162,8 +186,11 @@ template <typename Flavor> class SumcheckProverRound {
size_t end = (thread_idx + 1) * iterations_per_thread;

for (size_t edge_idx = start; edge_idx < end; edge_idx += 2) {
extend_edges(extended_edges[thread_idx], polynomials, edge_idx);

if constexpr (!Flavor::HasZK) {
extend_edges(extended_edges[thread_idx], polynomials, edge_idx);
} else {
extend_edges(extended_edges[thread_idx], polynomials, edge_idx, zk_sumcheck_data);
}
// Compute the \f$ \ell \f$-th edge's univariate contribution,
// scale it by the corresponding \f$ pow_{\beta} \f$ contribution and add it to the accumulators for \f$
// \tilde{S}^i(X_i) \f$. If \f$ \ell \f$'s binary representation is given by \f$ (\ell_{i+1},\ldots,
Expand All @@ -180,10 +207,19 @@ template <typename Flavor> class SumcheckProverRound {
for (auto& accumulators : thread_univariate_accumulators) {
Utils::add_nested_tuples(univariate_accumulators, accumulators);
}

// For ZK Flavors: The evaluations of the round univariates are masked by the evaluations of Libra univariates
if constexpr (Flavor::HasZK) {
auto libra_round_univariate = compute_libra_round_univariate(zk_sumcheck_data.value(), round_idx);
// Batch the univariate contributions from each sub-relation to obtain the round univariate
auto round_univariate =
batch_over_relations<SumcheckRoundUnivariate>(univariate_accumulators, alpha, pow_polynomial);
// Mask the round univariate
return round_univariate + libra_round_univariate;
}
// Batch the univariate contributions from each sub-relation to obtain the round univariate
return batch_over_relations<bb::Univariate<FF, BATCHED_RELATION_PARTIAL_LENGTH>>(
univariate_accumulators, alpha, pow_polynomial);
else {
return batch_over_relations<SumcheckRoundUnivariate>(univariate_accumulators, alpha, pow_polynomial);
}
}

/**
Expand Down Expand Up @@ -263,6 +299,32 @@ template <typename Flavor> class SumcheckProverRound {
Utils::apply_to_tuple_of_tuples(tuple, extend_and_sum);
}

/**
* @brief Compute Libra round univariate expressed given by the formula
\f{align}{
\texttt{libra_round_univariate}_i(k) =
\rho \cdot 2^{d-1-i} \left(\sum_{j = 0}^{i-1} g_j(u_{j}) + g_{i,k}+
\sum_{j=i+1}^{d-1}\left(g_{j,0}+g_{j,1}\right)\right)
= \texttt{libra_univariates}_{i}(k) + \texttt{libra_running_sum}
\f}.
*
* @param zk_sumcheck_data
* @param round_idx
*/
static SumcheckRoundUnivariate compute_libra_round_univariate(ZKSumcheckData<Flavor> zk_sumcheck_data,
size_t round_idx)
{
SumcheckRoundUnivariate libra_round_univariate;
// select the i'th column of Libra book-keeping table
auto current_column = zk_sumcheck_data.libra_univariates[round_idx];
// the evaluation of Libra round univariate at k=0...D are equal to \f$\texttt{libra_univariates}_{i}(k)\f$
// corrected by the Libra running sum
for (size_t idx = 0; idx < BATCHED_RELATION_PARTIAL_LENGTH; ++idx) {
libra_round_univariate.value_at(idx) = current_column.value_at(idx) + zk_sumcheck_data.libra_running_sum;
};
return libra_round_univariate;
}

private:
/**
* @brief In Round \f$ i \f$, for a given point \f$ \vec \ell \in \{0,1\}^{d-1 - i}\f$, calculate the contribution
Expand Down Expand Up @@ -295,7 +357,6 @@ template <typename Flavor> class SumcheckProverRound {
const FF& scaling_factor)
{
using Relation = std::tuple_element_t<relation_idx, Relations>;

// Check if the relation is skippable to speed up accumulation
if constexpr (!isSkippable<Relation, decltype(extended_edges)>) {
// If not, accumulate normally
Expand All @@ -310,7 +371,6 @@ template <typename Flavor> class SumcheckProverRound {
scaling_factor);
}
}

// Repeat for the next relation.
if constexpr (relation_idx + 1 < NUM_RELATIONS) {
accumulate_relation_univariates<relation_idx + 1>(
Expand Down Expand Up @@ -340,6 +400,7 @@ template <typename Flavor> class SumcheckVerifierRound {
public:
using FF = typename Flavor::FF;
using ClaimedEvaluations = typename Flavor::AllValues;
using ClaimedLibraEvaluations = typename std::vector<FF>;

bool round_failed = false;
/**
Expand All @@ -352,6 +413,7 @@ template <typename Flavor> class SumcheckVerifierRound {
* MAX_PARTIAL_RELATION_LENGTH "MAX_PARTIAL_RELATION_LENGTH + 1".
*/
static constexpr size_t BATCHED_RELATION_PARTIAL_LENGTH = Flavor::BATCHED_RELATION_PARTIAL_LENGTH;
using SumcheckRoundUnivariate = bb::Univariate<FF, BATCHED_RELATION_PARTIAL_LENGTH>;

FF target_total_sum = 0;

Expand All @@ -370,7 +432,7 @@ template <typename Flavor> class SumcheckVerifierRound {
* @param univariate Round univariate \f$\tilde{S}^{i}\f$ represented by its evaluations over \f$0,\ldots,D\f$.
*
*/
bool check_sum(bb::Univariate<FF, BATCHED_RELATION_PARTIAL_LENGTH>& univariate)
bool check_sum(SumcheckRoundUnivariate& univariate)
{
FF total_sum = univariate.value_at(0) + univariate.value_at(1);
// TODO(#673): Conditionals like this can go away once native verification is is just recursive verification
Expand Down Expand Up @@ -437,7 +499,7 @@ template <typename Flavor> class SumcheckVerifierRound {
* @param round_challenge \f$ u_i\f$
* @return FF \f$ \sigma_{i+1} = \tilde{S}^i(u_i)\f$
*/
FF compute_next_target_sum(bb::Univariate<FF, BATCHED_RELATION_PARTIAL_LENGTH>& univariate, FF& round_challenge)
FF compute_next_target_sum(SumcheckRoundUnivariate& univariate, FF& round_challenge)
{
// Evaluate \f$\tilde{S}^{i}(u_{i}) \f$
target_total_sum = univariate.evaluate(round_challenge);
Expand Down Expand Up @@ -473,7 +535,8 @@ template <typename Flavor> class SumcheckVerifierRound {
FF compute_full_honk_relation_purported_value(ClaimedEvaluations purported_evaluations,
const bb::RelationParameters<FF>& relation_parameters,
const bb::PowPolynomial<FF>& pow_polynomial,
const RelationSeparator alpha)
const RelationSeparator alpha,
std::optional<FF> full_libra_purported_value = std::nullopt)
{
// The verifier should never skip computation of contributions from any relation
Utils::template accumulate_relation_evaluations_without_skipping<>(
Expand All @@ -482,6 +545,9 @@ template <typename Flavor> class SumcheckVerifierRound {
FF running_challenge{ 1 };
FF output{ 0 };
Utils::scale_and_batch_elements(relation_evaluations, alpha, running_challenge, output);
if constexpr (Flavor::HasZK) {
output += full_libra_purported_value.value();
};
return output;
}
};
Expand Down
53 changes: 53 additions & 0 deletions barretenberg/cpp/src/barretenberg/sumcheck/zk_sumcheck_data.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#pragma once

#include <array>
#include <optional>
#include <vector>

namespace bb {

/**
* @brief This structure is created to contain various polynomials and constants required by ZK Sumcheck.
*
*/
template <typename Flavor> struct ZKSumcheckData {
using FF = typename Flavor::FF;
/**
* @brief The total algebraic degree of the Sumcheck relation \f$ F \f$ as a polynomial in Prover Polynomials
* \f$P_1,\ldots, P_N\f$.
*/
static constexpr size_t MAX_PARTIAL_RELATION_LENGTH = Flavor::MAX_PARTIAL_RELATION_LENGTH;
// The number of all witnesses including shifts and derived witnesses from flavors that have ZK,
// otherwise, set this constant to 0.
/**
* @brief The total algebraic degree of the Sumcheck relation \f$ F \f$ as a polynomial in Prover Polynomials
* \f$P_1,\ldots, P_N\f$ <b> incremented by </b> 1, i.e. it is equal \ref MAX_PARTIAL_RELATION_LENGTH
* "MAX_PARTIAL_RELATION_LENGTH + 1".
*/
static constexpr size_t BATCHED_RELATION_PARTIAL_LENGTH = Flavor::BATCHED_RELATION_PARTIAL_LENGTH;
// Initialize the length of the array of evaluation masking scalars as 0 for non-ZK Flavors and as
// NUM_ALL_WITNESS_ENTITIES for ZK FLavors
static constexpr size_t MASKING_SCALARS_LENGTH = Flavor::HasZK ? Flavor::NUM_ALL_WITNESS_ENTITIES : 0;
// Array of random scalars used to hide the witness info from leaking through the claimed evaluations
using EvalMaskingScalars = std::array<FF, MASKING_SCALARS_LENGTH>;
// Auxiliary table that represents the evaluations of quadratic polynomials r_j * X(1-X) at 0,...,
// MAX_PARTIAL_RELATION_LENGTH - 1
using EvaluationMaskingTable = std::array<bb::Univariate<FF, MAX_PARTIAL_RELATION_LENGTH>, MASKING_SCALARS_LENGTH>;
// The size of the LibraUnivariates. We ensure that they do not take extra space when Flavor runs non-ZK
// Sumcheck.
static constexpr size_t LIBRA_UNIVARIATES_LENGTH = Flavor::HasZK ? Flavor::BATCHED_RELATION_PARTIAL_LENGTH : 0;
// Container for the Libra Univariates. Their number depends on the size of the circuit.
using LibraUnivariates = std::vector<bb::Univariate<FF, LIBRA_UNIVARIATES_LENGTH>>;
// Container for the evaluations of Libra Univariates that have to be proven.
using ClaimedLibraEvaluations = std::vector<FF>;

EvalMaskingScalars eval_masking_scalars;
EvaluationMaskingTable masking_terms_evaluations;
LibraUnivariates libra_univariates;
FF libra_scaling_factor{ 1 };
FF libra_challenge;
FF libra_running_sum;
ClaimedLibraEvaluations libra_evaluations;
};

} // namespace bb

0 comments on commit 0e9a530

Please sign in to comment.