Skip to content

Commit

Permalink
perf: compute Bézout coefficients faster
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Apr 16, 2024
2 parents 22834b0 + 751aa0f commit 652b7e9
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 64 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ strum = { version = "0.26", features = ["derive"] }
syn = "2.0"
test-strategy = "0.3.1"
thiserror = "1.0"
twenty-first = "0.39"
twenty-first = "0.40"
unicode-width = "0.1"

[workspace.lints.clippy]
Expand Down
4 changes: 4 additions & 0 deletions triton-vm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ test-strategy.workspace = true
[lints]
workspace = true

[[bench]]
name = "bezout_coeffs"
harness = false

[[bench]]
name = "mem_io"
harness = false
Expand Down
63 changes: 63 additions & 0 deletions triton-vm/benches/bezout_coeffs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use criterion::criterion_group;
use criterion::criterion_main;
use criterion::Criterion;
use num_traits::Zero;
use twenty_first::prelude::b_field_element::BFIELD_ONE;
use twenty_first::prelude::*;

use triton_vm::prelude::*;
use triton_vm::table::ram_table::RamTable;

criterion_main!(benches);
criterion_group!(
name = benches;
config = Criterion::default().sample_size(10);
targets = current_design<10>,
current_design<100>,
current_design<1_000>,
current_design<10_000>,
with_xgcd<10>,
with_xgcd<100>,
with_xgcd<1_000>,
with_xgcd<10_000>,
);

fn with_xgcd<const N: u64>(c: &mut Criterion) {
let roots = unique_roots::<N>();
let bench_id = format!("Bézout coefficients (XGCD) (degree {N})");
c.bench_function(&bench_id, |b| b.iter(|| bezout_coeffs_xgcd(&roots)));
}

fn current_design<const N: u64>(c: &mut Criterion) {
let roots = unique_roots::<N>();
let bench_id = format!("Bézout coefficients (current design) (degree {N})");
c.bench_function(&bench_id, |b| {
b.iter(|| RamTable::bezout_coefficient_polynomials_coefficients(&roots))
});
}

fn unique_roots<const N: u64>() -> Vec<BFieldElement> {
(0..N).map(BFieldElement::new).collect()
}

fn bezout_coeffs_xgcd(
unique_ram_pointers: &[BFieldElement],
) -> (Vec<BFieldElement>, Vec<BFieldElement>) {
let linear_poly_with_root = |&r: &BFieldElement| Polynomial::new(vec![-r, BFIELD_ONE]);

let polynomial_with_ram_pointers_as_roots = unique_ram_pointers
.iter()
.map(linear_poly_with_root)
.reduce(|accumulator, linear_poly| accumulator * linear_poly)
.unwrap_or_else(Polynomial::zero);
let formal_derivative = polynomial_with_ram_pointers_as_roots.formal_derivative();

let (_, bezout_poly_0, bezout_poly_1) =
Polynomial::xgcd(polynomial_with_ram_pointers_as_roots, formal_derivative);

let mut coefficients_0 = bezout_poly_0.coefficients;
let mut coefficients_1 = bezout_poly_1.coefficients;
coefficients_0.resize(unique_ram_pointers.len(), bfe!(0));
coefficients_1.resize(unique_ram_pointers.len(), bfe!(0));
(coefficients_0, coefficients_1)
}
37 changes: 19 additions & 18 deletions triton-vm/src/arithmetic_domain.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use itertools::Itertools;
use std::ops::Mul;
use std::ops::MulAssign;

