Skip to content

Commit

Permalink
feat!: constant inputs for blackbox (#7222)
Browse files Browse the repository at this point in the history
This PR allows to use constant values for blackbox inputs.
Only MultiScalarMul is currently handling constant input, so it will
fail if constant inputs are used for any other blackboxes. Noir does
ensure that other blackboxes functions do not use constant inputs in
this PR.
I will make a follow-up PR once this one is merged to have more blackbox
functions using constant inputs.

---------

Co-authored-by: TomAFrench <[email protected]>
Co-authored-by: Tom French <[email protected]>
  • Loading branch information
3 people authored Jul 8, 2024
1 parent 819f370 commit 9f9ded2
Show file tree
Hide file tree
Showing 30 changed files with 924 additions and 350 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,56 @@ TEST_F(EcOperations, TestECMultiScalarMul)
fr(0),
};
msm_constrain = MultiScalarMul{
.points = { 1, 2, 3, 1, 2, 3 },
.scalars = { 4, 5, 4, 5 },
.points = { WitnessConstant<fr>{
.index = 1,
.value = fr(0),
.is_constant = false,
},
WitnessConstant<fr>{
.index = 2,
.value = fr(0),
.is_constant = false,
},
WitnessConstant<fr>{
.index = 3,
.value = fr(0),
.is_constant = false,
},
WitnessConstant<fr>{
.index = 1,
.value = fr(0),
.is_constant = false,
},
WitnessConstant<fr>{
.index = 2,
.value = fr(0),
.is_constant = false,
},
WitnessConstant<fr>{
.index = 3,
.value = fr(0),
.is_constant = false,
} },
.scalars = { WitnessConstant<fr>{
.index = 4,
.value = fr(0),
.is_constant = false,
},
WitnessConstant<fr>{
.index = 5,
.value = fr(0),
.is_constant = false,
},
WitnessConstant<fr>{
.index = 4,
.value = fr(0),
.is_constant = false,
},
WitnessConstant<fr>{
.index = 5,
.value = fr(0),
.is_constant = false,
} },
.out_point_x = 6,
.out_point_y = 7,
.out_point_is_infinite = 0,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "multi_scalar_mul.hpp"
#include "barretenberg/dsl/acir_format/serde/acir.hpp"
#include "barretenberg/ecc/curves/bn254/fr.hpp"
#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp"
#include "barretenberg/plonk_honk_shared/arithmetization/gate_data.hpp"
Expand All @@ -21,27 +22,62 @@ template <typename Builder> void create_multi_scalar_mul_constraint(Builder& bui

for (size_t i = 0; i < input.points.size(); i += 3) {
// Instantiate the input point/variable base as `cycle_group_ct`
auto point_x = field_ct::from_witness_index(&builder, input.points[i]);
auto point_y = field_ct::from_witness_index(&builder, input.points[i + 1]);
auto infinite = bool_ct(field_ct::from_witness_index(&builder, input.points[i + 2]));
field_ct point_x;
field_ct point_y;
bool_ct infinite;
if (input.points[i].is_constant) {
point_x = field_ct(input.points[i].value);
} else {
point_x = field_ct::from_witness_index(&builder, input.points[i].index);
}
if (input.points[i + 1].is_constant) {
point_y = field_ct(input.points[i + 1].value);
} else {
point_y = field_ct::from_witness_index(&builder, input.points[i + 1].index);
}
if (input.points[i + 2].is_constant) {
infinite = bool_ct(field_ct(input.points[i + 2].value));
} else {
infinite = bool_ct(field_ct::from_witness_index(&builder, input.points[i + 2].index));
}
cycle_group_ct input_point(point_x, point_y, infinite);
// Reconstruct the scalar from the low and high limbs
field_ct scalar_low_as_field = field_ct::from_witness_index(&builder, input.scalars[2 * (i / 3)]);
field_ct scalar_high_as_field = field_ct::from_witness_index(&builder, input.scalars[2 * (i / 3) + 1]);
field_ct scalar_low_as_field;
field_ct scalar_high_as_field;
if (input.scalars[2 * (i / 3)].is_constant) {
scalar_low_as_field = field_ct(input.scalars[2 * (i / 3)].value);
} else {
scalar_low_as_field = field_ct::from_witness_index(&builder, input.scalars[2 * (i / 3)].index);
}
if (input.scalars[2 * (i / 3) + 1].is_constant) {
scalar_high_as_field = field_ct(input.scalars[2 * (i / 3) + 1].value);
} else {
scalar_high_as_field = field_ct::from_witness_index(&builder, input.scalars[2 * (i / 3) + 1].index);
}
cycle_scalar_ct scalar(scalar_low_as_field, scalar_high_as_field);

// Add the point and scalar to the vectors
points.push_back(input_point);
scalars.push_back(scalar);
}

// Call batch_mul to multiply the points and scalars and sum the results
auto output_point = cycle_group_ct::batch_mul(points, scalars).get_standard_form();

// Add the constraints
builder.assert_equal(output_point.x.get_witness_index(), input.out_point_x);
builder.assert_equal(output_point.y.get_witness_index(), input.out_point_y);
builder.assert_equal(output_point.is_point_at_infinity().witness_index, input.out_point_is_infinite);
// Add the constraints and handle constant values
if (output_point.is_point_at_infinity().is_constant()) {
builder.fix_witness(input.out_point_is_infinite, output_point.is_point_at_infinity().get_value());
} else {
builder.assert_equal(output_point.is_point_at_infinity().witness_index, input.out_point_is_infinite);
}
if (output_point.x.is_constant()) {
builder.fix_witness(input.out_point_x, output_point.x.get_value());
} else {
builder.assert_equal(output_point.x.get_witness_index(), input.out_point_x);
}
if (output_point.y.is_constant()) {
builder.fix_witness(input.out_point_y, output_point.y.get_value());
} else {
builder.assert_equal(output_point.y.get_witness_index(), input.out_point_y);
}
}

template void create_multi_scalar_mul_constraint<UltraCircuitBuilder>(UltraCircuitBuilder& builder,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
#pragma once
#include "barretenberg/serialize/msgpack.hpp"
#include "barretenberg/stdlib/primitives/field/field.hpp"
#include "serde/index.hpp"
#include <cstdint>
#include <vector>

namespace acir_format {

template <typename FF> struct WitnessConstant {
uint32_t index;
FF value;
bool is_constant;
MSGPACK_FIELDS(index, value, is_constant);
friend bool operator==(WitnessConstant const& lhs, WitnessConstant const& rhs) = default;
};

struct MultiScalarMul {
std::vector<uint32_t> points;
std::vector<uint32_t> scalars;
std::vector<WitnessConstant<bb::fr>> points;
std::vector<WitnessConstant<bb::fr>> scalars;

uint32_t out_point_x;
uint32_t out_point_y;
Expand Down
178 changes: 174 additions & 4 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,33 @@ struct Witness {
static Witness bincodeDeserialize(std::vector<uint8_t>);
};

struct ConstantOrWitnessEnum {

struct Constant {
std::string value;

friend bool operator==(const Constant&, const Constant&);
std::vector<uint8_t> bincodeSerialize() const;
static Constant bincodeDeserialize(std::vector<uint8_t>);
};

struct Witness {
Program::Witness value;

friend bool operator==(const Witness&, const Witness&);
std::vector<uint8_t> bincodeSerialize() const;
static Witness bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Constant, Witness> value;

friend bool operator==(const ConstantOrWitnessEnum&, const ConstantOrWitnessEnum&);
std::vector<uint8_t> bincodeSerialize() const;
static ConstantOrWitnessEnum bincodeDeserialize(std::vector<uint8_t>);
};

struct FunctionInput {
Program::Witness witness;
Program::ConstantOrWitnessEnum input;
uint32_t num_bits;

friend bool operator==(const FunctionInput&, const FunctionInput&);
Expand Down Expand Up @@ -6911,6 +6936,151 @@ Program::Circuit serde::Deserializable<Program::Circuit>::deserialize(Deserializ

namespace Program {

inline bool operator==(const ConstantOrWitnessEnum& lhs, const ConstantOrWitnessEnum& rhs)
{
if (!(lhs.value == rhs.value)) {
return false;
}
return true;
}

inline std::vector<uint8_t> ConstantOrWitnessEnum::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<ConstantOrWitnessEnum>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline ConstantOrWitnessEnum ConstantOrWitnessEnum::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<ConstantOrWitnessEnum>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::ConstantOrWitnessEnum>::serialize(const Program::ConstantOrWitnessEnum& obj,
Serializer& serializer)
{
serializer.increase_container_depth();
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
serializer.decrease_container_depth();
}

template <>
template <typename Deserializer>
Program::ConstantOrWitnessEnum serde::Deserializable<Program::ConstantOrWitnessEnum>::deserialize(
Deserializer& deserializer)
{
deserializer.increase_container_depth();
Program::ConstantOrWitnessEnum obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
deserializer.decrease_container_depth();
return obj;
}

namespace Program {

inline bool operator==(const ConstantOrWitnessEnum::Constant& lhs, const ConstantOrWitnessEnum::Constant& rhs)
{
if (!(lhs.value == rhs.value)) {
return false;
}
return true;
}

inline std::vector<uint8_t> ConstantOrWitnessEnum::Constant::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<ConstantOrWitnessEnum::Constant>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline ConstantOrWitnessEnum::Constant ConstantOrWitnessEnum::Constant::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<ConstantOrWitnessEnum::Constant>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::ConstantOrWitnessEnum::Constant>::serialize(
const Program::ConstantOrWitnessEnum::Constant& obj, Serializer& serializer)
{
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
}

template <>
template <typename Deserializer>
Program::ConstantOrWitnessEnum::Constant serde::Deserializable<Program::ConstantOrWitnessEnum::Constant>::deserialize(
Deserializer& deserializer)
{
Program::ConstantOrWitnessEnum::Constant obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const ConstantOrWitnessEnum::Witness& lhs, const ConstantOrWitnessEnum::Witness& rhs)
{
if (!(lhs.value == rhs.value)) {
return false;
}
return true;
}

inline std::vector<uint8_t> ConstantOrWitnessEnum::Witness::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<ConstantOrWitnessEnum::Witness>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline ConstantOrWitnessEnum::Witness ConstantOrWitnessEnum::Witness::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<ConstantOrWitnessEnum::Witness>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::ConstantOrWitnessEnum::Witness>::serialize(
const Program::ConstantOrWitnessEnum::Witness& obj, Serializer& serializer)
{
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
}

template <>
template <typename Deserializer>
Program::ConstantOrWitnessEnum::Witness serde::Deserializable<Program::ConstantOrWitnessEnum::Witness>::deserialize(
Deserializer& deserializer)
{
Program::ConstantOrWitnessEnum::Witness obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const Directive& lhs, const Directive& rhs)
{
if (!(lhs.value == rhs.value)) {
Expand Down Expand Up @@ -7360,7 +7530,7 @@ namespace Program {

inline bool operator==(const FunctionInput& lhs, const FunctionInput& rhs)
{
if (!(lhs.witness == rhs.witness)) {
if (!(lhs.input == rhs.input)) {
return false;
}
if (!(lhs.num_bits == rhs.num_bits)) {
Expand Down Expand Up @@ -7393,7 +7563,7 @@ template <typename Serializer>
void serde::Serializable<Program::FunctionInput>::serialize(const Program::FunctionInput& obj, Serializer& serializer)
{
serializer.increase_container_depth();
serde::Serializable<decltype(obj.witness)>::serialize(obj.witness, serializer);
serde::Serializable<decltype(obj.input)>::serialize(obj.input, serializer);
serde::Serializable<decltype(obj.num_bits)>::serialize(obj.num_bits, serializer);
serializer.decrease_container_depth();
}
Expand All @@ -7404,7 +7574,7 @@ Program::FunctionInput serde::Deserializable<Program::FunctionInput>::deserializ
{
deserializer.increase_container_depth();
Program::FunctionInput obj;
obj.witness = serde::Deserializable<decltype(obj.witness)>::deserialize(deserializer);
obj.input = serde::Deserializable<decltype(obj.input)>::deserialize(deserializer);
obj.num_bits = serde::Deserializable<decltype(obj.num_bits)>::deserialize(deserializer);
deserializer.decrease_container_depth();
return obj;
Expand Down
Loading

0 comments on commit 9f9ded2

Please sign in to comment.