Skip to content

Commit

Permalink
refactor!: use Polynomial in FRI proof item
Browse files Browse the repository at this point in the history
Also
- reduce the number of small methods that were used only once
- `match` on `bool` a bit less often
- derive basic derivables for various FRI structs

BREAKING CHANGE: The proof item `FriPolynomial` has payload of type
`Polynomial<XFieldElement>` instead of `Vec<XFieldElement>`.
  • Loading branch information
jan-ferdinand committed Apr 24, 2024
1 parent 2c7e286 commit 7367c67
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 54 deletions.
70 changes: 17 additions & 53 deletions triton-vm/src/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub struct Fri<H: AlgebraicHasher> {
_hasher: PhantomData<H>,
}

#[derive(Debug, Eq, PartialEq)]
struct FriProver<'stream, H: AlgebraicHasher> {
proof_stream: &'stream mut ProofStream,
rounds: Vec<ProverRound<H>>,
Expand All @@ -42,6 +43,7 @@ struct FriProver<'stream, H: AlgebraicHasher> {
first_round_collinearity_check_indices: Vec<usize>,
}

#[derive(Debug, Clone, Eq, PartialEq)]
struct ProverRound<H: AlgebraicHasher> {
domain: ArithmeticDomain,
codeword: Vec<XFieldElement>,
Expand Down Expand Up @@ -102,7 +104,7 @@ impl<'stream, H: AlgebraicHasher> FriProver<'stream, H> {
let last_polynomial = ArithmeticDomain::of_length(last_codeword.len())
.unwrap()
.interpolate(last_codeword);
let proof_item = ProofItem::FriPolynomial(last_polynomial.coefficients);
let proof_item = ProofItem::FriPolynomial(last_polynomial);
self.proof_stream.enqueue(proof_item);
}

Expand Down Expand Up @@ -200,6 +202,7 @@ impl<H: AlgebraicHasher> ProverRound<H> {
}
}

#[derive(Debug, Eq, PartialEq)]
struct FriVerifier<'stream, H: AlgebraicHasher> {
proof_stream: &'stream mut ProofStream,
rounds: Vec<VerifierRound>,
Expand All @@ -213,6 +216,7 @@ struct FriVerifier<'stream, H: AlgebraicHasher> {
_hasher: PhantomData<H>,
}

#[derive(Debug, Clone, Eq, PartialEq)]
struct VerifierRound {
domain: ArithmeticDomain,
partial_codeword_a: Vec<XFieldElement>,
Expand All @@ -223,48 +227,30 @@ struct VerifierRound {

impl<'stream, H: AlgebraicHasher> FriVerifier<'stream, H> {
fn initialize(&mut self) -> VerifierResult<()> {
self.initialize_verification_rounds()?;
self.receive_last_round_codeword()?;
self.receive_last_round_polynomial()
}
let domain = self.first_round_domain;
let first_round = self.construct_round_with_domain(domain)?;
self.rounds.push(first_round);

fn initialize_verification_rounds(&mut self) -> VerifierResult<()> {
self.initialize_first_round()?;
for _ in 0..self.num_rounds {
self.initialize_next_round()?;
let previous_round = self.rounds.last().unwrap();
let domain = previous_round.domain.halve()?;
let next_round = self.construct_round_with_domain(domain)?;
self.rounds.push(next_round);
}
Ok(())
}

fn initialize_first_round(&mut self) -> VerifierResult<()> {
let first_round = self.construct_first_round()?;
self.store_round(first_round);
Ok(())
}

fn initialize_next_round(&mut self) -> VerifierResult<()> {
let next_round = self.construct_next_round()?;
self.store_round(next_round);
self.last_round_codeword = self.proof_stream.dequeue()?.try_into_fri_codeword()?;
self.last_round_polynomial = self.proof_stream.dequeue()?.try_into_fri_polynomial()?;
Ok(())
}

fn construct_first_round(&mut self) -> VerifierResult<VerifierRound> {
let domain = self.first_round_domain;
self.construct_round_with_domain(domain)
}

fn construct_next_round(&mut self) -> VerifierResult<VerifierRound> {
let previous_round = self.rounds.last().unwrap();
let domain = previous_round.domain.halve()?;
self.construct_round_with_domain(domain)
}

fn construct_round_with_domain(
&mut self,
domain: ArithmeticDomain,
) -> VerifierResult<VerifierRound> {
let merkle_root = self.proof_stream.dequeue()?.try_into_merkle_root()?;
let folding_challenge = self.maybe_sample_folding_challenge();
let folding_challenge = self
.need_more_folding_challenges()
.then(|| self.proof_stream.sample_scalars(1)[0]);

let verifier_round = VerifierRound {
domain,
Expand All @@ -276,13 +262,6 @@ impl<'stream, H: AlgebraicHasher> FriVerifier<'stream, H> {
Ok(verifier_round)
}

fn maybe_sample_folding_challenge(&mut self) -> Option<XFieldElement> {
match self.need_more_folding_challenges() {
true => Some(self.proof_stream.sample_scalars(1)[0]),
false => None,
}
}

fn need_more_folding_challenges(&self) -> bool {
if self.num_rounds == 0 {
return false;
Expand All @@ -293,21 +272,6 @@ impl<'stream, H: AlgebraicHasher> FriVerifier<'stream, H> {
num_initialized_rounds <= num_rounds_that_have_a_next_round
}

fn store_round(&mut self, round: VerifierRound) {
self.rounds.push(round);
}

fn receive_last_round_codeword(&mut self) -> VerifierResult<()> {
self.last_round_codeword = self.proof_stream.dequeue()?.try_into_fri_codeword()?;
Ok(())
}

fn receive_last_round_polynomial(&mut self) -> VerifierResult<()> {
let coefficients = self.proof_stream.dequeue()?.try_into_fri_polynomial()?;
self.last_round_polynomial = Polynomial::new(coefficients);
Ok(())
}

fn compute_last_round_folded_partial_codeword(&mut self) -> VerifierResult<()> {
self.sample_first_round_collinearity_check_indices();
self.receive_authentic_partially_revealed_codewords()?;
Expand Down
2 changes: 1 addition & 1 deletion triton-vm/src/proof_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ proof_items!(
Log2PaddedHeight(u32) => false, try_into_log2_padded_height,
QuotientSegmentsElements(Vec<QuotientSegments>) => false, try_into_quot_segments_elements,
FriCodeword(Vec<XFieldElement>) => false, try_into_fri_codeword,
FriPolynomial(Vec<XFieldElement>) => false, try_into_fri_polynomial,
FriPolynomial(Polynomial<XFieldElement>) => false, try_into_fri_polynomial,
FriResponse(FriResponse) => false, try_into_fri_response,
);

Expand Down

0 comments on commit 7367c67

Please sign in to comment.