Skip to content

Commit

Permalink
Fixed hash decomposition component.
Browse files Browse the repository at this point in the history
  • Loading branch information
Iluvmagick committed Nov 16, 2023
1 parent 93b6463 commit 158192b
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 188 deletions.
201 changes: 130 additions & 71 deletions include/nil/blueprint/components/hashes/sha2/plonk/decomposition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Copyright (c) 2021 Nikita Kaskov <[email protected]>
// Copyright (c) 2022 Alisa Cherniaeva <[email protected]>
// Copyright (c) 2022 Ekaterina Chukavina <[email protected]>
// Copyright (c) 2023 Dmitrii Tabalin <[email protected]>
//
// MIT License
//
Expand Down Expand Up @@ -80,11 +81,11 @@ namespace nil {

constexpr static std::size_t get_rows_amount(std::size_t witness_amount,
std::size_t lookup_column_amount) {
return 3;
return 4;
}

const std::size_t rows_amount = get_rows_amount(this->witness_amount(), 0);
constexpr static const std::size_t gates_amount = 1;
constexpr static const std::size_t gates_amount = 2;

struct input_type {
std::array<var, 2> data;
Expand All @@ -98,14 +99,14 @@ namespace nil {
std::array<var, 8> output;

result_type(const decomposition &component, std::uint32_t start_row_index) {
output = {var(component.W(0), start_row_index + 1, false),
var(component.W(1), start_row_index + 1, false),
var(component.W(2), start_row_index + 1, false),
var(component.W(3), start_row_index + 1, false),
var(component.W(4), start_row_index + 1, false),
output = {var(component.W(6), start_row_index + 1, false),
var(component.W(5), start_row_index + 1, false),
var(component.W(6), start_row_index + 1, false),
var(component.W(7), start_row_index + 1, false)};
var(component.W(4), start_row_index + 1, false),
var(component.W(3), start_row_index + 1, false),
var(component.W(6), start_row_index + 3, false),
var(component.W(5), start_row_index + 3, false),
var(component.W(4), start_row_index + 3, false),
var(component.W(3), start_row_index + 3, false)};
}

std::vector<var> all_vars() const {
Expand All @@ -130,6 +131,13 @@ namespace nil {
std::initializer_list<typename component_type::public_input_container_type::value_type>
public_inputs) :
component_type(witnesses, constants, public_inputs, get_manifest()) {};

std::map<std::string, std::size_t> component_lookup_tables(){
std::map<std::string, std::size_t> lookup_tables;
lookup_tables["sha256_sparse_base4/first_column"] = 0; // REQUIRED_TABLE

return lookup_tables;
}
};

template<typename BlueprintFieldType, typename ArithmetizationParams>
Expand All @@ -146,77 +154,126 @@ namespace nil {
const typename plonk_native_decomposition<BlueprintFieldType, ArithmetizationParams>::input_type
instance_input,
const std::uint32_t start_row_index) {
using integral_type = typename BlueprintFieldType::integral_type;

std::size_t row = start_row_index;
std::array<typename BlueprintFieldType::integral_type, 2> data = {
typename BlueprintFieldType::integral_type(var_value(assignment, instance_input.data[0]).data),
typename BlueprintFieldType::integral_type(var_value(assignment, instance_input.data[1]).data)};
std::array<typename BlueprintFieldType::integral_type, 16> range_chunks;
std::array<integral_type, 2> data = {
integral_type(var_value(assignment, instance_input.data[0]).data),
integral_type(var_value(assignment, instance_input.data[1]).data)};
std::array<std::array<std::array<integral_type, 3>, 4>, 2> range_chunks;
std::array<std::array<integral_type, 4>, 2> output_chunks;
std::size_t shift = 0;

for (std::size_t i = 0; i < 8; i++) {
range_chunks[i] = (data[0] >> shift) & ((65536) - 1);
assignment.witness(component.W(i), row) = range_chunks[i];
range_chunks[i + 8] = (data[1] >> shift) & ((65536) - 1);
assignment.witness(component.W(i), row + 2) = range_chunks[i + 8];
shift += 16;
for (std::size_t data_idx = 0; data_idx < 2; data_idx++) {
for (std::size_t chunk_idx = 0; chunk_idx < 4; chunk_idx++) {
output_chunks[data_idx][chunk_idx] = (data[data_idx] >> (chunk_idx * 32)) & 0xFFFFFFFF;
// subchunks are 14, 14, and 4 bits long respectively
range_chunks[data_idx][chunk_idx][0] =
(output_chunks[data_idx][chunk_idx] & 0b11111111111111000000000000000000) >> 18;
range_chunks[data_idx][chunk_idx][1] =
(output_chunks[data_idx][chunk_idx] & 0b00000000000000111111111111110000) >> 4;
range_chunks[data_idx][chunk_idx][2] =
(output_chunks[data_idx][chunk_idx] & 0b00000000000000000000000000001111);
BOOST_ASSERT(
output_chunks[data_idx][chunk_idx] ==
range_chunks[data_idx][chunk_idx][0] * (1 << 18) +
range_chunks[data_idx][chunk_idx][1] * (1 << 4) +
range_chunks[data_idx][chunk_idx][2]);
}
}
for (std::size_t data_idx = 0; data_idx < 2; data_idx++) {
const std::size_t first_row = start_row_index + 2 * data_idx,
second_row = start_row_index + 2 * data_idx + 1;
// placing subchunks for first three chunks
for (std::size_t chunk_idx = 0; chunk_idx < 3; chunk_idx++) {
for (std::size_t subchunk_idx = 0; subchunk_idx < 3; subchunk_idx++) {
assignment.witness(component.W(3 * chunk_idx + subchunk_idx), first_row) =
range_chunks[data_idx][chunk_idx][subchunk_idx];
}
}
// placing subchunk for the last chunk
for (std::size_t subchunk_idx = 0; subchunk_idx < 3; subchunk_idx++) {
assignment.witness(component.W(subchunk_idx), second_row) =
range_chunks[data_idx][3][subchunk_idx];
}
// placing chunks
for (std::size_t chunk_idx = 0; chunk_idx < 4; chunk_idx++) {
assignment.witness(component.W(3 + chunk_idx), second_row) =
output_chunks[data_idx][chunk_idx];
}
// placing the original data
assignment.witness(component.W(7), second_row) = data[data_idx];
}

assignment.witness(component.W(8), row) = data[0];
assignment.witness(component.W(8), row + 2) = data[1];

assignment.witness(component.W(3), row + 1) = range_chunks[1] * (65536) + range_chunks[0];
assignment.witness(component.W(2), row + 1) = range_chunks[3] * (65536) + range_chunks[2];
assignment.witness(component.W(1), row + 1) = range_chunks[5] * (65536) + range_chunks[4];
assignment.witness(component.W(0), row + 1) = range_chunks[7] * (65536) + range_chunks[6];

assignment.witness(component.W(7), row + 1) = range_chunks[9] * (65536) + range_chunks[8];
assignment.witness(component.W(6), row + 1) = range_chunks[11] * (65536) + range_chunks[10];
assignment.witness(component.W(5), row + 1) = range_chunks[13] * (65536) + range_chunks[12];
assignment.witness(component.W(4), row + 1) = range_chunks[15] * (65536) + range_chunks[14];

return typename plonk_native_decomposition<BlueprintFieldType, ArithmetizationParams>::result_type(
component, start_row_index);
}

template<typename BlueprintFieldType, typename ArithmetizationParams>
std::size_t generate_gates(
std::array<std::size_t, 2> generate_gates(
const plonk_native_decomposition<BlueprintFieldType, ArithmetizationParams> &component,
circuit<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
const typename plonk_native_decomposition<BlueprintFieldType, ArithmetizationParams>::input_type
&instance_input) {
&instance_input,
const typename lookup_library<BlueprintFieldType>::left_reserved_type &lookup_tables_indices) {

using var = typename plonk_native_decomposition<BlueprintFieldType, ArithmetizationParams>::var;

auto constraint_1 =
var(component.W(8), -1) - (var(component.W(3), 0) + var(component.W(2), 0) * 0x100000000_cppui255 +
var(component.W(1), 0) * 0x10000000000000000_cppui255 +
var(component.W(0), 0) * 0x1000000000000000000000000_cppui255);
auto constraint_2 =
var(component.W(8), 1) - (var(component.W(7), 0) + var(component.W(6), 0) * 0x100000000_cppui255 +
var(component.W(5), 0) * 0x10000000000000000_cppui255 +
var(component.W(4), 0) * 0x1000000000000000000000000_cppui255);
auto constraint_3 = var(component.W(3), 0) -
(var(component.W(0), -1) + var(component.W(1), -1) * (65536));
auto constraint_4 = var(component.W(2), 0) -
(var(component.W(2), -1) + var(component.W(3), -1) * (65536));
auto constraint_5 = var(component.W(1), 0) -
(var(component.W(4), -1) + var(component.W(5), -1) * (65536));
auto constraint_6 = var(component.W(0), 0) -
(var(component.W(6), -1) + var(component.W(7), -1) * (65536));
auto constraint_7 = var(component.W(7), 0) -
(var(component.W(0), +1) + var(component.W(1), +1) * (65536));
auto constraint_8 = var(component.W(6), 0) -
(var(component.W(2), +1) + var(component.W(3), +1) * (65536));
auto constraint_9 = var(component.W(5), 0) -
(var(component.W(4), +1) + var(component.W(5), +1) * (65536));
auto constraint_10 = var(component.W(4), 0) -
(var(component.W(6), +1) + var(component.W(7), +1) * (65536));
return bp.add_gate(
{constraint_1, constraint_2, constraint_3, constraint_4, constraint_5, constraint_6,
constraint_7, constraint_8, constraint_9, constraint_10});
using lookup_constraint = crypto3::zk::snark::plonk_lookup_constraint<BlueprintFieldType>;
using constraint = crypto3::zk::snark::plonk_constraint<BlueprintFieldType>;

const typename BlueprintFieldType::integral_type one = 1;
std::array<std::size_t, 2> selectors;

std::vector<lookup_constraint> subchunk_lookup_constraints(12);
// lookup constraints for the first three chunks
for (std::size_t chunk_idx = 0; chunk_idx < 3; chunk_idx++) {
subchunk_lookup_constraints[3 * chunk_idx] =
{lookup_tables_indices.at("sha256_sparse_base4/first_column"),
{var(component.W(3 * chunk_idx), -1)}};
subchunk_lookup_constraints[3 * chunk_idx + 1] =
{lookup_tables_indices.at("sha256_sparse_base4/first_column"),
{var(component.W(3 * chunk_idx + 1), -1)}};
subchunk_lookup_constraints[3 * chunk_idx + 2] =
{lookup_tables_indices.at("sha256_sparse_base4/first_column"),
{1024 * var(component.W(3 * chunk_idx + 2), -1)}};
}
// lookup constraints for the last chunk
subchunk_lookup_constraints[9] =
{lookup_tables_indices.at("sha256_sparse_base4/first_column"),
{var(component.W(0), 0)}};
subchunk_lookup_constraints[10] =
{lookup_tables_indices.at("sha256_sparse_base4/first_column"),
{var(component.W(1), 0)}};
subchunk_lookup_constraints[11] =
{lookup_tables_indices.at("sha256_sparse_base4/first_column"),
{1024 * var(component.W(2), 0)}};

selectors[0] = bp.add_lookup_gate(subchunk_lookup_constraints);

std::vector<constraint> chunk_constraints(5);
// chunk sum constraints for the first three chunks
for (std::size_t chunk_idx = 0; chunk_idx < 3; chunk_idx++) {
chunk_constraints[chunk_idx] =
var(component.W(3 * chunk_idx), -1) * (1 << 18) +
var(component.W(3 * chunk_idx + 1), -1) * (1 << 4) +
var(component.W(3 * chunk_idx + 2), -1) -
var(component.W(3 + chunk_idx), 0);
}
// chunk sum constraints for the last chunk
chunk_constraints[3] =
var(component.W(0), 0) * (1 << 18) +
var(component.W(1), 0) * (1 << 4) +
var(component.W(2), 0) -
var(component.W(6), 0);
// chunk sum constraint for input
chunk_constraints[4] =
var(component.W(3), 0) + var(component.W(4), 0) * (one << 32) +
var(component.W(5), 0) * (one << 64) + var(component.W(6), 0) * (one << 96) -
var(component.W(7), 0);
selectors[1] = bp.add_gate(chunk_constraints);

return selectors;
}

template<typename BlueprintFieldType, typename ArithmetizationParams>
Expand All @@ -230,11 +287,9 @@ namespace nil {
const std::size_t start_row_index) {

using var = typename plonk_native_decomposition<BlueprintFieldType, ArithmetizationParams>::var;
// CRITICAL: these copy constraints might not be sufficient, but are definitely required.
// I've added copy constraints for the inputs, but internal ones might be missing
// Proceed with care
bp.add_copy_constraint({instance_input.data[0], var(component.W(8), start_row_index, false)});
bp.add_copy_constraint({instance_input.data[1], var(component.W(8), start_row_index + 2, false)});

bp.add_copy_constraint({instance_input.data[0], var(component.W(7), start_row_index + 1, false)});
bp.add_copy_constraint({instance_input.data[1], var(component.W(7), start_row_index + 3, false)});
}

template<typename BlueprintFieldType, typename ArithmetizationParams>
Expand All @@ -248,10 +303,14 @@ namespace nil {
&instance_input,
const std::size_t start_row_index) {

std::size_t j = start_row_index + 1;
std::size_t selector_index = generate_gates(component, bp, assignment, instance_input);
std::array<std::size_t, 2> selector_indices =
generate_gates(component, bp, assignment, instance_input, bp.get_reserved_indices());

assignment.enable_selector(selector_indices[0], start_row_index + 1);
assignment.enable_selector(selector_indices[0], start_row_index + 3);
assignment.enable_selector(selector_indices[1], start_row_index + 1);
assignment.enable_selector(selector_indices[1], start_row_index + 3);

assignment.enable_selector(selector_index, j);
generate_copy_constraints(component, bp, assignment, instance_input, start_row_index);

return typename plonk_native_decomposition<BlueprintFieldType, ArithmetizationParams>::result_type(
Expand Down
24 changes: 0 additions & 24 deletions include/nil/blueprint/components/hashes/sha2/plonk/sha256.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,30 +139,6 @@ namespace nil {
using lookup_table_definition = typename
nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>;

std::vector<std::shared_ptr<lookup_table_definition>> component_custom_lookup_tables(){
std::vector<std::shared_ptr<lookup_table_definition>> result = {};

auto sparse_values_base4 = std::shared_ptr<lookup_table_definition>(new typename sha256_process_type::sparse_values_base4_table());
result.push_back(sparse_values_base4);

auto sparse_values_base7 = std::shared_ptr<lookup_table_definition>(new typename sha256_process_type::sparse_values_base7_table());
result.push_back(sparse_values_base7);

auto maj = std::shared_ptr<lookup_table_definition>(new typename sha256_process_type::maj_function_table());
result.push_back(maj);

auto reverse_sparse_sigmas_base4 = std::shared_ptr<lookup_table_definition>(new typename sha256_process_type::reverse_sparse_sigmas_base4_table());
result.push_back(reverse_sparse_sigmas_base4);

auto reverse_sparse_sigmas_base7 = std::shared_ptr<lookup_table_definition>(new typename sha256_process_type::reverse_sparse_sigmas_base7_table());
result.push_back(reverse_sparse_sigmas_base7);

auto ch = std::shared_ptr<lookup_table_definition>(new typename sha256_process_type::ch_function_table());
result.push_back(ch);

return result;
}

std::map<std::string, std::size_t> component_lookup_tables(){
std::map<std::string, std::size_t> lookup_tables;
lookup_tables["sha256_sparse_base4/full"] = 0; // REQUIRED_TABLE
Expand Down
Loading

0 comments on commit 158192b

Please sign in to comment.