diff --git a/triton-vm/src/arithmetic_domain.rs b/triton-vm/src/arithmetic_domain.rs index dc4ebf81e..cd6804dac 100644 --- a/triton-vm/src/arithmetic_domain.rs +++ b/triton-vm/src/arithmetic_domain.rs @@ -2,6 +2,7 @@ use std::ops::Mul; use std::ops::MulAssign; use num_traits::One; +use rayon::prelude::*; use twenty_first::math::traits::FiniteField; use twenty_first::math::traits::PrimitiveRootOfUnity; use twenty_first::prelude::*; @@ -57,27 +58,27 @@ impl ArithmeticDomain { + Mul + From, { - let mut indices_and_chunks = polynomial.coefficients.chunks(self.length).enumerate(); - let mut values = - match indices_and_chunks.next() { - Some((_, first_chunk)) => Polynomial::new(first_chunk.to_vec()) - .fast_coset_evaluate(self.offset, self.generator, self.length), - None => vec![FF::zero(); self.length], - }; - for (i, chunk) in indices_and_chunks { - let scalar = self.offset.mod_pow(i as u64 * self.length as u64); - let current_values = Polynomial::new(chunk.to_vec()).fast_coset_evaluate( - self.offset, - self.generator, - self.length, - ); + let (offset, generator, length) = (self.offset, self.generator, self.length); + let evaluate_from = + |chunk| Polynomial::from(chunk).fast_coset_evaluate(offset, generator, length); + + // avoid `enumerate` to directly get index of the right type + let mut indexed_chunks = (0..).zip(polynomial.coefficients.chunks(length)); + + // only allocate a bunch of zeros if there are no chunks + let mut values = indexed_chunks.next().map_or_else( + || vec![FF::zero(); length], + |(_, first_chunk)| evaluate_from(first_chunk), + ); + for (chunk_index, chunk) in indexed_chunks { + let coefficient_index = chunk_index * u64::try_from(length).unwrap(); + let scaled_offset = offset.mod_pow(coefficient_index); values - .iter_mut() - .zip(current_values.iter()) - .for_each(|(v, cv)| { - *v += *cv * scalar; - }); + .par_iter_mut() + .zip(evaluate_from(chunk)) + .for_each(|(value, evaluation)| *value += evaluation * scaled_offset); } + values } @@ -99,9 +100,9 @@ impl ArithmeticDomain { target_domain.evaluate(&self.interpolate(codeword)) } - /// Compute the nth element of the domain. - pub fn domain_value(&self, index: u32) -> BFieldElement { - self.generator.mod_pow_u32(index) * self.offset + /// Compute the `n`th element of the domain. + pub fn domain_value(&self, n: u32) -> BFieldElement { + self.generator.mod_pow_u32(n) * self.offset } pub fn domain_values(&self) -> Vec { @@ -300,13 +301,12 @@ mod tests { #[proptest] fn can_evaluate_polynomial_larger_than_domain( - #[strategy(1usize..10)] log_domain_length: usize, - #[strategy(1usize..5)] _polynomial_expansion_factor: usize, + #[strategy(1_usize..10)] _log_domain_length: usize, + #[strategy(1_usize..5)] _expansion_factor: usize, + #[strategy(Just(1 << #_log_domain_length))] domain_length: usize, + #[strategy(vec(arb(),#domain_length*#_expansion_factor))] coefficients: Vec, #[strategy(arb())] offset: BFieldElement, - #[strategy(vec(arb(),(1<<#log_domain_length)*#_polynomial_expansion_factor))] - coefficients: Vec, ) { - let domain_length = 1 << log_domain_length; let domain = ArithmeticDomain::of_length(domain_length) .unwrap() .with_offset(offset);