Skip to content

Commit

Permalink
perf: use parallelism more when evaluating domain
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Apr 30, 2024
1 parent d53c1b4 commit 8c623e8
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions triton-vm/src/arithmetic_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::ops::Mul;
use std::ops::MulAssign;

use num_traits::One;
use rayon::prelude::*;
use twenty_first::math::traits::FiniteField;
use twenty_first::math::traits::PrimitiveRootOfUnity;
use twenty_first::prelude::*;
Expand Down Expand Up @@ -57,27 +58,27 @@ impl ArithmeticDomain {
+ Mul<BFieldElement, Output = FF>
+ From<BFieldElement>,
{
let mut indices_and_chunks = polynomial.coefficients.chunks(self.length).enumerate();
let mut values =
match indices_and_chunks.next() {
Some((_, first_chunk)) => Polynomial::new(first_chunk.to_vec())
.fast_coset_evaluate(self.offset, self.generator, self.length),
None => vec![FF::zero(); self.length],
};
for (i, chunk) in indices_and_chunks {
let scalar = self.offset.mod_pow(i as u64 * self.length as u64);
let current_values = Polynomial::new(chunk.to_vec()).fast_coset_evaluate(
self.offset,
self.generator,
self.length,
);
let (offset, generator, length) = (self.offset, self.generator, self.length);
let evaluate_from =
|chunk| Polynomial::from(chunk).fast_coset_evaluate(offset, generator, length);

// avoid `enumerate` to directly get index of the right type
let mut indexed_chunks = (0..).zip(polynomial.coefficients.chunks(length));

// only allocate a bunch of zeros if there are no chunks
let mut values = indexed_chunks.next().map_or_else(
|| vec![FF::zero(); length],
|(_, first_chunk)| evaluate_from(first_chunk),
);
for (chunk_index, chunk) in indexed_chunks {
let coefficient_index = chunk_index * u64::try_from(length).unwrap();
let scaled_offset = offset.mod_pow(coefficient_index);
values
.iter_mut()
.zip(current_values.iter())
.for_each(|(v, cv)| {
*v += *cv * scalar;
});
.par_iter_mut()
.zip(evaluate_from(chunk))
.for_each(|(value, evaluation)| *value += evaluation * scaled_offset);
}

values
}

Expand All @@ -99,9 +100,9 @@ impl ArithmeticDomain {
target_domain.evaluate(&self.interpolate(codeword))
}

/// Compute the nth element of the domain.
pub fn domain_value(&self, index: u32) -> BFieldElement {
self.generator.mod_pow_u32(index) * self.offset
/// Compute the `n`th element of the domain.
pub fn domain_value(&self, n: u32) -> BFieldElement {
self.generator.mod_pow_u32(n) * self.offset
}

pub fn domain_values(&self) -> Vec<BFieldElement> {
Expand Down Expand Up @@ -300,13 +301,12 @@ mod tests {

#[proptest]
fn can_evaluate_polynomial_larger_than_domain(
#[strategy(1usize..10)] log_domain_length: usize,
#[strategy(1usize..5)] _polynomial_expansion_factor: usize,
#[strategy(1_usize..10)] _log_domain_length: usize,
#[strategy(1_usize..5)] _expansion_factor: usize,
#[strategy(Just(1 << #_log_domain_length))] domain_length: usize,
#[strategy(vec(arb(),#domain_length*#_expansion_factor))] coefficients: Vec<BFieldElement>,
#[strategy(arb())] offset: BFieldElement,
#[strategy(vec(arb(),(1<<#log_domain_length)*#_polynomial_expansion_factor))]
coefficients: Vec<BFieldElement>,
) {
let domain_length = 1 << log_domain_length;
let domain = ArithmeticDomain::of_length(domain_length)
.unwrap()
.with_offset(offset);
Expand Down

0 comments on commit 8c623e8

Please sign in to comment.