use num_traits::One;
use rayon::prelude::*;
use twenty_first::math::traits::FiniteField;
use twenty_first::math::traits::PrimitiveRootOfUnity;
use twenty_first::prelude::*;
Expand Down Expand Up @@ -52,37 +53,37 @@ impl ArithmeticDomain {

pub fn evaluate<FF>(&self, polynomial: &Polynomial<FF>) -> Vec<FF>
where
FF: FiniteField + MulAssign<BFieldElement> + From<BFieldElement>,
FF: FiniteField
+ MulAssign<BFieldElement>
+ Mul<BFieldElement, Output = FF>
+ From<BFieldElement>,
{
// The limitation arises in `Polynomial::fast_coset_evaluate` in dependency `twenty-first`.
let batch_evaluation_is_possible = self.length >= polynomial.coefficients.len();
if batch_evaluation_is_possible {
polynomial.fast_coset_evaluate(self.offset.into(), self.generator, self.length)
let (offset, generator, length) = (self.offset, self.generator, self.length);
polynomial.fast_coset_evaluate::<BFieldElement>(offset, generator, length)
} else {
self.evaluate_in_every_point_individually(polynomial)
let domain_values = self.domain_values().into_iter();
let domain_values = domain_values.map(FF::from).collect_vec();
polynomial.batch_evaluate(&domain_values)
}
}

fn evaluate_in_every_point_individually<FF>(&self, polynomial: &Polynomial<FF>) -> Vec<FF>
where
FF: FiniteField + MulAssign<BFieldElement> + From<BFieldElement>,
{
self.domain_values()
.par_iter()
.map(|&v| polynomial.evaluate(&v.into()))
.collect()
}

pub fn interpolate<FF>(&self, values: &[FF]) -> Polynomial<FF>
where
FF: FiniteField + MulAssign<BFieldElement> + From<BFieldElement>,
FF: FiniteField + MulAssign<BFieldElement> + Mul<BFieldElement, Output = FF>,
{
Polynomial::fast_coset_interpolate(self.offset.into(), self.generator, values)
// generic type made explicit to avoid performance regressions due to auto-conversion
Polynomial::fast_coset_interpolate::<BFieldElement>(self.offset, self.generator, values)
}

pub fn low_degree_extension<FF>(&self, codeword: &[FF], target_domain: Self) -> Vec<FF>
where
FF: FiniteField + MulAssign<BFieldElement> + From<BFieldElement>,
FF: FiniteField
+ MulAssign<BFieldElement>
+ Mul<BFieldElement, Output = FF>
+ From<BFieldElement>,
{
target_domain.evaluate(&self.interpolate(codeword))
}
Expand Down Expand Up @@ -228,7 +229,7 @@ mod tests {
// Verify that batch-evaluated values match a manual evaluation
for i in 0..order {
assert_eq!(
poly.evaluate(&b_domain.domain_value(i as u32)),
poly.evaluate(b_domain.domain_value(i as u32)),
values[i as usize]
);
}
Expand Down
14 changes: 7 additions & 7 deletions triton-vm/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ impl Stark {
let out_of_domain_point_curr_row_pow_num_segments =
out_of_domain_point_curr_row.mod_pow_u32(NUM_QUOTIENT_SEGMENTS as u32);
let out_of_domain_curr_row_quot_segments = quotient_segment_polynomials
.map(|poly| poly.evaluate(&out_of_domain_point_curr_row_pow_num_segments))
.map(|poly| poly.evaluate(out_of_domain_point_curr_row_pow_num_segments))
.to_vec()
.try_into()
.unwrap();
Expand Down Expand Up @@ -341,7 +341,7 @@ impl Stark {
prof_stop!(maybe_profiler, "interpolate");
prof_start!(maybe_profiler, "base&ext curr row");
let out_of_domain_curr_row_base_and_ext_value =
base_and_ext_interpolation_poly.evaluate(&out_of_domain_point_curr_row);
base_and_ext_interpolation_poly.evaluate(out_of_domain_point_curr_row);
let base_and_ext_curr_row_deep_codeword = Self::deep_codeword(
&base_and_ext_codeword.to_vec(),
short_domain,
Expand All @@ -352,7 +352,7 @@ impl Stark {

prof_start!(maybe_profiler, "base&ext next row");
let out_of_domain_next_row_base_and_ext_value =
base_and_ext_interpolation_poly.evaluate(&out_of_domain_point_next_row);
base_and_ext_interpolation_poly.evaluate(out_of_domain_point_next_row);
let base_and_ext_next_row_deep_codeword = Self::deep_codeword(
&base_and_ext_codeword.to_vec(),
short_domain,
Expand All @@ -363,7 +363,7 @@ impl Stark {

prof_start!(maybe_profiler, "segmented quotient");
let out_of_domain_curr_row_quot_segments_value = quotient_segments_interpolation_poly
.evaluate(&out_of_domain_point_curr_row_pow_num_segments);
.evaluate(out_of_domain_point_curr_row_pow_num_segments);
let quotient_segments_curr_row_deep_codeword = Self::deep_codeword(
&quotient_segments_codeword.to_vec(),
short_domain,
Expand Down Expand Up @@ -2404,7 +2404,7 @@ pub(crate) mod tests {
let low_deg_codeword = domain.evaluate(&low_deg_poly);

let out_of_domain_point: XFieldElement = thread_rng().gen();
let out_of_domain_value = low_deg_poly.evaluate(&out_of_domain_point);
let out_of_domain_value = low_deg_poly.evaluate(out_of_domain_point);

let deep_poly = Stark::deep_codeword(
&low_deg_codeword,
Expand Down Expand Up @@ -2435,11 +2435,11 @@ pub(crate) mod tests {
) {
let x_pow_n = x.mod_pow_u32(N as u32);
let evaluate_segment = |(segment_idx, segment): (_, &Polynomial<_>)| {
segment.evaluate(&x_pow_n) * x.mod_pow_u32(segment_idx as u32)
segment.evaluate(x_pow_n) * x.mod_pow_u32(segment_idx as u32)
};
let evaluated_segments = segments.iter().enumerate().map(evaluate_segment);
let sum_of_evaluated_segments = evaluated_segments.fold(FF::zero(), |acc, x| acc + x);
assert!(f.evaluate(&x) == sum_of_evaluated_segments);
assert!(f.evaluate(x) == sum_of_evaluated_segments);
}

fn assert_segments_degrees_are_small_enough<const N: usize, FF: FiniteField>(
Expand Down
18 changes: 11 additions & 7 deletions triton-vm/src/table/master_table.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::ops::Mul;
use std::ops::MulAssign;
use std::ops::Range;

Expand Down Expand Up @@ -206,7 +207,10 @@ pub enum TableId {
/// [master_quot_table]: all_quotients_combined
pub trait MasterTable<FF>: Sync
where
FF: FiniteField + MulAssign<BFieldElement> + From<BFieldElement>,
FF: FiniteField
+ MulAssign<BFieldElement>
+ Mul<BFieldElement, Output = FF>
+ From<BFieldElement>,
Standard: Distribution<FF>,
{
fn trace_domain(&self) -> ArithmeticDomain;
Expand Down Expand Up @@ -451,7 +455,7 @@ impl MasterTable<BFieldElement> for MasterBaseTable {
fn row(&self, row_index: XFieldElement) -> Array1<XFieldElement> {
self.interpolation_polynomials()
.into_par_iter()
.map(|polynomial| polynomial.evaluate(&row_index))
.map(|polynomial| polynomial.evaluate(row_index))
.collect::<Vec<_>>()
.into()
}
Expand Down Expand Up @@ -552,7 +556,7 @@ impl MasterTable<XFieldElement> for MasterExtTable {
self.interpolation_polynomials()
.slice(s![..NUM_EXT_COLUMNS])
.into_par_iter()
.map(|polynomial| polynomial.evaluate(&row_index))
.map(|polynomial| polynomial.evaluate(row_index))
.collect::<Vec<_>>()
.into()
}
Expand Down Expand Up @@ -1287,7 +1291,7 @@ mod tests {
assert_eq!(big_order as usize, initial_zerofier_inv.len());
assert_eq!(1, initial_zerofier_poly.degree());
assert!(initial_zerofier_poly
.evaluate(&small_domain.domain_value(0))
.evaluate(small_domain.domain_value(0))
.is_zero());

let consistency_zerofier_inv =
Expand All @@ -1298,7 +1302,7 @@ mod tests {
assert_eq!(big_order as usize, consistency_zerofier_inv.len());
assert_eq!(small_order as isize, consistency_zerofier_poly.degree());
for val in small_domain.domain_values() {
assert!(consistency_zerofier_poly.evaluate(&val).is_zero());
assert!(consistency_zerofier_poly.evaluate(val).is_zero());
}

let transition_zerofier_inv =
Expand All @@ -1307,7 +1311,7 @@ mod tests {
let transition_zerofier_poly = big_domain.interpolate(&transition_zerofier);
assert_eq!(big_order as usize, transition_zerofier_inv.len());
assert_eq!(small_order as isize - 1, transition_zerofier_poly.degree());
for val in small_domain
for &val in small_domain
.domain_values()
.iter()
.take(small_order as usize - 1)
Expand All @@ -1321,7 +1325,7 @@ mod tests {
assert_eq!(big_order as usize, terminal_zerofier_inv.len());
assert_eq!(1, terminal_zerofier_poly.degree());
assert!(terminal_zerofier_poly
.evaluate(&small_domain.domain_value(small_order as u32 - 1))
.evaluate(small_domain.domain_value(small_order as u32 - 1))
.is_zero());
}

Expand Down
Loading

0 comments on commit 652b7e9

Please sign in to comment.