diff --git a/Cargo.toml b/Cargo.toml index c4266b68..7b474deb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ bench = false [dependencies] primitive-types = { version = "0.12.1", features = ["serde"] } # H256 & U256 (I think parity uses this so maybe we just use that crate instead) +thiserror = "1.0" # crypto rand = "0.8.5" @@ -40,7 +41,6 @@ bulletproofs = "4.0.0" curve25519-dalek-ng = "4.1.1" # concurrency -thiserror = "1.0" displaydoc = "0.2" rayon = "1.7.0" dashmap = { version = "5.5.3", features = ["serde"] } @@ -57,9 +57,10 @@ patharg = "0.3.0" # files & serialization serde = { version = "1.0.188", features = ["derive"] } +serde_with = "3.4.0" +serde_bytes = "0.11.12" toml = "0.8.2" csv = "1.3.0" -serde_bytes = "0.11.12" bincode = "1.3.3" chrono = "0.4.31" derive_builder = "0.12.0" diff --git a/README.md b/README.md index 17c1aaa6..d96edf51 100644 --- a/README.md +++ b/README.md @@ -66,18 +66,18 @@ The CLI offers 3 main operations: tree building, proof generation & proof verifi #### Tree building Building a tree can be done: -- from a config file (see tree_config_example.toml) +- from a config file (see dapol_config_example.toml) - from CLI arguments - by deserializing an already-built tree Build a tree using config file (full log verbosity): ```bash -./target/release/dapol -vvv build-tree config-file ./examples/tree_config_example.toml +./target/release/dapol -vvv build-tree config-file ./examples/dapol_config_example.toml ``` Add serialization: ```bash -./target/release/dapol -vvv build-tree config-file ./examples/tree_config_example.toml --serialize . +./target/release/dapol -vvv build-tree config-file ./examples/dapol_config_example.toml --serialize . ``` Deserialize a tree from a file: @@ -87,7 +87,7 @@ Deserialize a tree from a file: Generate proofs (proofs will live in the `./inclusion_proofs/` directory): ```bash -./target/release/dapol -vvv build-tree config-file ./examples/tree_config_example.toml --gen-proofs ./examples/entities_example.csv +./target/release/dapol -vvv build-tree config-file ./examples/dapol_config_example.toml --gen-proofs ./examples/entities_example.csv ``` Build a tree using cli args as apposed to a config file: diff --git a/benches/criterion_benches.rs b/benches/criterion_benches.rs index 49390885..9b95c9c9 100644 --- a/benches/criterion_benches.rs +++ b/benches/criterion_benches.rs @@ -8,14 +8,14 @@ //! long (see large_input_benches.rs). use std::path::Path; +use std::str::FromStr; use criterion::measurement::Measurement; use criterion::{criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion, SamplingMode}; use statistical::*; -use dapol::accumulators::{NdmSmt, NdmSmtConfigBuilder}; -use dapol::{Accumulator, InclusionProof}; +use dapol::{DapolConfigBuilder, DapolTree, InclusionProof, Secret}; mod inputs; use inputs::{max_thread_counts_greater_than, num_entities_in_range, tree_heights_in_range}; @@ -43,6 +43,8 @@ pub fn bench_build_tree(c: &mut Criterion) { let epoch = jemalloc_ctl::epoch::mib().unwrap(); let allocated = jemalloc_ctl::stats::allocated::mib().unwrap(); + let master_secret = Secret::from_str("secret").unwrap(); + dapol::initialize_machine_parallelism(); dapol::utils::activate_logging(*LOG_VERBOSITY); @@ -103,7 +105,7 @@ pub fn bench_build_tree(c: &mut Criterion) { // Tree build. let mut memory_readings = vec![]; - let mut ndm_smt = Option::::None; + let mut dapol_tree = Option::::None; group.bench_with_input( BenchmarkId::new( @@ -119,17 +121,20 @@ pub fn bench_build_tree(c: &mut Criterion) { |bench, tup| { bench.iter(|| { // this is necessary for the memory readings to work - ndm_smt = None; + dapol_tree = None; epoch.advance().unwrap(); let before = allocated.read().unwrap(); - ndm_smt = Some( - NdmSmtConfigBuilder::default() + dapol_tree = Some( + DapolConfigBuilder::default() + .accumulator_type(dapol::AccumulatorType::NdmSmt) .height(tup.0) .max_thread_count(tup.1) .num_random_entities(tup.2) + .master_secret(master_secret.clone()) .build() + .expect("Unable to build DapolConfig") .parse() .expect("Unable to parse NdmSmtConfig"), ); @@ -160,8 +165,8 @@ pub fn bench_build_tree(c: &mut Criterion) { let src_dir = env!("CARGO_MANIFEST_DIR"); let target_dir = Path::new(&src_dir).join("target"); let dir = target_dir.join("serialized_trees"); - let path = Accumulator::parse_accumulator_serialization_path(dir).unwrap(); - let acc = Accumulator::NdmSmt(ndm_smt.expect("Tree should have been built")); + let path = DapolTree::parse_serialization_path(dir).unwrap(); + let tree = dapol_tree.expect("Tree should have been built"); group.bench_function( BenchmarkId::new( @@ -174,7 +179,7 @@ pub fn bench_build_tree(c: &mut Criterion) { ), ), |bench| { - bench.iter(|| acc.serialize(path.clone()).unwrap()); + bench.iter(|| tree.serialize(path.clone()).unwrap()); }, ); @@ -196,6 +201,11 @@ pub fn bench_build_tree(c: &mut Criterion) { pub fn bench_generate_proof(c: &mut Criterion) { let mut group = c.benchmark_group("proofs"); + let master_secret = Secret::from_str("secret").unwrap(); + + dapol::initialize_machine_parallelism(); + dapol::utils::activate_logging(*LOG_VERBOSITY); + for h in tree_heights_in_range(*MIN_HEIGHT, *MAX_HEIGHT).into_iter() { for n in num_entities_in_range(*MIN_ENTITIES, *MAX_ENTITIES).into_iter() { { @@ -238,15 +248,19 @@ pub fn bench_generate_proof(c: &mut Criterion) { continue; } - let ndm_smt = NdmSmtConfigBuilder::default() + let dapol_tree = DapolConfigBuilder::default() + .accumulator_type(dapol::AccumulatorType::NdmSmt) + .master_secret(master_secret.clone()) .height(h) .num_random_entities(n) .build() + .expect("Unable to build DapolConfig") .parse() .expect("Unable to parse NdmSmtConfig"); - let entity_id = ndm_smt + let entity_id = dapol_tree .entity_mapping() + .unwrap() .keys() .next() .expect("Tree should have at least 1 entity"); @@ -261,7 +275,7 @@ pub fn bench_generate_proof(c: &mut Criterion) { |bench| { bench.iter(|| { proof = Some( - ndm_smt + dapol_tree .generate_inclusion_proof(entity_id) .expect("Proof should have been generated successfully"), ); @@ -296,6 +310,11 @@ pub fn bench_generate_proof(c: &mut Criterion) { pub fn bench_verify_proof(c: &mut Criterion) { let mut group = c.benchmark_group("proofs"); + let master_secret = Secret::from_str("secret").unwrap(); + + dapol::initialize_machine_parallelism(); + dapol::utils::activate_logging(*LOG_VERBOSITY); + for h in tree_heights_in_range(*MIN_HEIGHT, *MAX_HEIGHT).into_iter() { for n in num_entities_in_range(*MIN_ENTITIES, *MAX_ENTITIES).into_iter() { { @@ -338,22 +357,26 @@ pub fn bench_verify_proof(c: &mut Criterion) { continue; } - let ndm_smt = NdmSmtConfigBuilder::default() + let dapol_tree = DapolConfigBuilder::default() + .accumulator_type(dapol::AccumulatorType::NdmSmt) + .master_secret(master_secret.clone()) .height(h) .num_random_entities(n) .build() + .expect("Unable to build DapolConfig") .parse() .expect("Unable to parse NdmSmtConfig"); - let root_hash = ndm_smt.root_hash(); + let root_hash = dapol_tree.root_hash(); - let entity_id = ndm_smt + let entity_id = dapol_tree .entity_mapping() + .unwrap() .keys() .next() .expect("Tree should have at least 1 entity"); - let proof = ndm_smt + let proof = dapol_tree .generate_inclusion_proof(entity_id) .expect("Proof should have been generated successfully"); diff --git a/benches/manual_benches.rs b/benches/manual_benches.rs index 6bd2bd11..34e15858 100644 --- a/benches/manual_benches.rs +++ b/benches/manual_benches.rs @@ -6,11 +6,12 @@ //! unfortunately, but this is the trade-off. use std::path::Path; +use std::str::FromStr; use std::time::Instant; use statistical::*; -use dapol::accumulators::{Accumulator, NdmSmt, NdmSmtConfigBuilder}; +use dapol::{DapolConfigBuilder, DapolTree, Secret}; mod inputs; use inputs::{max_thread_counts_greater_than, num_entities_in_range, tree_heights_in_range}; @@ -36,6 +37,8 @@ fn main() { let total_mem = system_total_memory_mb(); + let master_secret = Secret::from_str("secret").unwrap(); + dapol::initialize_machine_parallelism(); dapol::utils::activate_logging(*LOG_VERBOSITY); @@ -100,14 +103,14 @@ fn main() { // ============================================================== // Tree build. - let mut ndm_smt = Option::::None; + let mut dapol_tree = Option::::None; let mut memory_readings = vec![]; let mut timings = vec![]; // Do 3 readings (Criterion does 10 minimum). for _i in 0..3 { // this is necessary for the memory readings to work - ndm_smt = None; + dapol_tree = None; println!( "building tree i {} time {}", @@ -119,14 +122,17 @@ fn main() { let mem_before = allocated.read().unwrap(); let time_start = Instant::now(); - ndm_smt = Some( - NdmSmtConfigBuilder::default() + dapol_tree = Some( + DapolConfigBuilder::default() + .accumulator_type(dapol::AccumulatorType::NdmSmt) .height(h) .max_thread_count(t) + .master_secret(master_secret.clone()) .num_random_entities(n) .build() + .expect("Unable to build DapolConfig") .parse() - .expect("Unable to parse NdmSmtConfig"), + .expect("Unable to parse DapolConfig"), ); let tree_build_time = time_start.elapsed(); @@ -154,12 +160,13 @@ fn main() { let src_dir = env!("CARGO_MANIFEST_DIR"); let target_dir = Path::new(&src_dir).join("target"); let dir = target_dir.join("serialized_trees"); - let path = Accumulator::parse_accumulator_serialization_path(dir).unwrap(); - let acc = - Accumulator::NdmSmt(ndm_smt.expect("NDM SMT should have been set in loop")); + let path = DapolTree::parse_serialization_path(dir).unwrap(); let time_start = Instant::now(); - acc.serialize(path.clone()).unwrap(); + dapol_tree + .expect("DapolTree should have been set in loop") + .serialize(path.clone()) + .unwrap(); let serialization_time = time_start.elapsed(); let file_size = std::fs::metadata(path) diff --git a/benches/utils.rs b/benches/utils.rs index 99661c8d..66a6e7ec 100644 --- a/benches/utils.rs +++ b/benches/utils.rs @@ -30,13 +30,13 @@ pub fn bytes_to_string(num_bytes: usize) -> String { if n < kb { format!("{} bytes", num_bytes) } else if n < mb { - format!("{} kB", n / kb) + format!("{:.3} kB", n as f64 / kb as f64) } else if n < gb { - format!("{:.2} MB", n as f64 / mb as f64) + format!("{:.3} MB", n as f64 / mb as f64) } else if n < tb { - format!("{:.2} GB", n as f64 / gb as f64) + format!("{:.3} GB", n as f64 / gb as f64) } else { - format!("{:.2} TB", n as f64 / tb as f64) + format!("{:.3} TB", n as f64 / tb as f64) } } diff --git a/examples/dapol_config_example.toml b/examples/dapol_config_example.toml new file mode 100644 index 00000000..21d6aa13 --- /dev/null +++ b/examples/dapol_config_example.toml @@ -0,0 +1,66 @@ +# There are various different accumulator types (e.g. NDM-SMT). +# +# This value must be set. +accumulator_type = "ndm-smt" + +# This is a public value that is used to aid the KDF when generating secret +# blinding factors for the Pedersen commitments. +# +# If it is not set then it will be randomly generated. +salt_b = "salt_b" + +# This is a public value that is used to aid the KDF when generating secret +# salt values, which are in turn used in the hash function when generating +# node hashes. +# +# If it is not set then it will be randomly generated. +salt_s = "salt_s" + +# Height of the tree. +# +# If not set the default height will be used: +# `dapol::Height::default()`. +height = 16 + +# This is a public value representing the maximum amount that any single +# entity's liability can be, and is used in the range proofs: +# $[0, 2^{\text{height}} \times \text{max_liability}]$ +# +# If not set then the default value will be used: +# `2.pow(dapol::DEFAULT_RANGE_PROOF_UPPER_BOUND_BIT_LENGTH)`. +max_liability = 10_000_000 + +# Max number of threads to be spawned for multi-threading algorithms. +# +# If not set the max parallelism of the underlying machine will be used. +max_thread_count = 8 + +# Can be a file or directory (default file name given in this case) +# +# If not set then no serialization is done. +serialization_path = "./tree.dapoltree" + +# At least one of file_path or generate_random must be present. +# +# If both are given then file_path is preferred and generate_random is ignored. +[entities] + +# Path to a file containing a list of entity IDs and their liabilities. +file_path = "./entities_example.csv" + +# Generate the given number of entities, with random IDs & liabilities. +# This is useful for testing. +num_random_entities= 100 + +# At least on of file_path or master_secret must be present. +# The master secret is known only to the tree generator and is used to +# generate all other secret values required by the tree. +# +# If both are given then file_path is preferred and master_secret is ignored. +[secrets] + +# Path to a file containing a list of entity IDs and their liabilities. +file_path = "./dapol_secrets_example.toml" + +# String value of the master secret. +master_secret = "master_secret" diff --git a/examples/dapol_secrets_example.toml b/examples/dapol_secrets_example.toml new file mode 100644 index 00000000..1cab9c2e --- /dev/null +++ b/examples/dapol_secrets_example.toml @@ -0,0 +1 @@ +master_secret = "master_secret" \ No newline at end of file diff --git a/examples/main.rs b/examples/main.rs new file mode 100644 index 00000000..96cd18fa --- /dev/null +++ b/examples/main.rs @@ -0,0 +1,150 @@ +//! Example of a full PoL workflow. +//! +//! 1. Build a tree +//! 2. Generate an inclusion proof +//! 3. Verify an inclusion proof +//! +//! At the time of writing (Nov 2023) only the NDM-SMT accumulator is supported +//! so this is the only type of tree that is used in this example. + +use std::path::Path; +use std::str::FromStr; + +extern crate clap_verbosity_flag; +extern crate csv; +extern crate dapol; + +use dapol::utils::LogOnErrUnwrap; + +fn main() { + let log_level = clap_verbosity_flag::LevelFilter::Debug; + dapol::utils::activate_logging(log_level); + + // ========================================================================= + // Tree building. + + let accumulator_type = dapol::AccumulatorType::NdmSmt; + + let dapol_tree_1 = build_dapol_tree_using_config_builder(accumulator_type); + let dapol_tree_2 = build_dapol_tree_using_config_file(); + + // The above 2 builder methods produce a different tree because the entities + // are mapped randomly to points on the bottom layer for NDM-SMT, but the + // entity mapping of one tree should simply be a permutation of the other. + // Let's check this: + match (dapol_tree_1.entity_mapping(), dapol_tree_2.entity_mapping()) { + (Some(entity_mapping_1), Some(entity_mapping_2)) => { + for (entity, _) in entity_mapping_1 { + assert!(entity_mapping_2.contains_key(&entity)); + } + } + _ => panic!("Expected both trees to be NDM-SMT"), + }; + + // Since the mappings are not the same the root hashes won't be either. + assert_ne!(dapol_tree_1.root_hash(), dapol_tree_2.root_hash()); + + // ========================================================================= + // Inclusion proof generation & verification. + + let entity_id = dapol::EntityId::from_str("john.doe@example.com").unwrap(); + simple_inclusion_proof_generation_and_verification(&dapol_tree_1, entity_id.clone()); + advanced_inclusion_proof_generation_and_verification(&dapol_tree_1, entity_id); +} + +/// Example on how to construct a DAPOL tree. +/// +/// Build the tree via the config builder. +pub fn build_dapol_tree_using_config_builder( + accumulator_type: dapol::AccumulatorType, +) -> dapol::DapolTree { + let src_dir = env!("CARGO_MANIFEST_DIR"); + let resources_dir = Path::new(&src_dir).join("examples"); + + let secrets_file_path = resources_dir.join("dapol_secrets_example.toml"); + let entities_file_path = resources_dir.join("entities_example.csv"); + + let height = dapol::Height::expect_from(16u8); + let salt_b = dapol::Salt::from_str("salt_b").unwrap(); + let salt_s = dapol::Salt::from_str("salt_s").unwrap(); + let max_liability = dapol::MaxLiability::from(10_000_000u64); + let max_thread_count = dapol::MaxThreadCount::from(8u8); + let master_secret = dapol::Secret::from_str("master_secret").unwrap(); + let num_entities = 100u64; + + // The builder requires at least the following to be given: + // - accumulator_type + // - entities + // - secrets + // The rest can be left to be default. + let mut config_builder = dapol::DapolConfigBuilder::default(); + config_builder + .accumulator_type(accumulator_type) + .height(height.clone()) + .salt_b(salt_b.clone()) + .salt_s(salt_s.clone()) + .max_liability(max_liability.clone()) + .max_thread_count(max_thread_count.clone()); + + // You only need to specify 1 of the following secret input methods. + config_builder + .secrets_file_path(secrets_file_path.clone()) + .master_secret(master_secret.clone()); + + // You only need to specify 1 of the following entity input methods. + config_builder + .entities_file_path(entities_file_path.clone()) + .num_random_entities(num_entities); + + config_builder.build().unwrap().parse().unwrap() +} + +/// Example on how to construct a DAPOL tree. +/// +/// Build the tree using a config file. +/// +/// This is also an example usage of [dapol][utils][LogOnErrUnwrap]. +pub fn build_dapol_tree_using_config_file() -> dapol::DapolTree { + let src_dir = env!("CARGO_MANIFEST_DIR"); + let resources_dir = Path::new(&src_dir).join("examples"); + let config_file = resources_dir.join("dapol_config_example.toml"); + + dapol::DapolConfig::deserialize(config_file) + .log_on_err_unwrap() + .parse() + .log_on_err_unwrap() +} + +/// Example on how to generate and verify inclusion proofs. +/// +/// An inclusion proof can be generated from only a tree + entity ID. +pub fn simple_inclusion_proof_generation_and_verification( + dapol_tree: &dapol::DapolTree, + entity_id: dapol::EntityId, +) { + let inclusion_proof = dapol_tree.generate_inclusion_proof(&entity_id).unwrap(); + inclusion_proof.verify(dapol_tree.root_hash()).unwrap(); +} + +/// Example on how to generate and verify inclusion proofs. +/// +/// The inclusion proof generation algorithm can be customized via some +/// parameters. See [dapol][InclusionProof] for more details. +pub fn advanced_inclusion_proof_generation_and_verification( + dapol_tree: &dapol::DapolTree, + entity_id: dapol::EntityId, +) { + // Determines how many of the range proofs in the inclusion proof are + // aggregated together. The ones that are not aggregated are proved + // individually. The more that are aggregated the faster the proving + // and verification times. + let aggregation_percentage = dapol::percentage::ONE_HUNDRED_PERCENT; + let aggregation_factor = dapol::AggregationFactor::Percent(aggregation_percentage); + let aggregation_factor = dapol::AggregationFactor::default(); + + let inclusion_proof = dapol_tree + .generate_inclusion_proof_with(&entity_id, aggregation_factor) + .unwrap(); + + inclusion_proof.verify(dapol_tree.root_hash()).unwrap(); +} diff --git a/examples/ndm_smt/accumulator_config_parser.rs b/examples/ndm_smt/accumulator_config_parser.rs deleted file mode 100644 index 09861f50..00000000 --- a/examples/ndm_smt/accumulator_config_parser.rs +++ /dev/null @@ -1,20 +0,0 @@ -//! Example on how to build a tree using a config file. -//! -//! The config file can be used for any accumulator type since the type is -//! specified by the config file. -//! -//! This is also an example usage of [dapol][utils][LogOnErrUnwrap]. - -use std::path::Path; -use dapol::utils::LogOnErrUnwrap; - -pub fn build_accumulator_using_config_file() -> dapol::Accumulator { - let src_dir = env!("CARGO_MANIFEST_DIR"); - let resources_dir = Path::new(&src_dir).join("examples"); - let config_file = resources_dir.join("tree_config_example.toml"); - - dapol::AccumulatorConfig::deserialize(config_file) - .log_on_err_unwrap() - .parse() - .log_on_err_unwrap() -} diff --git a/examples/ndm_smt/inclusion_proof_handling.rs b/examples/ndm_smt/inclusion_proof_handling.rs deleted file mode 100644 index 3fa5ca58..00000000 --- a/examples/ndm_smt/inclusion_proof_handling.rs +++ /dev/null @@ -1,36 +0,0 @@ -//! Examples on how to generate and verify inclusion proofs. - -/// An inclusion proof can be generated from only a tree + entity ID. -pub fn simple_inclusion_proof_generation_and_verification( - ndm_smt: &dapol::accumulators::NdmSmt, - entity_id: dapol::EntityId, -) { - let inclusion_proof = ndm_smt.generate_inclusion_proof(&entity_id).unwrap(); - inclusion_proof.verify(ndm_smt.root_hash()).unwrap(); -} - -/// The inclusion proof generation algorithm can be customized via some -/// parameters. See [dapol][InclusionProof] for more details. -pub fn advanced_inclusion_proof_generation_and_verification( - ndm_smt: &dapol::accumulators::NdmSmt, - entity_id: dapol::EntityId, -) { - // Determines how many of the range proofs in the inclusion proof are - // aggregated together. The ones that are not aggregated are proved - // individually. The more that are aggregated the faster the proving - // and verification times. - let aggregation_percentage = dapol::percentage::ONE_HUNDRED_PERCENT; - let aggregation_factor = dapol::AggregationFactor::Percent(aggregation_percentage); - let aggregation_factor = dapol::AggregationFactor::default(); - - // 2^upper_bound_bit_length is the upper bound used in the range proof i.e. - // the secret value is shown to reside in the range [0, 2^upper_bound_bit_length]. - let upper_bound_bit_length = 32u8; - let upper_bound_bit_length = dapol::DEFAULT_RANGE_PROOF_UPPER_BOUND_BIT_LENGTH; - - let inclusion_proof = ndm_smt - .generate_inclusion_proof_with(&entity_id, aggregation_factor, upper_bound_bit_length) - .unwrap(); - - inclusion_proof.verify(ndm_smt.root_hash()).unwrap(); -} diff --git a/examples/ndm_smt/main.rs b/examples/ndm_smt/main.rs deleted file mode 100644 index 25fc5bcc..00000000 --- a/examples/ndm_smt/main.rs +++ /dev/null @@ -1,56 +0,0 @@ -//! Example of a full PoL workflow. -//! -//! 1. Build a tree -//! 2. Generate an inclusion proof -//! 3. Verify an inclusion proof -//! -//! At the time of writing (Nov 2023) only the NDM-SMT accumulator is supported -//! so this is the only type of tree that is used in this example. - -use std::str::FromStr; - -extern crate clap_verbosity_flag; -extern crate csv; -extern crate dapol; - -mod ndm_smt_builder; -use ndm_smt_builder::build_ndm_smt_using_builder_pattern; - -mod accumulator_config_parser; -use accumulator_config_parser::build_accumulator_using_config_file; - -mod inclusion_proof_handling; -use inclusion_proof_handling::{simple_inclusion_proof_generation_and_verification, advanced_inclusion_proof_generation_and_verification}; - -fn main() { - let log_level = clap_verbosity_flag::LevelFilter::Debug; - dapol::utils::activate_logging(log_level); - - // ========================================================================= - // Tree building. - - let ndm_smt = build_ndm_smt_using_builder_pattern(); - let accumulator = build_accumulator_using_config_file(); - - // The above 2 builder methods produce a different tree because the entities - // are mapped randomly to points on the bottom layer, but the entity mapping - // of one tree should simply be a permutation of the other. We check this: - let ndm_smt_other = match accumulator { - dapol::Accumulator::NdmSmt(ndm_smt_other) => { - assert_ne!(ndm_smt_other.root_hash(), ndm_smt.root_hash()); - - for (entity, _) in ndm_smt_other.entity_mapping() { - assert!(ndm_smt.entity_mapping().contains_key(&entity)); - } - - ndm_smt_other - } - }; - - // ========================================================================= - // Inclusion proof generation & verification. - - let entity_id = dapol::EntityId::from_str("john.doe@example.com").unwrap(); - simple_inclusion_proof_generation_and_verification(&ndm_smt, entity_id.clone()); - advanced_inclusion_proof_generation_and_verification(&ndm_smt_other, entity_id); -} diff --git a/examples/ndm_smt/ndm_smt_builder.rs b/examples/ndm_smt/ndm_smt_builder.rs deleted file mode 100644 index e849033b..00000000 --- a/examples/ndm_smt/ndm_smt_builder.rs +++ /dev/null @@ -1,21 +0,0 @@ -//! Example on how to use the builder pattern to construct an NDM-SMT tree. - -use std::path::Path; - -pub fn build_ndm_smt_using_builder_pattern() -> dapol::accumulators::NdmSmt { - let src_dir = env!("CARGO_MANIFEST_DIR"); - let resources_dir = Path::new(&src_dir).join("examples"); - - let secrets_file = resources_dir.join("ndm_smt_secrets_example.toml"); - let entities_file = resources_dir.join("entities_example.csv"); - - let height = dapol::Height::expect_from(16); - - let config = dapol::accumulators::NdmSmtConfigBuilder::default() - .height(height) - .secrets_file_path(secrets_file) - .entities_path(entities_file) - .build(); - - config.parse().unwrap() -} diff --git a/examples/ndm_smt_secrets_example.toml b/examples/ndm_smt_secrets_example.toml deleted file mode 100644 index f3f7b57e..00000000 --- a/examples/ndm_smt_secrets_example.toml +++ /dev/null @@ -1,11 +0,0 @@ -# None of these values should be shared. They should be kept with the tree -# creator. - -# Used for generating secrets for each entity. -master_secret = "master_secret" - -# Used for generating blinding factors for Pedersen commitments. -salt_b = "salt_b" - -# Used as an input to the hash function when merging nodes. -salt_s = "salt_s" diff --git a/examples/tree_config_example.toml b/examples/tree_config_example.toml deleted file mode 100644 index f7afe348..00000000 --- a/examples/tree_config_example.toml +++ /dev/null @@ -1,29 +0,0 @@ -# Accumulator type of the tree. -# This value determines what other values are required. -accumulator_type = "ndm-smt" - -# Height of the tree. -# If the height is not set the default height will be used. -height = 16 - -# Max number of threads to be spawned for multi-threading algorithms. -# If the height is not set a default value will be used. -max_thread_count = 4 - -# Path to the secrets file. -# If not present the secrets will be generated randomly. -secrets_file_path = "./examples/ndm_smt_secrets_example.toml" - -# Can be a file or directory (default file name given in this case) -# If not present then no serialization is done. -serialization_path = "./tree.dapoltree" - -# At least one of file_path & generate_random must be present. -# If both are given then file_path is prioritized. -[entities] - -# Path to a file containing a list of entity IDs and their liabilities. -file_path = "./examples/entities_example.csv" - -# Generate the given number of entities, with random IDs & liabilities. -generate_random = 4 \ No newline at end of file diff --git a/src/accumulators.rs b/src/accumulators.rs index 8678aaac..90aaa6cb 100644 --- a/src/accumulators.rs +++ b/src/accumulators.rs @@ -1,192 +1,51 @@ //! Various accumulator variants of the DAPOL+ protocol. //! -//! This is the top-most module in the hierarchy of the [dapol] crate. An -//! accumulator defines how the binary tree is built. There are different types -//! of accumulators, which can all be found under this module. Each accumulator -//! has different configuration requirements, which are detailed in each of the -//! sub-modules. The currently supported accumulator types are: -//! - [Non-Deterministic Mapping Sparse Merkle Tree] -//! -//! Accumulators can be constructed via the configuration parsers: -//! - [AccumulatorConfig] is used to deserialize config from a file (the -//! specific type of accumulator is determined from the config file). After -//! parsing the config the accumulator can be constructed. -//! - [NdmSmtConfigBuilder] is used to construct the -//! config for the NDM-SMT accumulator type using a builder pattern. The config -//! can then be parsed to construct an NDM-SMT. -//! -//! [Non-Deterministic Mapping Sparse Merkle Tree]: crate::accumulators::NdmSmt +//! An accumulator defines how the binary tree is built. There are different +//! types of accumulators, which can all be found under this module. -use log::{debug, info}; +use clap::ValueEnum; +use primitive_types::H256; use serde::{Deserialize, Serialize}; -use std::path::PathBuf; - -use crate::{ - read_write_utils::{self, ReadWriteError}, - utils::LogOnErr, - AggregationFactor, EntityId, InclusionProof, -}; - -mod config; -pub use config::{AccumulatorConfig, AccumulatorConfigError, AccumulatorParserError}; mod ndm_smt; -pub use ndm_smt::{ - NdmSmt, NdmSmtConfig, NdmSmtConfigBuilder, NdmSmtError, NdmSmtConfigParserError, NdmSmtSecrets, - NdmSmtSecretsParser, RandomXCoordGenerator -}; +pub use ndm_smt::{NdmSmt, NdmSmtError, RandomXCoordGenerator}; -const SERIALIZED_ACCUMULATOR_EXTENSION: &str = "dapoltree"; -const SERIALIZED_ACCUMULATOR_FILE_PREFIX: &str = "accumulator_"; +use crate::Height; -/// Various supported accumulator types. -/// -/// Accumulators can be constructed via the configuration parsers: -/// - [AccumulatorConfig] is used to deserialize config from a file (the -/// specific type of accumulator is determined from the config file). After -/// parsing the config the accumulator can be constructed. -/// - [NdmSmtConfigBuilder] is used to construct the -/// config for the NDM-SMT accumulator type using a builder pattern. The config -/// can then be parsed to construct an NDM-SMT. -#[derive(Serialize, Deserialize)] +/// Supported accumulators, with their linked data. +#[derive(Debug, Serialize, Deserialize)] pub enum Accumulator { NdmSmt(ndm_smt::NdmSmt), - // TODO other accumulators.. + // TODO add other accumulators.. } impl Accumulator { - /// Try deserialize an accumulator from the given file path. - /// - /// The file is assumed to be in [bincode] format. - /// - /// An error is logged and returned if - /// 1. The file cannot be opened. - /// 2. The [bincode] deserializer fails. - pub fn deserialize(path: PathBuf) -> Result { - debug!( - "Deserializing accumulator from file {:?}", - path.clone().into_os_string() - ); - - match path.extension() { - Some(ext) => { - if ext != SERIALIZED_ACCUMULATOR_EXTENSION { - Err(ReadWriteError::UnsupportedFileExtension { - expected: SERIALIZED_ACCUMULATOR_EXTENSION.to_owned(), - actual: ext.to_os_string(), - })?; - } - } - None => Err(ReadWriteError::NotAFile(path.clone().into_os_string()))?, + /// Height of the binary tree. + pub fn height(&self) -> &Height { + match self { + Accumulator::NdmSmt(ndm_smt) => ndm_smt.height(), } - - let accumulator: Accumulator = - read_write_utils::deserialize_from_bin_file(path.clone()).log_on_err()?; - - let root_hash = match &accumulator { - Accumulator::NdmSmt(ndm_smt) => ndm_smt.root_hash(), - }; - - info!( - "Successfully deserialized accumulator from file {:?} with root hash {:?}", - path.clone().into_os_string(), - root_hash - ); - - Ok(accumulator) } - /// Parse `path` as one that points to a serialized dapol tree file. - /// - /// `path` can be either of the following: - /// 1. Existing directory: in this case a default file name is appended to - /// `path`. 2. Non-existing directory: in this case all dirs in the path - /// are created, and a default file name is appended. - /// 3. File in existing dir: in this case the extension is checked to be - /// [SERIALIZED_ACCUMULATOR_EXTENSION], then `path` is returned. - /// 4. File in non-existing dir: dirs in the path are created and the file - /// extension is checked. - /// - /// The file prefix is [SERIALIZED_ACCUMULATOR_FILE_PREFIX]. - /// - /// Example: - /// ``` - /// use dapol::Accumulator; - /// use std::path::PathBuf; - /// - /// let dir = PathBuf::from("./"); - /// let path = Accumulator::parse_accumulator_serialization_path(dir).unwrap(); - /// ``` - pub fn parse_accumulator_serialization_path(path: PathBuf) -> Result { - read_write_utils::parse_serialization_path( - path, - SERIALIZED_ACCUMULATOR_EXTENSION, - SERIALIZED_ACCUMULATOR_FILE_PREFIX, - ) - } - - /// Serialize to a file. - /// - /// Serialization is done using [bincode] - /// - /// An error is returned if - /// 1. [bincode] fails to serialize the file. - /// 2. There is an issue opening or writing the file. - pub fn serialize(&self, path: PathBuf) -> Result<(), AccumulatorError> { - info!( - "Serializing accumulator to file {:?}", - path.clone().into_os_string() - ); - - read_write_utils::serialize_to_bin_file(self, path).log_on_err()?; - Ok(()) - } - - /// Generate an inclusion proof for the given `entity_id`. - /// - /// `aggregation_factor` is used to determine how many of the range proofs - /// are aggregated. Those that do not form part of the aggregated proof - /// are just proved individually. The aggregation is a feature of the - /// Bulletproofs protocol that improves efficiency. - /// - /// `upper_bound_bit_length` is used to determine the upper bound for the - /// range proof, which is set to `2^upper_bound_bit_length` i.e. the - /// range proof shows `0 <= liability <= 2^upper_bound_bit_length` for - /// some liability. The type is set to `u8` because we are not expected - /// to require bounds higher than $2^256$. Note that if the value is set - /// to anything other than 8, 16, 32 or 64 the Bulletproofs code will return - /// an Err. - pub fn generate_inclusion_proof_with( - &self, - entity_id: &EntityId, - aggregation_factor: AggregationFactor, - upper_bound_bit_length: u8, - ) -> Result { + /// Return the accumulator type. + pub fn get_type(&self) -> AccumulatorType { match self { - Accumulator::NdmSmt(ndm_smt) => ndm_smt.generate_inclusion_proof_with( - entity_id, - aggregation_factor, - upper_bound_bit_length, - ), + Self::NdmSmt(_) => AccumulatorType::NdmSmt, } } - /// Generate an inclusion proof for the given `entity_id`. - pub fn generate_inclusion_proof( - &self, - entity_id: &EntityId, - ) -> Result { + /// Return the hash digest/bytes of the root node for the binary tree. + pub fn root_hash(&self) -> H256 { match self { - Accumulator::NdmSmt(ndm_smt) => ndm_smt.generate_inclusion_proof(entity_id), + Self::NdmSmt(ndm_smt) => ndm_smt.root_hash(), } } } -/// Errors encountered when handling an [Accumulator]. -#[derive(thiserror::Error, Debug)] -pub enum AccumulatorError { - #[error("Error serializing/deserializing file")] - SerdeError(#[from] ReadWriteError), +/// Various supported accumulator types. +#[derive(Clone, Deserialize, Debug, ValueEnum, PartialEq)] +#[serde(rename_all = "kebab-case")] +pub enum AccumulatorType { + NdmSmt, + // TODO add other accumulators.. } - -// NOTE no unit tests here because this code is tested in the integration tests. diff --git a/src/accumulators/config.rs b/src/accumulators/config.rs deleted file mode 100644 index 8e98daf0..00000000 --- a/src/accumulators/config.rs +++ /dev/null @@ -1,120 +0,0 @@ -use log::debug; -use serde::Deserialize; -use std::{ffi::OsString, fs::File, io::Read, path::PathBuf, str::FromStr}; - -use super::{ndm_smt, Accumulator}; - -/// Configuration required for building various accumulator types. -/// -/// Currently only TOML files are supported for config files. The only -/// config requirement at this level (not including the specific accumulator -/// config) is the accumulator type: -/// -/// ```toml,ignore -/// accumulator_type = "ndm-smt" -/// ``` -/// -/// The rest of the config details can be found in the sub-modules: -/// - [crate][accumulators][NdmSmtConfig] -/// -/// Config deserialization example: -/// ``` -/// use std::path::PathBuf; -/// use dapol::AccumulatorConfig; -/// -/// let file_path = PathBuf::from("./examples/tree_config_example.toml"); -/// let config = AccumulatorConfig::deserialize(file_path).unwrap(); -/// ``` -#[derive(Deserialize, Debug)] -#[serde(tag = "accumulator_type", rename_all = "kebab-case")] -pub enum AccumulatorConfig { - NdmSmt(ndm_smt::NdmSmtConfig), - // TODO other accumulators.. -} - -// STENT TODO rename all other builder methods that are 'new' to 'default' since -// this is what derive_default uses STENT TODO also maybe get rid of the 'with' -// in the setters - -impl AccumulatorConfig { - /// Open the config file, then try to create an accumulator object. - /// - /// An error is returned if: - /// 1. The file cannot be opened. - /// 2. The file cannot be read. - /// 3. The file type is not supported. - pub fn deserialize(config_file_path: PathBuf) -> Result { - debug!( - "Attempting to parse {:?} as a file containing accumulator config", - config_file_path.clone().into_os_string() - ); - - let ext = config_file_path - .extension() - .and_then(|s| s.to_str()) - .ok_or(AccumulatorConfigError::UnknownFileType( - config_file_path.clone().into_os_string(), - ))?; - - let config = match FileType::from_str(ext)? { - FileType::Toml => { - let mut buf = String::new(); - File::open(config_file_path)?.read_to_string(&mut buf)?; - let config: AccumulatorConfig = toml::from_str(&buf)?; - config - } - }; - - debug!("Successfully parsed accumulator config file"); - - Ok(config) - } - - /// Parse the config, attempting to create an accumulator object. - /// - /// An error is returned if the parser for the specific accumulator type - /// fails. - pub fn parse(self) -> Result { - let accumulator = match self { - AccumulatorConfig::NdmSmt(config) => Accumulator::NdmSmt(config.parse()?), - // TODO add more accumulators.. - }; - - Ok(accumulator) - } -} - -/// Supported file types for deserialization. -enum FileType { - Toml, -} - -impl FromStr for FileType { - type Err = AccumulatorConfigError; - - fn from_str(ext: &str) -> Result { - match ext { - "toml" => Ok(FileType::Toml), - _ => Err(AccumulatorConfigError::UnsupportedFileType { ext: ext.into() }), - } - } -} - -/// Errors encountered when handling [AccumulatorConfig]. -#[derive(thiserror::Error, Debug)] -pub enum AccumulatorConfigError { - #[error("Unable to find file extension for path {0:?}")] - UnknownFileType(OsString), - #[error("The file type with extension {ext:?} is not supported")] - UnsupportedFileType { ext: String }, - #[error("Error reading the file")] - FileReadError(#[from] std::io::Error), - #[error("Deserialization process failed")] - DeserializationError(#[from] toml::de::Error), -} - -#[derive(thiserror::Error, Debug)] -pub enum AccumulatorParserError { - #[error("Error parsing NDM-SMT config")] - NdmSmtError(#[from] ndm_smt::NdmSmtConfigParserError), -} diff --git a/src/accumulators/ndm_smt.rs b/src/accumulators/ndm_smt.rs index 03623af1..e15f8685 100644 --- a/src/accumulators/ndm_smt.rs +++ b/src/accumulators/ndm_smt.rs @@ -8,29 +8,20 @@ use logging_timer::{timer, Level}; use rayon::prelude::*; -use crate::binary_tree::{ - BinaryTree, Coordinate, Height, InputLeafNode, PathSiblings, TreeBuilder, +use crate::{ + binary_tree::{ + BinaryTree, BinaryTreeBuilder, Coordinate, FullNodeContent, Height, InputLeafNode, + PathSiblings, + }, + entity::{Entity, EntityId}, + inclusion_proof::{AggregationFactor, InclusionProof}, + kdf::generate_key, + MaxThreadCount, Salt, Secret, DEFAULT_RANGE_PROOF_UPPER_BOUND_BIT_LENGTH, }; -use crate::entity::{Entity, EntityId}; -use crate::inclusion_proof::{ - AggregationFactor, InclusionProof, DEFAULT_RANGE_PROOF_UPPER_BOUND_BIT_LENGTH, -}; -use crate::kdf::generate_key; -use crate::node_content::FullNodeContent; -use crate::MaxThreadCount; - -mod ndm_smt_secrets; -pub use ndm_smt_secrets::NdmSmtSecrets; - -mod ndm_smt_secrets_parser; -pub use ndm_smt_secrets_parser::NdmSmtSecretsParser; mod x_coord_generator; pub use x_coord_generator::RandomXCoordGenerator; -mod ndm_smt_config; -pub use ndm_smt_config::{NdmSmtConfig, NdmSmtConfigBuilder, NdmSmtConfigParserError}; - // ------------------------------------------------------------------------------------------------- // Main struct and implementation. @@ -55,14 +46,26 @@ type Content = FullNodeContent; #[derive(Debug, Serialize, Deserialize)] pub struct NdmSmt { - secrets: NdmSmtSecrets, - tree: BinaryTree, + binary_tree: BinaryTree, entity_mapping: HashMap, } impl NdmSmt { /// Constructor. /// + /// Parameters: + /// - `master_secret`: + #[doc = include_str!("../shared_docs/master_secret.md")] + /// - `salt_b`: + #[doc = include_str!("../shared_docs/salt_b.md")] + /// - `salt_s`: + #[doc = include_str!("../shared_docs/salt_s.md")] + /// - `height`: + #[doc = include_str!("../shared_docs/height.md")] + /// - `max_thread_count`: + #[doc = include_str!("../shared_docs/max_thread_count.md")] + /// - `entities`: + #[doc = include_str!("../shared_docs/entities_vector.md")] /// Each element in `entities` is converted to an /// [input leaf node] and randomly assigned a position on the /// bottom layer of the tree. @@ -80,27 +83,34 @@ impl NdmSmt { /// /// [input leaf node]: crate::binary_tree::InputLeafNode pub fn new( - secrets: NdmSmtSecrets, + master_secret: Secret, + salt_b: Salt, + salt_s: Salt, height: Height, max_thread_count: MaxThreadCount, entities: Vec, ) -> Result { - let master_secret_bytes = secrets.master_secret.as_bytes(); - let salt_b_bytes = secrets.salt_b.as_bytes(); - let salt_s_bytes = secrets.salt_s.as_bytes(); + let master_secret_bytes = master_secret.as_bytes(); + let salt_b_bytes = salt_b.as_bytes(); + let salt_s_bytes = salt_s.as_bytes(); info!( "\nCreating NDM-SMT with the following configuration:\n \ - height: {}\n \ - number of entities: {}\n \ - - master secret: 0x{}\n \ + - master secret: \n \ - salt b: 0x{}\n \ - salt s: 0x{}", height.as_u32(), entities.len(), - master_secret_bytes.iter().map(|b| format!("{:02x}", b)).collect::(), - salt_b_bytes.iter().map(|b| format!("{:02x}", b)).collect::(), - salt_s_bytes.iter().map(|b| format!("{:02x}", b)).collect::(), + salt_b_bytes + .iter() + .map(|b| format!("{:02x}", b)) + .collect::(), + salt_s_bytes + .iter() + .map(|b| format!("{:02x}", b)) + .collect::(), ); let (leaf_nodes, entity_coord_tuples) = { @@ -162,7 +172,7 @@ impl NdmSmt { entity_mapping.insert(entity.id, x_coord); } - let tree = TreeBuilder::new() + let tree = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes) .with_max_thread_count(max_thread_count) @@ -173,8 +183,7 @@ impl NdmSmt { ))?; Ok(NdmSmt { - tree, - secrets, + binary_tree: tree, entity_mapping, }) } @@ -186,12 +195,19 @@ impl NdmSmt { /// factor for the range proof, which are both required for the range /// proof that is done in the [InclusionProof] constructor. /// - /// `aggregation_factor` is used to determine how many of the range proofs + /// Parameters: + /// - `master_secret`: + #[doc = include_str!("../shared_docs/master_secret.md")] + /// - `salt_b`: + #[doc = include_str!("../shared_docs/salt_b.md")] + /// - `salt_s`: + #[doc = include_str!("../shared_docs/salt_s.md")] + /// - `entity_id`: unique ID for the entity that the proof will be generated for. + /// - `aggregation_factor` is used to determine how many of the range proofs /// are aggregated. Those that do not form part of the aggregated proof /// are just proved individually. The aggregation is a feature of the /// Bulletproofs protocol that improves efficiency. - /// - /// `upper_bound_bit_length` is used to determine the upper bound for the + /// - `upper_bound_bit_length` is used to determine the upper bound for the /// range proof, which is set to `2^upper_bound_bit_length` i.e. the /// range proof shows `0 <= liability <= 2^upper_bound_bit_length` for /// some liability. The type is set to `u8` because we are not expected @@ -200,24 +216,27 @@ impl NdmSmt { /// an Err. pub fn generate_inclusion_proof_with( &self, + master_secret: &Secret, + salt_b: &Salt, + salt_s: &Salt, entity_id: &EntityId, aggregation_factor: AggregationFactor, upper_bound_bit_length: u8, ) -> Result { - let master_secret_bytes = self.secrets.master_secret.as_bytes(); - let salt_b_bytes = self.secrets.salt_b.as_bytes(); - let salt_s_bytes = self.secrets.salt_s.as_bytes(); + let master_secret_bytes = master_secret.as_bytes(); + let salt_b_bytes = salt_b.as_bytes(); + let salt_s_bytes = salt_s.as_bytes(); let new_padding_node_content = new_padding_node_content_closure(*master_secret_bytes, *salt_b_bytes, *salt_s_bytes); let leaf_node = self .entity_mapping .get(entity_id) - .and_then(|leaf_x_coord| self.tree.get_leaf_node(*leaf_x_coord)) + .and_then(|leaf_x_coord| self.binary_tree.get_leaf_node(*leaf_x_coord)) .ok_or(NdmSmtError::EntityIdNotFound)?; let path_siblings = PathSiblings::build_using_multi_threaded_algorithm( - &self.tree, + &self.binary_tree, &leaf_node, new_padding_node_content, )?; @@ -236,11 +255,26 @@ impl NdmSmt { /// - `aggregation_factor`: half of all the range proofs are aggregated /// - `upper_bound_bit_length`: 64 (which should be plenty enough for most /// real-world cases) + /// + /// Parameters: + /// - `master_secret`: + #[doc = include_str!("../shared_docs/master_secret.md")] + /// - `salt_b`: + #[doc = include_str!("../shared_docs/salt_b.md")] + /// - `salt_s`: + #[doc = include_str!("../shared_docs/salt_s.md")] + /// - `entity_id`: unique ID for the entity that the proof will be generated for. pub fn generate_inclusion_proof( &self, + master_secret: &Secret, + salt_b: &Salt, + salt_s: &Salt, entity_id: &EntityId, ) -> Result { self.generate_inclusion_proof_with( + master_secret, + salt_b, + salt_s, entity_id, AggregationFactor::default(), DEFAULT_RANGE_PROOF_UPPER_BOUND_BIT_LENGTH, @@ -249,7 +283,7 @@ impl NdmSmt { /// Return the hash digest/bytes of the root node for the binary tree. pub fn root_hash(&self) -> H256 { - self.tree.root().content.hash + self.binary_tree.root().content.hash } /// Return the entity mapping, the x-coord that each entity is mapped to. @@ -257,9 +291,8 @@ impl NdmSmt { &self.entity_mapping } - /// Return the height of the binary tree. pub fn height(&self) -> &Height { - self.tree.height() + self.binary_tree.height() } } @@ -324,13 +357,8 @@ mod tests { #[test] fn constructor_works() { let master_secret: Secret = 1u64.into(); - let salt_b: Secret = 2u64.into(); - let salt_s: Secret = 3u64.into(); - let secrets = NdmSmtSecrets { - master_secret, - salt_b, - salt_s, - }; + let salt_b: Salt = 2u64.into(); + let salt_s: Salt = 3u64.into(); let height = Height::expect_from(4u8); let max_thread_count = MaxThreadCount::default(); @@ -339,6 +367,6 @@ mod tests { id: EntityId::from_str("some entity").unwrap(), }]; - NdmSmt::new(secrets, height, max_thread_count, entities).unwrap(); + NdmSmt::new(master_secret, salt_b, salt_s, height, max_thread_count, entities).unwrap(); } } diff --git a/src/accumulators/ndm_smt/ndm_smt_config.rs b/src/accumulators/ndm_smt/ndm_smt_config.rs deleted file mode 100644 index 4e1762ca..00000000 --- a/src/accumulators/ndm_smt/ndm_smt_config.rs +++ /dev/null @@ -1,309 +0,0 @@ -use std::path::PathBuf; - -use derive_builder::Builder; -use log::{debug, info}; -use serde::Deserialize; - -use crate::entity::{self, EntitiesParser}; -use crate::utils::LogOnErr; -use crate::Height; -use crate::MaxThreadCount; - -use super::{ndm_smt_secrets_parser, NdmSmt, NdmSmtSecretsParser}; - -/// Configuration needed to construct an NDM-SMT. -/// -/// The config is defined by a struct. A builder pattern is used to construct -/// the config, but it can also be constructed by deserializing a file. -/// Construction is handled by [crate][AccumulatorConfig] and so have -/// a look there for more details on file format for deserialization or examples -/// on how to use the parser. Currently only toml files are supported, with the -/// following format: -/// -/// ```toml,ignore -/// accumulator_type = "ndm-smt" -/// -/// # Height of the tree. -/// # If the height is not set the default height will be used. -/// height = 32 -/// -/// # Max number of threads to be spawned for multi-threading algorithms. -/// # If the height is not set a default value will be used. -/// max_thread_count = 4 -/// -/// # Path to the secrets file. -/// # If not present the secrets will be generated randomly. -/// secrets_file_path = "./examples/ndm_smt_secrets_example.toml" -/// -/// # At least one of file_path & generate_random must be present. -/// # If both are given then file_path is prioritized. -/// [entities] -/// -/// # Path to a file containing a list of entity IDs and their liabilities. -/// file_path = "./examples/entities_example.csv" -/// -/// # Generate the given number of entities, with random IDs & liabilities. -/// generate_random = 4 -/// ``` -/// -/// Construction of this tree using a config file must be done via -/// [crate][AccumulatorConfig]. -/// -/// Example how to use the builder: -/// ``` -/// use std::path::PathBuf; -/// use dapol::{Height, MaxThreadCount}; -/// use dapol::accumulators::NdmSmtConfigBuilder; -/// -/// let height = Height::expect_from(8); -/// let max_thread_count = MaxThreadCount::default(); -/// -/// let config = NdmSmtConfigBuilder::default() -/// .height(height) -/// .secrets_file_path(PathBuf::from("./examples/ndm_smt_secrets_example.toml")) -/// .entities_path(PathBuf::from("./examples/entities_example.csv")) -/// .build(); -/// ``` -#[derive(Deserialize, Debug, Builder)] -#[builder(build_fn(skip))] -pub struct NdmSmtConfig { - height: Height, - max_thread_count: MaxThreadCount, - #[builder(setter(strip_option))] - secrets_file_path: Option, - #[builder(private)] - entities: EntityConfig, -} - -#[derive(Deserialize, Debug, Clone, Default)] -pub struct EntityConfig { - file_path: Option, - num_random_entities: Option, -} - -impl NdmSmtConfig { - /// Try to construct an NDM-SMT from the config. - pub fn parse(self) -> Result { - debug!("Parsing config to create a new NDM-SMT: {:?}", self); - - let secrets = NdmSmtSecretsParser::from(self.secrets_file_path) - .parse_or_generate_random()?; - - let height = self.height; - let max_thread_count = self.max_thread_count; - - let entities = EntitiesParser::new() - .with_path_opt(self.entities.file_path) - .with_num_entities_opt(self.entities.num_random_entities) - .parse_file_or_generate_random()?; - - let ndm_smt = NdmSmt::new(secrets, height, max_thread_count, entities).log_on_err()?; - - info!( - "Successfully built NDM-SMT with root hash {:?}", - ndm_smt.root_hash() - ); - - Ok(ndm_smt) - } -} - -impl NdmSmtConfigBuilder { - pub fn secrets_file_path_opt(&mut self, path: Option) -> &mut Self { - self.secrets_file_path = Some(path); - self - } - - pub fn entities_path_opt(&mut self, path: Option) -> &mut Self { - match &mut self.entities { - None => { - self.entities = Some(EntityConfig { - file_path: path, - num_random_entities: None, - }) - } - Some(entities) => entities.file_path = path, - } - self - } - - pub fn entities_path(&mut self, path: PathBuf) -> &mut Self { - self.entities_path_opt(Some(path)) - } - - pub fn num_random_entities_opt(&mut self, num_entities: Option) -> &mut Self { - match &mut self.entities { - None => { - self.entities = Some(EntityConfig { - file_path: None, - num_random_entities: num_entities, - }) - } - Some(entities) => entities.num_random_entities = num_entities, - } - self - } - - pub fn num_random_entities(&mut self, num_entities: u64) -> &mut Self { - self.num_random_entities_opt(Some(num_entities)) - } - - pub fn build(&self) -> NdmSmtConfig { - let entities = EntityConfig { - file_path: self.entities.clone().and_then(|e| e.file_path).or(None), - num_random_entities: self - .entities - .clone() - .and_then(|e| e.num_random_entities) - .or(None), - }; - - NdmSmtConfig { - height: self.height.unwrap_or_default(), - max_thread_count: self.max_thread_count.unwrap_or_default(), - secrets_file_path: self.secrets_file_path.clone().unwrap_or(None), - entities, - } - } -} - -/// Errors encountered when parsing [crate][accumulators][NdmSmtConfig]. -#[derive(thiserror::Error, Debug)] -pub enum NdmSmtConfigParserError { - #[error("Secrets parsing failed while trying to parse NDM-SMT config")] - SecretsError(#[from] ndm_smt_secrets_parser::NdmSmtSecretsParserError), - #[error("Entities parsing failed while trying to parse NDM-SMT config")] - EntitiesError(#[from] entity::EntitiesParserError), - #[error("Tree construction failed after parsing NDM-SMT config")] - BuildError(#[from] super::NdmSmtError), -} - -// ------------------------------------------------------------------------------------------------- -// Unit tests - -#[cfg(test)] -mod tests { - use crate::utils::test_utils::assert_err; - - use super::*; - use std::fs::File; - use std::io::{BufRead, BufReader}; - use std::path::Path; - - #[test] - fn builder_with_entities_file() { - let height = Height::expect_from(8); - - let src_dir = env!("CARGO_MANIFEST_DIR"); - let resources_dir = Path::new(&src_dir).join("examples"); - let secrets_file_path = resources_dir.join("ndm_smt_secrets_example.toml"); - let entities_file_path = resources_dir.join("entities_example.csv"); - - let entities_file = File::open(entities_file_path.clone()).unwrap(); - // "-1" because we don't include the top line of the csv which defines - // the column headings. - let num_entities = BufReader::new(entities_file).lines().count() - 1; - - let ndm_smt = NdmSmtConfigBuilder::default() - .height(height) - .secrets_file_path(secrets_file_path) - .entities_path(entities_file_path) - .build() - .parse() - .unwrap(); - - assert_eq!(ndm_smt.entity_mapping.len(), num_entities); - assert_eq!(ndm_smt.height(), &height); - } - - #[test] - fn builder_with_random_entities() { - let height = Height::expect_from(8); - let num_random_entities = 10; - - let src_dir = env!("CARGO_MANIFEST_DIR"); - let resources_dir = Path::new(&src_dir).join("examples"); - let secrets_file = resources_dir.join("ndm_smt_secrets_example.toml"); - - let ndm_smt = NdmSmtConfigBuilder::default() - .height(height) - .secrets_file_path(secrets_file) - .num_random_entities(num_random_entities) - .build() - .parse() - .unwrap(); - - assert_eq!(ndm_smt.entity_mapping.len(), num_random_entities as usize); - assert_eq!(ndm_smt.height(), &height); - } - - #[test] - fn builder_without_height_should_give_default() { - let num_random_entities = 10; - - let src_dir = env!("CARGO_MANIFEST_DIR"); - let resources_dir = Path::new(&src_dir).join("examples"); - let secrets_file = resources_dir.join("ndm_smt_secrets_example.toml"); - - let ndm_smt = NdmSmtConfigBuilder::default() - .secrets_file_path(secrets_file) - .num_random_entities(num_random_entities) - .build() - .parse() - .unwrap(); - - assert_eq!(ndm_smt.entity_mapping.len(), num_random_entities as usize); - assert_eq!(ndm_smt.height(), &Height::default()); - } - - #[test] - fn builder_without_any_values_fails() { - use crate::entity::EntitiesParserError; - let res = NdmSmtConfigBuilder::default().build().parse(); - assert_err!( - res, - Err(NdmSmtConfigParserError::EntitiesError( - EntitiesParserError::NumEntitiesNotSet - )) - ); - } - - #[test] - fn builder_with_all_values() { - let height = Height::expect_from(8); - let num_random_entities = 10; - - let src_dir = env!("CARGO_MANIFEST_DIR"); - let resources_dir = Path::new(&src_dir).join("examples"); - let secrets_file_path = resources_dir.join("ndm_smt_secrets_example.toml"); - let entities_file_path = resources_dir.join("entities_example.csv"); - - let entities_file = File::open(entities_file_path.clone()).unwrap(); - // "-1" because we don't include the top line of the csv which defines - // the column headings. - let num_entities = BufReader::new(entities_file).lines().count() - 1; - - let ndm_smt = NdmSmtConfigBuilder::default() - .height(height) - .secrets_file_path(secrets_file_path) - .entities_path(entities_file_path) - .num_random_entities(num_random_entities) - .build() - .parse() - .unwrap(); - - assert_eq!(ndm_smt.entity_mapping.len(), num_entities); - assert_eq!(ndm_smt.height(), &height); - } - - #[test] - fn builder_without_secrets_file_path() { - let num_random_entities = 10; - - let _ndm_smt = NdmSmtConfigBuilder::default() - .num_random_entities(num_random_entities) - .build() - .parse() - .unwrap(); - } -} diff --git a/src/accumulators/ndm_smt/ndm_smt_secrets.rs b/src/accumulators/ndm_smt/ndm_smt_secrets.rs deleted file mode 100644 index ef18f713..00000000 --- a/src/accumulators/ndm_smt/ndm_smt_secrets.rs +++ /dev/null @@ -1,80 +0,0 @@ -use std::{convert::TryFrom, str::FromStr}; - -use logging_timer::time; -use serde::{Deserialize, Serialize}; - -use crate::secret::{Secret, MAX_LENGTH_BYTES}; - -use rand::{ - distributions::{Alphanumeric, DistString}, - thread_rng, -}; - -/// Secret values required to construct the NDM-SMT. -/// -/// The values required for tree construction are as specified in the DAPOL+ -/// paper (with exactly the same names so as not to create confusion): -/// - master_secret: used for generating each entity's secret values. -/// - salt_b: salt used in the random generation of Pedersen blinding factors. -/// - salt_s: salt used in the node-merging hash function. -/// -/// These values should be generated by the same party that constructs the -/// tree. They should not be shared with anyone. -/// -/// The names of the fields are exactly the same as the ones given in the -/// DAPOL+ paper. -/// -/// See [crate][accumulators][NdmSmtSecretsParser] for how to -/// build this struct. -#[derive(Debug, Serialize, Deserialize)] -pub struct NdmSmtSecrets { - pub master_secret: Secret, - pub salt_b: Secret, - pub salt_s: Secret, -} - -/// This coding style is a bit ugly but it is the simplest way to get the -/// desired outcome, which is to deserialize toml values into a byte array. -/// We can't deserialize automatically to [a secret] without a custom -/// implementation of the [deserialize trait]. Instead we deserialize to -/// [NdmSmtSecretsInput] and then convert the individual string fields to byte -/// arrays. -/// -/// [a secret] crate::secret::Secret -/// [deserialize trait] serde::Deserialize -#[derive(Deserialize)] -pub struct NdmSmtSecretsInput { - master_secret: String, - salt_b: String, - salt_s: String, -} - -const STRING_CONVERSION_ERR_MSG: &str = "A failure should not be possible here because the length of the random string exactly matches the max allowed length"; - -impl NdmSmtSecrets { - #[time("debug", "NdmSmt::NdmSmtSecrets::{}")] - pub fn generate_random() -> Self { - let mut rng = thread_rng(); - let master_secret_str = Alphanumeric.sample_string(&mut rng, MAX_LENGTH_BYTES); - let salt_b_str = Alphanumeric.sample_string(&mut rng, MAX_LENGTH_BYTES); - let salt_s_str = Alphanumeric.sample_string(&mut rng, MAX_LENGTH_BYTES); - - NdmSmtSecrets { - master_secret: Secret::from_str(&master_secret_str).expect(STRING_CONVERSION_ERR_MSG), - salt_b: Secret::from_str(&salt_b_str).expect(STRING_CONVERSION_ERR_MSG), - salt_s: Secret::from_str(&salt_s_str).expect(STRING_CONVERSION_ERR_MSG), - } - } -} - -impl TryFrom for NdmSmtSecrets { - type Error = super::ndm_smt_secrets_parser::NdmSmtSecretsParserError; - - fn try_from(input: NdmSmtSecretsInput) -> Result { - Ok(NdmSmtSecrets { - master_secret: Secret::from_str(&input.master_secret)?, - salt_b: Secret::from_str(&input.salt_b)?, - salt_s: Secret::from_str(&input.salt_s)?, - }) - } -} diff --git a/src/accumulators/ndm_smt/ndm_smt_secrets_parser.rs b/src/accumulators/ndm_smt/ndm_smt_secrets_parser.rs deleted file mode 100644 index 26249781..00000000 --- a/src/accumulators/ndm_smt/ndm_smt_secrets_parser.rs +++ /dev/null @@ -1,174 +0,0 @@ -use log::{debug, warn}; -use std::{ffi::OsString, fs::File, io::Read, path::PathBuf, str::FromStr}; - -use super::ndm_smt_secrets::{NdmSmtSecrets, NdmSmtSecretsInput}; -use crate::secret::SecretParserError; - -/// Parser for files containing NDM-SMT-related secrets. -/// -/// Supported file types: toml -/// Note that the file type is inferred from its path extension. -/// -/// TOML format: -/// ```toml,ignore -/// # None of these values should be shared. They should be kept with the tree -/// # creator. -/// -/// # Used for generating secrets for each entity. -/// master_secret = "master_secret" -/// -/// # Used for generating blinding factors for Pedersen commitments. -/// salt_b = "salt_b" -/// -/// # Used as an input to the hash function when merging nodes. -/// salt_s = "salt_s" -/// ``` -/// -/// See [crate][accumulators][NdmSmtSecrets] for more details about the -/// secret values. -pub struct NdmSmtSecretsParser { - path: Option, -} - -impl NdmSmtSecretsParser { - /// Open and parse the file, returning a [NdmSmtSecrets] struct. - /// - /// An error is returned if: - /// 1. The path is None (i.e. was not set). - /// 2. The file cannot be opened. - /// 3. The file cannot be read. - /// 4. The file type is not supported. - /// 5. Deserialization of any of the records in the file fails. - pub fn parse(self) -> Result { - debug!( - "Attempting to parse {:?} as a file containing NDM-SMT secrets", - &self.path - ); - - let path = self.path.ok_or(NdmSmtSecretsParserError::PathNotSet)?; - - let ext = path.extension().and_then(|s| s.to_str()).ok_or( - NdmSmtSecretsParserError::UnknownFileType(path.clone().into_os_string()), - )?; - - let secrets = match FileType::from_str(ext)? { - FileType::Toml => { - let mut buf = String::new(); - File::open(path)?.read_to_string(&mut buf)?; - let secrets: NdmSmtSecretsInput = toml::from_str(&buf)?; - NdmSmtSecrets::try_from(secrets)? - } - }; - - debug!("Successfully parsed NDM-SMT secrets file",); - - Ok(secrets) - } - - pub fn parse_or_generate_random(self) -> Result { - match &self.path { - Some(_) => self.parse(), - None => { - warn!( - "Could not determine path for secrets file, defaulting to randomized secrets" - ); - Ok(NdmSmtSecrets::generate_random()) - } - } - } -} - -impl From> for NdmSmtSecretsParser { - /// Convert `PathBuf` to `NdmSmtSecretsParser`. - /// - /// `Option` is used to wrap the parameter to make the code work more - /// seamlessly with the config builders in [crate][accumulators]. - fn from(path: Option) -> Self { - Self { path } - } -} - -impl From for NdmSmtSecretsParser { - fn from(path: PathBuf) -> Self { - Self { path: Some(path) } - } -} - -/// Supported file types for the parser. -enum FileType { - Toml, -} - -impl FromStr for FileType { - type Err = NdmSmtSecretsParserError; - - fn from_str(ext: &str) -> Result { - match ext { - "toml" => Ok(FileType::Toml), - _ => Err(NdmSmtSecretsParserError::UnsupportedFileType { ext: ext.into() }), - } - } -} - -#[derive(thiserror::Error, Debug)] -pub enum NdmSmtSecretsParserError { - #[error("Expected path to be set but found none")] - PathNotSet, - #[error("Unable to find file extension for path {0:?}")] - UnknownFileType(OsString), - #[error("The file type with extension {ext:?} is not supported")] - UnsupportedFileType { ext: String }, - #[error("Error converting string found in file to Secret")] - StringConversionError(#[from] SecretParserError), - #[error("Error reading the file")] - FileReadError(#[from] std::io::Error), - #[error("Deserialization process failed")] - DeserializationError(#[from] toml::de::Error), -} - -// ------------------------------------------------------------------------------------------------- -// Unit tests - -#[cfg(test)] -mod tests { - use super::*; - use crate::utils::test_utils::assert_err; - use crate::Secret; - use std::path::Path; - - #[test] - fn parser_toml_file_happy_case() { - let src_dir = env!("CARGO_MANIFEST_DIR"); - let resources_dir = Path::new(&src_dir).join("examples"); - let path = resources_dir.join("ndm_smt_secrets_example.toml"); - - let secrets = NdmSmtSecretsParser::from(path).parse().unwrap(); - - assert_eq!( - secrets.master_secret, - Secret::from_str("master_secret").unwrap() - ); - assert_eq!(secrets.salt_b, Secret::from_str("salt_b").unwrap()); - assert_eq!(secrets.salt_s, Secret::from_str("salt_s").unwrap()); - } - - #[test] - fn unsupported_file_type() { - let this_file = std::file!(); - let path = PathBuf::from(this_file); - - assert_err!( - NdmSmtSecretsParser::from(path).parse(), - Err(NdmSmtSecretsParserError::UnsupportedFileType { ext: _ }) - ); - } - - #[test] - fn unknown_file_type() { - let path = PathBuf::from("./"); - assert_err!( - NdmSmtSecretsParser::from(path).parse(), - Err(NdmSmtSecretsParserError::UnknownFileType(_)) - ); - } -} diff --git a/src/accumulators/ndm_smt/x_coord_generator.rs b/src/accumulators/ndm_smt/x_coord_generator.rs index 6c8dfcfd..d987eb86 100644 --- a/src/accumulators/ndm_smt/x_coord_generator.rs +++ b/src/accumulators/ndm_smt/x_coord_generator.rs @@ -18,8 +18,8 @@ use std::collections::HashMap; /// - `i` is used to track the current position of the algorithm. /// /// Example: -/// ``` -/// use dapol::accumulators::RandomXCoordGenerator; +/// ```rust,ignore +/// use crate::accumulators::RandomXCoordGenerator; /// /// let height = dapol::Height::default(); /// let mut x_coord_generator = RandomXCoordGenerator::new(&height); diff --git a/src/binary_tree.rs b/src/binary_tree.rs index c8f056da..37d9df5d 100644 --- a/src/binary_tree.rs +++ b/src/binary_tree.rs @@ -35,17 +35,20 @@ use serde::{Deserialize, Serialize}; use std::fmt; +mod utils; + +mod node_content; +pub use node_content::{FullNodeContent, HiddenNodeContent, Mergeable}; + mod tree_builder; pub use tree_builder::multi_threaded; pub use tree_builder::{ - single_threaded, InputLeafNode, TreeBuildError, TreeBuilder, MIN_STORE_DEPTH, + single_threaded, InputLeafNode, TreeBuildError, BinaryTreeBuilder, MIN_STORE_DEPTH, }; mod path_siblings; pub use path_siblings::{PathSiblings, PathSiblingsBuildError, PathSiblingsError}; -mod utils; - mod height; pub use height::{Height, HeightError, MAX_HEIGHT, MIN_HEIGHT}; @@ -120,12 +123,6 @@ pub enum Store { SingleThreadedStore(single_threaded::HashMapStore), } -/// The generic content type of a [Node] must implement this trait to allow 2 -/// sibling nodes to be combined to make a new parent node. -pub trait Mergeable { - fn merge(left_sibling: &Self, right_sibling: &Self) -> Self; -} - // ------------------------------------------------------------------------------------------------- // Accessor methods. diff --git a/src/binary_tree/height.rs b/src/binary_tree/height.rs index 933be344..5b2bf392 100644 --- a/src/binary_tree/height.rs +++ b/src/binary_tree/height.rs @@ -146,7 +146,7 @@ impl FromStr for Height { } // ------------------------------------------------------------------------------------------------- -// From for OsStr. +// From for OsStr (for the CLI). use clap::builder::{OsStr, Str}; diff --git a/src/binary_tree/node_content.rs b/src/binary_tree/node_content.rs new file mode 100644 index 00000000..508bec73 --- /dev/null +++ b/src/binary_tree/node_content.rs @@ -0,0 +1,21 @@ +//! Implementation of the generic node content type. +//! +//! The [crate][binary_tree][BinaryTree] implementation uses a generic value for +//! the content so that all the code can be easily reused for different types of +//! nodes. +//! +//! In order to implement a node content type one must create a struct +//! containing the data for the node, and then implement the [Mergeable] trait +//! which takes 2 children nodes and combines them to make a parent node. + +mod full_node; +pub use full_node::FullNodeContent; + +mod hidden_node; +pub use hidden_node::HiddenNodeContent; + +/// The generic content type of a [Node] must implement this trait to allow 2 +/// sibling nodes to be combined to make a new parent node. +pub trait Mergeable { + fn merge(left_sibling: &Self, right_sibling: &Self) -> Self; +} diff --git a/src/node_content/full_node.rs b/src/binary_tree/node_content/full_node.rs similarity index 100% rename from src/node_content/full_node.rs rename to src/binary_tree/node_content/full_node.rs diff --git a/src/node_content/hidden_node.rs b/src/binary_tree/node_content/hidden_node.rs similarity index 100% rename from src/node_content/hidden_node.rs rename to src/binary_tree/node_content/hidden_node.rs diff --git a/src/binary_tree/path_siblings.rs b/src/binary_tree/path_siblings.rs index 9325783e..df326ee4 100644 --- a/src/binary_tree/path_siblings.rs +++ b/src/binary_tree/path_siblings.rs @@ -402,7 +402,7 @@ mod tests { let leaf_nodes = full_bottom_layer(&height); - let tree_single_threaded = TreeBuilder::new() + let tree_single_threaded = BinaryTreeBuilder::new() .with_height(height) .with_store_depth(MIN_STORE_DEPTH) .with_leaf_nodes(leaf_nodes.clone()) @@ -434,7 +434,7 @@ mod tests { let leaf_nodes = full_bottom_layer(&height); - let tree_multi_threaded = TreeBuilder::new() + let tree_multi_threaded = BinaryTreeBuilder::new() .with_height(height) .with_store_depth(MIN_STORE_DEPTH) .with_leaf_nodes(leaf_nodes.clone()) @@ -466,7 +466,7 @@ mod tests { let leaf_nodes = sparse_leaves(&height); - let tree_single_threaded = TreeBuilder::new() + let tree_single_threaded = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes.clone()) .with_store_depth(MIN_STORE_DEPTH) @@ -498,7 +498,7 @@ mod tests { let leaf_nodes = sparse_leaves(&height); - let tree_multi_threaded = TreeBuilder::new() + let tree_multi_threaded = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes.clone()) .with_store_depth(MIN_STORE_DEPTH) @@ -531,7 +531,7 @@ mod tests { for i in 0..height.max_bottom_layer_nodes() { let leaf_node = vec![single_leaf(i)]; - let tree_single_threaded = TreeBuilder::new() + let tree_single_threaded = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_node.clone()) .with_store_depth(MIN_STORE_DEPTH) @@ -565,7 +565,7 @@ mod tests { for x_coord in 0..height.max_bottom_layer_nodes() { let leaf_node = vec![single_leaf(x_coord)]; - let tree_multi_threaded = TreeBuilder::new() + let tree_multi_threaded = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_node.clone()) .with_store_depth(MIN_STORE_DEPTH) diff --git a/src/binary_tree/tree_builder.rs b/src/binary_tree/tree_builder.rs index 4492ce6e..89d1797d 100644 --- a/src/binary_tree/tree_builder.rs +++ b/src/binary_tree/tree_builder.rs @@ -47,7 +47,7 @@ pub const MIN_STORE_DEPTH: u8 = 1; /// /// [binary tree]: super::BinaryTree #[derive(Debug)] -pub struct TreeBuilder { +pub struct BinaryTreeBuilder { height: Option, leaf_nodes: Option>>, store_depth: Option, @@ -67,13 +67,13 @@ pub struct InputLeafNode { // ------------------------------------------------------------------------------------------------- // Implementations. -impl TreeBuilder +impl BinaryTreeBuilder where C: Clone + Mergeable + 'static, /* The static is needed when the single threaded builder * builds the boxed hashmap. */ { pub fn new() -> Self { - TreeBuilder { + BinaryTreeBuilder { height: None, leaf_nodes: None, store_depth: None, @@ -334,13 +334,13 @@ mod tests { let leaf_nodes = sparse_leaves(&height); - let single_threaded = TreeBuilder::new() + let single_threaded = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes.clone()) .build_using_single_threaded_algorithm(generate_padding_closure()) .unwrap(); - let multi_threaded = TreeBuilder::new() + let multi_threaded = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes) .build_using_multi_threaded_algorithm(generate_padding_closure()) @@ -357,13 +357,13 @@ mod tests { let leaf_nodes = full_bottom_layer(&height); - let single_threaded = TreeBuilder::new() + let single_threaded = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes.clone()) .build_using_single_threaded_algorithm(generate_padding_closure()) .unwrap(); - let multi_threaded = TreeBuilder::new() + let multi_threaded = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes) .build_using_multi_threaded_algorithm(generate_padding_closure()) @@ -381,13 +381,13 @@ mod tests { for i in 0..height.max_bottom_layer_nodes() { let leaf_node = vec![single_leaf(i)]; - let single_threaded = TreeBuilder::new() + let single_threaded = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_node.clone()) .build_using_single_threaded_algorithm(generate_padding_closure()) .unwrap(); - let multi_threaded = TreeBuilder::new() + let multi_threaded = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_node) .build_using_multi_threaded_algorithm(generate_padding_closure()) @@ -405,7 +405,7 @@ mod tests { fn err_when_parent_builder_height_not_set() { let height = Height::expect_from(4); let leaf_nodes = full_bottom_layer(&height); - let res = TreeBuilder::new().with_leaf_nodes(leaf_nodes).height(); + let res = BinaryTreeBuilder::new().with_leaf_nodes(leaf_nodes).height(); // cannot use assert_err because it requires Func to have the Debug trait assert_err_simple!(res, Err(TreeBuildError::NoHeightProvided)); @@ -414,7 +414,7 @@ mod tests { #[test] fn err_when_parent_builder_leaf_nodes_not_set() { let height = Height::expect_from(4); - let res = TreeBuilder::::new() + let res = BinaryTreeBuilder::::new() .with_height(height) .leaf_nodes(&height); @@ -425,7 +425,7 @@ mod tests { #[test] fn err_for_empty_leaves() { let height = Height::expect_from(5); - let res = TreeBuilder::::new() + let res = BinaryTreeBuilder::::new() .with_height(height) .with_leaf_nodes(Vec::new()) .leaf_nodes(&height); @@ -446,7 +446,7 @@ mod tests { }, }); - let res = TreeBuilder::new() + let res = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes) .leaf_nodes(&height); diff --git a/src/binary_tree/tree_builder/multi_threaded.rs b/src/binary_tree/tree_builder/multi_threaded.rs index fd691122..37b1375d 100644 --- a/src/binary_tree/tree_builder/multi_threaded.rs +++ b/src/binary_tree/tree_builder/multi_threaded.rs @@ -619,7 +619,7 @@ pub(crate) mod tests { fn err_when_parent_builder_height_not_set() { let height = Height::expect_from(4); let leaf_nodes = full_bottom_layer(&height); - let res = TreeBuilder::new() + let res = BinaryTreeBuilder::new() .with_leaf_nodes(leaf_nodes) .build_using_multi_threaded_algorithm(generate_padding_closure()); @@ -630,7 +630,7 @@ pub(crate) mod tests { #[test] fn err_when_parent_builder_leaf_nodes_not_set() { let height = Height::expect_from(4); - let res = TreeBuilder::new() + let res = BinaryTreeBuilder::new() .with_height(height) .build_using_multi_threaded_algorithm(generate_padding_closure()); @@ -641,7 +641,7 @@ pub(crate) mod tests { #[test] fn err_for_empty_leaves() { let height = Height::expect_from(5); - let res = TreeBuilder::::new() + let res = BinaryTreeBuilder::::new() .with_height(height) .with_leaf_nodes(Vec::>::new()) .build_using_multi_threaded_algorithm(generate_padding_closure()); @@ -663,7 +663,7 @@ pub(crate) mod tests { }, }); - let res = TreeBuilder::new() + let res = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes) .build_using_multi_threaded_algorithm(generate_padding_closure()); @@ -684,7 +684,7 @@ pub(crate) mod tests { let mut leaf_nodes = sparse_leaves(&height); leaf_nodes.push(single_leaf(leaf_nodes.get(0).unwrap().x_coord)); - let res = TreeBuilder::new() + let res = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes) .build_using_multi_threaded_algorithm(generate_padding_closure()); @@ -698,7 +698,7 @@ pub(crate) mod tests { let height = Height::expect_from(4); let leaf_node = single_leaf(height.max_bottom_layer_nodes() + 1); - let res = TreeBuilder::new() + let res = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(vec![leaf_node]) .build_using_multi_threaded_algorithm(generate_padding_closure()); @@ -716,7 +716,7 @@ pub(crate) mod tests { let height = Height::expect_from(4); let mut leaf_nodes = sparse_leaves(&height); - let tree = TreeBuilder::new() + let tree = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes.clone()) .build_using_multi_threaded_algorithm(generate_padding_closure()) @@ -725,7 +725,7 @@ pub(crate) mod tests { leaf_nodes.shuffle(&mut thread_rng()); - let tree = TreeBuilder::new() + let tree = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes) .build_using_multi_threaded_algorithm(generate_padding_closure()) @@ -739,7 +739,7 @@ pub(crate) mod tests { let height = Height::expect_from(5); let leaf_nodes = sparse_leaves(&height); - let tree = TreeBuilder::new() + let tree = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes.clone()) .build_using_multi_threaded_algorithm(generate_padding_closure()) @@ -760,7 +760,7 @@ pub(crate) mod tests { let height = Height::expect_from(8); let leaf_nodes = full_bottom_layer(&height); - let tree = TreeBuilder::new() + let tree = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes.clone()) .build_using_multi_threaded_algorithm(generate_padding_closure()) @@ -798,7 +798,7 @@ pub(crate) mod tests { // TODO fuzz on this store depth let store_depth = 1; - let tree = TreeBuilder::new() + let tree = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes.clone()) .with_store_depth(store_depth) @@ -847,7 +847,7 @@ pub(crate) mod tests { let store_depth = height.as_u8(); let leaf_nodes = random_leaf_nodes(num_leaf_nodes, &height, randomness); - let tree = TreeBuilder::new() + let tree = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes) .with_store_depth(store_depth) @@ -869,7 +869,7 @@ pub(crate) mod tests { let leaf_nodes = random_leaf_nodes(num_leaf_nodes, &height, seed); let expected_number_of_nodes_in_store = max_nodes_to_store(num_leaf_nodes, &height) - 1; - let tree = TreeBuilder::new() + let tree = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes) .with_store_depth(store_depth) diff --git a/src/binary_tree/tree_builder/single_threaded.rs b/src/binary_tree/tree_builder/single_threaded.rs index 69f4adb0..8d7d3e6c 100644 --- a/src/binary_tree/tree_builder/single_threaded.rs +++ b/src/binary_tree/tree_builder/single_threaded.rs @@ -328,7 +328,7 @@ mod tests { fn err_when_parent_builder_height_not_set() { let height = Height::expect_from(4); let leaf_nodes = full_bottom_layer(&height); - let res = TreeBuilder::new() + let res = BinaryTreeBuilder::new() .with_leaf_nodes(leaf_nodes) .build_using_single_threaded_algorithm(generate_padding_closure()); @@ -339,7 +339,7 @@ mod tests { #[test] fn err_when_parent_builder_leaf_nodes_not_set() { let height = Height::expect_from(4); - let res = TreeBuilder::new() + let res = BinaryTreeBuilder::new() .with_height(height) .build_using_single_threaded_algorithm(generate_padding_closure()); @@ -350,7 +350,7 @@ mod tests { #[test] fn err_for_empty_leaves() { let height = Height::expect_from(5); - let res = TreeBuilder::::new() + let res = BinaryTreeBuilder::::new() .with_height(height) .with_leaf_nodes(Vec::>::new()) .build_using_single_threaded_algorithm(generate_padding_closure()); @@ -372,7 +372,7 @@ mod tests { }, }); - let res = TreeBuilder::new() + let res = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes) .build_using_single_threaded_algorithm(generate_padding_closure()); @@ -392,7 +392,7 @@ mod tests { let mut leaf_nodes = sparse_leaves(&height); leaf_nodes.push(single_leaf(leaf_nodes.get(0).unwrap().x_coord)); - let res = TreeBuilder::new() + let res = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes) .build_using_single_threaded_algorithm(generate_padding_closure()); @@ -406,7 +406,7 @@ mod tests { let height = Height::expect_from(4); let leaf_node = single_leaf(height.max_bottom_layer_nodes() + 1); - let res = TreeBuilder::new() + let res = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(vec![leaf_node]) .build_using_single_threaded_algorithm(generate_padding_closure()); @@ -424,7 +424,7 @@ mod tests { let height = Height::expect_from(4); let mut leaf_nodes = sparse_leaves(&height); - let tree = TreeBuilder::new() + let tree = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes.clone()) .build_using_single_threaded_algorithm(&generate_padding_closure()) @@ -433,7 +433,7 @@ mod tests { leaf_nodes.shuffle(&mut thread_rng()); - let tree = TreeBuilder::new() + let tree = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes) .build_using_single_threaded_algorithm(&generate_padding_closure()) @@ -447,7 +447,7 @@ mod tests { let height = Height::expect_from(5); let leaf_nodes = sparse_leaves(&height); - let tree = TreeBuilder::new() + let tree = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes.clone()) .build_using_single_threaded_algorithm(&generate_padding_closure()) @@ -468,7 +468,7 @@ mod tests { let height = Height::expect_from(8); let leaf_nodes = full_bottom_layer(&height); - let tree = TreeBuilder::new() + let tree = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes.clone()) .build_using_single_threaded_algorithm(&generate_padding_closure()) @@ -506,7 +506,7 @@ mod tests { // TODO fuzz on this store depth let store_depth = 1; - let tree = TreeBuilder::new() + let tree = BinaryTreeBuilder::new() .with_height(height) .with_leaf_nodes(leaf_nodes.clone()) .with_store_depth(store_depth) diff --git a/src/cli.rs b/src/cli.rs index 00147ad6..0602881a 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -2,7 +2,7 @@ //! //! See [MAIN_LONG_ABOUT] for more information. -use clap::{command, Args, Parser, Subcommand, ValueEnum}; +use clap::{command, Args, Parser, Subcommand}; use clap_verbosity_flag::{Verbosity, WarnLevel}; use patharg::{InputArg, OutputArg}; use primitive_types::H256; @@ -10,9 +10,10 @@ use primitive_types::H256; use std::str::FromStr; use crate::{ + accumulators::AccumulatorType, binary_tree::Height, - inclusion_proof::DEFAULT_RANGE_PROOF_UPPER_BOUND_BIT_LENGTH, - percentage::{Percentage, ONE_HUNDRED_PERCENT}, MaxThreadCount, + percentage::{Percentage, ONE_HUNDRED_PERCENT}, + MaxLiability, MaxThreadCount, Salt }; // ------------------------------------------------------------------------------------------------- @@ -44,7 +45,7 @@ pub enum Command { /// Inclusion proofs can be generated, but configuration is not supported. /// If you want more config options then use the `gen-proofs` command. BuildTree { - /// Choose the accumulator type for the tree. + /// Config DAPOL tree. #[command(subcommand)] build_kind: BuildKindCommand, @@ -81,10 +82,6 @@ pub enum Command { /// are aggregated using the Bulletproofs protocol. #[arg(short, long, value_parser = Percentage::from_str, default_value = ONE_HUNDRED_PERCENT, value_name = "PERCENTAGE")] range_proof_aggregation: Percentage, - - /// Upper bound for the range proofs is 2^(this_number). - #[arg(short, long, default_value_t = DEFAULT_RANGE_PROOF_UPPER_BOUND_BIT_LENGTH, value_name = "U8_INT")] - upper_bound_bit_length: u8, }, /// Verify an inclusion proof. @@ -109,19 +106,25 @@ pub enum BuildKindCommand { /// supported by the configuration file format which can be found in the ///`build-tree config-file` command."; New { - /// Choose an accumulator type for the tree. - #[arg(short, long, value_enum)] - accumulator: AccumulatorType, + #[arg(short, long, value_enum, help = include_str!("./shared_docs/accumulator_type.md"))] + accumulator_type: AccumulatorType, + + #[arg(long, value_parser = Salt::from_str, help = include_str!("./shared_docs/salt_b.md"))] + salt_b: Option, + + #[arg(long, value_parser = Salt::from_str, help = include_str!("./shared_docs/salt_s.md"))] + salt_s: Option, - /// Height to use for the binary tree. - #[arg(long, value_parser = Height::from_str, default_value = Height::default(), value_name = "U8_INT")] + #[arg(long, value_parser = Height::from_str, default_value = Height::default(), value_name = "U8_INT", help = include_str!("./shared_docs/height.md"))] height: Height, - /// Max thread count allowed for parallel tree builder. - #[arg(long, value_parser = MaxThreadCount::from_str, default_value = MaxThreadCount::default(), value_name = "U8_INT")] + #[arg(long, value_parser = MaxLiability::from_str, default_value = MaxLiability::default(), value_name = "U64_INT", help = include_str!("./shared_docs/max_liability.md"))] + max_liability: MaxLiability, + + #[arg(long, value_parser = MaxThreadCount::from_str, default_value = MaxThreadCount::default(), value_name = "U8_INT", help = include_str!("./shared_docs/max_thread_count.md"))] max_thread_count: MaxThreadCount, - #[arg(short, long, value_name = "FILE_PATH", long_help = NDM_SMT_SECRETS_HELP)] + #[arg(short, long, value_name = "FILE_PATH", long_help = SECRETS_HELP)] secrets_file: Option, #[command(flatten)] @@ -138,12 +141,6 @@ pub enum BuildKindCommand { Deserialize { path: InputArg }, } -#[derive(ValueEnum, Debug, Clone)] -pub enum AccumulatorType { - NdmSmt, - // TODO other accumulators.. -} - #[derive(Args, Debug)] #[group(required = true, multiple = false)] pub struct EntitySource { @@ -187,12 +184,10 @@ overwritten (if it exists) or created (if it does not exist). The file extension must be `.dapoltree`. The serialization option is ignored if `build-tree deserialize` command is used."; -const NDM_SMT_SECRETS_HELP: &str = " +const SECRETS_HELP: &str = " TOML file containing secrets. The file format is as follows: ``` master_secret = \"master_secret\" -salt_b = \"salt_b\" -salt_s = \"salt_s\" ``` All secrets should have at least 128-bit security, but need not be chosen from a uniform distribution as they are passed through a key derivation function before @@ -206,37 +201,17 @@ CSV file format: entity_id,liability"; const COMMAND_CONFIG_FILE_ABOUT: &str = - "Read accumulator type and other tree configuration from a file. Supported file formats: TOML."; + "Read tree configuration from a file. Supported file formats: TOML."; -const COMMAND_CONFIG_FILE_LONG_ABOUT: &str = " -Read accumulator type and other tree configuration from a file. +const COMMAND_CONFIG_FILE_LONG_ABOUT: &str = concat!( + " +Read tree configuration from a file. Supported file formats: TOML. Config file format (TOML): ``` -# Accumulator type of the tree. -# This value determines what other values are required. -accumulator_type = \"ndm-smt\" - -# Height of the tree. -# If the height is not set the default height will be used. -height = 16 - -# Path to the secrets file. -# If not present the secrets will be generated randomly. -secrets_file_path = \"./examples/ndm_smt_secrets_example.toml\" - -# Can be a file or directory (default file name given in this case) -# If not present then no serialization is done. -serialization_path = \"./tree.dapoltree\" - -# At least one of file_path & generate_random must be present. -# If both are given then file_path is prioritized. -[entities] - -# Path to a file containing a list of entity IDs and their liabilities. -file_path = \"./examples/entities_example.csv\" - -# Generate the given number of entities, with random IDs & liabilities. -generate_random = 4 -```"; +", + include_str!("../examples/dapol_config_example.toml"), + " +```" +); diff --git a/src/dapol_config.rs b/src/dapol_config.rs new file mode 100644 index 00000000..fb2ae3b9 --- /dev/null +++ b/src/dapol_config.rs @@ -0,0 +1,815 @@ +use derive_builder::Builder; +use log::{debug, info}; +use serde::Deserialize; +use std::{ffi::OsString, fs::File, io::Read, path::PathBuf, str::FromStr}; + +use crate::{ + accumulators::AccumulatorType, + entity::{self, EntitiesParser}, + utils::LogOnErr, + DapolTree, DapolTreeError, Height, MaxLiability, MaxThreadCount, Salt, Secret, +}; +use crate::{salt, secret}; + +/// Configuration needed to construct a [crate][DapolTree]. +/// +/// The config is defined by a struct. A builder pattern is used to construct +/// the config, but it can also be constructed by deserializing a file. +/// Currently only toml files are supported, with the following format: +/// +/// ```toml,ignore +#[doc = include_str!("../examples/dapol_config_example.toml")] +/// ``` +/// +/// Example of how to use the builder to construct a [crate][DapolTree]: +/// ``` +/// use std::{path::PathBuf, str::FromStr}; +/// use dapol::{ +/// AccumulatorType, DapolConfigBuilder, DapolTree, Entity, Height, +/// MaxLiability, MaxThreadCount, Salt, Secret, +/// }; +/// +/// let secrets_file_path = +/// PathBuf::from("./examples/dapol_secrets_example.toml"); +/// let entities_file_path = PathBuf::from("./examples/entities_example.csv"); +/// let height = Height::expect_from(8); +/// let salt_b = Salt::from_str("salt_b").unwrap(); +/// let salt_s = Salt::from_str("salt_s").unwrap(); +/// let max_liability = MaxLiability::from(10_000_000); +/// let max_thread_count = MaxThreadCount::from(8); +/// +/// // The builder requires at least the following to be given: +/// // - accumulator_type +/// // - entities +/// // - secrets +/// let dapol_config = DapolConfigBuilder::default() +/// .accumulator_type(AccumulatorType::NdmSmt) +/// .height(height.clone()) +/// .salt_b(salt_b.clone()) +/// .salt_s(salt_s.clone()) +/// .max_liability(max_liability.clone()) +/// .max_thread_count(max_thread_count.clone()) +/// .secrets_file_path(secrets_file_path.clone()) +/// .entities_file_path(entities_file_path.clone()) +/// .build() +/// .unwrap(); +/// ``` +/// +/// Example of how to use a config file to construct a [crate][DapolTree]: +/// ``` +/// use std::{path::PathBuf, str::FromStr}; +/// use dapol::DapolConfig; +/// +/// let config_file_path = +/// PathBuf::from("./examples/dapol_config_example.toml"); +/// let dapol_config_from_file = +/// DapolConfig::deserialize(config_file_path).unwrap(); +/// ``` +/// +/// Note that you can also construct a [crate][DapolTree] by calling the +/// constructor directly (see [crate][DapolTree]). +#[derive(Deserialize, Debug, Builder, PartialEq)] +#[builder(build_fn(skip))] +pub struct DapolConfig { + #[doc = include_str!("./shared_docs/accumulator_type.md")] + accumulator_type: AccumulatorType, + + #[doc = include_str!("./shared_docs/salt_b.md")] + salt_b: Salt, + + #[doc = include_str!("./shared_docs/salt_s.md")] + salt_s: Salt, + + #[doc = include_str!("./shared_docs/max_liability.md")] + max_liability: MaxLiability, + + #[doc = include_str!("./shared_docs/height.md")] + height: Height, + + #[doc = include_str!("./shared_docs/max_thread_count.md")] + max_thread_count: MaxThreadCount, + + #[builder(private)] + entities: EntityConfig, + + #[builder(private)] + secrets: SecretsConfig, +} + +#[derive(Deserialize, Debug, Clone, Default, PartialEq)] +pub struct SecretsConfig { + file_path: Option, + master_secret: Option, +} + +#[derive(Deserialize, Debug, Clone, Default, PartialEq)] +pub struct EntityConfig { + file_path: Option, + num_random_entities: Option, +} + +// ------------------------------------------------------------------------------------------------- +// Builder. + +impl DapolConfigBuilder { + /// Set the path for the file containing the entity data. + /// + /// Wrapped in an option to provide ease of use if the PathBuf is already + /// an option. + pub fn entities_file_path_opt(&mut self, path: Option) -> &mut Self { + match &mut self.entities { + None => { + self.entities = Some(EntityConfig { + file_path: path, + num_random_entities: None, + }) + } + Some(entities) => entities.file_path = path, + } + self + } + + /// Set the path for the file containing the entity data. + pub fn entities_file_path(&mut self, path: PathBuf) -> &mut Self { + self.entities_file_path_opt(Some(path)) + } + + /// Set the number of entities that will be generated randomly. + /// + /// If a path is also given for the entities then that is used instead, + /// i.e. they are not combined. + /// + /// Wrapped in an option to provide ease of use if the PathBuf is already + /// an option. + pub fn num_random_entities_opt(&mut self, num_entities: Option) -> &mut Self { + match &mut self.entities { + None => { + self.entities = Some(EntityConfig { + file_path: None, + num_random_entities: num_entities, + }) + } + Some(entities) => entities.num_random_entities = num_entities, + } + self + } + + /// Set the number of entities that will be generated randomly. + /// + /// If a path is also given for the entities then that is used instead, + /// i.e. they are not combined. + pub fn num_random_entities(&mut self, num_entities: u64) -> &mut Self { + self.num_random_entities_opt(Some(num_entities)) + } + + /// Set the path for the file containing the secrets. + /// + /// Wrapped in an option to provide ease of use if the PathBuf is already + /// an option. + pub fn secrets_file_path_opt(&mut self, path: Option) -> &mut Self { + match &mut self.secrets { + None => { + self.secrets = Some(SecretsConfig { + file_path: path, + master_secret: None, + }) + } + Some(secrets) => secrets.file_path = path, + } + self + } + + /// Set the path for the file containing the secrets. + pub fn secrets_file_path(&mut self, path: PathBuf) -> &mut Self { + self.secrets_file_path_opt(Some(path)) + } + + /// Set the master secret value directly. + #[doc = include_str!("./shared_docs/master_secret.md")] + pub fn master_secret(&mut self, master_secret: Secret) -> &mut Self { + match &mut self.secrets { + None => { + self.secrets = Some(SecretsConfig { + file_path: None, + master_secret: Some(master_secret), + }) + } + Some(secrets) => secrets.master_secret = Some(master_secret), + } + self + } + + #[doc = include_str!("./shared_docs/salt_b.md")] + /// + /// Wrapped in an option to provide ease of use if the value is already + /// an option. + pub fn salt_b_opt(&mut self, salt_b: Option) -> &mut Self { + self.salt_b = salt_b; + self + } + + #[doc = include_str!("./shared_docs/salt_s.md")] + /// + /// Wrapped in an option to provide ease of use if the value is already + /// an option. + pub fn salt_s_opt(&mut self, salt_s: Option) -> &mut Self { + self.salt_s = salt_s; + self + } + + /// Build the config struct. + pub fn build(&self) -> Result { + let accumulator_type = + self.accumulator_type + .clone() + .ok_or(DapolConfigBuilderError::UninitializedField( + "accumulator_type", + ))?; + + let entities = EntityConfig { + file_path: self.entities.clone().and_then(|e| e.file_path).or(None), + num_random_entities: self + .entities + .clone() + .and_then(|e| e.num_random_entities) + .or(None), + }; + + if entities.file_path.is_none() && entities.num_random_entities.is_none() { + return Err(DapolConfigBuilderError::UninitializedField("entities")); + } + + let secrets = SecretsConfig { + file_path: self.secrets.clone().and_then(|e| e.file_path).or(None), + master_secret: self.secrets.clone().and_then(|e| e.master_secret).or(None), + }; + + if secrets.file_path.is_none() && secrets.master_secret.is_none() { + return Err(DapolConfigBuilderError::UninitializedField("secrets")); + } + + let salt_b = self.salt_b.clone().unwrap_or_default(); + let salt_s = self.salt_s.clone().unwrap_or_default(); + let height = self.height.unwrap_or_default(); + let max_thread_count = self.max_thread_count.unwrap_or_default(); + let max_liability = self.max_liability.unwrap_or_default(); + + Ok(DapolConfig { + accumulator_type, + salt_b, + salt_s, + max_liability, + height, + max_thread_count, + entities, + secrets, + }) + } +} + +// ------------------------------------------------------------------------------------------------- +// Deserialization & parsing. + +impl DapolConfig { + /// Open the file, then try to create the [DapolConfig] struct. + /// + /// An error is returned if: + /// 1. The file cannot be opened. + /// 2. The file cannot be read. + /// 3. The file type is not supported. + /// + /// Config deserialization example: + /// ``` + /// use std::path::PathBuf; + /// use dapol::DapolConfig; + /// + /// let file_path = PathBuf::from("./examples/dapol_config_example.toml"); + /// let config = DapolConfig::deserialize(file_path).unwrap(); + /// ``` + pub fn deserialize(config_file_path: PathBuf) -> Result { + debug!( + "Attempting to deserialize {:?} as a file containing DAPOL config", + config_file_path.clone().into_os_string() + ); + + let ext = config_file_path + .extension() + .and_then(|s| s.to_str()) + .ok_or(DapolConfigError::UnknownFileType( + config_file_path.clone().into_os_string(), + ))?; + + let mut config = match FileType::from_str(ext)? { + FileType::Toml => { + let mut buf = String::new(); + File::open(config_file_path.clone())?.read_to_string(&mut buf)?; + let config: DapolConfig = toml::from_str(&buf)?; + config + } + }; + + config.entities.file_path = + extend_path_if_relative(config_file_path.clone(), config.entities.file_path); + config.secrets.file_path = + extend_path_if_relative(config_file_path, config.secrets.file_path); + + debug!("Successfully deserialized DAPOL config file"); + + Ok(config) + } + + /// Try to construct a [crate][DapolTree] from the config. + pub fn parse(self) -> Result { + debug!("Parsing config to create a new DAPOL tree: {:?}", self); + + let salt_b = self.salt_b; + let salt_s = self.salt_s; + + let entities = EntitiesParser::new() + .with_path_opt(self.entities.file_path) + .with_num_entities_opt(self.entities.num_random_entities) + .parse_file_or_generate_random()?; + + let master_secret = if let Some(path) = self.secrets.file_path { + Ok(DapolConfig::parse_secrets_file(path)?) + } else if let Some(master_secret) = self.secrets.master_secret { + Ok(master_secret) + } else { + Err(DapolConfigError::CannotFindMasterSecret) + }?; + + let dapol_tree = DapolTree::new( + self.accumulator_type, + master_secret, + salt_b, + salt_s, + self.max_liability, + self.max_thread_count, + self.height, + entities, + ) + .log_on_err()?; + + info!( + "Successfully built DAPOL tree with root hash {:?}", + dapol_tree.root_hash() + ); + + Ok(dapol_tree) + } + + /// Open and parse the secrets file, returning a [crate][Secret]. + /// + /// An error is returned if: + /// 1. The path is None (i.e. was not set). + /// 2. The file cannot be opened. + /// 3. The file cannot be read. + /// 4. The file type is not supported. + fn parse_secrets_file(path: PathBuf) -> Result { + debug!( + "Attempting to parse {:?} as a file containing secrets", + path + ); + + let ext = path.extension().and_then(|s| s.to_str()).ok_or( + SecretsParserError::UnknownFileType(path.clone().into_os_string()), + )?; + + let master_secret = match FileType::from_str(ext)? { + FileType::Toml => { + let mut buf = String::new(); + File::open(path)?.read_to_string(&mut buf)?; + let secrets: DapolSecrets = toml::from_str(&buf)?; + secrets.master_secret + } + }; + + debug!("Successfully parsed DAPOL secrets file",); + + Ok(master_secret) + } +} + +fn extend_path_if_relative( + leader_path: PathBuf, + possibly_relative_path: Option, +) -> Option { + match possibly_relative_path { + Some(path) => Some( + path.strip_prefix("./") + .map(|p| p.to_path_buf()) + .ok() + .and_then(|tail| leader_path.parent().map(|parent| parent.join(tail))) + .unwrap_or(path.clone()), + ), + None => None, + } +} + +/// Supported file types for deserialization. +enum FileType { + Toml, +} + +impl FromStr for FileType { + type Err = SecretsParserError; + + fn from_str(ext: &str) -> Result { + match ext { + "toml" => Ok(FileType::Toml), + _ => Err(SecretsParserError::UnsupportedFileType { ext: ext.into() }), + } + } +} + +#[derive(Deserialize, Debug)] +struct DapolSecrets { + master_secret: Secret, +} + +// ------------------------------------------------------------------------------------------------- +// Errors. + +/// Errors encountered when parsing [crate][DapolConfig]. +#[derive(thiserror::Error, Debug)] +pub enum DapolConfigError { + #[error("Entities parsing failed while trying to parse DAPOL config")] + EntitiesError(#[from] entity::EntitiesParserError), + #[error("Error parsing the master secret string")] + MasterSecretParseError(#[from] secret::SecretParserError), + #[error("Error parsing the master secret file")] + MasterSecretFileParseError(#[from] SecretsParserError), + #[error("Either master secret must be set directly, or a path to a file containing it must be given")] + CannotFindMasterSecret, + #[error("Error parsing the salt string")] + SaltParseError(#[from] salt::SaltParserError), + #[error("Tree construction failed after parsing DAPOL config")] + BuildError(#[from] DapolTreeError), + #[error("Unable to find file extension for path {0:?}")] + UnknownFileType(OsString), + #[error("The file type with extension {ext:?} is not supported")] + UnsupportedFileType { ext: String }, + #[error("Error reading the file")] + FileReadError(#[from] std::io::Error), + #[error("Deserialization process failed")] + DeserializationError(#[from] toml::de::Error), +} + +#[derive(thiserror::Error, Debug)] +pub enum SecretsParserError { + #[error("Unable to find file extension for path {0:?}")] + UnknownFileType(OsString), + #[error("The file type with extension {ext:?} is not supported")] + UnsupportedFileType { ext: String }, + #[error("Error reading the file")] + FileReadError(#[from] std::io::Error), + #[error("Deserialization process failed")] + DeserializationError(#[from] toml::de::Error), +} + +// ------------------------------------------------------------------------------------------------- +// Unit tests + +#[cfg(test)] +mod tests { + use crate::accumulators::Accumulator; + use crate::utils::test_utils::assert_err; + + use super::*; + use std::fs::File; + use std::io::{BufRead, BufReader}; + use std::path::Path; + + // Matches the config found in the dapol_config_example.toml file. + fn dapol_config_builder_matching_example_file() -> DapolConfigBuilder { + let src_dir = env!("CARGO_MANIFEST_DIR"); + let resources_dir = Path::new(&src_dir).join("examples"); + let secrets_file_path = resources_dir.join("dapol_secrets_example.toml"); + let entities_file_path = resources_dir.join("entities_example.csv"); + + let height = Height::expect_from(16u8); + let salt_b = Salt::from_str("salt_b").unwrap(); + let salt_s = Salt::from_str("salt_s").unwrap(); + let max_liability = MaxLiability::from(10_000_000u64); + let max_thread_count = MaxThreadCount::from(8u8); + let master_secret = Secret::from_str("master_secret").unwrap(); + let num_entities = 100u64; + + DapolConfigBuilder::default() + .accumulator_type(AccumulatorType::NdmSmt) + .height(height.clone()) + .salt_b(salt_b.clone()) + .salt_s(salt_s.clone()) + .max_liability(max_liability.clone()) + .max_thread_count(max_thread_count.clone()) + .secrets_file_path(secrets_file_path.clone()) + .master_secret(master_secret.clone()) + .entities_file_path(entities_file_path.clone()) + .num_random_entities(num_entities) + .clone() + } + + mod creating_config { + use super::*; + + #[test] + fn builder_with_all_default_values_gives_correct_config() { + // The builder requires at least the following to be given: + // - accumulator_type + // - entities + // - secrets + // The rest are left as default. + + let src_dir = env!("CARGO_MANIFEST_DIR"); + let resources_dir = Path::new(&src_dir).join("examples"); + let secrets_file_path = resources_dir.join("dapol_secrets_example.toml"); + let entities_file_path = resources_dir.join("entities_example.csv"); + + let dapol_config = DapolConfigBuilder::default() + .accumulator_type(AccumulatorType::NdmSmt) + .secrets_file_path(secrets_file_path.clone()) + .entities_file_path(entities_file_path.clone()) + .build() + .unwrap(); + + // Assert the values that were explicitly set: + assert_eq!(dapol_config.accumulator_type, AccumulatorType::NdmSmt); + assert_eq!(dapol_config.entities.file_path, Some(entities_file_path)); + assert_eq!(dapol_config.secrets.file_path, Some(secrets_file_path)); + + // Assert the values that were not set: + assert_eq!(dapol_config.entities.num_random_entities, None); + assert_eq!(dapol_config.secrets.master_secret, None); + assert_eq!(dapol_config.max_thread_count, MaxThreadCount::default()); + assert_eq!(dapol_config.height, Height::default()); + assert_eq!(dapol_config.max_liability, MaxLiability::default()); + + // Salts should be random bytes. Check that at least one byte is non-zero. + assert!(dapol_config.salt_b.as_bytes().iter().any(|b| *b != 0u8)); + assert!(dapol_config.salt_s.as_bytes().iter().any(|b| *b != 0u8)); + } + + #[test] + fn builder_with_no_default_values_gives_correct_config() { + let src_dir = env!("CARGO_MANIFEST_DIR"); + let resources_dir = Path::new(&src_dir).join("examples"); + let secrets_file_path = resources_dir.join("dapol_secrets_example.toml"); + let entities_file_path = resources_dir.join("entities_example.csv"); + + let height = Height::expect_from(16u8); + let salt_b = Salt::from_str("salt_b").unwrap(); + let salt_s = Salt::from_str("salt_s").unwrap(); + let max_liability = MaxLiability::from(10_000_000u64); + let max_thread_count = MaxThreadCount::from(8u8); + let master_secret = Secret::from_str("master_secret").unwrap(); + let num_entities = 100u64; + + let dapol_config = dapol_config_builder_matching_example_file() + .build() + .unwrap(); + + assert_eq!(dapol_config.accumulator_type, AccumulatorType::NdmSmt); + assert_eq!(dapol_config.entities.file_path, Some(entities_file_path)); + assert_eq!(dapol_config.secrets.file_path, Some(secrets_file_path)); + assert_eq!( + dapol_config.entities.num_random_entities, + Some(num_entities) + ); + assert_eq!(dapol_config.secrets.master_secret, Some(master_secret)); + assert_eq!(dapol_config.max_thread_count, max_thread_count); + assert_eq!(dapol_config.max_liability, max_liability); + assert_eq!(dapol_config.height, height); + assert_eq!(dapol_config.salt_b, salt_b); + assert_eq!(dapol_config.salt_s, salt_s); + } + + #[test] + fn config_file_gives_same_config_as_builder() { + let src_dir = env!("CARGO_MANIFEST_DIR"); + let resources_dir = Path::new(&src_dir).join("examples"); + let config_file_path = resources_dir.join("dapol_config_example.toml"); + + let dapol_config_from_file = DapolConfig::deserialize(config_file_path).unwrap(); + let dapol_config_from_builder = dapol_config_builder_matching_example_file() + .build() + .unwrap(); + + assert_eq!(dapol_config_from_file, dapol_config_from_builder); + } + + #[test] + fn builder_without_accumulator_type_fails() { + let master_secret = Secret::from_str("master_secret").unwrap(); + let num_entities = 100u64; + + let res = DapolConfigBuilder::default() + .master_secret(master_secret) + .num_random_entities(num_entities) + .build(); + + assert_err!( + res, + Err(DapolConfigBuilderError::UninitializedField( + "accumulator_type" + )) + ); + } + + #[test] + fn builder_without_secrets_fails() { + let num_entities = 100u64; + + let res = DapolConfigBuilder::default() + .accumulator_type(AccumulatorType::NdmSmt) + .num_random_entities(num_entities) + .build(); + + assert_err!( + res, + Err(DapolConfigBuilderError::UninitializedField("secrets")) + ); + } + + #[test] + fn builder_without_entities_fails() { + let master_secret = Secret::from_str("master_secret").unwrap(); + + let res = DapolConfigBuilder::default() + .accumulator_type(AccumulatorType::NdmSmt) + .master_secret(master_secret) + .build(); + + assert_err!( + res, + Err(DapolConfigBuilderError::UninitializedField("entities")) + ); + } + + #[test] + fn fail_when_unsupproted_secrets_file_type() { + let this_file = std::file!(); + let unsupported_path = PathBuf::from(this_file); + + let num_entities = 100u64; + + let res = DapolConfigBuilder::default() + .accumulator_type(AccumulatorType::NdmSmt) + .num_random_entities(num_entities) + .secrets_file_path(unsupported_path) + .build() + .unwrap() + .parse(); + + assert_err!( + res, + Err(DapolConfigError::MasterSecretFileParseError( + SecretsParserError::UnsupportedFileType { ext: _ } + )) + ); + } + + #[test] + fn fail_when_unknown_secrets_file_type() { + let no_file_ext = PathBuf::from("../LICENSE"); + + let num_entities = 100u64; + + let res = DapolConfigBuilder::default() + .accumulator_type(AccumulatorType::NdmSmt) + .num_random_entities(num_entities) + .secrets_file_path(no_file_ext) + .build() + .unwrap() + .parse(); + + assert_err!( + res, + Err(DapolConfigError::MasterSecretFileParseError( + SecretsParserError::UnknownFileType(_) + )) + ); + } + } + + // TODO these are actually integration tests, so move them to tests dir + mod config_to_tree { + use super::*; + + #[test] + fn parsing_config_gives_correct_tree() { + let src_dir = env!("CARGO_MANIFEST_DIR"); + let resources_dir = Path::new(&src_dir).join("examples"); + let entities_file_path = resources_dir.join("entities_example.csv"); + + let entities_file = File::open(entities_file_path.clone()).unwrap(); + // "-1" because we don't include the top line of the csv which defines + // the column headings. + let num_entities = BufReader::new(entities_file).lines().count() - 1; + + let height = Height::expect_from(8u8); + let master_secret = Secret::from_str("master_secret").unwrap(); + let salt_b = Salt::from_str("salt_b").unwrap(); + let salt_s = Salt::from_str("salt_s").unwrap(); + + let dapol_tree = DapolConfigBuilder::default() + .accumulator_type(AccumulatorType::NdmSmt) + .height(height.clone()) + .salt_b(salt_b.clone()) + .salt_s(salt_s.clone()) + .master_secret(master_secret.clone()) + .entities_file_path(entities_file_path.clone()) + .build() + .unwrap() + .parse() + .unwrap(); + + assert_eq!(dapol_tree.entity_mapping().unwrap().len(), num_entities as usize); + assert_eq!(dapol_tree.accumulator_type(), AccumulatorType::NdmSmt); + assert_eq!(*dapol_tree.height(), height); + assert_eq!(*dapol_tree.master_secret(), master_secret); + assert_eq!(dapol_tree.max_liability(), MaxLiability::default()); + assert_eq!(*dapol_tree.salt_b(), salt_b); + assert_eq!(*dapol_tree.salt_s(), salt_s); + } + + #[test] + fn config_with_random_entities_gives_correct_tree() { + let height = Height::expect_from(8); + let num_random_entities = 10; + let master_secret = Secret::from_str("master_secret").unwrap(); + + let dapol_tree = DapolConfigBuilder::default() + .accumulator_type(AccumulatorType::NdmSmt) + .height(height) + .master_secret(master_secret) + .num_random_entities(num_random_entities) + .build() + .unwrap() + .parse() + .unwrap(); + + assert_eq!(dapol_tree.entity_mapping().unwrap().len(), num_random_entities as usize); + } + + #[test] + fn secrets_file_gives_same_master_secret_as_setting_directly() { + let src_dir = env!("CARGO_MANIFEST_DIR"); + let resources_dir = Path::new(&src_dir).join("examples"); + let secrets_file_path = resources_dir.join("dapol_secrets_example.toml"); + let entities_file_path = resources_dir.join("entities_example.csv"); + let master_secret = Secret::from_str("master_secret").unwrap(); + let height = Height::expect_from(8u8); + + let tree_from_secrets_file = DapolConfigBuilder::default() + .accumulator_type(AccumulatorType::NdmSmt) + .height(height) + .secrets_file_path(secrets_file_path.clone()) + .entities_file_path(entities_file_path.clone()) + .build() + .unwrap() + .parse() + .unwrap(); + + let tree_from_direct_secret = DapolConfigBuilder::default() + .accumulator_type(AccumulatorType::NdmSmt) + .height(height) + .master_secret(master_secret.clone()) + .entities_file_path(entities_file_path.clone()) + .build() + .unwrap() + .parse() + .unwrap(); + + assert_eq!( + tree_from_direct_secret.master_secret(), + tree_from_secrets_file.master_secret() + ); + } + + #[test] + fn secrets_file_preferred_over_setting_directly() { + let src_dir = env!("CARGO_MANIFEST_DIR"); + let resources_dir = Path::new(&src_dir).join("examples"); + let secrets_file_path = resources_dir.join("dapol_secrets_example.toml"); + let entities_file_path = resources_dir.join("entities_example.csv"); + let master_secret = Secret::from_str("garbage").unwrap(); + let height = Height::expect_from(8u8); + + let dapol_tree = DapolConfigBuilder::default() + .accumulator_type(AccumulatorType::NdmSmt) + .height(height) + .secrets_file_path(secrets_file_path.clone()) + .master_secret(master_secret) + .entities_file_path(entities_file_path.clone()) + .build() + .unwrap() + .parse() + .unwrap(); + + assert_eq!( + dapol_tree.master_secret(), + &Secret::from_str("master_secret").unwrap() + ); + } + } +} diff --git a/src/dapol_tree.rs b/src/dapol_tree.rs new file mode 100644 index 00000000..69ae95b3 --- /dev/null +++ b/src/dapol_tree.rs @@ -0,0 +1,333 @@ +use log::{debug, info}; +use primitive_types::H256; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +use crate::{ + accumulators::{Accumulator, AccumulatorType, NdmSmt, NdmSmtError}, + read_write_utils::{self, ReadWriteError}, + utils::LogOnErr, + AggregationFactor, Entity, EntityId, Height, InclusionProof, MaxLiability, MaxThreadCount, + Salt, Secret, +}; + +const SERIALIZED_TREE_EXTENSION: &str = "dapoltree"; +const SERIALIZED_TREE_FILE_PREFIX: &str = "proof_of_liabilities_merkle_sum_tree_"; + +/// Proof of Liabilities Sparse Merkle Sum Tree. +/// +/// This is the top-most module in the hierarchy of the [dapol] crate. +/// +/// It is recommended that one use [crate][DapolConfig] to construct the +/// tree, which has extra sanity checks on the inputs and more ways to set +/// the parameters. But there is also a `new` function for direct construction. +#[derive(Debug, Serialize, Deserialize)] +pub struct DapolTree { + accumulator: Accumulator, + master_secret: Secret, + salt_s: Salt, + salt_b: Salt, + max_liability: MaxLiability, +} + +// ------------------------------------------------------------------------------------------------- +// Construction & proof generation. + +impl DapolTree { + /// Construct a new tree. + /// + /// It is recommended to rather use [crate][DapolConfig] to construct the + /// tree, which has extra sanity checks on the inputs and more ways to set + /// the parameters. + /// + /// An error is returned if the underlying accumulator type construction + /// fails. + /// + /// - `accumulator_type`: This value must be set. + #[doc = include_str!("./shared_docs/accumulator_type.md")] + /// - `master_secret`: This value is known only to the tree generator, and + /// is used to determine all other secret values needed in the tree. This + /// value must be set. + /// - `salt_b`: If not set then it will be randomly generated. + #[doc = include_str!("./shared_docs/salt_b.md")] + /// - `salt_s`: If not set then it will be + /// randomly generated. + #[doc = include_str!("./shared_docs/salt_s.md")] + /// - `max_liability`: If not set then a default value is used. + #[doc = include_str!("./shared_docs/max_liability.md")] + /// - `height`: If not set the [default height] will be used [crate][Height]. + #[doc = include_str!("./shared_docs/height.md")] + /// - `max_thread_count`: If not set the max parallelism of the + /// underlying machine will be used. + #[doc = include_str!("./shared_docs/max_thread_count.md")] + /// - `secrets_file_path`: Path to the secrets file. If not present the + /// secrets will be generated randomly. + /// - `entities`: + #[doc = include_str!("./shared_docs/entities_vector.md")] + /// + /// Example of how to use the construtor: + /// ``` + /// use std::str::FromStr; + /// use dapol::{ + /// AccumulatorType, DapolTree, Entity, EntityId, Height, MaxLiability, + /// MaxThreadCount, Salt, Secret, + /// }; + /// + /// let accumulator_type = AccumulatorType::NdmSmt; + /// let height = Height::expect_from(8); + /// let salt_b = Salt::from_str("salt_b").unwrap(); + /// let salt_s = Salt::from_str("salt_s").unwrap(); + /// let master_secret = Secret::from_str("master_secret").unwrap(); + /// let max_liability = MaxLiability::from(10_000_000); + /// let max_thread_count = MaxThreadCount::from(8); + /// + /// let entity = Entity { + /// liability: 1u64, + /// id: EntityId::from_str("id").unwrap(), + /// }; + /// let entities = vec![entity]; + /// + /// let dapol_tree = DapolTree::new( + /// accumulator_type, + /// master_secret, + /// salt_b, + /// salt_s, + /// max_liability, + /// max_thread_count, + /// height, + /// entities, + /// ).unwrap(); + /// ``` + /// + /// [default height]: crate::Height::default + pub fn new( + accumulator_type: AccumulatorType, + master_secret: Secret, + salt_b: Salt, + salt_s: Salt, + max_liability: MaxLiability, + max_thread_count: MaxThreadCount, + height: Height, + entities: Vec, + ) -> Result { + let accumulator = match accumulator_type { + AccumulatorType::NdmSmt => { + let ndm_smt = NdmSmt::new( + master_secret.clone(), + salt_b.clone(), + salt_s.clone(), + height, + max_thread_count, + entities, + )?; + Accumulator::NdmSmt(ndm_smt) + } + }; + + Ok(DapolTree { + accumulator, + master_secret, + salt_s, + salt_b, + max_liability, + }) + } + + /// Generate an inclusion proof for the given `entity_id`. + /// + /// `aggregation_factor` is used to determine how many of the range proofs + /// are aggregated. Those that do not form part of the aggregated proof + /// are just proved individually. The aggregation is a feature of the + /// Bulletproofs protocol that improves efficiency. + pub fn generate_inclusion_proof_with( + &self, + entity_id: &EntityId, + aggregation_factor: AggregationFactor, + ) -> Result { + match &self.accumulator { + Accumulator::NdmSmt(ndm_smt) => ndm_smt.generate_inclusion_proof_with( + &self.master_secret, + &self.salt_b, + &self.salt_s, + entity_id, + aggregation_factor, + self.max_liability.as_range_proof_upper_bound_bit_length(), + ), + } + } + + /// Generate an inclusion proof for the given `entity_id`. + pub fn generate_inclusion_proof( + &self, + entity_id: &EntityId, + ) -> Result { + match &self.accumulator { + Accumulator::NdmSmt(ndm_smt) => ndm_smt.generate_inclusion_proof( + &self.master_secret, + &self.salt_b, + &self.salt_s, + entity_id, + ), + } + } +} + +// ------------------------------------------------------------------------------------------------- +// Accessor methods. + +impl DapolTree { + #[doc = include_str!("./shared_docs/accumulator_type.md")] + pub fn accumulator_type(&self) -> AccumulatorType { + self.accumulator.get_type() + } + + #[doc = include_str!("./shared_docs/master_secret.md")] + pub fn master_secret(&self) -> &Secret { + &self.master_secret + } + + #[doc = include_str!("./shared_docs/salt_b.md")] + pub fn salt_b(&self) -> &Salt { + &self.salt_b + } + + #[doc = include_str!("./shared_docs/salt_s.md")] + pub fn salt_s(&self) -> &Salt { + &self.salt_s + } + + #[doc = include_str!("./shared_docs/max_liability.md")] + pub fn max_liability(&self) -> MaxLiability { + self.max_liability + } + + #[doc = include_str!("./shared_docs/height.md")] + pub fn height(&self) -> &Height { + self.accumulator.height() + } + + /// Mapping of [crate][EntityId] to x-coord on the bottom layer of the tree. + /// + /// If the underlying accumulator is an NDM-SMT then a hashmap is returned + /// otherwise None is returned. + pub fn entity_mapping(&self) -> Option<&std::collections::HashMap> { + match &self.accumulator { + Accumulator::NdmSmt(ndm_smt) => Some(ndm_smt.entity_mapping()), + _ => None, + } + } + + /// Return the hash digest/bytes of the root node for the binary tree. + pub fn root_hash(&self) -> H256 { + self.accumulator.root_hash() + } +} + +// ------------------------------------------------------------------------------------------------- +// Serialization & deserialization. + +impl DapolTree { + /// Try deserialize from the given file path. + /// + /// The file is assumed to be in [bincode] format. + /// + /// An error is logged and returned if + /// 1. The file cannot be opened. + /// 2. The [bincode] deserializer fails. + pub fn deserialize(path: PathBuf) -> Result { + debug!( + "Deserializing accumulator from file {:?}", + path.clone().into_os_string() + ); + + match path.extension() { + Some(ext) => { + if ext != SERIALIZED_TREE_EXTENSION { + Err(ReadWriteError::UnsupportedFileExtension { + expected: SERIALIZED_TREE_EXTENSION.to_owned(), + actual: ext.to_os_string(), + })?; + } + } + None => Err(ReadWriteError::NotAFile(path.clone().into_os_string()))?, + } + + let dapol_tree: DapolTree = + read_write_utils::deserialize_from_bin_file(path.clone()).log_on_err()?; + + let root_hash = match &dapol_tree.accumulator { + Accumulator::NdmSmt(ndm_smt) => ndm_smt.root_hash(), + }; + + info!( + "Successfully deserialized dapol tree from file {:?} with root hash {:?}", + path.clone().into_os_string(), + root_hash + ); + + Ok(dapol_tree) + } + + /// Parse `path` as one that points to a serialized dapol tree file. + /// + /// `path` can be either of the following: + /// 1. Existing directory: in this case a default file name is appended to + /// `path`. 2. Non-existing directory: in this case all dirs in the path + /// are created, and a default file name is appended. + /// 3. File in existing dir: in this case the extension is checked to be + /// [SERIALIZED_TREE_EXTENSION], then `path` is returned. + /// 4. File in non-existing dir: dirs in the path are created and the file + /// extension is checked. + /// + /// The file prefix is [SERIALIZED_TREE_FILE_PREFIX]. + /// + /// Example: + /// ``` + /// use dapol::DapolTree; + /// use std::path::PathBuf; + /// + /// let dir = PathBuf::from("./"); + /// let path = DapolTree::parse_serialization_path(dir).unwrap(); + /// ``` + pub fn parse_serialization_path(path: PathBuf) -> Result { + read_write_utils::parse_serialization_path( + path, + SERIALIZED_TREE_EXTENSION, + SERIALIZED_TREE_FILE_PREFIX, + ) + } + + /// Serialize to a file. + /// + /// Serialization is done using [bincode] + /// + /// An error is returned if + /// 1. [bincode] fails to serialize the file. + /// 2. There is an issue opening or writing the file. + pub fn serialize(&self, path: PathBuf) -> Result<(), DapolTreeError> { + let path = DapolTree::parse_serialization_path(path)?; + + info!( + "Serializing accumulator to file {:?}", + path.clone().into_os_string() + ); + + read_write_utils::serialize_to_bin_file(self, path).log_on_err()?; + Ok(()) + } +} + +// ------------------------------------------------------------------------------------------------- +// Errors. + +/// Errors encountered when handling an [Accumulator]. +#[derive(thiserror::Error, Debug)] +pub enum DapolTreeError { + #[error("Error serializing/deserializing file")] + SerdeError(#[from] ReadWriteError), + #[error("Error constructing a new NDM-SMT")] + NdmSmtConstructionError(#[from] NdmSmtError), +} + +// ------------------------------------------------------------------------------------------------- +// STENT TODO test diff --git a/src/entity.rs b/src/entity.rs index 7344617e..a29a938a 100644 --- a/src/entity.rs +++ b/src/entity.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; - -use std::{convert::From}; +use serde_with::DeserializeFromStr; +use std::convert::From; use std::str::FromStr; mod entities_parser; @@ -24,7 +24,7 @@ pub use entity_ids_parser::{EntityIdsParser, EntityIdsParserError}; /// chosen above 'user' because it has a more general connotation. /// /// The entity struct has only 2 fields: ID and liability. -#[derive(Deserialize, PartialEq)] +#[derive(Debug, Deserialize, PartialEq)] pub struct Entity { pub liability: u64, pub id: EntityId, @@ -33,11 +33,10 @@ pub struct Entity { /// The max size of the entity ID is 256 bits, but this is a soft limit so it /// can be increased if necessary. Note that the underlying array length will /// also have to be increased. -// STENT TODO this is not enforced on deserialization, do that pub const ENTITY_ID_MAX_BYTES: usize = 32; /// Abstract representation of an entity ID. -#[derive(PartialEq, Eq, Hash, Clone, Debug, Deserialize, Serialize)] +#[derive(PartialEq, Eq, Hash, Clone, Debug, DeserializeFromStr, Serialize)] pub struct EntityId(String); impl FromStr for EntityId { diff --git a/src/entity/entities_parser.rs b/src/entity/entities_parser.rs index 87197538..dc5395fe 100644 --- a/src/entity/entities_parser.rs +++ b/src/entity/entities_parser.rs @@ -183,6 +183,7 @@ pub enum EntitiesParserError { mod tests { use super::*; use std::path::Path; + use crate::utils::test_utils::assert_err; #[test] fn parser_csv_file_happy_case() { @@ -218,4 +219,19 @@ mod tests { .unwrap(); assert_eq!(entities.len(), num_entities as usize); } + + #[test] + fn fail_when_unsupproted_file_type() { + let this_file = std::file!(); + let unsupported_path = PathBuf::from(this_file); + let res = EntitiesParser::new().with_path(unsupported_path).parse_file(); + assert_err!(res, Err(EntitiesParserError::UnsupportedFileType { ext: _ })); + } + + #[test] + fn fail_when_unknown_file_type() { + let no_file_ext = PathBuf::from("../../LICENSE"); + let res = EntitiesParser::new().with_path(no_file_ext).parse_file(); + assert_err!(res, Err(EntitiesParserError::UnknownFileType(_))); + } } diff --git a/src/inclusion_proof.rs b/src/inclusion_proof.rs index 0b6073f7..0cfefce2 100644 --- a/src/inclusion_proof.rs +++ b/src/inclusion_proof.rs @@ -6,7 +6,7 @@ use std::{fmt::Debug, path::PathBuf}; use log::info; use crate::binary_tree::{Coordinate, Height, Node, PathSiblings}; -use crate::node_content::{FullNodeContent, HiddenNodeContent}; +use crate::binary_tree::{FullNodeContent, HiddenNodeContent}; use crate::{read_write_utils, EntityId}; mod individual_range_proof; @@ -18,11 +18,6 @@ use aggregated_range_proof::AggregatedRangeProof; mod aggregation_factor; pub use aggregation_factor::AggregationFactor; -/// Default upper bound for the range proof in the inclusion proof. -/// 64 bits should be more than enough bits to represent liabilities for real -/// world applications such as crypto asset exchange balances. -pub const DEFAULT_RANGE_PROOF_UPPER_BOUND_BIT_LENGTH: u8 = 64u8; - /// The file extension used when writing serialized binary files. const SERIALIZED_PROOF_EXTENSION: &str = "dapolproof"; diff --git a/src/lib.rs b/src/lib.rs index edb79630..a6c2d086 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,132 +49,40 @@ //! - verify an inclusion proof using a root hash (no tree required) //! //! ``` -//! use std::str::FromStr; -//! use std::path::Path; -//! -//! use dapol::utils::LogOnErrUnwrap; -//! -//! fn main() { -//! let log_level = clap_verbosity_flag::LevelFilter::Debug; -//! dapol::utils::activate_logging(log_level); -//! -//! // ========================================================================= -//! // Tree building. -//! -//! let ndm_smt = build_ndm_smt_using_builder_pattern(); -//! let accumulator = build_accumulator_using_config_file(); -//! -//! // The above 2 builder methods produce a different tree because the entities -//! // are mapped randomly to points on the bottom layer, but the entity mapping -//! // of one tree should simply be a permutation of the other. We check this: -//! let ndm_smt_other = match accumulator { -//! dapol::Accumulator::NdmSmt(ndm_smt_other) => { -//! assert_ne!(ndm_smt_other.root_hash(), ndm_smt.root_hash()); -//! -//! for (entity, _) in ndm_smt_other.entity_mapping() { -//! assert!(ndm_smt.entity_mapping().contains_key(&entity)); -//! } -//! -//! ndm_smt_other -//! } -//! }; -//! -//! // ========================================================================= -//! // Inclusion proof generation & verification. -//! -//! let entity_id = dapol::EntityId::from_str("john.doe@example.com").unwrap(); -//! simple_inclusion_proof_generation_and_verification(&ndm_smt, entity_id.clone()); -//! advanced_inclusion_proof_generation_and_verification(&ndm_smt_other, entity_id); -//! } -//! -//! /// Example on how to use the builder pattern to construct an NDM-SMT tree. -//! pub fn build_ndm_smt_using_builder_pattern() -> dapol::accumulators::NdmSmt { -//! let src_dir = env!("CARGO_MANIFEST_DIR"); -//! let resources_dir = Path::new(&src_dir).join("examples"); -//! -//! let secrets_file = resources_dir.join("ndm_smt_secrets_example.toml"); -//! let entities_file = resources_dir.join("entities_example.csv"); -//! -//! let height = dapol::Height::expect_from(16); -//! -//! let config = dapol::accumulators::NdmSmtConfigBuilder::default() -//! .height(height) -//! .secrets_file_path(secrets_file) -//! .entities_path(entities_file) -//! .build(); -//! -//! config.parse().unwrap() -//! } -//! -//! /// An inclusion proof can be generated from only a tree + entity ID. -//! pub fn simple_inclusion_proof_generation_and_verification( -//! ndm_smt: &dapol::accumulators::NdmSmt, -//! entity_id: dapol::EntityId, -//! ) { -//! let inclusion_proof = ndm_smt.generate_inclusion_proof(&entity_id).unwrap(); -//! inclusion_proof.verify(ndm_smt.root_hash()).unwrap(); -//! } -//! -//! /// The inclusion proof generation algorithm can be customized via some -//! /// parameters. See [dapol][InclusionProof] for more details. -//! pub fn advanced_inclusion_proof_generation_and_verification( -//! ndm_smt: &dapol::accumulators::NdmSmt, -//! entity_id: dapol::EntityId, -//! ) { -//! // Determines how many of the range proofs in the inclusion proof are -//! // aggregated together. The ones that are not aggregated are proved -//! // individually. The more that are aggregated the faster the proving -//! // and verification times. -//! let aggregation_percentage = dapol::percentage::ONE_HUNDRED_PERCENT; -//! let aggregation_factor = dapol::AggregationFactor::Percent(aggregation_percentage); -//! let aggregation_factor = dapol::AggregationFactor::default(); -//! -//! // 2^upper_bound_bit_length is the upper bound used in the range proof i.e. -//! // the secret value is shown to reside in the range [0, 2^upper_bound_bit_length]. -//! let upper_bound_bit_length = 32u8; -//! let upper_bound_bit_length = dapol::DEFAULT_RANGE_PROOF_UPPER_BOUND_BIT_LENGTH; -//! -//! let inclusion_proof = ndm_smt -//! .generate_inclusion_proof_with(&entity_id, aggregation_factor, upper_bound_bit_length) -//! .unwrap(); -//! -//! inclusion_proof.verify(ndm_smt.root_hash()).unwrap(); -//! } -//! -//! /// Example on how to build a tree using a config file. -//! /// -//! /// The config file can be used for any accumulator type since the type is -//! /// specified by the config file. -//! /// -//! /// This is also an example usage of [dapol][utils][LogOnErrUnwrap]. -//! pub fn build_accumulator_using_config_file() -> dapol::Accumulator { -//! let src_dir = env!("CARGO_MANIFEST_DIR"); -//! let resources_dir = Path::new(&src_dir).join("examples"); -//! let config_file = resources_dir.join("tree_config_example.toml"); -//! -//! dapol::AccumulatorConfig::deserialize(config_file) -//! .log_on_err_unwrap() -//! .parse() -//! .log_on_err_unwrap() -//! } +#![doc = include_str!("../examples/main.rs")] //! ``` mod kdf; -mod node_content; pub mod cli; pub mod percentage; pub mod read_write_utils; pub mod utils; +mod dapol_tree; +pub use dapol_tree::{DapolTree, DapolTreeError}; + +mod dapol_config; +pub use dapol_config::{ + DapolConfig, DapolConfigBuilder, DapolConfigBuilderError, DapolConfigError, +}; + +mod accumulators; +pub use accumulators::AccumulatorType; + +mod salt; +pub use salt::Salt; + mod hasher; pub use hasher::Hasher; mod max_thread_count; pub use max_thread_count::{initialize_machine_parallelism, MaxThreadCount, MACHINE_PARALLELISM}; -pub mod accumulators; -pub use accumulators::{Accumulator, AccumulatorConfig, AccumulatorConfigError, AccumulatorError}; +mod max_liability; +pub use max_liability::{ + MaxLiability, DEFAULT_MAX_LIABILITY, DEFAULT_RANGE_PROOF_UPPER_BOUND_BIT_LENGTH, +}; mod binary_tree; pub use binary_tree::{Height, HeightError, MAX_HEIGHT, MIN_HEIGHT}; @@ -183,10 +91,7 @@ mod secret; pub use secret::{Secret, SecretParserError}; mod inclusion_proof; -pub use inclusion_proof::{ - AggregationFactor, InclusionProof, InclusionProofError, - DEFAULT_RANGE_PROOF_UPPER_BOUND_BIT_LENGTH, -}; +pub use inclusion_proof::{AggregationFactor, InclusionProof, InclusionProofError}; mod entity; pub use entity::{Entity, EntityId, EntityIdsParser, EntityIdsParserError}; diff --git a/src/main.rs b/src/main.rs index 7cd25f16..87f31425 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,13 @@ -use std::path::PathBuf; +use std::{path::PathBuf, str::FromStr}; use clap::Parser; use log::debug; use dapol::{ - accumulators::NdmSmtConfigBuilder, - cli::{AccumulatorType, BuildKindCommand, Cli, Command}, + cli::{BuildKindCommand, Cli, Command}, initialize_machine_parallelism, utils::{activate_logging, Consume, IfNoneThen, LogOnErr, LogOnErrUnwrap}, - Accumulator, AccumulatorConfig, AggregationFactor, EntityIdsParser, InclusionProof, + AggregationFactor, DapolConfig, DapolConfigBuilder, DapolTree, EntityIdsParser, InclusionProof, }; use patharg::InputArg; @@ -25,6 +24,8 @@ fn main() { } => { initialize_machine_parallelism(); + // It's not necessary to do this first, but it allows fast-failure + // for bad paths. let serialization_path = // Do not try serialize if the command is Deserialize because // this means there already is a serialized file. @@ -34,7 +35,7 @@ fn main() { match serialize { Some(patharg) => { let path = patharg.into_path().expect("Expected a file path, not stdout"); - Accumulator::parse_accumulator_serialization_path(path).log_on_err().ok() + DapolTree::parse_serialization_path(path).log_on_err().ok() } None => None, } @@ -42,35 +43,37 @@ fn main() { None }; - let accumulator: Accumulator = match build_kind { + let dapol_tree: DapolTree = match build_kind { BuildKindCommand::New { - accumulator, + accumulator_type, + salt_b, + salt_s, height, + max_liability, max_thread_count, secrets_file, entity_source, - } => match accumulator { - AccumulatorType::NdmSmt => { - let ndm_smt = NdmSmtConfigBuilder::default() - .height(height) - .max_thread_count(max_thread_count) - .secrets_file_path_opt(secrets_file.and_then(|arg| arg.into_path())) - .entities_path_opt( - entity_source.entities_file.and_then(|arg| arg.into_path()), - ) - .num_random_entities_opt(entity_source.random_entities) - .build() - .parse() - .log_on_err_unwrap(); - - Accumulator::NdmSmt(ndm_smt) - } - }, - BuildKindCommand::Deserialize { path } => Accumulator::deserialize( + } => DapolConfigBuilder::default() + .accumulator_type(accumulator_type) + .salt_b_opt(salt_b) + .salt_s_opt(salt_s) + .max_liability(max_liability) + .height(height) + .max_thread_count(max_thread_count) + .entities_file_path_opt( + entity_source.entities_file.and_then(|arg| arg.into_path()), + ) + .num_random_entities_opt(entity_source.random_entities) + .secrets_file_path_opt(secrets_file.and_then(|arg| arg.into_path())) + .build() + .log_on_err_unwrap() + .parse() + .log_on_err_unwrap(), + BuildKindCommand::Deserialize { path } => DapolTree::deserialize( path.into_path().expect("Expected file path, not stdout"), ) .log_on_err_unwrap(), - BuildKindCommand::ConfigFile { file_path } => AccumulatorConfig::deserialize( + BuildKindCommand::ConfigFile { file_path } => DapolConfig::deserialize( file_path .into_path() .expect("Expected file path, not stdin"), @@ -84,7 +87,7 @@ fn main() { .if_none_then(|| { debug!("No serialization path set, skipping serialization of the tree"); }) - .consume(|path| accumulator.serialize(path).unwrap()); + .consume(|path| dapol_tree.serialize(path).unwrap()); if let Some(patharg) = gen_proofs { let entity_ids = EntityIdsParser::from( @@ -97,7 +100,7 @@ fn main() { std::fs::create_dir(dir.as_path()).log_on_err_unwrap(); for entity_id in entity_ids { - let proof = accumulator + let proof = dapol_tree .generate_inclusion_proof(&entity_id) .log_on_err_unwrap(); @@ -109,9 +112,8 @@ fn main() { entity_ids, tree_file, range_proof_aggregation, - upper_bound_bit_length, } => { - let accumulator = Accumulator::deserialize( + let dapol_tree = DapolTree::deserialize( tree_file .into_path() .expect("Expected file path, not stdout"), @@ -133,11 +135,10 @@ fn main() { let aggregation_factor = AggregationFactor::Percent(range_proof_aggregation); for entity_id in entity_ids { - let proof = accumulator + let proof = dapol_tree .generate_inclusion_proof_with( &entity_id, aggregation_factor.clone(), - upper_bound_bit_length, ) .log_on_err_unwrap(); diff --git a/src/max_liability.rs b/src/max_liability.rs new file mode 100644 index 00000000..1c71593e --- /dev/null +++ b/src/max_liability.rs @@ -0,0 +1,131 @@ +use log::error; +use serde::{Deserialize, Serialize}; + +/// The default max liability. +/// +/// We would like to accommodate as high a value as possible while still being +/// able to add $N$ of these together without overflow (where $N$ is the number +/// of entities). A reasonable expectation is that $N < 1_000_000_000$ which +/// gives us $L_{\text{max}} = 2^{64} / 1_000_000_000$. But things are simpler +/// and computationally easier if we stick to powers of $2$, so we rather use: +/// $L_{\text{max}} = 2^{64} / 2^{30} = 2^{34}$. +/// +/// $L_{\text{max}}$ is used as the upper bound in the range proof. We use +/// Bulletproofs for range proofs The Bulletproofs library only allows upper +/// bounds that are powers of 2 and have bit lengths in this list: $[8, 16, 32, +/// 64]$. So instead of $2^{30}$ we use $2^{32}$. +pub const DEFAULT_MAX_LIABILITY: u64 = 2u64.pow(DEFAULT_RANGE_PROOF_UPPER_BOUND_BIT_LENGTH as u32); + +/// Default upper bound for the range proof in the inclusion proof. +/// +/// This value is determined by the max liability since we want to produce +/// proofs of liabilities being within the range $[0, L_{\text{max}}]$. +pub const DEFAULT_RANGE_PROOF_UPPER_BOUND_BIT_LENGTH: u8 = 32u8; + +/// These are the only allowed powers of 2 for the range proof upper bound +/// when using Bulletproofs. +pub const ALLOWED_RANGE_PROOF_UPPER_BIT_SIZES: [u8; 4] = [8, 16, 32, 64]; + +/// Abstraction for the max liabilty value. +#[doc = include_str!("./shared_docs/max_liability.md")] +/// +/// Example: +/// ``` +/// use dapol::MaxLiability; +/// use std::str::FromStr; +/// +/// let max_liability = MaxLiability::default(); +/// let max_liability = MaxLiability::from(1000u64); +/// let max_liability = MaxLiability::from_str("1000").unwrap(); +/// ``` +#[derive(Copy, Clone, Debug, Deserialize, Serialize, PartialEq, PartialOrd)] +pub struct MaxLiability(u64); + +impl MaxLiability { + pub fn as_u64(&self) -> u64 { + self.0 + } + + /// Take the logarithm of the underlying value and return the smallest + /// allowed bit length that is greater than this value. + pub fn as_range_proof_upper_bound_bit_length(&self) -> u8 { + let pow_2 = (self.0 as f64).log2() as u8; + *ALLOWED_RANGE_PROOF_UPPER_BIT_SIZES + .iter() + .find(|i| **i > pow_2) + .unwrap_or_else(|| { + panic!( + "[BUG] It should not be possible for the base 2 logarithm of {} to be above 64", + self.0 + ) + }) + } +} + +// ------------------------------------------------------------------------------------------------- +// From for u64 + +impl From for MaxLiability { + fn from(max_liability: u64) -> Self { + Self(max_liability) + } +} + +// ------------------------------------------------------------------------------------------------- +// Default. + +impl Default for MaxLiability { + fn default() -> Self { + Self(DEFAULT_MAX_LIABILITY) + } +} + +// ------------------------------------------------------------------------------------------------- +// From for str. + +use std::str::FromStr; + +impl FromStr for MaxLiability { + type Err = MaxLiabilityError; + + fn from_str(s: &str) -> Result { + Ok(MaxLiability(u64::from_str(s)?)) + } +} + +// ------------------------------------------------------------------------------------------------- +// Into for OsStr. + +use clap::builder::{OsStr, Str}; + +impl From for OsStr { + fn from(max_liability: MaxLiability) -> OsStr { + OsStr::from(Str::from(max_liability.as_u64().to_string())) + } +} + +// ------------------------------------------------------------------------------------------------- +// Errors. + +#[derive(thiserror::Error, Debug)] +pub enum MaxLiabilityError { + #[error("Malformed string input for u64 type")] + MalformedString(#[from] std::num::ParseIntError), +} + +// ------------------------------------------------------------------------------------------------- +// Unit tests. + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_max_liability_is_in_allowed_list() { + let pow_2 = (DEFAULT_MAX_LIABILITY as f64).log2() as u8; + assert!(ALLOWED_RANGE_PROOF_UPPER_BIT_SIZES + .iter() + .find(|i| **i == pow_2) + .is_some()); + } +} diff --git a/src/max_thread_count.rs b/src/max_thread_count.rs index 89fe7b7c..a5494ad4 100644 --- a/src/max_thread_count.rs +++ b/src/max_thread_count.rs @@ -9,8 +9,7 @@ pub const DEFAULT_MAX_THREAD_COUNT: u8 = 4; /// Abstraction for the max number of threads. /// -/// This struct is used when determining how many threads can be spawned when -/// doing work in parallel. +#[doc = include_str!("./shared_docs/max_thread_count.md")] /// /// Example: /// ``` @@ -30,6 +29,9 @@ impl MaxThreadCount { } } +// ------------------------------------------------------------------------------------------------- +// From for u8 + impl From for MaxThreadCount { fn from(max_thread_count: u8) -> Self { Self(max_thread_count) @@ -68,7 +70,7 @@ impl FromStr for MaxThreadCount { } // ------------------------------------------------------------------------------------------------- -// From for OsStr. +// Into for OsStr. use clap::builder::{OsStr, Str}; diff --git a/src/node_content.rs b/src/node_content.rs deleted file mode 100644 index 52ab5ea9..00000000 --- a/src/node_content.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod full_node; -pub use full_node::FullNodeContent; - -mod hidden_node; -pub use hidden_node::HiddenNodeContent; diff --git a/src/read_write_utils.rs b/src/read_write_utils.rs index 6d49202d..a9f45cb9 100644 --- a/src/read_write_utils.rs +++ b/src/read_write_utils.rs @@ -55,7 +55,8 @@ pub fn deserialize_from_bin_file(path: PathBuf) -> Result &[u8; 32] { + &self.0 + } + + /// Use a cryptographic PRNG to produce a random salt value. + #[time("debug", "NdmSmt::NdmSmtSalts::{}")] + pub fn generate_random() -> Self { + let mut rng = thread_rng(); + let random_str = Alphanumeric.sample_string(&mut rng, MAX_LENGTH_BYTES); + Salt::from_str(&random_str).expect(STRING_CONVERSION_ERR_MSG) + } +} + +// ------------------------------------------------------------------------------------------------- +// From for KDF key. + +use crate::kdf; + +impl From for Salt { + fn from(key: kdf::Key) -> Self { + let bytes: [u8; 32] = key.into(); + Salt(bytes) + } +} + +// ------------------------------------------------------------------------------------------------- +// From for str. + +use std::str::FromStr; + +impl FromStr for Salt { + type Err = SaltParserError; + + /// Constructor that takes in a string slice. + /// If the length of the str is greater than the max then [Err] is returned. + fn from_str(s: &str) -> Result { + if s.len() > MAX_LENGTH_BYTES { + Err(SaltParserError::StringTooLongError) + } else { + let mut arr = [0u8; 32]; + // this works because string slices are stored fundamentally as u8 arrays + arr[..s.len()].copy_from_slice(s.as_bytes()); + Ok(Salt(arr)) + } + } +} + +// ------------------------------------------------------------------------------------------------- +// Into for raw bytes. + +impl From for [u8; 32] { + fn from(item: Salt) -> Self { + item.0 + } +} + +// ------------------------------------------------------------------------------------------------- +// From for u64. + +impl From for Salt { + /// Constructor that takes in a u64. + fn from(num: u64) -> Self { + let bytes = num.to_le_bytes(); + let mut arr = [0u8; 32]; + arr[..8].copy_from_slice(&bytes[..8]); + Salt(arr) + } +} + +// ------------------------------------------------------------------------------------------------- +// From for OsStr (for the CLI). + +use clap::builder::OsStr; + +impl From for OsStr { + // https://stackoverflow.com/questions/19076719/how-do-i-convert-a-vector-of-bytes-u8-to-a-string + fn from(salt: Salt) -> OsStr { + OsStr::from(String::from_utf8_lossy(&salt.0).into_owned()) + } +} + +// ------------------------------------------------------------------------------------------------- +// Default. + +impl Default for Salt { + fn default() -> Self { + Salt::generate_random() + } +} + +// ------------------------------------------------------------------------------------------------- +// Errors. + +/// Errors encountered when parsing [Salt]. +#[derive(Debug, thiserror::Error)] +pub enum SaltParserError { + #[error("The given string has more than the max allowed bytes of {MAX_LENGTH_BYTES}")] + StringTooLongError, +} + +// ------------------------------------------------------------------------------------------------- +// Unit tests. + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn randomly_generated_salts_differ_enough() { + let salt_1 = Salt::generate_random(); + let salt_2 = Salt::generate_random(); + let threshold = 10; + + let iter_1 = salt_1.0.iter(); + let iter_2 = salt_2.0.iter(); + + // Technically there is a chance fair chance that this test fails. + assert!( + iter_1 + .zip(iter_2) + .filter(|(byte_1, byte_2)| byte_1 == byte_2) + .count() + < threshold + ); + } +} diff --git a/src/secret.rs b/src/secret.rs index 159d8e06..9dee2c49 100644 --- a/src/secret.rs +++ b/src/secret.rs @@ -1,9 +1,7 @@ use std::convert::From; -use std::str::FromStr; - -use serde::{Serialize, Deserialize}; -use crate::kdf::Key; +use serde::Serialize; +use serde_with::DeserializeFromStr; /// The max size of the secret is 256 bits, but this is a soft limit so it /// can be increased if necessary. Note that the underlying array length will @@ -24,7 +22,7 @@ pub const MAX_LENGTH_BYTES: usize = 32; /// Currently there is no need for the functionality provided by something like /// [primitive_types][U256] or [num256][Uint256] but those are options for /// later need be. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, DeserializeFromStr)] pub struct Secret([u8; 32]); impl Secret { @@ -33,13 +31,21 @@ impl Secret { } } -impl From for Secret { - fn from(key: Key) -> Self { +// ------------------------------------------------------------------------------------------------- +// From for KDF key. + +use crate::kdf; + +impl From for Secret { + fn from(key: kdf::Key) -> Self { let bytes: [u8; 32] = key.into(); Secret(bytes) } } +// ------------------------------------------------------------------------------------------------- +// From for u64. + impl From for Secret { /// Constructor that takes in a u64. fn from(num: u64) -> Self { @@ -50,6 +56,11 @@ impl From for Secret { } } +// ------------------------------------------------------------------------------------------------- +// From for str. + +use std::str::FromStr; + impl FromStr for Secret { type Err = SecretParserError; @@ -67,6 +78,9 @@ impl FromStr for Secret { } } +// ------------------------------------------------------------------------------------------------- +// Into for raw bytes. + impl From for [u8; 32] { fn from(item: Secret) -> Self { item.0 diff --git a/src/shared_docs/accumulator_type.md b/src/shared_docs/accumulator_type.md new file mode 100644 index 00000000..bcdb4bd3 --- /dev/null +++ b/src/shared_docs/accumulator_type.md @@ -0,0 +1 @@ +There are various different accumulator types (e.g. NDM-SMT). diff --git a/src/shared_docs/entities_vector.md b/src/shared_docs/entities_vector.md new file mode 100644 index 00000000..69fbc354 --- /dev/null +++ b/src/shared_docs/entities_vector.md @@ -0,0 +1 @@ +Vector of [Entity][Entity] (aka users)--the bottom-layer leaf nodes in the tree. diff --git a/src/shared_docs/height.md b/src/shared_docs/height.md new file mode 100644 index 00000000..a032fc8f --- /dev/null +++ b/src/shared_docs/height.md @@ -0,0 +1 @@ +Height of the binary tree. diff --git a/src/shared_docs/master_secret.md b/src/shared_docs/master_secret.md new file mode 100644 index 00000000..e8977001 --- /dev/null +++ b/src/shared_docs/master_secret.md @@ -0,0 +1 @@ +Tree generator's singular secret value. This value is known only to the tree generator, and is used to determine all other secret values needed in the tree. \ No newline at end of file diff --git a/src/shared_docs/max_liability.md b/src/shared_docs/max_liability.md new file mode 100644 index 00000000..0f9644df --- /dev/null +++ b/src/shared_docs/max_liability.md @@ -0,0 +1 @@ +This is a public value representing the maximum amount that any single entity's liability can be, and is used in the range proofs: $[0, 2^{\text{height}} \times \text{max_liability}]$ diff --git a/src/shared_docs/max_thread_count.md b/src/shared_docs/max_thread_count.md new file mode 100644 index 00000000..dde4265b --- /dev/null +++ b/src/shared_docs/max_thread_count.md @@ -0,0 +1 @@ +Max size of the thread pool when using the parallelized tree build algorithms. diff --git a/src/shared_docs/salt_b.md b/src/shared_docs/salt_b.md new file mode 100644 index 00000000..af2dabff --- /dev/null +++ b/src/shared_docs/salt_b.md @@ -0,0 +1 @@ +This is a public value that is used to aid the KDF when generating secret blinding factors for the Pedersen commitments. diff --git a/src/shared_docs/salt_s.md b/src/shared_docs/salt_s.md new file mode 100644 index 00000000..96ef94ed --- /dev/null +++ b/src/shared_docs/salt_s.md @@ -0,0 +1 @@ +This is a public value that is used to aid the KDF when generating secret salt values, which are in turn used in the hash function when generating node hashes.