From b4ad8f38250d82531439d6db33c8f81387c42496 Mon Sep 17 00:00:00 2001 From: Zachary James Williamson Date: Fri, 29 Sep 2023 11:58:08 +0100 Subject: [PATCH] feat: consistent pedersen hash (work in progress) (#1945) This PR implements part of our pedersen refactor project https://hackmd.io/XYBiWhHPT9C1bo4nrtoo0A?view We introduce a new stdlib class `cycle_group` , that implements a full suite of group operations over a generic SNARK-friendly embedded curve. Of key interest is `cycle_group::batch_mul`, which uses both fixed-base and variable-base multiplication to optimally evaluate in-circuit scalar multiplications. All external `cycle_group` operations are statistically complete i.e. the edge-cases for point addition on short weierstrass curves are handled, either explicitly or statistically using 'offset generators' (i.e. when performing a cycle_group computation, precomputed generator points are introduced to prevent intermediate results triggering addition formula edge-cases). This enables us to efficiently represent points at infinity. In the future we can reduce the complexity of our stdlib/recursion implementation by not requiring Prover commitments to not be points at infinity. Additionally, `pedersen_commitment` and `pedersen_hash` have been refactored according to the project specification - using `cycle_group` methods internally instead of bespoke algorithms that are difficult to reproduce. This PR does not modify existing interfaces or implementations w.r.t pedersen commitents/hashing. This will come as part 2 of the refactor as the interface modifications increase would increase the code surface of an already large PR. # Checklist: Remove the checklist to signal you've completed it. Enable auto-merge if the PR is ready to merge. - [x] If the pull request requires a cryptography review (e.g. cryptographic algorithm implementations) I have added the 'crypto' tag. - [x] I have reviewed my diff in github, line by line and removed unexpected formatting changes, testing logs, or commented-out code. - [x] Every change is related to the PR description. - [x] I have [linked](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) this pull request to relevant issues (if any exist). --------- Co-authored-by: Charlie Lye --- .../benchmark/relations_bench/CMakeLists.txt | 1 + .../generators/fixed_base_scalar_mul.hpp | 3 + .../crypto/generators/generator_data.cpp | 2 + .../crypto/generators/generator_data.hpp | 2 + .../crypto/generators/generator_data.test.cpp | 2 + .../crypto/pedersen_commitment/c_bind.cpp | 2 + .../crypto/pedersen_commitment/c_bind.hpp | 2 + .../crypto/pedersen_commitment/pedersen.cpp | 2 + .../crypto/pedersen_commitment/pedersen.hpp | 2 + .../pedersen_commitment/pedersen_lookup.cpp | 2 + .../pedersen_commitment/pedersen_lookup.hpp | 2 + .../pedersen_lookup.test.cpp | 2 + .../pedersen_commitment/pedersen_refactor.cpp | 51 + .../pedersen_commitment/pedersen_refactor.hpp | 109 ++ .../crypto/pedersen_hash/c_bind.cpp | 2 + .../crypto/pedersen_hash/c_bind.hpp | 2 + .../crypto/pedersen_hash/pedersen.cpp | 2 + .../crypto/pedersen_hash/pedersen.hpp | 3 + .../crypto/pedersen_hash/pedersen_lookup.cpp | 2 + .../crypto/pedersen_hash/pedersen_lookup.hpp | 1 + .../pedersen_hash/pedersen_refactor.cpp | 57 + .../pedersen_hash/pedersen_refactor.hpp | 78 + .../cpp/src/barretenberg/ecc/CMakeLists.txt | 2 +- .../ecc/curves/grumpkin/grumpkin.cpp | 2 + .../ecc/curves/grumpkin/grumpkin.hpp | 2 + .../ecc/curves/secp256k1/secp256k1.cpp | 2 + .../ecc/curves/secp256k1/secp256k1.hpp | 3 + .../ecc/curves/secp256r1/secp256r1.cpp | 2 + .../ecc/curves/secp256r1/secp256r1.hpp | 2 + .../ecc/fields/field_declarations.hpp | 16 + .../ecc/groups/affine_element.hpp | 10 +- .../ecc/groups/affine_element_impl.hpp | 101 +- .../barretenberg/ecc/groups/element_impl.hpp | 5 +- .../cpp/src/barretenberg/ecc/groups/group.hpp | 78 +- .../arithmetization/arithmetization.hpp | 48 +- .../arithmetization/gate_data.hpp | 6 + .../circuit_builder/circuit_builder_base.hpp | 4 + .../circuit_builder/ultra_circuit_builder.cpp | 154 +- .../circuit_builder/ultra_circuit_builder.hpp | 21 +- .../ultra_circuit_builder.test.cpp | 23 + .../plookup_tables/fixed_base/fixed_base.cpp | 275 ++++ .../plookup_tables/fixed_base/fixed_base.hpp | 110 ++ .../fixed_base/fixed_base_params.hpp | 76 + .../proof_system/plookup_tables/pedersen.hpp | 2 + .../plookup_tables/plookup_tables.cpp | 19 +- .../plookup_tables/plookup_tables.hpp | 25 +- .../proof_system/plookup_tables/types.hpp | 13 +- .../stdlib/hash/pedersen/pedersen.test.cpp | 53 + .../hash/pedersen/pedersen_refactor.cpp | 46 + .../hash/pedersen/pedersen_refactor.hpp | 43 + .../stdlib/primitives/field/field.cpp | 11 + .../stdlib/primitives/field/field.hpp | 40 +- .../stdlib/primitives/group/cycle_group.cpp | 1323 +++++++++++++++++ .../stdlib/primitives/group/cycle_group.hpp | 240 +++ .../primitives/group/cycle_group.test.cpp | 566 +++++++ .../stdlib/primitives/group/group.hpp | 1 + .../stdlib/primitives/group/group.test.cpp | 2 + 57 files changed, 3589 insertions(+), 68 deletions(-) create mode 100644 barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_refactor.cpp create mode 100644 barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_refactor.hpp create mode 100644 barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_refactor.cpp create mode 100644 barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_refactor.hpp create mode 100644 barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/fixed_base/fixed_base.cpp create mode 100644 barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/fixed_base/fixed_base.hpp create mode 100644 barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/fixed_base/fixed_base_params.hpp create mode 100644 barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen.test.cpp create mode 100644 barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen_refactor.cpp create mode 100644 barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen_refactor.hpp create mode 100644 barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.cpp create mode 100644 barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.hpp create mode 100644 barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.test.cpp diff --git a/barretenberg/cpp/src/barretenberg/benchmark/relations_bench/CMakeLists.txt b/barretenberg/cpp/src/barretenberg/benchmark/relations_bench/CMakeLists.txt index 57f0eea70ee..d6b5a9a0df3 100644 --- a/barretenberg/cpp/src/barretenberg/benchmark/relations_bench/CMakeLists.txt +++ b/barretenberg/cpp/src/barretenberg/benchmark/relations_bench/CMakeLists.txt @@ -7,6 +7,7 @@ relations.bench.cpp # Required libraries for benchmark suites set(LINKED_LIBRARIES polynomials + proof_system benchmark::benchmark ) diff --git a/barretenberg/cpp/src/barretenberg/crypto/generators/fixed_base_scalar_mul.hpp b/barretenberg/cpp/src/barretenberg/crypto/generators/fixed_base_scalar_mul.hpp index 555b8837f18..1fdd669d778 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/generators/fixed_base_scalar_mul.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/generators/fixed_base_scalar_mul.hpp @@ -1,4 +1,7 @@ #pragma once + +// TODO(@zac-williamson #2341 delete this file once we migrate to new pedersen hash standard) + #include "./generator_data.hpp" #include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" diff --git a/barretenberg/cpp/src/barretenberg/crypto/generators/generator_data.cpp b/barretenberg/cpp/src/barretenberg/crypto/generators/generator_data.cpp index bcc500653bf..b8910ed897b 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/generators/generator_data.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/generators/generator_data.cpp @@ -1,5 +1,7 @@ #include "./generator_data.hpp" +// TODO(@zac-williamson #2341 delete this file once we migrate to new pedersen hash standard) + namespace crypto { namespace generators { namespace { diff --git a/barretenberg/cpp/src/barretenberg/crypto/generators/generator_data.hpp b/barretenberg/cpp/src/barretenberg/crypto/generators/generator_data.hpp index 34b1d107df7..999b802ccc4 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/generators/generator_data.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/generators/generator_data.hpp @@ -1,4 +1,6 @@ #pragma once + +// TODO(@zac-williamson #2341 delete this file once we migrate to new pedersen hash standard) #include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" #include #include diff --git a/barretenberg/cpp/src/barretenberg/crypto/generators/generator_data.test.cpp b/barretenberg/cpp/src/barretenberg/crypto/generators/generator_data.test.cpp index 8c6257be9f0..45b5b5f461f 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/generators/generator_data.test.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/generators/generator_data.test.cpp @@ -1,3 +1,5 @@ +// TODO(@zac-williamson #2341 delete this file once we migrate to new pedersen hash standard) + #include "./generator_data.hpp" #include "./fixed_base_scalar_mul.hpp" #include "barretenberg/common/streams.hpp" diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/c_bind.cpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/c_bind.cpp index 1a174cb64fd..63b9b4834dd 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/c_bind.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/c_bind.cpp @@ -1,3 +1,5 @@ +// TODO(@zac-wiliamson #2341 delete this file and rename c_bind_new to c_bind once we have migrated to new hash standard + #include "c_bind.hpp" #include "barretenberg/common/mem.hpp" #include "barretenberg/common/serialize.hpp" diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/c_bind.hpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/c_bind.hpp index 26d5308df70..19b5de4404c 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/c_bind.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/c_bind.hpp @@ -1,3 +1,5 @@ +// TODO(@zac-wiliamson #2341 delete this file and rename c_bind_new to c_bind once we have migrated to new hash standard + #pragma once #include "barretenberg/common/mem.hpp" #include "barretenberg/common/serialize.hpp" diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen.cpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen.cpp index ae410af0197..924b3bb4b08 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen.cpp @@ -1,3 +1,5 @@ +// TODO(@zac-wiliamson #2341 delete this file once we migrate to new hash standard + #include "./pedersen.hpp" #include "./convert_buffer_to_field.hpp" #include "barretenberg/common/throw_or_abort.hpp" diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen.hpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen.hpp index 3571016ebd7..82493dedc14 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen.hpp @@ -1,3 +1,5 @@ +// TODO(@zac-wiliamson #2341 delete this file once we migrate to new hash standard + #pragma once #include "../generators/fixed_base_scalar_mul.hpp" #include "../generators/generator_data.hpp" diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_lookup.cpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_lookup.cpp index 5e6288e8dfa..1310afe8a33 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_lookup.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_lookup.cpp @@ -1,3 +1,5 @@ +// TODO(@zac-wiliamson #2341 delete this file once we migrate to new hash standard + #include "./pedersen_lookup.hpp" #include "../pedersen_hash/pedersen_lookup.hpp" #include "./convert_buffer_to_field.hpp" diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_lookup.hpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_lookup.hpp index b77fac9688d..a0c4c50e02c 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_lookup.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_lookup.hpp @@ -1,5 +1,7 @@ #pragma once +// TODO(@zac-wiliamson #2341 delete this file once we migrate to new hash standard + #include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" namespace crypto { diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_lookup.test.cpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_lookup.test.cpp index 49ca4825ab1..a83f903953d 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_lookup.test.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_lookup.test.cpp @@ -1,3 +1,5 @@ +// TODO(@zac-wiliamson #2341 delete this file once we migrate to new hash standard + #include "barretenberg/numeric/bitop/get_msb.hpp" #include "barretenberg/numeric/random/engine.hpp" #include diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_refactor.cpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_refactor.cpp new file mode 100644 index 00000000000..cd044164241 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_refactor.cpp @@ -0,0 +1,51 @@ +// TODO(@zac-wiliamson #2341 rename to pedersen.cpp once we migrate to new hash standard) + +#include "./pedersen_refactor.hpp" +#include "./convert_buffer_to_field.hpp" +#include "barretenberg/common/serialize.hpp" +#include "barretenberg/common/throw_or_abort.hpp" +#include +#ifndef NO_OMP_MULTITHREADING +#include +#endif + +namespace crypto { + +/** + * @brief Given a vector of fields, generate a pedersen commitment using the indexed generators. + * + * @details This method uses `Curve::BaseField` members as inputs. This aligns with what we expect when creating + * grumpkin commitments to field elements inside a BN254 SNARK circuit. + * + * @note Fq is the *coordinate field* of Curve. Curve itself is a SNARK-friendly curve, + * i.e. Fq represents the native field type of the SNARK circuit. + * @param inputs + * @param hash_index + * @param generator_context + * @return Curve::AffineElement + */ +template +typename Curve::AffineElement pedersen_commitment_refactor::commit_native( + const std::vector& inputs, const size_t hash_index, const generator_data* const generator_context) +{ + const auto generators = generator_context->conditional_extend(inputs.size() + hash_index); + Element result = Group::point_at_infinity; + + // `Curve::Fq` represents the field that `Curve` is defined over (i.e. x/y coordinate field) and `Curve::Fr` is the + // field whose modulus = the group order of `Curve`. + // The `Curve` we're working over here is a generic SNARK-friendly curve. i.e. the SNARK circuit is defined over a + // field equivalent to `Curve::Fq`. This adds complexity when we wish to commit to SNARK circuit field elements, as + // these are members of `Fq` and *not* `Fr`. We cast to `uint256_t` in order to convert an element of `Fq` into an + // `Fr` element, which is the required type when performing scalar multiplications. + static_assert(Fr::modulus > Fq::modulus, + "pedersen_commitment::commit_native Curve subgroup field is smaller than coordinate field. Cannot " + "perform injective conversion"); + for (size_t i = 0; i < inputs.size(); ++i) { + Fr scalar_multiplier(static_cast(inputs[i])); + result += Element(generators.get(i, hash_index)) * scalar_multiplier; + } + return result; +} + +template class pedersen_commitment_refactor; +} // namespace crypto diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_refactor.hpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_refactor.hpp new file mode 100644 index 00000000000..5fec5e24186 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen_refactor.hpp @@ -0,0 +1,109 @@ +#pragma once + +// TODO(@zac-wiliamson #2341 rename to pedersen.hpp once we migrate to new hash standard) + +#include "../generators/fixed_base_scalar_mul.hpp" +#include "../generators/generator_data.hpp" +#include "barretenberg/ecc/curves/bn254/bn254.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" +#include + +namespace crypto { + +/** + * @brief Contains a vector of precomputed generator points. + * Generators are defined via a domain separator. + * Number of generators in generator_data is fixed for a given object instance. + * + * @details generator_data is used to precompute short lists of commonly used generators, + * (e.g. static inline const default_generators = generator_data()). + * If an algorithm requires more than `_size_ generators, + * the `conditional_extend` method can be called to return a new `generator_data` object. + * N.B. we explicitly do not support mutating an existing `generator_data` object to increase the size of + * its `std::vector generators` member variable. + * This is because this class is intended to be used as a `static` member of other classes to provide lists + * of precomputed generators. Mutating static member variables is *not* thread safe! + */ +template class generator_data { + public: + using Group = typename Curve::Group; + using AffineElement = typename Curve::AffineElement; + static inline constexpr size_t DEFAULT_NUM_GENERATORS = 32; + static inline const std::string DEFAULT_DOMAIN_SEPARATOR = "default_domain_separator"; + inline generator_data(const size_t num_generators = DEFAULT_NUM_GENERATORS, + const std::string& domain_separator = DEFAULT_DOMAIN_SEPARATOR) + : _domain_separator(domain_separator) + , _domain_separator_bytes(domain_separator.begin(), domain_separator.end()) + , _size(num_generators){}; + + [[nodiscard]] inline std::string domain_separator() const { return _domain_separator; } + [[nodiscard]] inline size_t size() const { return _size; } + [[nodiscard]] inline AffineElement get(const size_t index, const size_t offset = 0) const + { + ASSERT(index + offset <= _size); + return generators[index + offset]; + } + + /** + * @brief If more generators than `_size` are required, this method will return a new `generator_data` object + * with the required generators. + * + * @note Question: is this a good pattern to support? Ideally downstream code would ensure their + * `generator_data` object is sufficiently large to cover potential needs. + * But if we did not support this pattern, it would make downstream code more complex as each method that + * uses `generator_data` would have to perform this accounting logic. + * + * @param target_num_generators + * @return generator_data + */ + [[nodiscard]] inline generator_data conditional_extend(const size_t target_num_generators) const + { + if (target_num_generators <= _size) { + return *this; + } + return { target_num_generators, _domain_separator }; + } + + private: + std::string _domain_separator; + std::vector _domain_separator_bytes; + size_t _size; + // ordering of static variable initialization is undefined, so we make `default_generators` private + // and only accessible via `get_default_generators()`, which ensures var will be initialized at the cost of some + // small runtime checks + inline static const generator_data default_generators = + generator_data(generator_data::DEFAULT_NUM_GENERATORS, generator_data::DEFAULT_DOMAIN_SEPARATOR); + + public: + inline static const generator_data* get_default_generators() { return &default_generators; } + const std::vector generators = (Group::derive_generators_secure(_domain_separator_bytes, _size)); +}; + +template class generator_data; + +/** + * @brief Performs pedersen commitments! + * + * To commit to a size-n list of field elements `x`, a commitment is defined as: + * + * Commit(x) = x[0].g[0] + x[1].g[1] + ... + x[n-1].g[n-1] + * + * Where `g` is a list of generator points defined by `generator_data` + * + */ +template class pedersen_commitment_refactor { + public: + using AffineElement = typename Curve::AffineElement; + using Element = typename Curve::Element; + using Fr = typename Curve::ScalarField; + using Fq = typename Curve::BaseField; + using Group = typename Curve::Group; + + static AffineElement commit_native( + const std::vector& inputs, + size_t hash_index = 0, + const generator_data* generator_context = generator_data::get_default_generators()); +}; + +extern template class pedersen_commitment_refactor; +} // namespace crypto diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/c_bind.cpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/c_bind.cpp index cf6ae337e76..be902124647 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/c_bind.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/c_bind.cpp @@ -1,3 +1,5 @@ +// TODO(@zac-wiliamson #2341 delete this file and rename c_bind_new to c_bind once we have migrated to new hash standard + #include "barretenberg/common/mem.hpp" #include "barretenberg/common/serialize.hpp" #include "barretenberg/common/streams.hpp" diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/c_bind.hpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/c_bind.hpp index d9b8c8735f9..ca063950401 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/c_bind.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/c_bind.hpp @@ -1,4 +1,6 @@ #pragma once +// TODO(@zac-wiliamson #2341 delete this file and rename c_bind_new to c_bind once we have migrated to new hash standard + #include "barretenberg/common/wasm_export.hpp" #include "barretenberg/ecc/curves/bn254/fr.hpp" diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.cpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.cpp index 6debd1b9ff3..ca3797cc16d 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.cpp @@ -1,3 +1,5 @@ +// TODO(@zac-wiliamson #2341 delete this file once we migrate to new hash standard + #include "./pedersen.hpp" #include #ifndef NO_OMP_MULTITHREADING diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.hpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.hpp index 40bdfc7ff8d..1cedec07b4a 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.hpp @@ -1,4 +1,7 @@ #pragma once + +// TODO(@zac-wiliamson #2341 delete this file once we migrate to new hash standard + #include "../generators/fixed_base_scalar_mul.hpp" #include "../generators/generator_data.hpp" #include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_lookup.cpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_lookup.cpp index 980b41a2259..3c1cc5eb835 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_lookup.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_lookup.cpp @@ -1,3 +1,5 @@ +// TODO(@zac-wiliamson #2341 delete this file once we migrate to new hash standard + #include "./pedersen_lookup.hpp" #include diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_lookup.hpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_lookup.hpp index 9a019a8547c..5e390776d90 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_lookup.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_lookup.hpp @@ -1,4 +1,5 @@ #pragma once +// TODO(@zac-wiliamson #2341 delete this file once we migrate to new hash standard #include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_refactor.cpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_refactor.cpp new file mode 100644 index 00000000000..681b12f64e2 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_refactor.cpp @@ -0,0 +1,57 @@ +#include "./pedersen_refactor.hpp" +#include +#ifndef NO_OMP_MULTITHREADING +#include +#endif + +// TODO(@zac-wiliamson #2341 rename to pedersen.cpp once we migrate to new hash standard) + +namespace crypto { + +using namespace generators; + +/** + * Given a vector of fields, generate a pedersen hash using the indexed generators. + */ + +/** + * @brief Given a vector of fields, generate a pedersen hash using generators from `generator_context`. + * + * @details `hash_index` is used to access offset elements of `generator_context` if required. + * e.g. if one desires to compute + * `inputs[0] * [generators[hash_index]] + `inputs[1] * [generators[hash_index + 1]]` + ... etc + * Potentially useful to ensure multiple hashes with the same domain separator cannot collide. + * + * TODO(@suyash67) can we change downstream code so that `hash_index` is no longer required? Now we have a proper + * domain_separator parameter, we no longer need to specify different generator indices to ensure hashes cannot collide. + * @param inputs what are we hashing? + * @param hash_index Describes an offset into the list of generators, if required + * @param generator_context + * @return Fq (i.e. SNARK circuit scalar field, when hashing using a curve defined over the SNARK circuit scalar field) + */ +template +typename Curve::BaseField pedersen_hash_refactor::hash_multiple(const std::vector& inputs, + const size_t hash_index, + const generator_data* const generator_context) +{ + const auto generators = generator_context->conditional_extend(inputs.size() + hash_index); + + Element result = get_length_generator() * Fr(inputs.size()); + + for (size_t i = 0; i < inputs.size(); ++i) { + result += generators.get(i, hash_index) * Fr(static_cast(inputs[i])); + } + result = result.normalize(); + return result.x; +} + +template +typename Curve::BaseField pedersen_hash_refactor::hash(const std::vector& inputs, + size_t hash_index, + const generator_data* const generator_context) +{ + return hash_multiple(inputs, hash_index, generator_context); +} + +template class pedersen_hash_refactor; +} // namespace crypto \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_refactor.hpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_refactor.hpp new file mode 100644 index 00000000000..abd898cc326 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen_refactor.hpp @@ -0,0 +1,78 @@ +#pragma once + +// TODO(@zac-wiliamson #2341 rename to pedersen.hpp once we migrate to new hash standard) + +#include "../generators/fixed_base_scalar_mul.hpp" +#include "../generators/generator_data.hpp" +#include "../pedersen_commitment/pedersen_refactor.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" +#include + +namespace crypto { + +/** + * @brief Performs pedersen hashes! + * + * To hash to a size-n list of field elements `x`, we return the X-coordinate of: + * + * Hash(x) = n.[h] + x_0. [g_0] + x_1 . [g_1] +... + x_n . [g_n] + * + * Where `g` is a list of generator points defined by `generator_data` + * And `h` is a unique generator whose domain separator is the string `pedersen_hash_length`. + * + * The addition of `n.[h]` into the hash is to prevent length-extension attacks. + * It also ensures that the hash output is never the point at infinity. + * + * It is neccessary that all generator points are linearly independent of one another, + * so that finding collisions is equivalent to solving the discrete logarithm problem. + * This is ensured via the generator derivation algorithm in `generator_data` + */ +template class pedersen_hash_refactor { + public: + using AffineElement = typename Curve::AffineElement; + using Element = typename Curve::Element; + using Fr = typename Curve::ScalarField; + using Fq = typename Curve::BaseField; + using Group = typename Curve::Group; + using generator_data = typename crypto::generator_data; + + /** + * @brief lhs_generator is an alias for the first element in `default_generators`. + * i.e. the 1st generator point in a size-2 pedersen hash + * + * @details Short story: don't make global static member variables publicly accessible. + * Ordering of global static variable initialization is not defined. + * Consider a scenario where this class has `inline static const AffineElement lhs_generator;` + * If another static variable's init function accesses `pedersen_hash_refactor::lhs_generator`, + * there is a chance that `lhs_generator` is not yet initialized due to undefined init order. + * This creates merry havoc due to assertions triggering during runtime initialization of global statics. + * So...don't do that. Wrap your statics. + */ + inline static AffineElement get_lhs_generator() { return generator_data::get_default_generators()->get(0); } + /** + * @brief rhs_generator is an alias for the second element in `default_generators`. + * i.e. the 2nd generator point in a size-2 pedersen hash + */ + inline static AffineElement get_rhs_generator() { return generator_data::get_default_generators()->get(1); } + /** + * @brief length_generator is used to ensure pedersen hash is not vulnerable to length-exstension attacks + */ + inline static AffineElement get_length_generator() + { + static const AffineElement length_generator = Group::get_secure_generator_from_index(0, "pedersen_hash_length"); + return length_generator; + } + + // TODO(@suyash67) as part of refactor project, can we remove this and replace with `hash` + // (i.e. simplify the name as we no longer have a need for `hash_single`) + static Fq hash_multiple(const std::vector& inputs, + size_t hash_index = 0, + const generator_data* generator_context = generator_data::get_default_generators()); + + static Fq hash(const std::vector& inputs, + size_t hash_index = 0, + const generator_data* generator_context = generator_data::get_default_generators()); +}; + +extern template class pedersen_hash_refactor; +} // namespace crypto diff --git a/barretenberg/cpp/src/barretenberg/ecc/CMakeLists.txt b/barretenberg/cpp/src/barretenberg/ecc/CMakeLists.txt index 1133de8da90..35e543283ee 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/CMakeLists.txt +++ b/barretenberg/cpp/src/barretenberg/ecc/CMakeLists.txt @@ -1,4 +1,4 @@ -barretenberg_module(ecc numeric crypto_keccak) +barretenberg_module(ecc numeric crypto_keccak crypto_sha256) if(DISABLE_ADX) message(STATUS "Disabling ADX assembly variant.") diff --git a/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/grumpkin.cpp b/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/grumpkin.cpp index caa7f871fbc..d49057edda5 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/grumpkin.cpp +++ b/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/grumpkin.cpp @@ -10,6 +10,8 @@ static std::array generators; static bool init_generators = false; } // namespace +// TODO(@zac-wiliamson #2341 remove this method once we migrate to new hash standard (derive_generators_secure is +// curve-agnostic) g1::affine_element get_generator(const size_t generator_index) { if (!init_generators) { diff --git a/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/grumpkin.hpp b/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/grumpkin.hpp index 0bad58a8d51..b351be77359 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/grumpkin.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/grumpkin.hpp @@ -30,6 +30,8 @@ struct GrumpkinG1Params { }; using g1 = barretenberg::group; +// TODO(@zac-wiliamson #2341 remove this method once we migrate to new hash standard (derive_generators_secure is +// curve-agnostic) g1::affine_element get_generator(size_t generator_index); }; // namespace grumpkin diff --git a/barretenberg/cpp/src/barretenberg/ecc/curves/secp256k1/secp256k1.cpp b/barretenberg/cpp/src/barretenberg/ecc/curves/secp256k1/secp256k1.cpp index b199208cec9..6c3f7366c2d 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/curves/secp256k1/secp256k1.cpp +++ b/barretenberg/cpp/src/barretenberg/ecc/curves/secp256k1/secp256k1.cpp @@ -13,6 +13,8 @@ static bool init_generators = false; /* In case where prime bit length is 256, the method produces a generator, but only with one less bit of randomness than the maximum possible, as the y coordinate in that case is determined by the x-coordinate. */ +// TODO(@zac-wiliamson #2341 remove this method once we migrate to new hash standard (derive_generators_secure is +// curve-agnostic) g1::affine_element get_generator(const size_t generator_index) { if (!init_generators) { diff --git a/barretenberg/cpp/src/barretenberg/ecc/curves/secp256k1/secp256k1.hpp b/barretenberg/cpp/src/barretenberg/ecc/curves/secp256k1/secp256k1.hpp index a2de49cd4c9..aa245e32f6e 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/curves/secp256k1/secp256k1.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/curves/secp256k1/secp256k1.hpp @@ -120,6 +120,9 @@ struct Secp256k1G1Params { using g1 = barretenberg:: group, barretenberg::field, Secp256k1G1Params>; + +// TODO(@zac-wiliamson #2341 remove this method once we migrate to new hash standard (derive_generators_secure is +// curve-agnostic) g1::affine_element get_generator(size_t generator_index); } // namespace secp256k1 diff --git a/barretenberg/cpp/src/barretenberg/ecc/curves/secp256r1/secp256r1.cpp b/barretenberg/cpp/src/barretenberg/ecc/curves/secp256r1/secp256r1.cpp index 061bbd2c2fd..f5409d30436 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/curves/secp256r1/secp256r1.cpp +++ b/barretenberg/cpp/src/barretenberg/ecc/curves/secp256r1/secp256r1.cpp @@ -13,6 +13,8 @@ static bool init_generators = false; /* In case where prime bit length is 256, the method produces a generator, but only with one less bit of randomness than the maximum possible, as the y coordinate in that case is determined by the x-coordinate. */ +// TODO(@zac-wiliamson #2341 remove this method once we migrate to new hash standard (derive_generators_secure is +// curve-agnostic) g1::affine_element get_generator(const size_t generator_index) { if (!init_generators) { diff --git a/barretenberg/cpp/src/barretenberg/ecc/curves/secp256r1/secp256r1.hpp b/barretenberg/cpp/src/barretenberg/ecc/curves/secp256r1/secp256r1.hpp index e7bf6422c95..653fa457435 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/curves/secp256r1/secp256r1.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/curves/secp256r1/secp256r1.hpp @@ -106,6 +106,8 @@ struct Secp256r1G1Params { using g1 = barretenberg:: group, barretenberg::field, Secp256r1G1Params>; +// TODO(@zac-wiliamson #2341 remove this method once we migrate to new hash standard (derive_generators_secure is +// curve-agnostic) g1::affine_element get_generator(size_t generator_index); } // namespace secp256r1 diff --git a/barretenberg/cpp/src/barretenberg/ecc/fields/field_declarations.hpp b/barretenberg/cpp/src/barretenberg/ecc/fields/field_declarations.hpp index 98e47b7f354..4e2b292fa8c 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/fields/field_declarations.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/fields/field_declarations.hpp @@ -89,6 +89,22 @@ template struct alignas(32) field { constexpr field(const uint64_t a, const uint64_t b, const uint64_t c, const uint64_t d) noexcept : data{ a, b, c, d } {}; + /** + * @brief Convert a 512-bit big integer into a field element. + * + * @details Used for deriving field elements from random values. 512-bits prevents biased output as 2^512>>modulus + * + */ + constexpr explicit field(const uint512_t& input) noexcept + { + uint256_t value = (input % modulus).lo; + data[0] = value.data[0]; + data[1] = value.data[1]; + data[2] = value.data[2]; + data[3] = value.data[3]; + self_to_montgomery_form(); + } + constexpr explicit operator uint32_t() const { field out = from_montgomery_form(); diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp index b7d0a43712e..0c7c33cb482 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp @@ -6,6 +6,8 @@ #include namespace barretenberg::group_elements { +template +concept SupportsHashToCurve = T::can_hash_to_curve; template class alignas(64) affine_element { public: using in_buf = const uint8_t*; @@ -68,6 +70,8 @@ template class alignas(64) affine_el [[nodiscard]] constexpr bool on_curve() const noexcept; + static constexpr std::optional derive_from_x_coordinate(const Fq& x, bool sign_bit) noexcept; + /** * @brief Samples a random point on the curve. * @@ -75,8 +79,6 @@ template class alignas(64) affine_el */ static affine_element random_element(numeric::random::Engine* engine = nullptr) noexcept; - static std::optional derive_from_x_coordinate(const Fq& x, bool sign_bit) noexcept; - /** * @brief Hash a seed value to curve. * @@ -85,8 +87,8 @@ template class alignas(64) affine_el template > static affine_element hash_to_curve(uint64_t seed) noexcept; - template > - static affine_element hash_to_curve(const std::vector& seed) noexcept; + static affine_element hash_to_curve(const std::vector& seed, uint8_t attempt_count = 0) noexcept + requires SupportsHashToCurve; constexpr bool operator==(const affine_element& other) const noexcept; diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element_impl.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element_impl.hpp index debf4a41d74..e74938ff495 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element_impl.hpp @@ -200,6 +200,27 @@ constexpr bool affine_element::operator>(const affine_element& other) return false; } +template +constexpr std::optional> affine_element::derive_from_x_coordinate( + const Fq& x, bool sign_bit) noexcept +{ + auto yy = x.sqr() * x + T::b; + if constexpr (T::has_a) { + yy += (x * T::a); + } + auto [found_root, y] = yy.sqrt(); + + if (found_root) { + // This is for determinism; a different sqrt algorithm could give -y instead of y and so this parity check + // allows all algorithms to get the "same" y + if (uint256_t(y).get_bit(0) != sign_bit) { + y = -y; + } + return affine_element(x, y); + } + return std::nullopt; +} + template template affine_element affine_element::hash_to_curve(uint64_t seed) noexcept @@ -218,22 +239,55 @@ affine_element affine_element::hash_to_curve(uint64_t seed bool y_bit = hash.get_bit(255); - Fq x_out = Fq(x_coordinate); - Fq y_out = (x_out.sqr() * x_out + T::b); - if constexpr (T::has_a) { - y_out += (x_out * T::a); + std::optional result = derive_from_x_coordinate(x_coordinate, y_bit); + + if (result.has_value()) { + return result.value(); } + return affine_element(0, 0); +} - // When the sqrt of y_out doesn't exist, return 0. - auto [is_quadratic_remainder, y_out_] = y_out.sqrt(); - if (!is_quadratic_remainder) { - return affine_element(Fq::zero(), Fq::zero()); +template +affine_element affine_element::hash_to_curve(const std::vector& seed, + uint8_t attempt_count) noexcept + requires SupportsHashToCurve + +{ + std::vector target_seed(seed); + // expand by 2 bytes to cover incremental hash attempts + const size_t seed_size = seed.size(); + for (size_t i = 0; i < 2; ++i) { + target_seed.push_back(0); } - if (uint256_t(y_out_).get_bit(0) != y_bit) { - y_out_ = -y_out_; + target_seed[seed_size] = attempt_count; + target_seed.back() = 0; + const auto hash_hi = sha256::sha256(target_seed); + target_seed.back() = 1; + const auto hash_lo = sha256::sha256(target_seed); + // custom serialize methods as common/serialize.hpp is not constexpr + // (next PR will make this method constexpr) + const auto read_uint256 = [](const uint8_t* in) { + const auto read_limb = [](const uint8_t* in, uint64_t& out) { + for (size_t i = 0; i < 8; ++i) { + out += static_cast(in[i]) << ((7 - i) * 8); + } + }; + uint256_t out = 0; + read_limb(&in[0], out.data[3]); + read_limb(&in[8], out.data[2]); + read_limb(&in[16], out.data[1]); + read_limb(&in[24], out.data[0]); + return out; + }; + // interpret 64 byte hash output as a uint512_t, reduce to Fq element + //(512 bits of entropy ensures result is not biased as 512 >> Fq::modulus.get_msb()) + Fq x(uint512_t(read_uint256(&hash_lo[0]), read_uint256(&hash_hi[0]))); + bool sign_bit = hash_hi[0] > 127; + std::optional result = derive_from_x_coordinate(x, sign_bit); + if (result.has_value()) { + return result.value(); } - - return affine_element(x_out, y_out_); + return hash_to_curve(seed, attempt_count + 1); } template @@ -243,28 +297,21 @@ affine_element affine_element::random_element(numeric::ran engine = &numeric::random::get_engine(); } - bool found_one = false; - Fq yy; Fq x; Fq y; - while (!found_one) { + while (true) { // Sample a random x-coordinate and check if it satisfies curve equation. x = Fq::random_element(engine); - yy = x.sqr() * x + T::b; - if constexpr (T::has_a) { - yy += (x * T::a); - } - auto [found_root, y1] = yy.sqrt(); - y = y1; - // Negate the y-coordinate based on a randomly sampled bit. - bool random_bit = (engine->get_random_uint8() & 1) != 0; - if (random_bit) { - y = -y; - } + bool sign_bit = (engine->get_random_uint8() & 1) != 0; + + std::optional result = derive_from_x_coordinate(x, sign_bit); - found_one = found_root; + if (result.has_value()) { + return result.value(); + } } + throw_or_abort("affine_element::random_element error"); return affine_element(x, y); } diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp index b6483aa5776..2deae832391 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp @@ -1,5 +1,6 @@ #pragma once #include "barretenberg/ecc/groups/element.hpp" +#include "element.hpp" // NOLINTBEGIN(readability-implicit-bool-conversion, cppcoreguidelines-avoid-c-arrays) namespace barretenberg::group_elements { @@ -601,8 +602,8 @@ element element::mul_without_endomorphism(const Fr& expone element work_element(*this); const uint64_t maximum_set_bit = converted_scalar.get_msb(); - // This is simpler and doublings of infinity should be fast. We should think if we want to defend against the timing - // leak here (if used with ECDSA it can sometimes lead to private key compromise) + // This is simpler and doublings of infinity should be fast. We should think if we want to defend against the + // timing leak here (if used with ECDSA it can sometimes lead to private key compromise) for (uint64_t i = maximum_set_bit - 1; i < maximum_set_bit; --i) { work_element.self_dbl(); if (converted_scalar.get_bit(i)) { diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/group.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/group.hpp index 6d4fdbc550e..a7eb24b92ea 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/group.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/group.hpp @@ -4,12 +4,13 @@ #include "./affine_element.hpp" #include "./element.hpp" #include "./wnaf.hpp" +#include "barretenberg/common/constexpr_utils.hpp" +#include "barretenberg/crypto/sha256/sha256.hpp" #include #include #include #include #include - namespace barretenberg { /** @@ -45,6 +46,8 @@ template static inline auto derive_generators() { std::array generators; @@ -62,6 +65,79 @@ template + */ + inline static std::vector derive_generators_secure(const std::vector& domain_separator, + const size_t num_generators, + const size_t starting_index = 0) + { + std::vector result; + std::array domain_hash = sha256::sha256(domain_separator); + std::vector generator_preimage; + generator_preimage.reserve(64); + std::copy(domain_hash.begin(), domain_hash.end(), std::back_inserter(generator_preimage)); + for (size_t i = 0; i < 32; ++i) { + generator_preimage.emplace_back(0); + } + for (size_t i = starting_index; i < starting_index + num_generators; ++i) { + auto generator_index = static_cast(i); + uint32_t mask = 0xff; + generator_preimage[32] = static_cast(generator_index >> 24); + generator_preimage[33] = static_cast((generator_index >> 16) & mask); + generator_preimage[34] = static_cast((generator_index >> 8) & mask); + generator_preimage[35] = static_cast(generator_index & mask); + result.push_back(affine_element::hash_to_curve(generator_preimage)); + } + return result; + } + + inline static affine_element get_secure_generator_from_index(size_t generator_index, + const std::string& domain_separator) + { + std::array domain_hash = sha256::sha256(domain_separator); + std::vector generator_preimage; + generator_preimage.reserve(64); + std::copy(domain_hash.begin(), domain_hash.end(), std::back_inserter(generator_preimage)); + for (size_t i = 0; i < 32; ++i) { + generator_preimage.emplace_back(0); + } + auto gen_idx = static_cast(generator_index); + uint32_t mask = 0xff; + generator_preimage[32] = static_cast(gen_idx >> 24); + generator_preimage[33] = static_cast((gen_idx >> 16) & mask); + generator_preimage[34] = static_cast((gen_idx >> 8) & mask); + generator_preimage[35] = static_cast(gen_idx & mask); + return affine_element::hash_to_curve(generator_preimage); + } + BBERG_INLINE static void conditional_negate_affine(const affine_element* src, affine_element* dest, uint64_t predicate); diff --git a/barretenberg/cpp/src/barretenberg/proof_system/arithmetization/arithmetization.hpp b/barretenberg/cpp/src/barretenberg/proof_system/arithmetization/arithmetization.hpp index 3c0b6ba8769..969aac893fc 100644 --- a/barretenberg/cpp/src/barretenberg/proof_system/arithmetization/arithmetization.hpp +++ b/barretenberg/cpp/src/barretenberg/proof_system/arithmetization/arithmetization.hpp @@ -81,7 +81,51 @@ template class Standard : public Arithmetization class Ultra : public Arithmetization { +template class Turbo : public Arithmetization { + public: + using FF = _FF; + struct Selectors : SelectorsBase { + std::vector>& q_m = std::get<0>(this->_data); + std::vector>& q_c = std::get<1>(this->_data); + std::vector>& q_1 = std::get<2>(this->_data); + std::vector>& q_2 = std::get<3>(this->_data); + std::vector>& q_3 = std::get<4>(this->_data); + std::vector>& q_4 = std::get<5>(this->_data); + std::vector>& q_5 = std::get<6>(this->_data); + std::vector>& q_arith = std::get<7>(this->_data); + std::vector>& q_fixed_base = std::get<8>(this->_data); + std::vector>& q_range = std::get<9>(this->_data); + std::vector>& q_logic = std::get<10>(this->_data); + Selectors() + : SelectorsBase(){}; + Selectors(const Selectors& other) + : SelectorsBase(other) + {} + Selectors(Selectors&& other) + { + this->_data = std::move(other._data); + this->q_m = std::get<0>(this->_data); + this->q_c = std::get<1>(this->_data); + this->q_1 = std::get<2>(this->_data); + this->q_2 = std::get<3>(this->_data); + this->q_3 = std::get<4>(this->_data); + this->q_4 = std::get<5>(this->_data); + this->q_5 = std::get<6>(this->_data); + this->q_arith = std::get<7>(this->_data); + this->q_fixed_base = std::get<8>(this->_data); + this->q_range = std::get<9>(this->_data); + this->q_logic = std::get<10>(this->_data); + }; + Selectors& operator=(Selectors&& other) + { + SelectorsBase::operator=(other); + return *this; + } + ~Selectors() = default; + }; +}; + +template class Ultra : public Arithmetization { public: using FF = _FF; struct Selectors : SelectorsBase { @@ -96,6 +140,7 @@ template class Ultra : public Arithmetization>& q_elliptic = std::get<8>(this->_data); std::vector>& q_aux = std::get<9>(this->_data); std::vector>& q_lookup_type = std::get<10>(this->_data); + std::vector>& q_elliptic_double = std::get<11>(this->_data); Selectors() : SelectorsBase(){}; Selectors(const Selectors& other) @@ -115,6 +160,7 @@ template class Ultra : public Arithmetizationq_elliptic = std::get<8>(this->_data); this->q_aux = std::get<9>(this->_data); this->q_lookup_type = std::get<10>(this->_data); + this->q_elliptic_double = std::get<11>(this->_data); }; Selectors& operator=(Selectors&& other) { diff --git a/barretenberg/cpp/src/barretenberg/proof_system/arithmetization/gate_data.hpp b/barretenberg/cpp/src/barretenberg/proof_system/arithmetization/gate_data.hpp index 50686a230ea..80909226332 100644 --- a/barretenberg/cpp/src/barretenberg/proof_system/arithmetization/gate_data.hpp +++ b/barretenberg/cpp/src/barretenberg/proof_system/arithmetization/gate_data.hpp @@ -127,4 +127,10 @@ template struct ecc_add_gate_ { FF endomorphism_coefficient; FF sign_coefficient; }; +template struct ecc_dbl_gate_ { + uint32_t x1; + uint32_t y1; + uint32_t x3; + uint32_t y3; +}; } // namespace proof_system diff --git a/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/circuit_builder_base.hpp b/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/circuit_builder_base.hpp index 5422cca43d3..cb226fc5255 100644 --- a/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/circuit_builder_base.hpp +++ b/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/circuit_builder_base.hpp @@ -1,6 +1,7 @@ #pragma once #include "barretenberg/common/slab_allocator.hpp" #include "barretenberg/ecc/curves/bn254/fr.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" #include "barretenberg/proof_system/arithmetization/arithmetization.hpp" #include "barretenberg/proof_system/arithmetization/gate_data.hpp" #include "barretenberg/serialize/cbind.hpp" @@ -14,6 +15,9 @@ static constexpr uint32_t DUMMY_TAG = 0; template class CircuitBuilderBase { public: using FF = typename Arithmetization::FF; + using EmbeddedCurve = + std::conditional_t, curve::BN254, curve::Grumpkin>; + static constexpr size_t NUM_WIRES = Arithmetization::NUM_WIRES; // Keeping NUM_WIRES, at least temporarily, for backward compatibility static constexpr size_t program_width = Arithmetization::NUM_WIRES; diff --git a/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/ultra_circuit_builder.cpp b/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/ultra_circuit_builder.cpp index b3289e3ea59..990982aefce 100644 --- a/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/ultra_circuit_builder.cpp +++ b/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/ultra_circuit_builder.cpp @@ -78,6 +78,7 @@ template void UltraCircuitBuilder_::add_gates_to_ensure_all_po q_lookup_type.emplace_back(0); q_elliptic.emplace_back(1); q_aux.emplace_back(1); + q_elliptic_double.emplace_back(1); ++this->num_gates; // Some relations depend on wire shifts so we add another gate with @@ -135,6 +136,7 @@ template void UltraCircuitBuilder_::create_add_gate(const add_ q_sort.emplace_back(0); q_lookup_type.emplace_back(0); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_aux.emplace_back(0); ++this->num_gates; } @@ -165,6 +167,7 @@ void UltraCircuitBuilder_::create_big_add_gate(const add_quad_& in, cons q_sort.emplace_back(0); q_lookup_type.emplace_back(0); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_aux.emplace_back(0); ++this->num_gates; } @@ -256,6 +259,7 @@ template void UltraCircuitBuilder_::create_big_mul_gate(const q_sort.emplace_back(0); q_lookup_type.emplace_back(0); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_aux.emplace_back(0); ++this->num_gates; } @@ -280,6 +284,7 @@ template void UltraCircuitBuilder_::create_balanced_add_gate(c q_sort.emplace_back(0); q_lookup_type.emplace_back(0); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_aux.emplace_back(0); ++this->num_gates; // Why 3? TODO: return to this @@ -320,6 +325,7 @@ template void UltraCircuitBuilder_::create_mul_gate(const mul_ q_sort.emplace_back(0); q_lookup_type.emplace_back(0); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_aux.emplace_back(0); ++this->num_gates; } @@ -347,6 +353,7 @@ template void UltraCircuitBuilder_::create_bool_gate(const uin q_4.emplace_back(0); q_lookup_type.emplace_back(0); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_aux.emplace_back(0); ++this->num_gates; } @@ -376,6 +383,7 @@ template void UltraCircuitBuilder_::create_poly_gate(const pol q_4.emplace_back(0); q_lookup_type.emplace_back(0); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_aux.emplace_back(0); ++this->num_gates; } @@ -393,11 +401,11 @@ template void UltraCircuitBuilder_::create_poly_gate(const pol template void UltraCircuitBuilder_::create_ecc_add_gate(const ecc_add_gate_& in) { /** + * gate structure: * | 1 | 2 | 3 | 4 | - * | a1 | a2 | x1 | y1 | - * | x2 | y2 | x3 | y3 | - * | -- | -- | x4 | y4 | - * + * | -- | x1 | y1 | -- | + * | x2 | x3 | y3 | y2 | + * we can chain successive ecc_add_gates if x3 y3 of previous gate equals x1 y1 of current gate **/ this->assert_valid_variables({ in.x1, in.x2, in.x3, in.y1, in.y2, in.y3 }); @@ -411,7 +419,6 @@ template void UltraCircuitBuilder_::create_ecc_add_gate(const can_fuse_into_previous_gate = can_fuse_into_previous_gate && (q_arith[this->num_gates - 1] == 0); if (can_fuse_into_previous_gate) { - q_3[this->num_gates - 1] = in.endomorphism_coefficient; q_4[this->num_gates - 1] = in.endomorphism_coefficient.sqr(); q_1[this->num_gates - 1] = in.sign_coefficient; @@ -432,10 +439,10 @@ template void UltraCircuitBuilder_::create_ecc_add_gate(const q_sort.emplace_back(0); q_lookup_type.emplace_back(0); q_elliptic.emplace_back(1); + q_elliptic_double.emplace_back(0); q_aux.emplace_back(0); ++this->num_gates; } - w_l.emplace_back(in.x2); w_4.emplace_back(in.y2); w_r.emplace_back(in.x3); @@ -450,6 +457,68 @@ template void UltraCircuitBuilder_::create_ecc_add_gate(const q_sort.emplace_back(0); q_lookup_type.emplace_back(0); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); + q_aux.emplace_back(0); + ++this->num_gates; +} + +/** + * @brief Create an elliptic curve doubling gate + * + * + * @param in Elliptic curve point doubling gate parameters + */ +template void UltraCircuitBuilder_::create_ecc_dbl_gate(const ecc_dbl_gate_& in) +{ + /** + * gate structure: + * | 1 | 2 | 3 | 4 | + * | - | x1 | y1 | - | + * | - | x3 | y3 | - | + * we can chain an ecc_add_gate + an ecc_dbl_gate if x3 y3 of previous add_gate equals x1 y1 of current gate + * can also chain double gates together + **/ + bool can_fuse_into_previous_gate = true; + can_fuse_into_previous_gate = can_fuse_into_previous_gate && (w_r[this->num_gates - 1] == in.x1); + can_fuse_into_previous_gate = can_fuse_into_previous_gate && (w_o[this->num_gates - 1] == in.y1); + + if (can_fuse_into_previous_gate) { + q_elliptic_double[this->num_gates - 1] = 1; + } else { + w_r.emplace_back(in.x1); + w_o.emplace_back(in.y1); + w_l.emplace_back(this->zero_idx); + w_4.emplace_back(this->zero_idx); + q_elliptic_double.emplace_back(1); + q_m.emplace_back(0); + q_1.emplace_back(0); + q_2.emplace_back(0); + q_3.emplace_back(0); + q_c.emplace_back(0); + q_arith.emplace_back(0); + q_4.emplace_back(0); + q_sort.emplace_back(0); + q_lookup_type.emplace_back(0); + q_elliptic.emplace_back(0); + q_aux.emplace_back(0); + ++this->num_gates; + } + + w_r.emplace_back(in.x3); + w_o.emplace_back(in.y3); + w_l.emplace_back(this->zero_idx); + w_4.emplace_back(this->zero_idx); + q_elliptic_double.emplace_back(0); + q_m.emplace_back(0); + q_1.emplace_back(0); + q_2.emplace_back(0); + q_3.emplace_back(0); + q_c.emplace_back(0); + q_arith.emplace_back(0); + q_4.emplace_back(0); + q_sort.emplace_back(0); + q_lookup_type.emplace_back(0); + q_elliptic.emplace_back(0); q_aux.emplace_back(0); ++this->num_gates; } @@ -478,6 +547,7 @@ template void UltraCircuitBuilder_::fix_witness(const uint32_t q_sort.emplace_back(0); q_lookup_type.emplace_back(0); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_aux.emplace_back(0); ++this->num_gates; } @@ -550,6 +620,7 @@ plookup::ReadData UltraCircuitBuilder_::create_gates_from_plookup_ q_4.emplace_back(0); q_sort.emplace_back(0); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_aux.emplace_back(0); ++this->num_gates; } @@ -859,6 +930,7 @@ void UltraCircuitBuilder_::create_sort_constraint(const std::vector::create_sort_constraint(const std::vector::create_dummy_constraints(const std::vector::create_sort_constraint_with_edges(const std::vect q_4.emplace_back(0); q_sort.emplace_back(1); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_lookup_type.emplace_back(0); q_aux.emplace_back(0); // enforce range check for middle rows @@ -960,6 +1035,7 @@ void UltraCircuitBuilder_::create_sort_constraint_with_edges(const std::vect q_4.emplace_back(0); q_sort.emplace_back(1); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_lookup_type.emplace_back(0); q_aux.emplace_back(0); } @@ -979,6 +1055,7 @@ void UltraCircuitBuilder_::create_sort_constraint_with_edges(const std::vect q_4.emplace_back(0); q_sort.emplace_back(1); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_lookup_type.emplace_back(0); q_aux.emplace_back(0); } @@ -999,6 +1076,7 @@ void UltraCircuitBuilder_::create_sort_constraint_with_edges(const std::vect q_4.emplace_back(0); q_sort.emplace_back(0); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_lookup_type.emplace_back(0); q_aux.emplace_back(0); } @@ -1106,6 +1184,7 @@ template void UltraCircuitBuilder_::apply_aux_selectors(const q_sort.emplace_back(0); q_lookup_type.emplace_back(0); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); switch (type) { case AUX_SELECTORS::LIMB_ACCUMULATE_1: { q_1.emplace_back(0); @@ -1771,6 +1850,7 @@ std::array UltraCircuitBuilder_::evaluate_non_native_field_addi q_sort.emplace_back(0); q_lookup_type.emplace_back(0); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_aux.emplace_back(0); } @@ -1892,6 +1972,7 @@ std::array UltraCircuitBuilder_::evaluate_non_native_field_subt q_sort.emplace_back(0); q_lookup_type.emplace_back(0); q_elliptic.emplace_back(0); + q_elliptic_double.emplace_back(0); q_aux.emplace_back(0); } @@ -3114,6 +3195,50 @@ inline FF UltraCircuitBuilder_::compute_auxilary_identity(FF q_aux_value, return auxiliary_identity; } +/** + * @brief Compute a single general permutation sorting identity + * + * @param w_1_value + * @param w_2_value + * @param w_3_value + * @param w_4_value + * @param w_1_shifted_value + * @param alpha_base + * @param alpha + * @return fr + */ +template +inline FF UltraCircuitBuilder_::compute_elliptic_double_identity(FF q_elliptic_double_value, + FF w_2_value, + FF w_3_value, + FF w_2_shifted_value, + FF w_3_shifted_value, + FF alpha_base, + FF alpha) const +{ + constexpr FF curve_b = CircuitBuilderBase>::EmbeddedCurve::Group::curve_b; + static_assert(CircuitBuilderBase>::EmbeddedCurve::Group::curve_a == 0); + const auto x1 = w_2_value; + const auto y1 = w_3_value; + const auto x3 = w_2_shifted_value; + const auto y3 = w_3_shifted_value; + + // x-coordinate relation + // (x3 + 2x1)(4y^2) - (9x^4) = 0 + // This is degree 4...but + // we can use x^3 = y^2 - b + // hon hon hon + // (x3 + 2x1)(4y^2) - (9x(y^2 - b)) is degree 3 + const FF x_pow_4 = (y1 * y1 - curve_b) * x1; + const FF x_relation = (x3 + x1 + x1) * (y1 + y1) * (y1 + y1) - x_pow_4 * FF(9); + + // Y relation: (x1 - x3)(3x^2) - (2y1)(y1 + y3) = 0 + const FF x_pow_2 = (x1 * x1); + const FF y_relation = x_pow_2 * (x1 - x3) * 3 - (y1 + y1) * (y1 + y3); + + return q_elliptic_double_value * alpha_base * (x_relation + y_relation * alpha); +} + /** * @brief Check that the circuit is correct in its current state * @@ -3136,6 +3261,7 @@ template bool UltraCircuitBuilder_::check_circuit() const FF elliptic_base = FF::random_element(); const FF genperm_sort_base = FF::random_element(); const FF auxillary_base = FF::random_element(); + const FF elliptic_double_base = FF::random_element(); const FF alpha = FF::random_element(); const FF eta = FF::random_element(); @@ -3216,6 +3342,7 @@ template bool UltraCircuitBuilder_::check_circuit() FF q_elliptic_value; FF q_sort_value; FF q_lookup_type_value; + FF q_elliptic_double_value; FF q_1_value; FF q_2_value; FF q_3_value; @@ -3233,6 +3360,7 @@ template bool UltraCircuitBuilder_::check_circuit() q_elliptic_value = q_elliptic[i]; q_sort_value = q_sort[i]; q_lookup_type_value = q_lookup_type[i]; + q_elliptic_double_value = q_elliptic_double[i]; q_1_value = q_1[i]; q_2_value = q_2[i]; q_3_value = q_3[i]; @@ -3372,6 +3500,20 @@ template bool UltraCircuitBuilder_::check_circuit() break; } } + if (!compute_elliptic_double_identity(q_elliptic_double_value, + w_2_value, + w_3_value, + w_2_shifted_value, + w_3_shifted_value, + elliptic_double_base, + alpha) + .is_zero()) { +#ifndef FUZZING + info("Elliptic doubling identity fails at gate ", i); +#endif + result = false; + break; + } } if (left_tag_product != right_tag_product) { #ifndef FUZZING diff --git a/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/ultra_circuit_builder.hpp b/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/ultra_circuit_builder.hpp index cbd84d1306d..0b6f64e3bf7 100644 --- a/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/ultra_circuit_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/ultra_circuit_builder.hpp @@ -214,8 +214,10 @@ template class UltraCircuitBuilder_ : public CircuitBuilderBase ultra_selector_names() { - std::vector result{ "q_m", "q_c", "q_1", "q_2", "q_3", "q_4", - "q_arith", "q_sort", "q_elliptic", "q_aux", "table_type" }; + std::vector result{ + "q_m", "q_c", "q_1", "q_2", "q_3", "q_4", + "q_arith", "q_sort", "q_elliptic", "q_aux", "table_type", "q_elliptic_double" + }; return result; } struct non_native_field_multiplication_cross_terms { @@ -264,6 +266,7 @@ template class UltraCircuitBuilder_ : public CircuitBuilderBase tau; @@ -316,6 +319,7 @@ template class UltraCircuitBuilder_ : public CircuitBuilderBase class UltraCircuitBuilder_ : public CircuitBuilderBaseq_elliptic.resize(num_gates); builder->q_aux.resize(num_gates); builder->q_lookup_type.resize(num_gates); + builder->q_elliptic_double.resize(num_gates); } /** * @brief Checks that the circuit state is the same as the stored circuit's one @@ -491,6 +496,9 @@ template class UltraCircuitBuilder_ : public CircuitBuilderBase class UltraCircuitBuilder_ : public CircuitBuilderBaseselectors.q_elliptic; SelectorVector& q_aux = this->selectors.q_aux; SelectorVector& q_lookup_type = this->selectors.q_lookup_type; + SelectorVector& q_elliptic_double = this->selectors.q_elliptic_double; // These are variables that we have used a gate on, to enforce that they are // equal to a defined value. @@ -642,6 +651,7 @@ template class UltraCircuitBuilder_ : public CircuitBuilderBase& in) override; void create_ecc_add_gate(const ecc_add_gate_& in); + void create_ecc_dbl_gate(const ecc_dbl_gate_& in); void fix_witness(const uint32_t witness_index, const FF& witness_value); @@ -1036,6 +1046,13 @@ template class UltraCircuitBuilder_ : public CircuitBuilderBase table_raw(MAX_TABLE_SIZE); + + element accumulator = offset_generator; + for (size_t i = 0; i < MAX_TABLE_SIZE; ++i) { + table_raw[i] = accumulator; + accumulator += base_point; + } + element::batch_normalize(&table_raw[0], MAX_TABLE_SIZE); + single_lookup_table table(MAX_TABLE_SIZE); + for (size_t i = 0; i < table_raw.size(); ++i) { + table[i] = affine_element{ table_raw[i].x, table_raw[i].y }; + } + return table; +} + +/** + * @brief For a given base point [P], compute the lookup tables required to traverse a `num_bits` sized lookup + * + * i.e. call `generate_single_lookup_table` for the following base points: + * + * { [P], [P] * (1 << BITS_PER_TABLE), [P] * (1 << BITS_PER_TABLE * 2), ..., [P] * (1 << BITS_PER_TABLE * (NUM_TABLES - + * 1)) } + * + * @tparam num_bits + * @param input + * @return table::fixed_base_scalar_mul_tables + */ +template table::fixed_base_scalar_mul_tables table::generate_tables(const affine_element& input) +{ + constexpr size_t NUM_TABLES = get_num_tables_per_multi_table(); + + fixed_base_scalar_mul_tables result; + result.reserve(NUM_TABLES); + + std::vector input_buf; + serialize::write(input_buf, input); + const auto offset_generators = grumpkin::g1::derive_generators_secure(input_buf, MAX_TABLE_SIZE); + + grumpkin::g1::element accumulator = input; + for (size_t i = 0; i < NUM_TABLES; ++i) { + result.emplace_back(generate_single_lookup_table(accumulator, offset_generators[i])); + for (size_t j = 0; j < BITS_PER_TABLE; ++j) { + accumulator = accumulator.dbl(); + } + } + return result; +} + +/** + * @brief For a fixed-base lookup of size `num_table_bits` and an input base point `input`, + * return the total contrbution in the scalar multiplication output from the offset generators in the lookup + * tables. + * + * @note We need the base point as an input parameter because we derive the offset generator using our hash-to-curve + * algorithm, where the base point is used as the domain separator. Ensures generator points cannot collide with base + * points w/o solving the dlog problem + * @tparam num_table_bits + * @param input + * @return grumpkin::g1::affine_element + */ +template +grumpkin::g1::affine_element table::generate_generator_offset(const grumpkin::g1::affine_element& input) +{ + constexpr size_t NUM_TABLES = get_num_tables_per_multi_table(); + + std::vector input_buf; + serialize::write(input_buf, input); + const auto offset_generators = grumpkin::g1::derive_generators_secure(input_buf, NUM_TABLES); + grumpkin::g1::element acc = grumpkin::g1::point_at_infinity; + for (const auto& gen : offset_generators) { + acc += gen; + } + return acc; +} + +/** + * @brief Given a point, do we have a precomputed lookup table for this point? + * + * @param input + * @return true + * @return false + */ +bool table::lookup_table_exists_for_point(const grumpkin::g1::affine_element& input) +{ + return (input == native_pedersen::get_lhs_generator() || input == native_pedersen::get_rhs_generator()); +} + +/** + * @brief Given a point, return (if it exists) the 2 MultiTableId's that correspond to the LO_SCALAR, HI_SCALAR + * MultiTables + * + * @param input + * @return std::optional> + */ +std::optional> table::get_lookup_table_ids_for_point( + const grumpkin::g1::affine_element& input) +{ + if (input == native_pedersen::get_lhs_generator()) { + return { { FIXED_BASE_LEFT_LO, FIXED_BASE_LEFT_HI } }; + } + if (input == native_pedersen::get_rhs_generator()) { + return { { FIXED_BASE_RIGHT_LO, FIXED_BASE_RIGHT_HI } }; + } + return std::nullopt; +} + +/** + * @brief Given a table id, return the offset generator term that will be present in the final scalar mul output. + * + * Return value is std::optional in case the table_id is not a fixed-base table. + * + * @param table_id + * @return std::optional + */ +std::optional table::get_generator_offset_for_table_id(const MultiTableId table_id) +{ + if (table_id == FIXED_BASE_LEFT_LO) { + return fixed_base_table_offset_generators[0]; + } + if (table_id == FIXED_BASE_LEFT_HI) { + return fixed_base_table_offset_generators[1]; + } + if (table_id == FIXED_BASE_RIGHT_LO) { + return fixed_base_table_offset_generators[2]; + } + if (table_id == FIXED_BASE_RIGHT_HI) { + return fixed_base_table_offset_generators[3]; + } + return std::nullopt; +} + +using function_ptr = std::array (*)(const std::array); +using function_ptr_table = + std::array, table::NUM_FIXED_BASE_MULTI_TABLES>; +/** + * @brief create a compile-time static 2D array of all our required `get_basic_fixed_base_table_values` function + * pointers, so that we can specify the function pointer required for this method call using runtime variables + * `multitable_index`, `table_index`. (downstream code becomes a lot simpler if `table_index` is not compile time, + * particularly the init code in `plookup_tables.cpp`) + * @return constexpr function_ptr_table + */ +constexpr function_ptr_table make_function_pointer_table() +{ + function_ptr_table table; + barretenberg::constexpr_for<0, table::NUM_FIXED_BASE_MULTI_TABLES, 1>([&]() { + barretenberg::constexpr_for<0, table::MAX_NUM_TABLES_IN_MULTITABLE, 1>( + [&]() { table[i][j] = &table::get_basic_fixed_base_table_values; }); + }); + return table; +}; + +/** + * @brief Generate a single fixed-base-scalar-mul plookup table + * + * @tparam multitable_index , which of our 4 multitables is this basic table a part of? + * @param id the BasicTableId + * @param basic_table_index plookup table index + * @param table_index This index describes which bit-slice the basic table corresponds to. i.e. table_index = 0 maps to + * the least significant bit slice + * @return BasicTable + */ +template +BasicTable table::generate_basic_fixed_base_table(BasicTableId id, size_t basic_table_index, size_t table_index) +{ + static_assert(multitable_index < NUM_FIXED_BASE_MULTI_TABLES); + ASSERT(table_index < MAX_NUM_TABLES_IN_MULTITABLE); + + const size_t multitable_bits = get_num_bits_of_multi_table(multitable_index); + const size_t bits_covered_by_previous_tables_in_multitable = BITS_PER_TABLE * table_index; + const bool is_small_table = (multitable_bits - bits_covered_by_previous_tables_in_multitable) < BITS_PER_TABLE; + const size_t table_bits = + is_small_table ? multitable_bits - bits_covered_by_previous_tables_in_multitable : BITS_PER_TABLE; + const auto table_size = static_cast(1ULL << table_bits); + BasicTable table; + table.id = id; + table.table_index = basic_table_index; + table.size = table_size; + table.use_twin_keys = false; + + const auto& basic_table = fixed_base_tables[multitable_index][table_index]; + + for (size_t i = 0; i < table.size; ++i) { + table.column_1.emplace_back(i); + table.column_2.emplace_back(basic_table[i].x); + table.column_3.emplace_back(basic_table[i].y); + } + table.get_values_from_key = nullptr; + + constexpr function_ptr_table get_values_from_key_table = make_function_pointer_table(); + table.get_values_from_key = get_values_from_key_table[multitable_index][table_index]; + + ASSERT(table.get_values_from_key != nullptr); + table.column_1_step_size = table.size; + table.column_2_step_size = 0; + table.column_3_step_size = 0; + + return table; +} + +/** + * @brief Generate a multi-table that describes the lookups required to cover a fixed-base-scalar-mul of `num_bits` + * + * @tparam multitable_index , which one of our 4 multitables are we generating? + * @tparam num_bits , this will be either `BITS_PER_LO_SCALAR` or `BITS_PER_HI_SCALAR` + * @param id + * @return MultiTable + */ +template MultiTable table::get_fixed_base_table(const MultiTableId id) +{ + static_assert(num_bits == BITS_PER_LO_SCALAR || num_bits == BITS_PER_HI_SCALAR); + constexpr size_t NUM_TABLES = get_num_tables_per_multi_table(); + constexpr std::array basic_table_ids{ + FIXED_BASE_0_0, + FIXED_BASE_1_0, + FIXED_BASE_2_0, + FIXED_BASE_3_0, + }; + constexpr function_ptr_table get_values_from_key_table = make_function_pointer_table(); + + MultiTable table(MAX_TABLE_SIZE, 0, 0, NUM_TABLES); + table.id = id; + table.get_table_values.resize(NUM_TABLES); + table.lookup_ids.resize(NUM_TABLES); + for (size_t i = 0; i < NUM_TABLES; ++i) { + table.slice_sizes.emplace_back(MAX_TABLE_SIZE); + table.get_table_values[i] = get_values_from_key_table[multitable_index][i]; + static_assert(multitable_index < NUM_FIXED_BASE_MULTI_TABLES); + size_t idx = i + static_cast(basic_table_ids[multitable_index]); + table.lookup_ids[i] = static_cast(idx); + } + return table; +} + +template grumpkin::g1::affine_element table::generate_generator_offset( + const grumpkin::g1::affine_element& input); +template grumpkin::g1::affine_element table::generate_generator_offset( + const grumpkin::g1::affine_element& input); +template table::fixed_base_scalar_mul_tables table::generate_tables( + const table::affine_element& input); +template table::fixed_base_scalar_mul_tables table::generate_tables( + const table::affine_element& input); + +template BasicTable table::generate_basic_fixed_base_table<0>(BasicTableId, size_t, size_t); +template BasicTable table::generate_basic_fixed_base_table<1>(BasicTableId, size_t, size_t); +template BasicTable table::generate_basic_fixed_base_table<2>(BasicTableId, size_t, size_t); +template BasicTable table::generate_basic_fixed_base_table<3>(BasicTableId, size_t, size_t); +template MultiTable table::get_fixed_base_table<0, table::BITS_PER_LO_SCALAR>(MultiTableId); +template MultiTable table::get_fixed_base_table<1, table::BITS_PER_HI_SCALAR>(MultiTableId); +template MultiTable table::get_fixed_base_table<2, table::BITS_PER_LO_SCALAR>(MultiTableId); +template MultiTable table::get_fixed_base_table<3, table::BITS_PER_HI_SCALAR>(MultiTableId); + +} // namespace plookup::fixed_base \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/fixed_base/fixed_base.hpp b/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/fixed_base/fixed_base.hpp new file mode 100644 index 00000000000..010f222074d --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/fixed_base/fixed_base.hpp @@ -0,0 +1,110 @@ +#pragma once + +#include "../types.hpp" +#include "./fixed_base_params.hpp" +#include "barretenberg/crypto/pedersen_hash/pedersen_refactor.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" + +namespace plookup::fixed_base { + +/** + * @brief Generates plookup tables required to perform fixed-base scalar multiplication over a fixed number of points. + * + */ +class table : public FixedBaseParams { + public: + using affine_element = grumpkin::g1::affine_element; + using element = grumpkin::g1::element; + using single_lookup_table = std::vector; + using fixed_base_scalar_mul_tables = std::vector; + using all_multi_tables = std::array; + using native_pedersen = crypto::pedersen_hash_refactor; + + static inline single_lookup_table generate_single_lookup_table(const affine_element& base_point, + const affine_element& offset_generator); + template static fixed_base_scalar_mul_tables generate_tables(const affine_element& input); + + template static affine_element generate_generator_offset(const affine_element& input); + + static constexpr uint256_t MAX_LO_SCALAR = uint256_t(1) << BITS_PER_LO_SCALAR; + // We split each scalar mulitplier into BITS_PER_LO_SCALAR, BITS_PER_HI_SCALAR chunks and perform 2 scalar muls of + // size BITS_PER_LO_SCALAR, BITS_PER_HI_SCALAR (see fixed_base_params.hpp for more details) + // i.e. we treat 1 scalar mul as two independent scalar muls over (roughly) half-width input scalars. + // The base_point members describe the fixed-base points that correspond to the two independent scalar muls, + // for our two supported points + inline static const affine_element lhs_base_point_lo = native_pedersen::get_lhs_generator(); + inline static const affine_element lhs_base_point_hi = element(lhs_base_point_lo) * MAX_LO_SCALAR; + inline static const affine_element rhs_base_point_lo = native_pedersen::get_rhs_generator(); + inline static const affine_element rhs_base_point_hi = element(rhs_base_point_lo) * MAX_LO_SCALAR; + + // fixed_base_tables = lookup tables of precomputed base points required for our lookup arguments. + // N.B. these "tables" are not plookup tables, just regular ol' software lookup tables. + // Used to build the proper plookup table and in the `BasicTable::get_values_from_key` method + inline static const all_multi_tables fixed_base_tables = { + table::generate_tables(lhs_base_point_lo), + table::generate_tables(lhs_base_point_hi), + table::generate_tables(rhs_base_point_lo), + table::generate_tables(rhs_base_point_hi), + }; + + /** + * @brief offset generators! + * + * We add a unique "offset generator" into each lookup table to ensure that we never trigger + * incomplete addition formulae for short Weierstrass curves. + * The offset generators are linearly independent from the fixed-base points we're multiplying, ensuring that a + * collision is as likely as solving the discrete logarithm problem. + * For example, imagine a 2-bit lookup table of a point [P]. The table would normally contain { + * 0.[P], 1.[P], 2.[P], 3.[P]}. But, we dont want to have to handle points at infinity and we also don't want to + * deal with windowed-non-adjacent-form complexities. Instead, we derive offset generator [G] and make the table + * equal to { [G] + 0.[P], [G] + 1.[P], [G] + 2.[P], [G] + 3.[P]}. Each table uses a unique offset generator to + * prevent collisions. + * The final scalar multiplication output will have a precisely-known contribution from the offset generators, + * which can then be subtracted off with a single point subtraction. + **/ + inline static const std::array + fixed_base_table_offset_generators = { + table::generate_generator_offset(lhs_base_point_lo), + table::generate_generator_offset(lhs_base_point_hi), + table::generate_generator_offset(rhs_base_point_lo), + table::generate_generator_offset(rhs_base_point_hi), + }; + + static bool lookup_table_exists_for_point(const affine_element& input); + static std::optional> get_lookup_table_ids_for_point(const affine_element& input); + static std::optional get_generator_offset_for_table_id(MultiTableId table_id); + + template + static BasicTable generate_basic_fixed_base_table(BasicTableId id, size_t basic_table_index, size_t table_index); + template static MultiTable get_fixed_base_table(MultiTableId id); + + template + static std::array get_basic_fixed_base_table_values(const std::array key) + { + static_assert(multitable_index < NUM_FIXED_BASE_MULTI_TABLES); + static_assert(table_index < get_num_bits_of_multi_table(multitable_index)); + const auto& basic_table = fixed_base_tables[multitable_index][table_index]; + const auto index = static_cast(key[0]); + return { basic_table[index].x, basic_table[index].y }; + } +}; + +extern template table::affine_element table::generate_generator_offset( + const table::affine_element&); +extern template table::affine_element table::generate_generator_offset( + const table::affine_element&); +extern template table::fixed_base_scalar_mul_tables table::generate_tables( + const table::affine_element&); +extern template table::fixed_base_scalar_mul_tables table::generate_tables( + const table::affine_element&); + +extern template BasicTable table::generate_basic_fixed_base_table<0>(BasicTableId, size_t, size_t); +extern template BasicTable table::generate_basic_fixed_base_table<1>(BasicTableId, size_t, size_t); +extern template BasicTable table::generate_basic_fixed_base_table<2>(BasicTableId, size_t, size_t); +extern template BasicTable table::generate_basic_fixed_base_table<3>(BasicTableId, size_t, size_t); +extern template MultiTable table::get_fixed_base_table<0, table::BITS_PER_LO_SCALAR>(MultiTableId); +extern template MultiTable table::get_fixed_base_table<1, table::BITS_PER_HI_SCALAR>(MultiTableId); +extern template MultiTable table::get_fixed_base_table<2, table::BITS_PER_LO_SCALAR>(MultiTableId); +extern template MultiTable table::get_fixed_base_table<3, table::BITS_PER_HI_SCALAR>(MultiTableId); + +} // namespace plookup::fixed_base \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/fixed_base/fixed_base_params.hpp b/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/fixed_base/fixed_base_params.hpp new file mode 100644 index 00000000000..2f69820961e --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/fixed_base/fixed_base_params.hpp @@ -0,0 +1,76 @@ +#pragma once + +#include "barretenberg/plonk/proof_system/constants.hpp" +#include +#include +#include +#include + +namespace plookup { +/** + * @brief Parameters definitions for our fixed-base-scalar-multiplication lookup tables + * + */ +struct FixedBaseParams { + static constexpr size_t BITS_PER_TABLE = 9; + static constexpr size_t BITS_ON_CURVE = 254; + + // We split 1 254-bit scalar mul into two scalar muls of size BITS_PER_LO_SCALAR, BITS_PER_HI_SCALAR. + // This enables us to efficiently decompose our input scalar multiplier into two chunks of a known size. + // (i.e. we get free BITS_PER_LO_SCALAR, BITS_PER_HI_SCALAR range checks as part of the lookup table subroutine) + // This in turn allows us to perform a primality test more efficiently. + // i.e. check that input scalar < prime modulus when evaluated over the integers + // (the primality check requires us to split the input into high / low bit chunks so getting this for free as part + // of the lookup algorithm is nice!) + static constexpr size_t BITS_PER_LO_SCALAR = 128; + static constexpr size_t BITS_PER_HI_SCALAR = BITS_ON_CURVE - BITS_PER_LO_SCALAR; + // max table size because the last lookup table might be smaller (BITS_PER_TABLE does not neatly divide + // BITS_PER_LO_SCALAR) + static constexpr size_t MAX_TABLE_SIZE = (1UL) << BITS_PER_TABLE; + // how many BITS_PER_TABLE lookup tables do we need to traverse BITS_PER_LO_SCALAR-amount of bits? + // (we implicitly assume BITS_PER_LO_SCALAR > BITS_PER_HI_SCALAR) + static constexpr size_t MAX_NUM_TABLES_IN_MULTITABLE = + (BITS_PER_LO_SCALAR / BITS_PER_TABLE) + (BITS_PER_LO_SCALAR % BITS_PER_TABLE == 0 ? 0 : 1); + static constexpr size_t NUM_POINTS = 2; + // how many multitables are we creating? It's 4 because we want enough lookup tables to cover two field elements, + // two field elements = 2 scalar muls = 4 scalar mul hi/lo slices = 4 multitables + static constexpr size_t NUM_FIXED_BASE_MULTI_TABLES = NUM_POINTS * 2; + static constexpr size_t NUM_TABLES_PER_LO_MULTITABLE = + (BITS_PER_LO_SCALAR / BITS_PER_TABLE) + ((BITS_PER_LO_SCALAR % BITS_PER_TABLE == 0) ? 0 : 1); + static constexpr size_t NUM_TABLES_PER_HI_MULTITABLE = + (BITS_PER_LO_SCALAR / BITS_PER_TABLE) + ((BITS_PER_LO_SCALAR % BITS_PER_TABLE == 0) ? 0 : 1); + // how many lookups are required to perform a scalar mul of a field element with a base point? + static constexpr size_t NUM_BASIC_TABLES_PER_BASE_POINT = + (NUM_TABLES_PER_LO_MULTITABLE + NUM_TABLES_PER_HI_MULTITABLE); + // how many basic lookup tables are we creating in total to support fixed-base-scalar-muls over two precomputed base + // points. + static constexpr size_t NUM_FIXED_BASE_BASIC_TABLES = NUM_BASIC_TABLES_PER_BASE_POINT * NUM_POINTS; + + /** + * @brief For a scalar multiplication table that covers input scalars up to `(1 << num_bits) - 1`, + * how many individual lookup tables of max size BITS_PER_TABLE do we need? + * (e.g. if BITS_PER_TABLE = 9, for `num_bits = 126` it's 14. For `num_bits = 128` it's 15) + * @tparam num_bits + * @return constexpr size_t + */ + template inline static constexpr size_t get_num_tables_per_multi_table() noexcept + { + return (num_bits / BITS_PER_TABLE) + ((num_bits % BITS_PER_TABLE == 0) ? 0 : 1); + } + + /** + * @brief For a given multitable index, how many scalar mul bits are we traversing with our multitable? + * + * @param multitable_index Ranges from 0 to NUM_FIXED_BASE_MULTI_TABLES - 1 + * @return constexpr size_t + */ + static constexpr size_t get_num_bits_of_multi_table(const size_t multitable_index) + { + ASSERT(multitable_index < NUM_FIXED_BASE_MULTI_TABLES); + constexpr std::array MULTI_TABLE_BIT_LENGTHS{ + BITS_PER_LO_SCALAR, BITS_PER_HI_SCALAR, BITS_PER_LO_SCALAR, BITS_PER_HI_SCALAR + }; + return MULTI_TABLE_BIT_LENGTHS[multitable_index]; + } +}; +} // namespace plookup \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/pedersen.hpp b/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/pedersen.hpp index e9f2e5490c2..0bafaead4d0 100644 --- a/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/pedersen.hpp +++ b/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/pedersen.hpp @@ -1,5 +1,7 @@ #pragma once +// TODO(@zac-wiliamson #2341 delete this file once we migrate to new hash standard + #include "./types.hpp" #include "barretenberg/crypto/pedersen_hash/pedersen_lookup.hpp" diff --git a/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/plookup_tables.cpp b/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/plookup_tables.cpp index 6580a47e8e4..9bd534982b6 100644 --- a/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/plookup_tables.cpp +++ b/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/plookup_tables.cpp @@ -6,8 +6,11 @@ namespace plookup { using namespace barretenberg; namespace { -static std::array MULTI_TABLES; -static bool inited = false; +// TODO(@zac-williamson) convert these into static const members of a struct +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +std::array MULTI_TABLES; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +bool inited = false; void init_multi_tables() { @@ -91,9 +94,17 @@ void init_multi_tables() keccak_tables::Chi::get_chi_output_table(MultiTableId::KECCAK_CHI_OUTPUT); MULTI_TABLES[MultiTableId::KECCAK_FORMAT_OUTPUT] = keccak_tables::KeccakOutput::get_keccak_output_table(MultiTableId::KECCAK_FORMAT_OUTPUT); + MULTI_TABLES[MultiTableId::FIXED_BASE_LEFT_LO] = + fixed_base::table::get_fixed_base_table<0, 128>(MultiTableId::FIXED_BASE_LEFT_LO); + MULTI_TABLES[MultiTableId::FIXED_BASE_LEFT_HI] = + fixed_base::table::get_fixed_base_table<1, 126>(MultiTableId::FIXED_BASE_LEFT_HI); + MULTI_TABLES[MultiTableId::FIXED_BASE_RIGHT_LO] = + fixed_base::table::get_fixed_base_table<2, 128>(MultiTableId::FIXED_BASE_RIGHT_LO); + MULTI_TABLES[MultiTableId::FIXED_BASE_RIGHT_HI] = + fixed_base::table::get_fixed_base_table<3, 126>(MultiTableId::FIXED_BASE_RIGHT_HI); barretenberg::constexpr_for<0, 25, 1>([&]() { - MULTI_TABLES[(size_t)MultiTableId::KECCAK_NORMALIZE_AND_ROTATE + i] = + MULTI_TABLES[static_cast(MultiTableId::KECCAK_NORMALIZE_AND_ROTATE) + i] = keccak_tables::Rho<8, i>::get_rho_output_table(MultiTableId::KECCAK_NORMALIZE_AND_ROTATE); }); MULTI_TABLES[MultiTableId::HONK_DUMMY_MULTI] = dummy_tables::get_honk_dummy_multitable(); @@ -119,7 +130,6 @@ ReadData get_lookup_accumulators(const MultiTableId id, const size_t num_lookups = multi_table.lookup_ids.size(); ReadData lookup; - const auto key_a_slices = numeric::slice_input_using_variable_bases(key_a, multi_table.slice_sizes); const auto key_b_slices = numeric::slice_input_using_variable_bases(key_b, multi_table.slice_sizes); @@ -139,6 +149,7 @@ ReadData get_lookup_accumulators(const MultiTableId id, const BasicTable::KeyEntry key_entry{ { key_a_slices[i], key_b_slices[i] }, values }; lookup.key_entries.emplace_back(key_entry); } + lookup[ColumnIdx::C1].resize(num_lookups); lookup[ColumnIdx::C2].resize(num_lookups); lookup[ColumnIdx::C3].resize(num_lookups); diff --git a/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/plookup_tables.hpp b/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/plookup_tables.hpp index d915409550a..145959017e3 100644 --- a/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/plookup_tables.hpp +++ b/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/plookup_tables.hpp @@ -1,6 +1,7 @@ #pragma once #include "barretenberg/common/throw_or_abort.hpp" +#include "./fixed_base/fixed_base.hpp" #include "aes128.hpp" #include "blake2s.hpp" #include "dummy.hpp" @@ -18,15 +19,33 @@ namespace plookup { -const MultiTable& create_table(const MultiTableId id); +const MultiTable& create_table(MultiTableId id); -ReadData get_lookup_accumulators(const MultiTableId id, +ReadData get_lookup_accumulators(MultiTableId id, const barretenberg::fr& key_a, const barretenberg::fr& key_b = 0, - const bool is_2_to_1_map = false); + bool is_2_to_1_lookup = false); inline BasicTable create_basic_table(const BasicTableId id, const size_t index) { + // we have >50 basic fixed base tables so we match with some logic instead of a switch statement + auto id_var = static_cast(id); + if (id_var >= static_cast(FIXED_BASE_0_0) && id_var < static_cast(FIXED_BASE_1_0)) { + return fixed_base::table::generate_basic_fixed_base_table<0>( + id, index, id_var - static_cast(FIXED_BASE_0_0)); + } + if (id_var >= static_cast(FIXED_BASE_1_0) && id_var < static_cast(FIXED_BASE_2_0)) { + return fixed_base::table::generate_basic_fixed_base_table<1>( + id, index, id_var - static_cast(FIXED_BASE_1_0)); + } + if (id_var >= static_cast(FIXED_BASE_2_0) && id_var < static_cast(FIXED_BASE_3_0)) { + return fixed_base::table::generate_basic_fixed_base_table<2>( + id, index, id_var - static_cast(FIXED_BASE_2_0)); + } + if (id_var >= static_cast(FIXED_BASE_3_0) && id_var < static_cast(PEDERSEN_29_SMALL)) { + return fixed_base::table::generate_basic_fixed_base_table<3>( + id, index, id_var - static_cast(FIXED_BASE_3_0)); + } switch (id) { case AES_SPARSE_MAP: { return sparse_tables::generate_sparse_table_with_rotation<9, 8, 0>(AES_SPARSE_MAP, index); diff --git a/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/types.hpp b/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/types.hpp index e7fd4e400ef..514cebd6d02 100644 --- a/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/types.hpp +++ b/barretenberg/cpp/src/barretenberg/proof_system/plookup_tables/types.hpp @@ -3,6 +3,7 @@ #include #include +#include "./fixed_base/fixed_base_params.hpp" #include "barretenberg/ecc/curves/bn254/fr.hpp" namespace plookup { @@ -52,7 +53,12 @@ enum BasicTableId { BLAKE_XOR_ROTATE1, BLAKE_XOR_ROTATE2, BLAKE_XOR_ROTATE4, - PEDERSEN_29_SMALL, + FIXED_BASE_0_0, + FIXED_BASE_1_0 = FIXED_BASE_0_0 + FixedBaseParams::NUM_TABLES_PER_LO_MULTITABLE, + FIXED_BASE_2_0 = FIXED_BASE_1_0 + FixedBaseParams::NUM_TABLES_PER_HI_MULTITABLE, + FIXED_BASE_3_0 = FIXED_BASE_2_0 + FixedBaseParams::NUM_TABLES_PER_LO_MULTITABLE, + // TODO(@zac-wiliamson #2341 remove PEDERSEN basic tables) + PEDERSEN_29_SMALL = FIXED_BASE_3_0 + FixedBaseParams::NUM_TABLES_PER_HI_MULTITABLE, PEDERSEN_28, PEDERSEN_27, PEDERSEN_26, @@ -111,10 +117,15 @@ enum MultiTableId { AES_NORMALIZE, AES_INPUT, AES_SBOX, + // TODO(@zac-wiliamson #2341 remove PEDERSEN_LEFT/RIGHT/HI/LO) PEDERSEN_LEFT_HI, PEDERSEN_LEFT_LO, PEDERSEN_RIGHT_HI, PEDERSEN_RIGHT_LO, + FIXED_BASE_LEFT_LO, + FIXED_BASE_LEFT_HI, + FIXED_BASE_RIGHT_LO, + FIXED_BASE_RIGHT_HI, UINT32_XOR, UINT32_AND, BN254_XLO, diff --git a/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen.test.cpp new file mode 100644 index 00000000000..c82041c4972 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen.test.cpp @@ -0,0 +1,53 @@ +#include "../../primitives/circuit_builders/circuit_builders.hpp" +#include "./pedersen_refactor.hpp" +#include "barretenberg/crypto/pedersen_hash/pedersen_refactor.hpp" +#include "barretenberg/numeric/random/engine.hpp" +#include + +#define STDLIB_TYPE_ALIASES using Composer = TypeParam; + +namespace stdlib_pedersen_tests { +using namespace barretenberg; +using namespace proof_system::plonk; + +namespace { +auto& engine = numeric::random::get_debug_engine(); +} + +template class PedersenTest : public ::testing::Test { + public: + static void SetUpTestSuite(){ + + }; +}; + +using CircuitTypes = ::testing::Types; +TYPED_TEST_SUITE(PedersenTest, CircuitTypes); + +TYPED_TEST(PedersenTest, TestHash) +{ + STDLIB_TYPE_ALIASES; + using field_ct = stdlib::field_t; + using witness_ct = stdlib::witness_t; + auto composer = Composer(); + + const size_t num_inputs = 10; + + std::vector inputs; + std::vector inputs_native; + + for (size_t i = 0; i < num_inputs; ++i) { + const auto element = fr::random_element(&engine); + inputs_native.emplace_back(element); + inputs.emplace_back(field_ct(witness_ct(&composer, element))); + } + + auto result = stdlib::pedersen_hash_refactor::hash(inputs); + auto expected = crypto::pedersen_hash_refactor::hash(inputs_native); + + EXPECT_EQ(result.get_value(), expected); + + bool proof_result = composer.check_circuit(); + EXPECT_EQ(proof_result, true); +} +} // namespace stdlib_pedersen_tests \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen_refactor.cpp b/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen_refactor.cpp new file mode 100644 index 00000000000..ea387251323 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen_refactor.cpp @@ -0,0 +1,46 @@ +#include "pedersen_refactor.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" +namespace proof_system::plonk::stdlib { + +using namespace barretenberg; +using namespace crypto::generators; +using namespace proof_system; + +template +field_t pedersen_hash_refactor::hash_multiple(const std::vector& inputs, + const size_t hash_index, + const generator_data* generator_context, + const bool /*unused*/) +{ + + using cycle_group = cycle_group; + using cycle_scalar = typename cycle_group::cycle_scalar; + using Curve = EmbeddedCurve; + + auto base_points = generator_context->conditional_extend(inputs.size() + hash_index).generators; + + std::vector scalars; + std::vector points; + scalars.emplace_back(cycle_scalar::create_from_bn254_scalar(field_t(inputs.size()))); + points.emplace_back(crypto::pedersen_hash_refactor::get_length_generator()); + for (size_t i = 0; i < inputs.size(); ++i) { + scalars.emplace_back(cycle_scalar::create_from_bn254_scalar(inputs[i])); + // constructs constant cycle_group objects (non-witness) + points.emplace_back(base_points[i + hash_index]); + } + + auto result = cycle_group::batch_mul(scalars, points); + return result.x; +} + +template +field_t pedersen_hash_refactor::hash(const std::vector& in, + size_t hash_index, + const generator_data* generator_context, + bool validate_inputs_in_field) +{ + return hash_multiple(in, hash_index, generator_context, validate_inputs_in_field); +} +INSTANTIATE_STDLIB_TYPE(pedersen_hash_refactor); + +} // namespace proof_system::plonk::stdlib diff --git a/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen_refactor.hpp b/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen_refactor.hpp new file mode 100644 index 00000000000..30727dcee46 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen_refactor.hpp @@ -0,0 +1,43 @@ +#pragma once +#include "../../primitives/field/field.hpp" +#include "../../primitives/point/point.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" +#include "barretenberg/stdlib/primitives/group/cycle_group.hpp" + +#include "../../primitives/circuit_builders/circuit_builders.hpp" + +namespace proof_system::plonk::stdlib { + +using namespace barretenberg; +/** + * @brief stdlib class that evaluates in-circuit pedersen hashes, consistent with behavior in + * crypto::pedersen_hash_refactor + * + * @tparam ComposerContext + */ +template class pedersen_hash_refactor { + + private: + using field_t = stdlib::field_t; + using point = stdlib::point; + using bool_t = stdlib::bool_t; + using EmbeddedCurve = typename cycle_group::Curve; + using generator_data = crypto::generator_data; + + public: + // TODO(@suyash67) as part of refactor project, can we remove this and replace with `hash` + // (i.e. simplify the name as we no longer have a need for `hash_single`) + static field_t hash_multiple(const std::vector& in, + size_t hash_index = 0, + const generator_data* generator_context = generator_data::get_default_generators(), + bool validate_inputs_in_field = true); + + static field_t hash(const std::vector& in, + size_t hash_index = 0, + const generator_data* generator_context = generator_data::get_default_generators(), + bool validate_inputs_in_field = true); +}; + +EXTERN_STDLIB_TYPE(pedersen_hash_refactor); + +} // namespace proof_system::plonk::stdlib diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp index 4e525e2fa5d..ae0cf7ae533 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp @@ -693,6 +693,9 @@ template bool_t field_t::operator!=(const f template field_t field_t::conditional_negate(const bool_t& predicate) const { + if (predicate.is_constant()) { + return predicate.get_value() ? -(*this) : *this; + } field_t predicate_field(predicate); field_t multiplicand = -(predicate_field + predicate_field); return multiplicand.madd(*this, *this); @@ -704,6 +707,14 @@ field_t field_t::conditional_assign(const bool_t& pre const field_t& lhs, const field_t& rhs) { + if (predicate.is_constant()) { + return predicate.get_value() ? lhs : rhs; + } + // if lhs and rhs are the same witness, just return it! + if (lhs.get_witness_index() == rhs.get_witness_index() && (lhs.additive_constant == rhs.additive_constant) && + (lhs.multiplicative_constant == rhs.multiplicative_constant)) { + return lhs; + } return (lhs - rhs).madd(predicate, rhs); } diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.hpp index cd73554418e..baa0d15d388 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.hpp @@ -4,8 +4,7 @@ #include "barretenberg/common/assert.hpp" #include -namespace proof_system::plonk { -namespace stdlib { +namespace proof_system::plonk::stdlib { template class bool_t; template class field_t { @@ -15,34 +14,36 @@ template class field_t { field_t(const int value) : context(nullptr) + , witness_index(IS_CONSTANT) { additive_constant = barretenberg::fr(value); multiplicative_constant = barretenberg::fr(0); - witness_index = IS_CONSTANT; } + // NOLINTNEXTLINE(google-runtime-int) intended behaviour field_t(const unsigned long long value) : context(nullptr) + , witness_index(IS_CONSTANT) { additive_constant = barretenberg::fr(value); multiplicative_constant = barretenberg::fr(0); - witness_index = IS_CONSTANT; } field_t(const unsigned int value) : context(nullptr) + , witness_index(IS_CONSTANT) { additive_constant = barretenberg::fr(value); multiplicative_constant = barretenberg::fr(0); - witness_index = IS_CONSTANT; } + // NOLINTNEXTLINE(google-runtime-int) intended behaviour field_t(const unsigned long value) : context(nullptr) + , witness_index(IS_CONSTANT) { additive_constant = barretenberg::fr(value); multiplicative_constant = barretenberg::fr(0); - witness_index = IS_CONSTANT; } field_t(const barretenberg::fr& value) @@ -68,7 +69,7 @@ template class field_t { , witness_index(other.witness_index) {} - field_t(field_t&& other) + field_t(field_t&& other) noexcept : context(other.context) , additive_constant(other.additive_constant) , multiplicative_constant(other.multiplicative_constant) @@ -77,15 +78,20 @@ template class field_t { field_t(const bool_t& other); + ~field_t() = default; + static constexpr bool is_composite = false; static constexpr uint256_t modulus = barretenberg::fr::modulus; - static field_t from_witness_index(Builder* parent_context, const uint32_t witness_index); + static field_t from_witness_index(Builder* parent_context, uint32_t witness_index); explicit operator bool_t() const; field_t& operator=(const field_t& other) { + if (this == &other) { + return *this; + } additive_constant = other.additive_constant; multiplicative_constant = other.multiplicative_constant; witness_index = other.witness_index; @@ -93,7 +99,7 @@ template class field_t { return *this; } - field_t& operator=(field_t&& other) + field_t& operator=(field_t&& other) noexcept { additive_constant = other.additive_constant; multiplicative_constant = other.multiplicative_constant; @@ -149,7 +155,8 @@ template class field_t { }; // Postfix increment (x++) - field_t operator++(int) + // NOLINTNEXTLINE + field_t operator++(const int) { field_t this_before_operation = field_t(*this); *this = *this + 1; @@ -244,7 +251,7 @@ template class field_t { * Slices a `field_t` at given indices (msb, lsb) both included in the slice, * returns three parts: [low, slice, high]. */ - std::array slice(const uint8_t msb, const uint8_t lsb) const; + std::array slice(uint8_t msb, uint8_t lsb) const; /** * is_zero will return a bool_t, and add constraints that enforce its correctness @@ -252,7 +259,7 @@ template class field_t { **/ bool_t is_zero() const; - void create_range_constraint(const size_t num_bits, std::string const& msg = "field_t::range_constraint") const; + void create_range_constraint(size_t num_bits, std::string const& msg = "field_t::range_constraint") const; void assert_is_not_zero(std::string const& msg = "field_t::assert_is_not_zero") const; void assert_is_zero(std::string const& msg = "field_t::assert_is_zero") const; bool is_constant() const { return witness_index == IS_CONSTANT; } @@ -289,9 +296,11 @@ template class field_t { uint32_t get_witness_index() const { return witness_index; } std::vector> decompose_into_bits( - const size_t num_bits = 256, + size_t num_bits = 256, std::function(Builder* ctx, uint64_t, uint256_t)> get_bit = - [](Builder* ctx, uint64_t j, uint256_t val) { return witness_t(ctx, val.get_bit(j)); }) const; + [](Builder* ctx, uint64_t j, const uint256_t& val) { + return witness_t(ctx, val.get_bit(j)); + }) const; /** * @brief Return (a < b) as bool circuit type. @@ -419,5 +428,4 @@ template inline std::ostream& operator<<(std::ostream& os, fi EXTERN_STDLIB_TYPE(field_t); -} // namespace stdlib -} // namespace proof_system::plonk +} // namespace proof_system::plonk::stdlib diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.cpp new file mode 100644 index 00000000000..806d1fa3c77 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.cpp @@ -0,0 +1,1323 @@ +#include "../field/field.hpp" +#include "barretenberg/crypto/pedersen_commitment/pedersen.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" + +#include "../../hash/pedersen/pedersen.hpp" +#include "../../hash/pedersen/pedersen_gates.hpp" + +#include "./cycle_group.hpp" +#include "barretenberg/proof_system/plookup_tables/types.hpp" +#include "barretenberg/stdlib/primitives/plookup/plookup.hpp" +namespace proof_system::plonk::stdlib { + +template +cycle_group::cycle_group(Composer* _context) + : x(0) + , y(0) + , _is_infinity(true) + , _is_constant(true) + , context(_context) +{} + +/** + * @brief Construct a new cycle group::cycle group object + * + * @param _x + * @param _y + * @param is_infinity + */ +template +cycle_group::cycle_group(field_t _x, field_t _y, bool_t is_infinity) + : x(_x.normalize()) + , y(_y.normalize()) + , _is_infinity(is_infinity) + , _is_constant(_x.is_constant() && _y.is_constant() && is_infinity.is_constant()) +{ + if (_x.get_context() != nullptr) { + context = _x.get_context(); + } else if (_y.get_context() != nullptr) { + context = _y.get_context(); + } else { + context = is_infinity.get_context(); + } +} + +/** + * @brief Construct a new cycle group::cycle group object + * + * @details is_infinity is a circuit constant. We EXPLICITLY require that whether this point is infinity/not infinity is + * known at circuit-construction time *and* we know this point is on the curve. These checks are not constrained. + * Use from_witness if these conditions are not met. + * Examples of when conditions are met: point is a derived from a point that is on the curve + not at infinity. + * e.g. output of a doubling operation + * @tparam Composer + * @param _x + * @param _y + * @param is_infinity + */ +template +cycle_group::cycle_group(const FF& _x, const FF& _y, bool is_infinity) + : x(_x) + , y(_y) + , _is_infinity(is_infinity) + , _is_constant(true) + , context(nullptr) +{ + ASSERT(get_value().on_curve()); +} + +/** + * @brief Construct a cycle_group object out of an AffineElement object + * + * @note This produces a circuit-constant object i.e. known at compile-time, no constraints. + * If `_in` is not fixed for a given circuit, use `from_witness` instead + * + * @tparam Composer + * @param _in + */ +template +cycle_group::cycle_group(const AffineElement& _in) + : x(_in.x) + , y(_in.y) + , _is_infinity(_in.is_point_at_infinity()) + , _is_constant(true) + , context(nullptr) +{} + +/** + * @brief Converts an AffineElement into a circuit witness. + * + * @details Somewhat expensive as we do an on-curve check and `_is_infinity` is a witness and not a constant. + * If an element is being converted where it is known the element is on the curve and/or cannot be point at + * infinity, it is best to use other methods (e.g. direct conversion of field_t coordinates) + * + * @tparam Composer + * @param _context + * @param _in + * @return cycle_group + */ +template +cycle_group cycle_group::from_witness(Composer* _context, const AffineElement& _in) +{ + cycle_group result(_context); + result.x = field_t(witness_t(_context, _in.x)); + result.y = field_t(witness_t(_context, _in.y)); + result._is_infinity = bool_t(witness_t(_context, _in.is_point_at_infinity())); + result._is_constant = false; + result.validate_is_on_curve(); + return result; +} + +/** + * @brief Converts a native AffineElement into a witness, but constrains the witness values to be known constants. + * + * @details When performing group operations where one operand is a witness and one is a constant, + * it can be more efficient to convert the constant element into a witness. This is because we have custom gates + * that evaluate additions in one constraint, but only if both operands are witnesses. + * + * @tparam Composer + * @param _context + * @param _in + * @return cycle_group + */ +template +cycle_group cycle_group::from_constant_witness(Composer* _context, const AffineElement& _in) +{ + cycle_group result(_context); + result.x = field_t(witness_t(_context, _in.x)); + result.y = field_t(witness_t(_context, _in.y)); + result.x.assert_equal(_in.x); + result.y.assert_equal(_in.y); + // point at infinity is circuit constant + result._is_infinity = _in.is_point_at_infinity(); + result._is_constant = false; + return result; +} + +template Composer* cycle_group::get_context(const cycle_group& other) const +{ + if (get_context() != nullptr) { + return get_context(); + } + return other.get_context(); +} + +template typename cycle_group::AffineElement cycle_group::get_value() const +{ + AffineElement result(x.get_value(), y.get_value()); + if (is_point_at_infinity().get_value()) { + result.self_set_infinity(); + } + return result; +} + +/** + * @brief On-curve check. + * + * @tparam Composer + */ +template void cycle_group::validate_is_on_curve() const +{ + // This class is for short Weierstrass curves only! + static_assert(Group::curve_a == 0); + auto xx = x * x; + auto xxx = xx * x; + auto res = y.madd(y, -xxx - Group::curve_b); + res *= is_point_at_infinity(); + res.assert_is_zero(); +} + +/** + * @brief Evaluates a doubling. Does not use Ultra double gate + * + * @tparam Composer + * @return cycle_group + */ +template +cycle_group cycle_group::dbl() const + requires IsNotUltraArithmetic +{ + auto lambda = (x * x * 3) / (y + y); + auto x3 = lambda.madd(lambda, -x - x); + auto y3 = lambda.madd(x - x3, -y); + return cycle_group(x3, y3, false); +} + +/** + * @brief Evaluates a doubling. Uses Ultra double gate + * + * @tparam Composer + * @return cycle_group + */ +template +cycle_group cycle_group::dbl() const + requires IsUltraArithmetic +{ + // n.b. if p1 is point at infinity, calling p1.dbl() does not give us an output that satisfies the double gate + // :o) (native code just checks out of the dbl() method if point is at infinity) + auto x1 = x.get_value(); + auto y1 = y.get_value(); + auto lambda = (x1 * x1 * 3) / (y1 + y1); + auto x3 = lambda * lambda - x1 - x1; + auto y3 = lambda * (x1 - x3) - y1; + AffineElement p3(x3, y3); + + if (is_constant()) { + return cycle_group(p3); + } + + auto context = get_context(); + + field_t r_x(witness_t(context, p3.x)); + field_t r_y(witness_t(context, p3.y)); + cycle_group result = cycle_group(r_x, r_y, false); + result.set_point_at_infinity(is_point_at_infinity()); + proof_system::ecc_dbl_gate_ dbl_gate{ + .x1 = x.get_witness_index(), + .y1 = y.get_witness_index(), + .x3 = result.x.get_witness_index(), + .y3 = result.y.get_witness_index(), + }; + + context->create_ecc_dbl_gate(dbl_gate); + return result; +} + +/** + * @brief Will evaluate ECC point addition over `*this` and `other`. + * Incomplete addition formula edge cases are *NOT* checked! + * Only use this method if you know the x-coordinates of the operands cannot collide + * Standard version that does not use ecc group gate + * + * @tparam Composer + * @param other + * @return cycle_group + */ +template +cycle_group cycle_group::unconditional_add(const cycle_group& other) const + requires IsNotUltraArithmetic +{ + auto x_diff = other.x - x; + auto y_diff = other.y - y; + // unconditional add so do not check divisor is zero + // (this also makes it much easier to test failure cases as this does not segfault!) + auto lambda = y_diff.divide_no_zero_check(x_diff); + auto x3 = lambda.madd(lambda, -other.x - x); + auto y3 = lambda.madd(x - x3, -y); + cycle_group result(x3, y3, false); + return result; +} + +/** + * @brief Will evaluate ECC point addition over `*this` and `other`. + * Incomplete addition formula edge cases are *NOT* checked! + * Only use this method if you know the x-coordinates of the operands cannot collide + * Ultra version that uses ecc group gate + * + * @tparam Composer + * @param other + * @return cycle_group + */ +template +cycle_group cycle_group::unconditional_add(const cycle_group& other) const + requires IsUltraArithmetic +{ + auto context = get_context(other); + + const bool lhs_constant = is_constant(); + const bool rhs_constant = other.is_constant(); + if (lhs_constant && !rhs_constant) { + auto lhs = cycle_group::from_constant_witness(context, get_value()); + return lhs.unconditional_add(other); + } + if (!lhs_constant && rhs_constant) { + auto rhs = cycle_group::from_constant_witness(context, other.get_value()); + return unconditional_add(rhs); + } + + const auto p1 = get_value(); + const auto p2 = other.get_value(); + AffineElement p3(Element(p1) + Element(p2)); + if (lhs_constant && rhs_constant) { + return cycle_group(p3); + } + field_t r_x(witness_t(context, p3.x)); + field_t r_y(witness_t(context, p3.y)); + cycle_group result(r_x, r_y, false); + + proof_system::ecc_add_gate_ add_gate{ + .x1 = x.get_witness_index(), + .y1 = y.get_witness_index(), + .x2 = other.x.get_witness_index(), + .y2 = other.y.get_witness_index(), + .x3 = result.x.get_witness_index(), + .y3 = result.y.get_witness_index(), + .endomorphism_coefficient = 1, + .sign_coefficient = 1, + }; + context->create_ecc_add_gate(add_gate); + + return result; +} + +/** + * @brief will evaluate ECC point subtraction over `*this` and `other`. + * Incomplete addition formula edge cases are *NOT* checked! + * Only use this method if you know the x-coordinates of the operands cannot collide + * + * @tparam Composer + * @param other + * @return cycle_group + */ +template +cycle_group cycle_group::unconditional_subtract(const cycle_group& other) const +{ + if constexpr (!IS_ULTRA) { + return unconditional_add(-other); + } else { + auto context = get_context(other); + + const bool lhs_constant = is_constant(); + const bool rhs_constant = other.is_constant(); + + if (lhs_constant && !rhs_constant) { + auto lhs = cycle_group::from_constant_witness(context, get_value()); + return lhs.unconditional_subtract(other); + } + if (!lhs_constant && rhs_constant) { + auto rhs = cycle_group::from_constant_witness(context, other.get_value()); + return unconditional_subtract(rhs); + } + auto p1 = get_value(); + auto p2 = other.get_value(); + AffineElement p3(Element(p1) - Element(p2)); + if (lhs_constant && rhs_constant) { + return cycle_group(p3); + } + field_t r_x(witness_t(context, p3.x)); + field_t r_y(witness_t(context, p3.y)); + cycle_group result(r_x, r_y, false); + + proof_system::ecc_add_gate_ add_gate{ + .x1 = x.get_witness_index(), + .y1 = y.get_witness_index(), + .x2 = other.x.get_witness_index(), + .y2 = other.y.get_witness_index(), + .x3 = result.x.get_witness_index(), + .y3 = result.y.get_witness_index(), + .endomorphism_coefficient = 1, + .sign_coefficient = -1, + }; + context->create_ecc_add_gate(add_gate); + + return result; + } +} + +/** + * @brief Will evaluate ECC point addition over `*this` and `other`. + * Uses incomplete addition formula + * If incomplete addition formula edge cases are triggered (x-coordinates of operands collide), + * the constraints produced by this method will be unsatisfiable. + * Useful when an honest prover will not produce a point collision with overwhelming probability, + * but a cheating prover will be able to. + * + * @tparam Composer + * @param other + * @return cycle_group + */ +template +cycle_group cycle_group::checked_unconditional_add(const cycle_group& other) const +{ + field_t x_delta = x - other.x; + x_delta.assert_is_not_zero("cycle_group::checked_unconditional_add, x-coordinate collision"); + return unconditional_add(other); +} + +/** + * @brief Will evaluate ECC point subtraction over `*this` and `other`. + * Uses incomplete addition formula + * If incomplete addition formula edge cases are triggered (x-coordinates of operands collide), + * the constraints produced by this method will be unsatisfiable. + * Useful when an honest prover will not produce a point collision with overwhelming probability, + * but a cheating prover will be able to. + * + * @tparam Composer + * @param other + * @return cycle_group + */ +template +cycle_group cycle_group::checked_unconditional_subtract(const cycle_group& other) const +{ + field_t x_delta = x - other.x; + x_delta.assert_is_not_zero("cycle_group::checked_unconditional_subtract, x-coordinate collision"); + return unconditional_subtract(other); +} + +/** + * @brief Will evaluate ECC point addition over `*this` and `other`. + * This method uses complete addition i.e. is compatible with edge cases. + * Method is expensive due to needing to evaluate both an addition, a doubling, + * plus conditional logic to handle points at infinity. + * + * @tparam Composer + * @param other + * @return cycle_group + */ +template cycle_group cycle_group::operator+(const cycle_group& other) const +{ + Composer* context = get_context(other); + const bool_t x_coordinates_match = (x == other.x); + const bool_t y_coordinates_match = (y == other.y); + const bool_t double_predicate = (x_coordinates_match && y_coordinates_match); + const bool_t infinity_predicate = (x_coordinates_match && !y_coordinates_match); + + auto x1 = x; + auto y1 = y; + auto x2 = other.x; + auto y2 = other.y; + // if x_coordinates match, lambda triggers a divide by zero error. + // Adding in `x_coordinates_match` ensures that lambda will always be well-formed + auto x_diff = x2.add_two(-x1, x_coordinates_match); + auto lambda = (y2 - y1) / x_diff; + auto x3 = lambda.madd(lambda, -(x2 + x1)); + auto y3 = lambda.madd(x1 - x3, -y1); + cycle_group add_result(x3, y3, x_coordinates_match); + + auto dbl_result = dbl(); + + // dbl if x_match, y_match + // infinity if x_match, !y_match + cycle_group result(context); + result.x = field_t::conditional_assign(double_predicate, dbl_result.x, add_result.x); + result.y = field_t::conditional_assign(double_predicate, dbl_result.y, add_result.y); + + const bool_t lhs_infinity = is_point_at_infinity(); + const bool_t rhs_infinity = other.is_point_at_infinity(); + // if lhs infinity, return rhs + result.x = field_t::conditional_assign(lhs_infinity, other.x, result.x); + result.y = field_t::conditional_assign(lhs_infinity, other.y, result.y); + + // if rhs infinity, return lhs + result.x = field_t::conditional_assign(rhs_infinity, x, result.x); + result.y = field_t::conditional_assign(rhs_infinity, y, result.y); + + // is result point at infinity? + // yes = infinity_predicate && !lhs_infinity && !rhs_infinity + // yes = lhs_infinity && rhs_infinity + // n.b. can likely optimise this + bool_t result_is_infinity = infinity_predicate && (!lhs_infinity && !rhs_infinity); + result_is_infinity = result_is_infinity || (lhs_infinity && rhs_infinity); + result.set_point_at_infinity(result_is_infinity); + return result; +} + +/** + * @brief Will evaluate ECC point subtraction over `*this` and `other`. + * This method uses complete addition i.e. is compatible with edge cases. + * Method is expensive due to needing to evaluate both an addition, a doubling, + * plus conditional logic to handle points at infinity. + * + * @tparam Composer + * @param other + * @return cycle_group + */ +template cycle_group cycle_group::operator-(const cycle_group& other) const +{ + Composer* context = get_context(other); + const bool_t x_coordinates_match = (x == other.x); + const bool_t y_coordinates_match = (y == other.y); + const bool_t double_predicate = (x_coordinates_match && !y_coordinates_match).normalize(); + const bool_t infinity_predicate = (x_coordinates_match && y_coordinates_match).normalize(); + + auto x1 = x; + auto y1 = y; + auto x2 = other.x; + auto y2 = other.y; + auto x_diff = x2.add_two(-x1, x_coordinates_match); + auto lambda = (-y2 - y1) / x_diff; + auto x3 = lambda.madd(lambda, -(x2 + x1)); + auto y3 = lambda.madd(x1 - x3, -y1); + cycle_group add_result(x3, y3, x_coordinates_match); + + auto dbl_result = dbl(); + + // dbl if x_match, !y_match + // infinity if x_match, y_match + cycle_group result(context); + result.x = field_t::conditional_assign(double_predicate, dbl_result.x, add_result.x); + result.y = field_t::conditional_assign(double_predicate, dbl_result.y, add_result.y); + + const bool_t lhs_infinity = is_point_at_infinity(); + const bool_t rhs_infinity = other.is_point_at_infinity(); + // if lhs infinity, return -rhs + result.x = field_t::conditional_assign(lhs_infinity, other.x, result.x); + result.y = field_t::conditional_assign(lhs_infinity, (-other.y).normalize(), result.y); + + // if rhs infinity, return lhs + result.x = field_t::conditional_assign(rhs_infinity, x, result.x); + result.y = field_t::conditional_assign(rhs_infinity, y, result.y); + + // is result point at infinity? + // yes = infinity_predicate && !lhs_infinity && !rhs_infinity + // yes = lhs_infinity && rhs_infinity + // n.b. can likely optimise this + bool_t result_is_infinity = infinity_predicate && (!lhs_infinity && !rhs_infinity); + result_is_infinity = result_is_infinity || (lhs_infinity && rhs_infinity); + result.set_point_at_infinity(result_is_infinity); + + return result; +} + +/** + * @brief Negates a point + * + * @tparam Composer + * @param other + * @return cycle_group + */ +template cycle_group cycle_group::operator-() const +{ + cycle_group result(*this); + result.y = -y; + return result; +} + +template cycle_group& cycle_group::operator+=(const cycle_group& other) +{ + *this = *this + other; + return *this; +} + +template cycle_group& cycle_group::operator-=(const cycle_group& other) +{ + *this = *this - other; + return *this; +} + +template +cycle_group::cycle_scalar::cycle_scalar(const field_t& _lo, const field_t& _hi) + : lo(_lo) + , hi(_hi) +{} + +template cycle_group::cycle_scalar::cycle_scalar(const field_t& _in) +{ + const uint256_t value(_in.get_value()); + const uint256_t lo_v = value.slice(0, LO_BITS); + const uint256_t hi_v = value.slice(LO_BITS, HI_BITS); + constexpr uint256_t shift = uint256_t(1) << LO_BITS; + if (_in.is_constant()) { + lo = lo_v; + hi = hi_v; + } else { + lo = witness_t(_in.get_context(), lo_v); + hi = witness_t(_in.get_context(), hi_v); + (lo + hi * shift).assert_equal(_in); + } +} + +template cycle_group::cycle_scalar::cycle_scalar(const ScalarField& _in) +{ + const uint256_t value(_in); + const uint256_t lo_v = value.slice(0, LO_BITS); + const uint256_t hi_v = value.slice(LO_BITS, HI_BITS); + lo = lo_v; + hi = hi_v; +} + +template +typename cycle_group::cycle_scalar cycle_group::cycle_scalar::from_witness(Composer* context, + const ScalarField& value) +{ + const uint256_t value_u256(value); + const uint256_t lo_v = value_u256.slice(0, LO_BITS); + const uint256_t hi_v = value_u256.slice(LO_BITS, HI_BITS); + field_t lo = witness_t(context, lo_v); + field_t hi = witness_t(context, hi_v); + return cycle_scalar(lo, hi); +} + +/** + * @brief Use when we want to multiply a group element by a string of bits of known size. + * N.B. using this constructor method will make our scalar multiplication methods not perform primality tests. + * + * @tparam Composer + * @param context + * @param value + * @param num_bits + * @return cycle_group::cycle_scalar + */ +template +typename cycle_group::cycle_scalar cycle_group::cycle_scalar::from_witness_bitstring( + Composer* context, const uint256_t& bitstring, const size_t num_bits) +{ + ASSERT(bitstring.get_msb() < num_bits); + const uint256_t lo_v = bitstring.slice(0, LO_BITS); + const uint256_t hi_v = bitstring.slice(LO_BITS, HI_BITS); + field_t lo = witness_t(context, lo_v); + field_t hi = witness_t(context, hi_v); + cycle_scalar result{ lo, hi, num_bits, true, false }; + return result; +} + +/** + * @brief Use when we want to multiply a group element by a string of bits of known size. + * N.B. using this constructor method will make our scalar multiplication methods not perform primality tests. + * + * @tparam Composer + * @param context + * @param value + * @param num_bits + * @return cycle_group::cycle_scalar + */ +template +typename cycle_group::cycle_scalar cycle_group::cycle_scalar::create_from_bn254_scalar( + const field_t& in) +{ + const uint256_t value_u256(in.get_value()); + const uint256_t lo_v = value_u256.slice(0, LO_BITS); + const uint256_t hi_v = value_u256.slice(LO_BITS, HI_BITS); + if (in.is_constant()) { + cycle_scalar result{ field_t(lo_v), field_t(hi_v), NUM_BITS, false, true }; + return result; + } + field_t lo = witness_t(in.get_context(), lo_v); + field_t hi = witness_t(in.get_context(), hi_v); + lo.add_two(hi * (uint256_t(1) << LO_BITS), -in).assert_equal(0); + cycle_scalar result{ lo, hi, NUM_BITS, false, true }; + return result; +} + +template bool cycle_group::cycle_scalar::is_constant() const +{ + return (lo.is_constant() && hi.is_constant()); +} + +template +typename cycle_group::ScalarField cycle_group::cycle_scalar::get_value() const +{ + uint256_t lo_v(lo.get_value()); + uint256_t hi_v(hi.get_value()); + return ScalarField(lo_v + (hi_v << LO_BITS)); +} + +/** + * @brief Construct a new cycle group::straus scalar slice::straus scalar slice object + * + * @details As part of slicing algoirthm, we also perform a primality test on the inut scalar. + * + * TODO(@zac-williamson) make the primality test configurable. + * We may want to validate the input < BN254::Fr OR input < Grumpkin::Fr depending on context! + * + * @tparam Composer + * @param context + * @param scalar + * @param table_bits + */ +template +cycle_group::straus_scalar_slice::straus_scalar_slice(Composer* context, + const cycle_scalar& scalar, + const size_t table_bits) + : _table_bits(table_bits) +{ + // convert an input cycle_scalar object into a vector of slices, each containing `table_bits` bits. + // this also performs an implicit range check on the input slices + const auto slice_scalar = [&](const field_t& scalar, const size_t num_bits) { + std::vector result; + if (num_bits == 0) { + return result; + } + if (scalar.is_constant()) { + const size_t num_slices = (num_bits + table_bits - 1) / table_bits; + const uint64_t table_mask = (1ULL << table_bits) - 1ULL; + uint256_t raw_value = scalar.get_value(); + for (size_t i = 0; i < num_slices; ++i) { + uint64_t slice_v = static_cast(raw_value.data[0]) & table_mask; + result.push_back(field_t(slice_v)); + raw_value = raw_value >> table_bits; + } + return result; + } + if constexpr (IS_ULTRA) { + const auto slice_indices = + context->decompose_into_default_range(scalar.normalize().get_witness_index(), + num_bits, + table_bits, + "straus_scalar_slice decompose_into_default_range"); + for (auto& idx : slice_indices) { + result.emplace_back(field_t::from_witness_index(context, idx)); + } + } else { + uint256_t raw_value = scalar.get_value(); + const uint64_t table_mask = (1ULL << table_bits) - 1ULL; + const size_t num_slices = (num_bits + table_bits - 1) / table_bits; + for (size_t i = 0; i < num_slices; ++i) { + uint64_t slice_v = static_cast(raw_value.data[0]) & table_mask; + field_t slice(witness_t(context, slice_v)); + + context->create_range_constraint( + slice.get_witness_index(), table_bits, "straus_scalar_slice create_range_constraint"); + + result.emplace_back(slice); + raw_value = raw_value >> table_bits; + } + std::vector linear_elements; + FF scaling_factor = 1; + for (size_t i = 0; i < num_slices; ++i) { + linear_elements.emplace_back(result[i] * scaling_factor); + scaling_factor += scaling_factor; + } + field_t::accumulate(linear_elements).assert_equal(scalar); + } + return result; + }; + + const size_t lo_bits = scalar.num_bits() > cycle_scalar::LO_BITS ? cycle_scalar::LO_BITS : scalar.num_bits(); + const size_t hi_bits = scalar.num_bits() > cycle_scalar::LO_BITS ? scalar.num_bits() - cycle_scalar::LO_BITS : 0; + auto hi_slices = slice_scalar(scalar.hi, hi_bits); + auto lo_slices = slice_scalar(scalar.lo, lo_bits); + + if (!scalar.is_constant() && !scalar.skip_primality_test()) { + // Check that scalar.hi * 2^LO_BITS + scalar.lo < cycle_group_modulus when evaluated over the integers + const uint256_t cycle_group_modulus = + scalar.use_bn254_scalar_field_for_primality_test() ? FF::modulus : ScalarField::modulus; + const uint256_t r_lo = cycle_group_modulus.slice(0, cycle_scalar::LO_BITS); + const uint256_t r_hi = cycle_group_modulus.slice(cycle_scalar::LO_BITS, cycle_scalar::HI_BITS); + + bool need_borrow = uint256_t(scalar.lo.get_value()) > r_lo; + field_t borrow = scalar.lo.is_constant() ? need_borrow : field_t::from_witness(context, need_borrow); + + // directly call `create_new_range_constraint` to avoid creating an arithmetic gate + if (!scalar.lo.is_constant()) { + if constexpr (IS_ULTRA) { + context->create_new_range_constraint(borrow.get_witness_index(), 1, "borrow"); + } else { + borrow.assert_equal(borrow * borrow); + } + } + // Hi range check = r_hi - y_hi - borrow + // Lo range check = r_lo - y_lo + borrow * 2^{126} + field_t hi = (-scalar.hi + r_hi) - borrow; + field_t lo = (-scalar.lo + r_lo) + (borrow * (uint256_t(1) << cycle_scalar::LO_BITS)); + + hi.create_range_constraint(cycle_scalar::HI_BITS); + lo.create_range_constraint(cycle_scalar::LO_BITS); + } + + std::copy(lo_slices.begin(), lo_slices.end(), std::back_inserter(slices)); + std::copy(hi_slices.begin(), hi_slices.end(), std::back_inserter(slices)); +} + +/** + * @brief Return a bit-slice associated with round `index`. + * + * @details In Straus algorithm, `index` is a known parameter, so no need for expensive lookup tables + * + * @tparam Composer + * @param index + * @return field_t + */ +template +std::optional> cycle_group::straus_scalar_slice::read(size_t index) +{ + if (index >= slices.size()) { + return std::nullopt; + } + return slices[index]; +} + +/** + * @brief Construct a new cycle group::straus lookup table::straus lookup table object + * + * @details Constructs a `table_bits` lookup table. + * + * If Composer is not ULTRA, `table_bits = 1` + * If Composer is ULTRA, ROM table is used as lookup table + * + * @tparam Composer + * @param context + * @param base_point + * @param offset_generator + * @param table_bits + */ +template +cycle_group::straus_lookup_table::straus_lookup_table(Composer* context, + const cycle_group& base_point, + const cycle_group& offset_generator, + size_t table_bits) + : _table_bits(table_bits) + , _context(context) +{ + const size_t table_size = 1UL << table_bits; + point_table.resize(table_size); + point_table[0] = offset_generator; + + // We want to support the case where input points are points at infinity. + // If base point is at infinity, we want every point in the table to just be `generator_point`. + // We achieve this via the following: + // 1: We create a "work_point" that is base_point if not at infinity, otherwise is just 1 + // 2: When computing the point table, we use "work_point" in additions instead of the "base_point" (to prevent + // x-coordinate collisions in honest case) 3: When assigning to the point table, we conditionally assign either + // the output of the point addition (if not at infinity) or the generator point (if at infinity) + // Note: if `base_point.is_point_at_infinity()` is constant, these conditional assigns produce zero gate overhead + cycle_group fallback_point(Group::affine_one); + field_t modded_x = field_t::conditional_assign(base_point.is_point_at_infinity(), fallback_point.x, base_point.x); + field_t modded_y = field_t::conditional_assign(base_point.is_point_at_infinity(), fallback_point.y, base_point.y); + cycle_group modded_base_point(modded_x, modded_y, false); + for (size_t i = 1; i < table_size; ++i) { + auto add_output = point_table[i - 1].checked_unconditional_add(modded_base_point); + field_t x = field_t::conditional_assign(base_point.is_point_at_infinity(), offset_generator.x, add_output.x); + field_t y = field_t::conditional_assign(base_point.is_point_at_infinity(), offset_generator.y, add_output.y); + point_table[i] = cycle_group(x, y, false); + } + if constexpr (IS_ULTRA) { + rom_id = context->create_ROM_array(table_size); + for (size_t i = 0; i < table_size; ++i) { + if (point_table[i].is_constant()) { + auto element = point_table[i].get_value(); + point_table[i] = cycle_group::from_constant_witness(_context, element); + point_table[i].x.assert_equal(element.x); + point_table[i].y.assert_equal(element.y); + } + context->set_ROM_element_pair( + rom_id, + i, + std::array{ point_table[i].x.get_witness_index(), point_table[i].y.get_witness_index() }); + } + } else { + ASSERT(table_bits == 1); + } +} + +/** + * @brief Given an `_index` witness, return `straus_lookup_table[index]` + * + * @tparam Composer + * @param _index + * @return cycle_group + */ +template +cycle_group cycle_group::straus_lookup_table::read(const field_t& _index) +{ + if constexpr (IS_ULTRA) { + field_t index(_index); + if (index.is_constant()) { + index = witness_t(_context, _index.get_value()); + index.assert_equal(_index.get_value()); + } + auto output_indices = _context->read_ROM_array_pair(rom_id, index.get_witness_index()); + field_t x = field_t::from_witness_index(_context, output_indices[0]); + field_t y = field_t::from_witness_index(_context, output_indices[1]); + return cycle_group(x, y, false); + } + field_t x = _index * (point_table[1].x - point_table[0].x) + point_table[0].x; + field_t y = _index * (point_table[1].y - point_table[0].y) + point_table[0].y; + return cycle_group(x, y, false); +} + +/** + * @brief Internal algorithm to perform a variable-base batch mul. + * + * @note Explicit assumption that all base_points are witnesses and not constants! + * Constant points must be filtered out by `batch_mul` before calling this. + * + * @details batch mul performed via the Straus multiscalar multiplication algorithm + * (optimal for MSMs where num points <128-ish). + * If Composer is not ULTRA, number of bits per Straus round = 1, + * which reduces to the basic double-and-add algorithm + * + * @details If `unconditional_add = true`, we use `::unconditional_add` instead of `::checked_unconditional_add`. + * Use with caution! Only should be `true` if we're doing an ULTRA fixed-base MSM so we know the points cannot + * collide with the offset generators. + * + * @note ULTRA Composer will call `_variable_base_batch_mul_internal` to evaluate fixed-base MSMs over points that do + * not exist in our precomputed plookup tables. This is a comprimise between maximising circuit efficiency and + * minimising the blowup size of our precomputed table polynomials. variable-base mul uses small ROM lookup tables + * which are witness-defined and not part of the plookup protocol. + * @tparam Composer + * @param scalars + * @param base_points + * @param offset_generators + * @param unconditional_add + * @return cycle_group::batch_mul_internal_output + */ +template +typename cycle_group::batch_mul_internal_output cycle_group::_variable_base_batch_mul_internal( + const std::span scalars, + const std::span base_points, + const std::span offset_generators, + const bool unconditional_add) +{ + ASSERT(scalars.size() == base_points.size()); + Composer* context = nullptr; + for (auto& scalar : scalars) { + if (scalar.lo.get_context() != nullptr) { + context = scalar.get_context(); + break; + } + } + for (auto& point : base_points) { + if (point.get_context() != nullptr) { + context = point.get_context(); + break; + } + } + + size_t num_bits = 0; + for (auto& s : scalars) { + num_bits = std::max(num_bits, s.num_bits()); + } + size_t num_rounds = (num_bits + TABLE_BITS - 1) / TABLE_BITS; + + const size_t num_points = scalars.size(); + + std::vector scalar_slices; + std::vector point_tables; + for (size_t i = 0; i < num_points; ++i) { + scalar_slices.emplace_back(straus_scalar_slice(context, scalars[i], TABLE_BITS)); + point_tables.emplace_back(straus_lookup_table(context, base_points[i], offset_generators[i + 1], TABLE_BITS)); + } + + Element offset_generator_accumulator = offset_generators[0]; + cycle_group accumulator = offset_generators[0]; + + // populate the set of points we are going to add into our accumulator, *before* we do any ECC operations + // this way we are able to fuse mutliple ecc add / ecc double operations and reduce total gate count. + // (ecc add/ecc double gates normally cost 2 UltraPlonk gates. However if we chain add->add, add->double, + // double->add, double->double, they only cost one) + std::vector points_to_add; + for (size_t i = 0; i < num_rounds; ++i) { + for (size_t j = 0; j < num_points; ++j) { + const std::optional scalar_slice = scalar_slices[j].read(num_rounds - i - 1); + // if we are doing a batch mul over scalars of different bit-lengths, we may not have any scalar bits for a + // given round and a given scalar + if (scalar_slice.has_value()) { + const cycle_group point = point_tables[j].read(scalar_slice.value()); + points_to_add.emplace_back(point); + } + } + } + std::vector> x_coordinate_checks; + size_t point_counter = 0; + for (size_t i = 0; i < num_rounds; ++i) { + if (i != 0) { + for (size_t j = 0; j < TABLE_BITS; ++j) { + // offset_generator_accuulator is a regular Element, so dbl() won't add constraints + accumulator = accumulator.dbl(); + offset_generator_accumulator = offset_generator_accumulator.dbl(); + } + } + + for (size_t j = 0; j < num_points; ++j) { + const std::optional scalar_slice = scalar_slices[j].read(num_rounds - i - 1); + // if we are doing a batch mul over scalars of different bit-lengths, we may not have a bit slice for a + // given round and a given scalar + if (scalar_slice.has_value()) { + const auto& point = points_to_add[point_counter++]; + if (!unconditional_add) { + x_coordinate_checks.push_back({ accumulator.x, point.x }); + } + accumulator = accumulator.unconditional_add(point); + offset_generator_accumulator = offset_generator_accumulator + Element(offset_generators[j + 1]); + } + } + } + + for (auto& [x1, x2] : x_coordinate_checks) { + auto x_diff = x2 - x1; + x_diff.assert_is_not_zero("_variable_base_batch_mul_internal x-coordinate collision"); + } + /** + * offset_generator_accumulator represents the sum of all the offset generator terms present in `accumulator`. + * We don't subtract off yet, as we may be able to combine `offset_generator_accumulator` with other constant terms + * in `batch_mul` before performing the subtraction. + */ + return { accumulator, AffineElement(offset_generator_accumulator) }; +} + +/** + * @brief Internal algorithm to perform a fixed-base batch mul for ULTRA Composer + * + * @details Uses plookup tables which contain lookups for precomputed multiples of the input base points. + * Means we can avoid all point doublings and reduce one scalar mul to ~29 lookups + 29 ecc addition gates + * + * @tparam Composer + * @param scalars + * @param base_points + * @param off + * @return cycle_group::batch_mul_internal_output + */ +template +typename cycle_group::batch_mul_internal_output cycle_group::_fixed_base_batch_mul_internal( + const std::span scalars, + const std::span base_points, + [[maybe_unused]] const std::span off) + requires IsUltraArithmetic +{ + ASSERT(scalars.size() == base_points.size()); + + const size_t num_points = base_points.size(); + using MultiTableId = plookup::MultiTableId; + using ColumnIdx = plookup::ColumnIdx; + + std::vector plookup_table_ids; + std::vector plookup_base_points; + std::vector plookup_scalars; + + for (size_t i = 0; i < num_points; ++i) { + std::optional> table_id = + plookup::fixed_base::table::get_lookup_table_ids_for_point(base_points[i]); + ASSERT(table_id.has_value()); + plookup_table_ids.emplace_back(table_id.value()[0]); + plookup_table_ids.emplace_back(table_id.value()[1]); + plookup_base_points.emplace_back(base_points[i]); + plookup_base_points.emplace_back(Element(base_points[i]) * (uint256_t(1) << cycle_scalar::LO_BITS)); + plookup_scalars.emplace_back(scalars[i].lo); + plookup_scalars.emplace_back(scalars[i].hi); + } + + std::vector lookup_points; + Element offset_generator_accumulator = Group::point_at_infinity; + for (size_t i = 0; i < plookup_scalars.size(); ++i) { + plookup::ReadData lookup_data = + plookup_read::get_lookup_accumulators(plookup_table_ids[i], plookup_scalars[i]); + for (size_t j = 0; j < lookup_data[ColumnIdx::C2].size(); ++j) { + const auto x = lookup_data[ColumnIdx::C2][j]; + const auto y = lookup_data[ColumnIdx::C3][j]; + lookup_points.emplace_back(cycle_group(x, y, false)); + } + + std::optional offset_1 = + plookup::fixed_base::table::get_generator_offset_for_table_id(plookup_table_ids[i]); + + ASSERT(offset_1.has_value()); + offset_generator_accumulator += offset_1.value(); + } + cycle_group accumulator = lookup_points[0]; + // Perform all point additions sequentially. The Ultra ecc_addition relation costs 1 gate iff additions are chained + // and output point of previous addition = input point of current addition. + // If this condition is not met, the addition relation costs 2 gates. So it's good to do these sequentially! + for (size_t i = 1; i < lookup_points.size(); ++i) { + accumulator = accumulator.unconditional_add(lookup_points[i]); + } + /** + * offset_generator_accumulator represents the sum of all the offset generator terms present in `accumulator`. + * We don't subtract off yet, as we may be able to combine `offset_generator_accumulator` with other constant terms + * in `batch_mul` before performing the subtraction. + */ + return { accumulator, offset_generator_accumulator }; +} + +/** + * @brief Internal algorithm to perform a fixed-base batch mul for Non-ULTRA Composers + * + * @details Multiples of the base point are precomputed, which avoids us having to add ecc doubling gates. + * More efficient than variable-base version. + * + * @tparam Composer + * @param scalars + * @param base_points + * @param off + * @return cycle_group::batch_mul_internal_output + */ +template +typename cycle_group::batch_mul_internal_output cycle_group::_fixed_base_batch_mul_internal( + const std::span scalars, + const std::span base_points, + const std::span offset_generators) + requires IsNotUltraArithmetic + +{ + ASSERT(scalars.size() == base_points.size()); + static_assert(TABLE_BITS == 1); + + Composer* context = nullptr; + for (auto& scalar : scalars) { + if (scalar.get_context() != nullptr) { + context = scalar.get_context(); + break; + } + } + + size_t num_bits = 0; + for (auto& s : scalars) { + num_bits = std::max(num_bits, s.num_bits()); + } + size_t num_rounds = (num_bits + TABLE_BITS - 1) / TABLE_BITS; + // core algorithm + // define a `table_bits` size lookup table + const size_t num_points = scalars.size(); + using straus_round_tables = std::vector; + + std::vector scalar_slices; + std::vector point_tables(num_points); + + // creating these point tables should cost 0 constraints if base points are constant + for (size_t i = 0; i < num_points; ++i) { + std::vector round_points(num_rounds); + std::vector round_offset_generators(num_rounds); + round_points[0] = base_points[i]; + round_offset_generators[0] = offset_generators[i + 1]; + for (size_t j = 1; j < num_rounds; ++j) { + round_points[j] = round_points[j - 1].dbl(); + round_offset_generators[j] = round_offset_generators[j - 1].dbl(); + } + Element::batch_normalize(&round_points[0], num_rounds); + Element::batch_normalize(&round_offset_generators[0], num_rounds); + point_tables[i].resize(num_rounds); + for (size_t j = 0; j < num_rounds; ++j) { + point_tables[i][j] = straus_lookup_table( + context, cycle_group(round_points[j]), cycle_group(round_offset_generators[j]), TABLE_BITS); + } + scalar_slices.emplace_back(straus_scalar_slice(context, scalars[i], TABLE_BITS)); + } + Element offset_generator_accumulator = offset_generators[0]; + cycle_group accumulator = cycle_group(Element(offset_generators[0]) * (uint256_t(1) << (num_rounds - 1))); + for (size_t i = 0; i < num_rounds; ++i) { + offset_generator_accumulator = (i > 0) ? offset_generator_accumulator.dbl() : offset_generator_accumulator; + for (size_t j = 0; j < num_points; ++j) { + auto& point_table = point_tables[j][i]; + const std::optional scalar_slice = scalar_slices[j].read(i); + // if we are doing a batch mul over scalars of different bit-lengths, we may not have any scalar bits for a + // given round and a given scalar + if (scalar_slice.has_value()) { + const cycle_group point = point_table.read(scalar_slice.value()); + accumulator = accumulator.unconditional_add(point); + offset_generator_accumulator = offset_generator_accumulator + Element(offset_generators[j + 1]); + } + } + } + + /** + * offset_generator_accumulator represents the sum of all the offset generator terms present in `accumulator`. + * We don't subtract off yet, as we may be able to combine `offset_generator_accumulator` with other constant terms + * in `batch_mul` before performing the subtraction. + */ + return { accumulator, offset_generator_accumulator }; +} + +/** + * @brief Multiscalar multiplication algorithm. + * + * @details Uses the Straus MSM algorithm. `batch_mul` splits inputs into three categories: + * 1. point and scalar multiplier are both constant + * 2. point is constant, scalar multiplier is a witness + * 3. point is a witness, scalar multiplier can be witness or constant + * + * For Category 1, the scalar mul can be precomuted without constraints + * For Category 2, we use a fixed-base variant of Straus (with plookup tables if available). + * For Category 3, we use standard Straus. + * The results from all 3 categories are combined and returned as an output point. + * + * @note batch_mul can handle all known cases of trigger incomplete addition formula exceptions and other weirdness: + * 1. some/all of the input points are points at infinity + * 2. some/all of the input scalars are 0 + * 3. some/all input points are equal to each other + * 4. output is the point at infinity + * 5. input vectors are empty + * + * @note offset_generator_data is a pointer to precomputed offset generator list. + * There is a default parameter point that poitns to a list with DEFAULT_NUM_GENERATORS generator points (32) + * If more offset generators are required, they will be derived in-place which can be expensive. + * (num required offset generators is either num input points + 1 or num input points + 2, + * depends on if one or both of _fixed_base_batch_mul_internal, _variable_base_batch_mul_internal are called) + * If you're calling this function repeatedly and you KNOW you need >32 offset generators, + * it's faster to create a `generator_data` object with the required size and pass it in as a parameter. + * @tparam Composer + * @param scalars + * @param base_points + * @param offset_generator_data + * @return cycle_group + */ +template +cycle_group cycle_group::batch_mul(const std::vector& scalars, + const std::vector& base_points, + const generator_data* const offset_generator_data) +{ + ASSERT(scalars.size() == base_points.size()); + + std::vector variable_base_scalars; + std::vector variable_base_points; + std::vector fixed_base_scalars; + std::vector fixed_base_points; + + size_t num_bits = 0; + for (auto& s : scalars) { + num_bits = std::max(num_bits, s.num_bits()); + } + + // if num_bits != NUM_BITS, skip lookup-version of fixed-base scalar mul. too much complexity + bool num_bits_not_full_field_size = num_bits != NUM_BITS; + + // When calling `_variable_base_batch_mul_internal`, we can unconditionally add iff all of the input points + // are fixed-base points + // (i.e. we are ULTRA Composer and we are doing fixed-base mul over points not present in our plookup tables) + bool can_unconditional_add = true; + bool has_non_constant_component = false; + Element constant_acc = Group::point_at_infinity; + for (size_t i = 0; i < scalars.size(); ++i) { + bool scalar_constant = scalars[i].is_constant(); + bool point_constant = base_points[i].is_constant(); + if (scalar_constant && point_constant) { + constant_acc += (base_points[i].get_value()) * (scalars[i].get_value()); + } else if (!scalar_constant && point_constant) { + if (base_points[i].get_value().is_point_at_infinity()) { + // oi mate, why are you creating a circuit that multiplies a known point at infinity? + continue; + } + if constexpr (IS_ULTRA) { + if (!num_bits_not_full_field_size && + plookup::fixed_base::table::lookup_table_exists_for_point(base_points[i].get_value())) { + fixed_base_scalars.push_back(scalars[i]); + fixed_base_points.push_back(base_points[i].get_value()); + } else { + // womp womp. We have lookup tables at home. ROM tables. + variable_base_scalars.push_back(scalars[i]); + variable_base_points.push_back(base_points[i]); + } + } else { + fixed_base_scalars.push_back(scalars[i]); + fixed_base_points.push_back(base_points[i].get_value()); + } + has_non_constant_component = true; + } else { + variable_base_scalars.push_back(scalars[i]); + variable_base_points.push_back(base_points[i]); + can_unconditional_add = false; + has_non_constant_component = true; + // variable base + } + } + + // If all inputs are constant, return the computed constant component and call it a day. + if (!has_non_constant_component) { + return cycle_group(constant_acc); + } + + // add the constant component into our offset accumulator + // (we'll subtract `offset_accumulator` from the MSM output i.e. we negate here to counter the future negation) + Element offset_accumulator = -constant_acc; + const bool has_variable_points = !variable_base_points.empty(); + const bool has_fixed_points = !fixed_base_points.empty(); + + // Compute all required offset generators. + const size_t num_offset_generators = + variable_base_points.size() + fixed_base_points.size() + has_variable_points + has_fixed_points; + std::vector offset_generators = + offset_generator_data->conditional_extend(num_offset_generators).generators; + + cycle_group result; + if (has_fixed_points) { + const auto [fixed_accumulator, offset_generator_delta] = + _fixed_base_batch_mul_internal(fixed_base_scalars, fixed_base_points, offset_generators); + offset_accumulator += offset_generator_delta; + result = fixed_accumulator; + } + + if (has_variable_points) { + std::span offset_generators_for_variable_base_batch_mul{ + offset_generators.data() + fixed_base_points.size(), offset_generators.size() - fixed_base_points.size() + }; + const auto [variable_accumulator, offset_generator_delta] = + _variable_base_batch_mul_internal(variable_base_scalars, + variable_base_points, + offset_generators_for_variable_base_batch_mul, + can_unconditional_add); + offset_accumulator += offset_generator_delta; + if (has_fixed_points) { + result = can_unconditional_add ? result.unconditional_add(variable_accumulator) + : result.checked_unconditional_add(variable_accumulator); + } else { + result = variable_accumulator; + } + } + + // Update `result` to remove the offset generator terms, and add in any constant terms from `constant_acc`. + // We have two potential modes here: + // 1. All inputs are fixed-base and we constant_acc is not the point at infinity + // 2. Everything else. + // Case 1 is a special case, as we *know* we cannot hit incomplete addition edge cases, + // under the assumption that all input points are linearly independent of one another. + // Because constant_acc is not the point at infnity we know that at least 1 input scalar was not zero, + // i.e. the output will not be the point at infinity. We also know under case 1, we won't trigger the + // doubling formula either, as every point is lienarly independent of every other point (including offset + // generators). + if (!constant_acc.is_point_at_infinity() && can_unconditional_add) { + result = result.unconditional_add(AffineElement(-offset_accumulator)); + } else { + // For case 2, we must use a full subtraction operation that handles all possible edge cases, as the output + // point may be the point at infinity. + // TODO(@zac-williamson) We can probably optimise this a bit actually. We might hit the point at infinity, + // but an honest prover won't trigger the doubling edge case. + // (doubling edge case implies input points are also the offset generator points, + // which we can assume an honest Prover will not do if we make this case produce unsatisfiable constraints) + // We could do the following: + // 1. If x-coords match, assert y-coords do not match + // 2. If x-coords match, return point at infinity, else return result - offset_accumulator. + // This would be slightly cheaper than operator- as we do not have to evaluate the double edge case. + result = result - AffineElement(offset_accumulator); + } + return result; +} + +template cycle_group cycle_group::operator*(const cycle_scalar& scalar) const +{ + return batch_mul({ scalar }, { *this }); +} + +template cycle_group& cycle_group::operator*=(const cycle_scalar& scalar) +{ + *this = operator*(scalar); + return *this; +} + +template cycle_group cycle_group::operator/(const cycle_group& /*unused*/) const +{ + // TODO(@kevaundray solve the discrete logarithm problem) + throw_or_abort("Implementation under construction..."); +} + +INSTANTIATE_STDLIB_TYPE(cycle_group); + +} // namespace proof_system::plonk::stdlib diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.hpp new file mode 100644 index 00000000000..c2f03df4105 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.hpp @@ -0,0 +1,240 @@ +#pragma once + +#include "../field/field.hpp" +#include "barretenberg/crypto/pedersen_commitment/pedersen.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" + +#include "../../hash/pedersen/pedersen.hpp" +#include "../../hash/pedersen/pedersen_gates.hpp" +#include "barretenberg/proof_system/plookup_tables/fixed_base/fixed_base_params.hpp" +#include + +namespace proof_system::plonk::stdlib { + +template +concept IsUltraArithmetic = (Composer::CIRCUIT_TYPE == CircuitType::ULTRA); +template +concept IsNotUltraArithmetic = (Composer::CIRCUIT_TYPE != CircuitType::ULTRA); + +/** + * @brief cycle_group represents a group Element of the proving system's embedded curve + * i.e. a curve with a cofactor 1 defined over a field equal to the circuit's native field Composer::FF + * + * (todo @zac-williamson) once the pedersen refactor project is finished, this class will supercede + * `stdlib::group` + * + * @tparam Composer + */ +template class cycle_group { + public: + using field_t = stdlib::field_t; + using bool_t = stdlib::bool_t; + using witness_t = stdlib::witness_t; + using FF = typename Composer::FF; + using Curve = typename Composer::EmbeddedCurve; + using Group = typename Curve::Group; + using Element = typename Curve::Element; + using AffineElement = typename Curve::AffineElement; + using generator_data = crypto::generator_data; + using ScalarField = typename Curve::ScalarField; + + static constexpr size_t STANDARD_NUM_TABLE_BITS = 1; + static constexpr size_t ULTRA_NUM_TABLE_BITS = 4; + static constexpr bool IS_ULTRA = Composer::CIRCUIT_TYPE == CircuitType::ULTRA; + static constexpr size_t TABLE_BITS = IS_ULTRA ? ULTRA_NUM_TABLE_BITS : STANDARD_NUM_TABLE_BITS; + static constexpr size_t NUM_BITS = ScalarField::modulus.get_msb() + 1; + static constexpr size_t NUM_ROUNDS = (NUM_BITS + TABLE_BITS - 1) / TABLE_BITS; + inline static const std::string OFFSET_GENERATOR_DOMAIN_SEPARATOR = "cycle_group_offset_generator"; + + private: + inline static const generator_data default_offset_generators = + generator_data(generator_data::DEFAULT_NUM_GENERATORS, OFFSET_GENERATOR_DOMAIN_SEPARATOR); + + public: + /** + * @brief cycle_scalar represents a member of the cycle curve SCALAR FIELD. + * This is NOT the native circuit field type. + * i.e. for a BN254 circuit, cycle_group will be Grumpkin and cycle_scalar will be Grumpkin::ScalarField + * (BN254 native field is BN254::ScalarField == Grumpkin::BaseField) + * + * @details We convert scalar multiplication inputs into cycle_scalars to enable scalar multiplication to be + * *complete* i.e. Grumpkin points multiplied by BN254 scalars does not produce a cyclic group + * as BN254::ScalarField < Grumpkin::ScalarField + * This complexity *should* not leak outside the cycle_group / cycle_scalar implementations, as cycle_scalar + * performs all required conversions if the input scalars are stdlib::field_t elements + * + * @note We opted to create a new class to represent `cycle_scalar` instead of using `bigfield`, + * as `bigfield` is inefficient in this context. All required range checks for `cycle_scalar` can be obtained for + * free from the `batch_mul` algorithm, making the range checks performed by `bigfield` largely redundant. + */ + struct cycle_scalar { + static constexpr size_t LO_BITS = plookup::FixedBaseParams::BITS_PER_LO_SCALAR; + static constexpr size_t HI_BITS = NUM_BITS - LO_BITS; + field_t lo; + field_t hi; + + private: + size_t _num_bits = NUM_BITS; + bool _skip_primality_test = false; + // if our scalar multiplier is a bn254 FF scalar (e.g. pedersen hash), + // we want to validate the cycle_scalar < bn254::fr::modulus *not* grumpkin::fr::modulus + bool _use_bn254_scalar_field_for_primality_test = false; + + public: + cycle_scalar(const field_t& _lo, + const field_t& _hi, + const size_t bits, + const bool skip_primality_test, + const bool use_bn254_scalar_field_for_primality_test) + : lo(_lo) + , hi(_hi) + , _num_bits(bits) + , _skip_primality_test(skip_primality_test) + , _use_bn254_scalar_field_for_primality_test(use_bn254_scalar_field_for_primality_test){}; + cycle_scalar(const ScalarField& _in = 0); + cycle_scalar(const field_t& _lo, const field_t& _hi); + cycle_scalar(const field_t& _in); + static cycle_scalar from_witness(Composer* context, const ScalarField& value); + static cycle_scalar from_witness_bitstring(Composer* context, const uint256_t& bitstring, size_t num_bits); + static cycle_scalar create_from_bn254_scalar(const field_t& _in); + [[nodiscard]] bool is_constant() const; + ScalarField get_value() const; + Composer* get_context() const { return lo.get_context() != nullptr ? lo.get_context() : hi.get_context(); } + [[nodiscard]] size_t num_bits() const { return _num_bits; } + [[nodiscard]] bool skip_primality_test() const { return _skip_primality_test; } + [[nodiscard]] bool use_bn254_scalar_field_for_primality_test() const + { + return _use_bn254_scalar_field_for_primality_test; + } + }; + + /** + * @brief straus_scalar_slice decomposes an input scalar into `table_bits` bit-slices. + * Used in `batch_mul`, which ses the Straus multiscalar multiplication algorithm. + * + */ + struct straus_scalar_slice { + straus_scalar_slice(Composer* context, const cycle_scalar& scalars, size_t table_bits); + std::optional read(size_t index); + size_t _table_bits; + std::vector slices; + }; + + /** + * @brief straus_lookup_table computes a lookup table of size 1 << table_bits + * + * @details for an input base_point [P] and offset_generator point [G], where N = 1 << table_bits, the following is + * computed: + * + * { [G] + 0.[P], [G] + 1.[P], ..., [G] + (N - 1).[P] } + * + * The point [G] is used to ensure that we do not have to handle the point at infinity associated with 0.[P]. + * + * For an HONEST Prover, the probability of [G] and [P] colliding is equivalent to solving the dlog problem. + * This allows us to partially ignore the incomplete addition formula edge-cases for short Weierstrass curves. + * + * When adding group elements in `batch_mul`, we can constrain+assert the x-coordinates of the operand points do not + * match. An honest prover will never trigger the case where x-coordinates match due to the above. Validating + * x-coordinates do not match is much cheaper than evaluating the full complete addition formulae for short + * Weierstrass curves. + * + * @note For the case of fixed-base scalar multipliation, all input points are defined at circuit compile. + * We can ensure that all Provers cannot create point collisions between the base points and offset generators. + * For this restricted case we can skip the x-coordiante collision checks when performing group operations. + * + * @note straus_lookup_table uses UltraPlonk ROM tables if available. If not, we use simple conditional assignment + * constraints and restrict the table size to be 1 bit. + */ + struct straus_lookup_table { + public: + straus_lookup_table() = default; + straus_lookup_table(Composer* context, + const cycle_group& base_point, + const cycle_group& offset_generator, + size_t table_bits); + cycle_group read(const field_t& index); + size_t _table_bits; + Composer* _context; + std::vector point_table; + size_t rom_id = 0; + }; + + private: + /** + * @brief Stores temporary variables produced by internal multiplication algorithms + * + */ + struct batch_mul_internal_output { + cycle_group accumulator; + AffineElement offset_generator_delta; + }; + + public: + cycle_group(Composer* _context = nullptr); + cycle_group(field_t _x, field_t _y, bool_t _is_infinity); + cycle_group(const FF& _x, const FF& _y, bool _is_infinity); + cycle_group(const AffineElement& _in); + static cycle_group from_witness(Composer* _context, const AffineElement& _in); + static cycle_group from_constant_witness(Composer* _context, const AffineElement& _in); + Composer* get_context(const cycle_group& other) const; + Composer* get_context() const { return context; } + AffineElement get_value() const; + [[nodiscard]] bool is_constant() const { return _is_constant; } + bool_t is_point_at_infinity() const { return _is_infinity; } + void set_point_at_infinity(const bool_t& is_infinity) { _is_infinity = is_infinity; } + void validate_is_on_curve() const; + cycle_group dbl() const + requires IsUltraArithmetic; + cycle_group dbl() const + requires IsNotUltraArithmetic; + cycle_group unconditional_add(const cycle_group& other) const + requires IsUltraArithmetic; + cycle_group unconditional_add(const cycle_group& other) const + requires IsNotUltraArithmetic; + cycle_group unconditional_subtract(const cycle_group& other) const; + cycle_group checked_unconditional_add(const cycle_group& other) const; + cycle_group checked_unconditional_subtract(const cycle_group& other) const; + cycle_group operator+(const cycle_group& other) const; + cycle_group operator-(const cycle_group& other) const; + cycle_group operator-() const; + cycle_group& operator+=(const cycle_group& other); + cycle_group& operator-=(const cycle_group& other); + static cycle_group batch_mul(const std::vector& scalars, + const std::vector& base_points, + const generator_data* offset_generator_data = &default_offset_generators); + cycle_group operator*(const cycle_scalar& scalar) const; + cycle_group& operator*=(const cycle_scalar& scalar); + cycle_group operator/(const cycle_group& other) const; + + field_t x; + field_t y; + + private: + bool_t _is_infinity; + bool _is_constant; + Composer* context; + + static batch_mul_internal_output _variable_base_batch_mul_internal(std::span scalars, + std::span base_points, + std::span offset_generators, + bool unconditional_add); + + static batch_mul_internal_output _fixed_base_batch_mul_internal(std::span scalars, + std::span base_points, + std::span offset_generators) + requires IsUltraArithmetic; + static batch_mul_internal_output _fixed_base_batch_mul_internal(std::span scalars, + std::span base_points, + std::span offset_generators) + requires IsNotUltraArithmetic; +}; + +template +inline std::ostream& operator<<(std::ostream& os, cycle_group const& v) +{ + return os << v.get_value(); +} + +EXTERN_STDLIB_TYPE(cycle_group); + +} // namespace proof_system::plonk::stdlib diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.test.cpp new file mode 100644 index 00000000000..65b722699db --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.test.cpp @@ -0,0 +1,566 @@ +#include "barretenberg/stdlib/primitives/group/cycle_group.hpp" +#include "barretenberg/crypto/pedersen_commitment/pedersen_refactor.hpp" +#include "barretenberg/crypto/pedersen_hash/pedersen.hpp" +#include "barretenberg/numeric/random/engine.hpp" +#include "barretenberg/stdlib/primitives/field/field.hpp" +#include "barretenberg/stdlib/primitives/witness/witness.hpp" +#include + +#define STDLIB_TYPE_ALIASES \ + using Composer = TypeParam; \ + using cycle_group_ct = stdlib::cycle_group; \ + using Curve = typename stdlib::cycle_group::Curve; \ + using Element = typename Curve::Element; \ + using AffineElement = typename Curve::AffineElement; \ + using Group = typename Curve::Group; \ + using bool_ct = stdlib::bool_t; \ + using witness_ct = stdlib::witness_t; + +namespace stdlib_cycle_group_tests { +using namespace barretenberg; +using namespace proof_system::plonk; + +namespace { +auto& engine = numeric::random::get_debug_engine(); +} +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-local-typedefs" + +template class CycleGroupTest : public ::testing::Test { + public: + using Curve = typename stdlib::cycle_group::Curve; + using Group = typename Curve::Group; + + using Element = typename Curve::Element; + using AffineElement = typename Curve::AffineElement; + + static constexpr size_t num_generators = 110; + static inline std::array generators{}; + + static void SetUpTestSuite() + { + + for (size_t i = 0; i < num_generators; ++i) { + generators[i] = Group::one * Curve::ScalarField::random_element(&engine); + } + }; +}; + +using CircuitTypes = ::testing::Types; +TYPED_TEST_SUITE(CycleGroupTest, CircuitTypes); + +TYPED_TEST(CycleGroupTest, TestDbl) +{ + STDLIB_TYPE_ALIASES; + auto composer = Composer(); + + auto lhs = TestFixture::generators[0]; + cycle_group_ct a = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct c = a.dbl(); + AffineElement expected(Element(lhs).dbl()); + AffineElement result = c.get_value(); + EXPECT_EQ(result, expected); + + bool proof_result = composer.check_circuit(); + EXPECT_EQ(proof_result, true); +} + +TYPED_TEST(CycleGroupTest, TestUnconditionalAdd) +{ + STDLIB_TYPE_ALIASES; + auto composer = Composer(); + + auto add = + [&](const AffineElement& lhs, const AffineElement& rhs, const bool lhs_constant, const bool rhs_constant) { + cycle_group_ct a = lhs_constant ? cycle_group_ct(lhs) : cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct b = rhs_constant ? cycle_group_ct(rhs) : cycle_group_ct::from_witness(&composer, rhs); + cycle_group_ct c = a.unconditional_add(b); + AffineElement expected(Element(lhs) + Element(rhs)); + AffineElement result = c.get_value(); + EXPECT_EQ(result, expected); + }; + + add(TestFixture::generators[0], TestFixture::generators[1], false, false); + add(TestFixture::generators[0], TestFixture::generators[1], false, true); + add(TestFixture::generators[0], TestFixture::generators[1], true, false); + add(TestFixture::generators[0], TestFixture::generators[1], true, true); + + bool proof_result = composer.check_circuit(); + EXPECT_EQ(proof_result, true); +} + +TYPED_TEST(CycleGroupTest, TestConstrainedUnconditionalAddSucceed) +{ + STDLIB_TYPE_ALIASES; + auto composer = Composer(); + + auto lhs = TestFixture::generators[0]; + auto rhs = TestFixture::generators[1]; + + // case 1. valid unconditional add + cycle_group_ct a = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct b = cycle_group_ct::from_witness(&composer, rhs); + cycle_group_ct c = a.checked_unconditional_add(b); + AffineElement expected(Element(lhs) + Element(rhs)); + AffineElement result = c.get_value(); + EXPECT_EQ(result, expected); + + bool proof_result = composer.check_circuit(); + EXPECT_EQ(proof_result, true); +} + +TYPED_TEST(CycleGroupTest, TestConstrainedUnconditionalAddFail) +{ + STDLIB_TYPE_ALIASES; + auto composer = Composer(); + + auto lhs = TestFixture::generators[0]; + auto rhs = -TestFixture::generators[0]; // ruh roh + + // case 2. invalid unconditional add + cycle_group_ct a = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct b = cycle_group_ct::from_witness(&composer, rhs); + a.checked_unconditional_add(b); + + EXPECT_TRUE(composer.failed()); + + bool proof_result = composer.check_circuit(); + EXPECT_EQ(proof_result, false); +} + +TYPED_TEST(CycleGroupTest, TestAdd) +{ + STDLIB_TYPE_ALIASES; + auto composer = Composer(); + + auto lhs = TestFixture::generators[0]; + auto rhs = -TestFixture::generators[1]; + + cycle_group_ct point_at_infinity = cycle_group_ct::from_witness(&composer, rhs); + point_at_infinity.set_point_at_infinity(bool_ct(witness_ct(&composer, true))); + + // case 1. no edge-cases triggered + { + cycle_group_ct a = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct b = cycle_group_ct::from_witness(&composer, rhs); + cycle_group_ct c = a + b; + AffineElement expected(Element(lhs) + Element(rhs)); + AffineElement result = c.get_value(); + EXPECT_EQ(result, expected); + } + + // case 2. lhs is point at infinity + { + cycle_group_ct a = point_at_infinity; + cycle_group_ct b = cycle_group_ct::from_witness(&composer, rhs); + cycle_group_ct c = a + b; + AffineElement result = c.get_value(); + EXPECT_EQ(result, rhs); + } + + // case 3. rhs is point at infinity + { + cycle_group_ct a = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct b = point_at_infinity; + cycle_group_ct c = a + b; + AffineElement result = c.get_value(); + EXPECT_EQ(result, lhs); + } + + // case 4. both points are at infinity + { + cycle_group_ct a = point_at_infinity; + cycle_group_ct b = point_at_infinity; + cycle_group_ct c = a + b; + EXPECT_TRUE(c.is_point_at_infinity().get_value()); + EXPECT_TRUE(c.get_value().is_point_at_infinity()); + } + + // case 5. lhs = -rhs + { + cycle_group_ct a = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct b = cycle_group_ct::from_witness(&composer, -lhs); + cycle_group_ct c = a + b; + EXPECT_TRUE(c.is_point_at_infinity().get_value()); + EXPECT_TRUE(c.get_value().is_point_at_infinity()); + } + + // case 6. lhs = rhs + { + cycle_group_ct a = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct b = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct c = a + b; + AffineElement expected((Element(lhs)).dbl()); + AffineElement result = c.get_value(); + EXPECT_EQ(result, expected); + } + + bool proof_result = composer.check_circuit(); + EXPECT_EQ(proof_result, true); +} + +TYPED_TEST(CycleGroupTest, TestUnconditionalSubtract) +{ + STDLIB_TYPE_ALIASES; + auto composer = Composer(); + + auto add = + [&](const AffineElement& lhs, const AffineElement& rhs, const bool lhs_constant, const bool rhs_constant) { + cycle_group_ct a = lhs_constant ? cycle_group_ct(lhs) : cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct b = rhs_constant ? cycle_group_ct(rhs) : cycle_group_ct::from_witness(&composer, rhs); + cycle_group_ct c = a.unconditional_subtract(b); + AffineElement expected(Element(lhs) - Element(rhs)); + AffineElement result = c.get_value(); + EXPECT_EQ(result, expected); + }; + + add(TestFixture::generators[0], TestFixture::generators[1], false, false); + add(TestFixture::generators[0], TestFixture::generators[1], false, true); + add(TestFixture::generators[0], TestFixture::generators[1], true, false); + add(TestFixture::generators[0], TestFixture::generators[1], true, true); + + bool proof_result = composer.check_circuit(); + EXPECT_EQ(proof_result, true); +} + +TYPED_TEST(CycleGroupTest, TestConstrainedUnconditionalSubtractSucceed) +{ + STDLIB_TYPE_ALIASES; + auto composer = Composer(); + + auto lhs = TestFixture::generators[0]; + auto rhs = TestFixture::generators[1]; + + // case 1. valid unconditional add + cycle_group_ct a = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct b = cycle_group_ct::from_witness(&composer, rhs); + cycle_group_ct c = a.checked_unconditional_subtract(b); + AffineElement expected(Element(lhs) - Element(rhs)); + AffineElement result = c.get_value(); + EXPECT_EQ(result, expected); + + bool proof_result = composer.check_circuit(); + EXPECT_EQ(proof_result, true); +} + +TYPED_TEST(CycleGroupTest, TestConstrainedUnconditionalSubtractFail) +{ + STDLIB_TYPE_ALIASES; + auto composer = Composer(); + + auto lhs = TestFixture::generators[0]; + auto rhs = -TestFixture::generators[0]; // ruh roh + + // case 2. invalid unconditional add + cycle_group_ct a = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct b = cycle_group_ct::from_witness(&composer, rhs); + a.checked_unconditional_subtract(b); + + EXPECT_TRUE(composer.failed()); + + bool proof_result = composer.check_circuit(); + EXPECT_EQ(proof_result, false); +} + +TYPED_TEST(CycleGroupTest, TestSubtract) +{ + STDLIB_TYPE_ALIASES; + using bool_ct = stdlib::bool_t; + using witness_ct = stdlib::witness_t; + auto composer = Composer(); + + auto lhs = TestFixture::generators[0]; + auto rhs = -TestFixture::generators[1]; + + cycle_group_ct point_at_infinity = cycle_group_ct::from_witness(&composer, rhs); + point_at_infinity.set_point_at_infinity(bool_ct(witness_ct(&composer, true))); + + // case 1. no edge-cases triggered + { + cycle_group_ct a = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct b = cycle_group_ct::from_witness(&composer, rhs); + cycle_group_ct c = a - b; + AffineElement expected(Element(lhs) - Element(rhs)); + AffineElement result = c.get_value(); + EXPECT_EQ(result, expected); + } + + // case 2. lhs is point at infinity + { + cycle_group_ct a = point_at_infinity; + cycle_group_ct b = cycle_group_ct::from_witness(&composer, rhs); + cycle_group_ct c = a - b; + AffineElement result = c.get_value(); + EXPECT_EQ(result, -rhs); + } + + // case 3. rhs is point at infinity + { + cycle_group_ct a = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct b = point_at_infinity; + cycle_group_ct c = a - b; + AffineElement result = c.get_value(); + EXPECT_EQ(result, lhs); + } + + // case 4. both points are at infinity + { + cycle_group_ct a = point_at_infinity; + cycle_group_ct b = point_at_infinity; + cycle_group_ct c = a - b; + EXPECT_TRUE(c.is_point_at_infinity().get_value()); + EXPECT_TRUE(c.get_value().is_point_at_infinity()); + } + + // case 5. lhs = -rhs + { + cycle_group_ct a = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct b = cycle_group_ct::from_witness(&composer, -lhs); + cycle_group_ct c = a - b; + AffineElement expected((Element(lhs)).dbl()); + AffineElement result = c.get_value(); + EXPECT_EQ(result, expected); + } + + // case 6. lhs = rhs + { + cycle_group_ct a = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct b = cycle_group_ct::from_witness(&composer, lhs); + cycle_group_ct c = a - b; + EXPECT_TRUE(c.is_point_at_infinity().get_value()); + EXPECT_TRUE(c.get_value().is_point_at_infinity()); + } + + bool proof_result = composer.check_circuit(); + EXPECT_EQ(proof_result, true); +} + +TYPED_TEST(CycleGroupTest, TestBatchMul) +{ + STDLIB_TYPE_ALIASES; + auto composer = Composer(); + + const size_t num_muls = 1; + + // case 1, general MSM with inputs that are combinations of constant and witnesses + { + std::vector points; + std::vector scalars; + Element expected = Group::point_at_infinity; + + for (size_t i = 0; i < num_muls; ++i) { + auto element = TestFixture::generators[i]; + typename Group::subgroup_field scalar = Group::subgroup_field::random_element(&engine); + + // 1: add entry where point, scalar are witnesses + expected += (element * scalar); + points.emplace_back(cycle_group_ct::from_witness(&composer, element)); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&composer, scalar)); + + // 2: add entry where point is constant, scalar is witness + expected += (element * scalar); + points.emplace_back(cycle_group_ct(element)); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&composer, scalar)); + + // 3: add entry where point is witness, scalar is constant + expected += (element * scalar); + points.emplace_back(cycle_group_ct::from_witness(&composer, element)); + scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); + + // 4: add entry where point is constant, scalar is constant + expected += (element * scalar); + points.emplace_back(cycle_group_ct(element)); + scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); + } + auto result = cycle_group_ct::batch_mul(scalars, points); + EXPECT_EQ(result.get_value(), AffineElement(expected)); + } + + // case 2, MSM that produces point at infinity + { + std::vector points; + std::vector scalars; + + auto element = TestFixture::generators[0]; + typename Group::subgroup_field scalar = Group::subgroup_field::random_element(&engine); + points.emplace_back(cycle_group_ct::from_witness(&composer, element)); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&composer, scalar)); + + points.emplace_back(cycle_group_ct::from_witness(&composer, element)); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&composer, -scalar)); + + auto result = cycle_group_ct::batch_mul(scalars, points); + EXPECT_TRUE(result.is_point_at_infinity().get_value()); + } + + // case 3. Multiply by zero + { + std::vector points; + std::vector scalars; + + auto element = TestFixture::generators[0]; + typename Group::subgroup_field scalar = 0; + points.emplace_back(cycle_group_ct::from_witness(&composer, element)); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&composer, scalar)); + auto result = cycle_group_ct::batch_mul(scalars, points); + EXPECT_TRUE(result.is_point_at_infinity().get_value()); + } + + // case 4. Inputs are points at infinity + { + std::vector points; + std::vector scalars; + + auto element = TestFixture::generators[0]; + typename Group::subgroup_field scalar = Group::subgroup_field::random_element(&engine); + + // is_infinity = witness + { + cycle_group_ct point = cycle_group_ct::from_witness(&composer, element); + point.set_point_at_infinity(witness_ct(&composer, true)); + points.emplace_back(point); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&composer, scalar)); + } + // is_infinity = constant + { + cycle_group_ct point = cycle_group_ct::from_witness(&composer, element); + point.set_point_at_infinity(true); + points.emplace_back(point); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&composer, scalar)); + } + auto result = cycle_group_ct::batch_mul(scalars, points); + EXPECT_TRUE(result.is_point_at_infinity().get_value()); + } + + // case 5, fixed-base MSM with inputs that are combinations of constant and witnesses (group elements are in lookup + // table) + { + std::vector points; + std::vector scalars; + std::vector scalars_native; + Element expected = Group::point_at_infinity; + for (size_t i = 0; i < num_muls; ++i) { + auto element = crypto::pedersen_hash_refactor::get_lhs_generator(); + typename Group::subgroup_field scalar = Group::subgroup_field::random_element(&engine); + + // 1: add entry where point is constant, scalar is witness + expected += (element * scalar); + points.emplace_back(element); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&composer, scalar)); + scalars_native.emplace_back(uint256_t(scalar)); + + // 2: add entry where point is constant, scalar is constant + element = crypto::pedersen_hash_refactor::get_rhs_generator(); + expected += (element * scalar); + points.emplace_back(element); + scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); + scalars_native.emplace_back(uint256_t(scalar)); + } + auto result = cycle_group_ct::batch_mul(scalars, points); + EXPECT_EQ(result.get_value(), AffineElement(expected)); + EXPECT_EQ(result.get_value(), crypto::pedersen_commitment_refactor::commit_native(scalars_native)); + } + + // case 6, fixed-base MSM with inputs that are combinations of constant and witnesses (some group elements are in + // lookup table) + { + std::vector points; + std::vector scalars; + std::vector scalars_native; + Element expected = Group::point_at_infinity; + for (size_t i = 0; i < num_muls; ++i) { + auto element = crypto::pedersen_hash_refactor::get_lhs_generator(); + typename Group::subgroup_field scalar = Group::subgroup_field::random_element(&engine); + + // 1: add entry where point is constant, scalar is witness + expected += (element * scalar); + points.emplace_back(element); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&composer, scalar)); + scalars_native.emplace_back(scalar); + + // 2: add entry where point is constant, scalar is constant + element = crypto::pedersen_hash_refactor::get_rhs_generator(); + expected += (element * scalar); + points.emplace_back(element); + scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); + scalars_native.emplace_back(scalar); + + // // 3: add entry where point is constant, scalar is witness + scalar = Group::subgroup_field::random_element(&engine); + element = Group::one * Group::subgroup_field::random_element(&engine); + expected += (element * scalar); + points.emplace_back(element); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&composer, scalar)); + scalars_native.emplace_back(scalar); + } + auto result = cycle_group_ct::batch_mul(scalars, points); + EXPECT_EQ(result.get_value(), AffineElement(expected)); + } + + // case 7, Fixed-base MSM where input scalars are 0 + { + std::vector points; + std::vector scalars; + + for (size_t i = 0; i < num_muls; ++i) { + auto element = crypto::pedersen_hash_refactor::get_lhs_generator(); + typename Group::subgroup_field scalar = 0; + + // 1: add entry where point is constant, scalar is witness + points.emplace_back((element)); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&composer, scalar)); + + // // 2: add entry where point is constant, scalar is constant + points.emplace_back((element)); + scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); + } + auto result = cycle_group_ct::batch_mul(scalars, points); + EXPECT_EQ(result.is_point_at_infinity().get_value(), true); + } + + bool proof_result = composer.check_circuit(); + EXPECT_EQ(proof_result, true); +} + +TYPED_TEST(CycleGroupTest, TestMul) +{ + STDLIB_TYPE_ALIASES + auto composer = Composer(); + + const size_t num_muls = 5; + + // case 1, general MSM with inputs that are combinations of constant and witnesses + { + cycle_group_ct point; + typename cycle_group_ct::cycle_scalar scalar; + for (size_t i = 0; i < num_muls; ++i) { + auto element = TestFixture::generators[i]; + typename Group::subgroup_field native_scalar = Group::subgroup_field::random_element(&engine); + + // 1: add entry where point, scalar are witnesses + point = (cycle_group_ct::from_witness(&composer, element)); + scalar = (cycle_group_ct::cycle_scalar::from_witness(&composer, native_scalar)); + EXPECT_EQ((point * scalar).get_value(), (element * native_scalar)); + + // 2: add entry where point is constant, scalar is witness + point = (cycle_group_ct(element)); + scalar = (cycle_group_ct::cycle_scalar::from_witness(&composer, native_scalar)); + + EXPECT_EQ((point * scalar).get_value(), (element * native_scalar)); + + // 3: add entry where point is witness, scalar is constant + point = (cycle_group_ct::from_witness(&composer, element)); + EXPECT_EQ((point * scalar).get_value(), (element * native_scalar)); + + // 4: add entry where point is constant, scalar is constant + point = (cycle_group_ct(element)); + EXPECT_EQ((point * scalar).get_value(), (element * native_scalar)); + } + } + + bool proof_result = composer.check_circuit(); + EXPECT_EQ(proof_result, true); +} +#pragma GCC diagnostic pop + +} // namespace stdlib_cycle_group_tests diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/group.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/group.hpp index df7c2f9aa1e..f7a6f8f8d86 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/group.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/group.hpp @@ -1,5 +1,6 @@ #pragma once +// TODO(@zac-williamson #2341 delete this file and rename cycle_group to group once we migrate to new hash standard) #include "../field/field.hpp" #include "barretenberg/crypto/pedersen_commitment/pedersen.hpp" #include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/group.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/group.test.cpp index 752f4f3b80e..a14fe241ef9 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/group.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/group.test.cpp @@ -1,3 +1,5 @@ +// TODO(@zac-williamson #2341 delete this file and once we migrate to new hash standard) + #include "barretenberg/stdlib/primitives/group/group.hpp" #include "barretenberg/numeric/random/engine.hpp" #include "barretenberg/stdlib/primitives/field/field.hpp"