From ab135be4834195b446986dc848513b70933bc645 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Fri, 22 Mar 2024 14:20:59 +0100 Subject: [PATCH] =?UTF-8?q?test:=20benchmark=20B=C3=A9zout=20coefficient?= =?UTF-8?q?=20computation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- triton-vm/Cargo.toml | 4 ++ triton-vm/benches/bezout_coeffs.rs | 63 ++++++++++++++++++++++++++++++ triton-vm/src/table/ram_table.rs | 14 ++++--- 3 files changed, 76 insertions(+), 5 deletions(-) create mode 100644 triton-vm/benches/bezout_coeffs.rs diff --git a/triton-vm/Cargo.toml b/triton-vm/Cargo.toml index 9cd9105d..7d62b2bb 100644 --- a/triton-vm/Cargo.toml +++ b/triton-vm/Cargo.toml @@ -51,6 +51,10 @@ test-strategy.workspace = true [lints] workspace = true +[[bench]] +name = "bezout_coeffs" +harness = false + [[bench]] name = "mem_io" harness = false diff --git a/triton-vm/benches/bezout_coeffs.rs b/triton-vm/benches/bezout_coeffs.rs new file mode 100644 index 00000000..118bb42a --- /dev/null +++ b/triton-vm/benches/bezout_coeffs.rs @@ -0,0 +1,63 @@ +use criterion::criterion_group; +use criterion::criterion_main; +use criterion::Criterion; +use num_traits::Zero; +use twenty_first::prelude::b_field_element::BFIELD_ONE; +use twenty_first::prelude::*; + +use triton_vm::prelude::*; +use triton_vm::table::ram_table::RamTable; + +criterion_main!(benches); +criterion_group!( + name = benches; + config = Criterion::default().sample_size(10); + targets = current_design<10>, + current_design<100>, + current_design<1_000>, + current_design<10_000>, + with_xgcd<10>, + with_xgcd<100>, + with_xgcd<1_000>, + with_xgcd<10_000>, +); + +fn with_xgcd(c: &mut Criterion) { + let roots = unique_roots::(); + let bench_id = format!("Bézout coefficients (XGCD) (degree {N})"); + c.bench_function(&bench_id, |b| b.iter(|| bezout_coeffs_xgcd(&roots))); +} + +fn current_design(c: &mut Criterion) { + let roots = unique_roots::(); + let bench_id = format!("Bézout coefficients (current design) (degree {N})"); + c.bench_function(&bench_id, |b| { + b.iter(|| RamTable::bezout_coefficient_polynomials_coefficients(&roots)) + }); +} + +fn unique_roots() -> Vec { + (0..N).map(BFieldElement::new).collect() +} + +fn bezout_coeffs_xgcd( + unique_ram_pointers: &[BFieldElement], +) -> (Vec, Vec) { + let linear_poly_with_root = |&r: &BFieldElement| Polynomial::new(vec![-r, BFIELD_ONE]); + + let polynomial_with_ram_pointers_as_roots = unique_ram_pointers + .iter() + .map(linear_poly_with_root) + .reduce(|accumulator, linear_poly| accumulator * linear_poly) + .unwrap_or_else(Polynomial::zero); + let formal_derivative = polynomial_with_ram_pointers_as_roots.formal_derivative(); + + let (_, bezout_poly_0, bezout_poly_1) = + Polynomial::xgcd(polynomial_with_ram_pointers_as_roots, formal_derivative); + + let mut coefficients_0 = bezout_poly_0.coefficients; + let mut coefficients_1 = bezout_poly_1.coefficients; + coefficients_0.resize(unique_ram_pointers.len(), bfe!(0)); + coefficients_1.resize(unique_ram_pointers.len(), bfe!(0)); + (coefficients_0, coefficients_1) +} diff --git a/triton-vm/src/table/ram_table.rs b/triton-vm/src/table/ram_table.rs index 88dc574e..0ff3e0c6 100644 --- a/triton-vm/src/table/ram_table.rs +++ b/triton-vm/src/table/ram_table.rs @@ -101,12 +101,16 @@ impl RamTable { compare_ram_pointers.then(compare_clocks) } - fn bezout_coefficient_polynomials_coefficients( - unique_ram_pointers: &[BFieldElement], + /// Compute the [Bézout coefficients](https://en.wikipedia.org/wiki/B%C3%A9zout%27s_identity) + /// of the polynomial with the given roots and its formal derivative. + /// + /// All roots _must_ be unique. That is, the corresponding polynomial must be square free. + pub fn bezout_coefficient_polynomials_coefficients( + unique_roots: &[BFieldElement], ) -> (Vec, Vec) { let linear_poly_with_root = |&r: &BFieldElement| Polynomial::new(vec![-r, BFIELD_ONE]); - let polynomial_with_ram_pointers_as_roots = unique_ram_pointers + let polynomial_with_ram_pointers_as_roots = unique_roots .iter() .map(linear_poly_with_root) .reduce(|accumulator, linear_poly| accumulator * linear_poly) @@ -118,8 +122,8 @@ impl RamTable { let mut coefficients_0 = bezout_poly_0.coefficients; let mut coefficients_1 = bezout_poly_1.coefficients; - coefficients_0.resize(unique_ram_pointers.len(), bfe!(0)); - coefficients_1.resize(unique_ram_pointers.len(), bfe!(0)); + coefficients_0.resize(unique_roots.len(), bfe!(0)); + coefficients_1.resize(unique_roots.len(), bfe!(0)); (coefficients_0, coefficients_1) }