diff --git a/constraint-evaluation-generator/src/main.rs b/constraint-evaluation-generator/src/main.rs index 134120d3..24e698ce 100644 --- a/constraint-evaluation-generator/src/main.rs +++ b/constraint-evaluation-generator/src/main.rs @@ -280,9 +280,6 @@ fn turn_circuits_into_degree_bounds_string( fn turn_circuits_into_string( constraint_circuits: &mut [ConstraintCircuit], ) -> String { - // Delete redundant nodes - ConstraintCircuit::constant_folding(&mut constraint_circuits.iter_mut().collect_vec()); - // Assert that all node IDs are unique (sanity check) ConstraintCircuit::assert_has_unique_ids(constraint_circuits); @@ -445,7 +442,7 @@ fn is_bfield_element(circuit: &ConstraintCircuit) -> boo match &circuit.expression { CircuitExpression::XConstant(_) => false, CircuitExpression::BConstant(_) => true, - CircuitExpression::Input(indicator) => indicator.is_base_table_row(), + CircuitExpression::Input(indicator) => indicator.is_base_table_column(), CircuitExpression::Challenge(_) => false, CircuitExpression::BinaryOperation(_, lhs, rhs) => { is_bfield_element(&lhs.as_ref().borrow()) && is_bfield_element(&rhs.as_ref().borrow()) diff --git a/triton-vm/src/table/constraint_circuit.rs b/triton-vm/src/table/constraint_circuit.rs index 5b4742cf..0325c15f 100644 --- a/triton-vm/src/table/constraint_circuit.rs +++ b/triton-vm/src/table/constraint_circuit.rs @@ -9,6 +9,7 @@ use std::borrow::BorrowMut; use std::cell::RefCell; use std::cmp; +use std::collections::HashMap; use std::collections::HashSet; use std::fmt::Debug; use std::fmt::Display; @@ -63,10 +64,10 @@ impl Display for BinOp { /// Having `Clone + Copy + Hash + PartialEq + Eq` helps putting `InputIndicator`s into containers. pub trait InputIndicator: Debug + Clone + Copy + Hash + PartialEq + Eq + Display { /// `true` iff `self` refers to a column in the base table. - fn is_base_table_row(&self) -> bool; + fn is_base_table_column(&self) -> bool; - fn base_row_index(&self) -> usize; - fn ext_row_index(&self) -> usize; + fn base_col_index(&self) -> usize; + fn ext_col_index(&self) -> usize; fn evaluate( &self, @@ -96,12 +97,12 @@ impl Display for SingleRowIndicator { } impl InputIndicator for SingleRowIndicator { - fn is_base_table_row(&self) -> bool { + fn is_base_table_column(&self) -> bool { use SingleRowIndicator::*; matches!(self, BaseRow(_)) } - fn base_row_index(&self) -> usize { + fn base_col_index(&self) -> usize { use SingleRowIndicator::*; match self { BaseRow(i) => *i, @@ -109,7 +110,7 @@ impl InputIndicator for SingleRowIndicator { } } - fn ext_row_index(&self) -> usize { + fn ext_col_index(&self) -> usize { use SingleRowIndicator::*; match self { BaseRow(_) => panic!("not an ext row"), @@ -155,12 +156,12 @@ impl Display for DualRowIndicator { } impl InputIndicator for DualRowIndicator { - fn is_base_table_row(&self) -> bool { + fn is_base_table_column(&self) -> bool { use DualRowIndicator::*; matches!(self, CurrentBaseRow(_) | NextBaseRow(_)) } - fn base_row_index(&self) -> usize { + fn base_col_index(&self) -> usize { use DualRowIndicator::*; match self { CurrentBaseRow(i) | NextBaseRow(i) => *i, @@ -168,7 +169,7 @@ impl InputIndicator for DualRowIndicator { } } - fn ext_row_index(&self) -> usize { + fn ext_col_index(&self) -> usize { use DualRowIndicator::*; match self { CurrentBaseRow(_) | NextBaseRow(_) => panic!("not an ext row"), @@ -223,13 +224,24 @@ pub enum CircuitExpression { impl Hash for CircuitExpression { fn hash(&self, state: &mut H) { match self { - BConstant(bfe) => bfe.hash(state), - XConstant(xfe) => xfe.hash(state), - Input(index) => index.hash(state), + BConstant(bfe) => { + "bfe".hash(state); + bfe.hash(state); + } + XConstant(xfe) => { + "xfe".hash(state); + xfe.hash(state); + } + Input(index) => { + "input".hash(state); + index.hash(state); + } Challenge(table_challenge_id) => { + "challenge".hash(state); table_challenge_id.hash(state); } BinaryOperation(binop, lhs, rhs) => { + "binop".hash(state); binop.hash(state); lhs.as_ref().borrow().hash(state); rhs.as_ref().borrow().hash(state); @@ -238,6 +250,41 @@ impl Hash for CircuitExpression { } } +impl PartialEq for CircuitExpression { + fn eq(&self, other: &Self) -> bool { + match self { + XConstant(self_xfe) => match other { + XConstant(other_xfe) => self_xfe == other_xfe, + _ => false, + }, + BConstant(self_bfe) => match other { + BConstant(other_bfe) => self_bfe == other_bfe, + _ => false, + }, + Input(self_input) => match other { + Input(other_input) => self_input == other_input, + _ => false, + }, + Challenge(self_challenge_id) => match other { + Challenge(other_challenge_id) => self_challenge_id == other_challenge_id, + _ => false, + }, + BinaryOperation(binop_self, lhs_self, rhs_self) => { + match other { + BinaryOperation(binop_other, lhs_other, rhs_other) => { + // a = b `op0` c, + // d = e `op1` f => + // a = d <= op0 == op1 && b == e && c ==f + binop_self == binop_other && lhs_self == lhs_other && rhs_self == rhs_other + } + + _ => false, + } + } + } + } +} + impl Hash for ConstraintCircuit { fn hash(&self, state: &mut H) { self.expression.hash(state) @@ -246,7 +293,7 @@ impl Hash for ConstraintCircuit { impl Hash for ConstraintCircuitMonad { fn hash(&self, state: &mut H) { - self.circuit.as_ref().borrow().expression.hash(state) + self.circuit.as_ref().borrow().hash(state) } } @@ -265,36 +312,7 @@ impl PartialEq for ConstraintCircuit { /// simplify or reduce neutral terms or products. So this comparison will return false for /// `a == a + 0`. It will also return false for `XFieldElement(7) == BFieldElement(7)` fn eq(&self, other: &Self) -> bool { - match &self.expression { - XConstant(self_xfe) => match &other.expression { - XConstant(other_xfe) => self_xfe == other_xfe, - _ => false, - }, - BConstant(self_bfe) => match &other.expression { - BConstant(other_bfe) => self_bfe == other_bfe, - _ => false, - }, - Input(self_input) => match &other.expression { - Input(other_input) => self_input == other_input, - _ => false, - }, - Challenge(self_challenge_id) => match &other.expression { - Challenge(other_challenge_id) => self_challenge_id == other_challenge_id, - _ => false, - }, - BinaryOperation(binop_self, lhs_self, rhs_self) => { - match &other.expression { - BinaryOperation(binop_other, lhs_other, rhs_other) => { - // a = b `op0` c, - // d = e `op1` f => - // a = d <= op0 == op1 && b == e && c ==f - binop_self == binop_other && lhs_self == lhs_other && rhs_self == rhs_other - } - - _ => false, - } - } - } + self.expression == other.expression } } @@ -358,12 +376,13 @@ impl ConstraintCircuit { } /// Verify that all IDs in the subtree are unique. Panics otherwise. - fn inner_has_unique_ids(&mut self, ids: &mut HashSet) { - let new_value = ids.insert(self.id); + fn inner_has_unique_ids(&mut self, ids: &mut HashMap>) { + let node_with_repeated_id = ids.insert(self.id, self.clone()); assert!( - !self.visited_counter.is_zero() || new_value, - "ID = {} was repeated", - self.id + !self.visited_counter.is_zero() || node_with_repeated_id.is_none(), + "ID = {} was repeated. Self was: {self:?}, other node: {:?}", + self.id, + node_with_repeated_id.unwrap(), ); self.visited_counter += 1; if let BinaryOperation(_, lhs, rhs) = &self.expression { @@ -372,9 +391,9 @@ impl ConstraintCircuit { } } - /// Verify that a multitree has unique IDs. Panics otherwise. + /// Verify that a multicircuit has unique IDs. Panics otherwise. pub fn assert_has_unique_ids(constraints: &mut [ConstraintCircuit]) { - let mut ids: HashSet = HashSet::new(); + let mut ids: HashMap> = HashMap::new(); for circuit in constraints.iter_mut() { circuit.inner_has_unique_ids(&mut ids); @@ -385,113 +404,6 @@ impl ConstraintCircuit { } } - /// Apply constant folding to simplify the (sub)tree. If the subtree is a leaf (terminal), no - /// change. If the subtree is a binary operation on: - /// - /// - one constant x one constant => fold - /// - one constant x one expr => can't - /// - one expr x one constant => can't - /// - one expr x one expr => can't - /// - /// This operation mutates self and returns true if a change was applied anywhere in the tree. - fn constant_fold_inner(&mut self) -> bool { - let mut change_tracker = false; - if let BinaryOperation(_, lhs, rhs) = &self.expression { - change_tracker |= lhs.clone().as_ref().borrow_mut().constant_fold_inner(); - change_tracker |= rhs.clone().as_ref().borrow_mut().constant_fold_inner(); - } - - match &self.expression.clone() { - BinaryOperation(binop, lhs, rhs) => { - // a + 0 = a ∧ a - 0 = a - if matches!(binop, BinOp::Add | BinOp::Sub) && rhs.as_ref().borrow().is_zero() { - *self.expression.borrow_mut() = lhs.as_ref().borrow().expression.clone(); - return true; - } - - // 0 + a = a - if *binop == BinOp::Add && lhs.as_ref().borrow().is_zero() { - *self.expression.borrow_mut() = rhs.as_ref().borrow().expression.clone(); - return true; - } - - if matches!(binop, BinOp::Mul) { - // a * 1 = a - if rhs.as_ref().borrow().is_one() { - *self.expression.borrow_mut() = lhs.as_ref().borrow().expression.clone(); - return true; - } - - // 1 * a = a - if lhs.as_ref().borrow().is_one() { - *self.expression.borrow_mut() = rhs.as_ref().borrow().expression.clone(); - return true; - } - - // 0 * a = a * 0 = 0 - if lhs.as_ref().borrow().is_zero() || rhs.as_ref().borrow().is_zero() { - *self.expression.borrow_mut() = BConstant(0u64.into()); - return true; - } - } - - // if left and right hand sides are both constants - if let XConstant(lhs_xfe) = lhs.as_ref().borrow().expression { - if let XConstant(rhs_xfe) = rhs.as_ref().borrow().expression { - *self.expression.borrow_mut() = match binop { - BinOp::Add => XConstant(lhs_xfe + rhs_xfe), - BinOp::Sub => XConstant(lhs_xfe - rhs_xfe), - BinOp::Mul => XConstant(lhs_xfe * rhs_xfe), - }; - return true; - } - - if let BConstant(rhs_bfe) = rhs.as_ref().borrow().expression { - *self.expression.borrow_mut() = match binop { - BinOp::Add => XConstant(lhs_xfe + rhs_bfe.lift()), - BinOp::Sub => XConstant(lhs_xfe - rhs_bfe.lift()), - BinOp::Mul => XConstant(lhs_xfe * rhs_bfe), - }; - return true; - } - } - - if let BConstant(lhs_bfe) = lhs.as_ref().borrow().expression { - if let XConstant(rhs_xfe) = rhs.as_ref().borrow().expression { - *self.expression.borrow_mut() = match binop { - BinOp::Add => XConstant(lhs_bfe.lift() + rhs_xfe), - BinOp::Sub => XConstant(lhs_bfe.lift() - rhs_xfe), - BinOp::Mul => XConstant(rhs_xfe * lhs_bfe), - }; - return true; - } - - if let BConstant(rhs_bfe) = rhs.as_ref().borrow().expression { - *self.expression.borrow_mut() = match binop { - BinOp::Add => BConstant(lhs_bfe + rhs_bfe), - BinOp::Sub => BConstant(lhs_bfe - rhs_bfe), - BinOp::Mul => BConstant(lhs_bfe * rhs_bfe), - }; - return true; - } - } - - change_tracker - } - _ => change_tracker, - } - } - - /// Reduce size of multitree by simplifying constant expressions such as `1·X` to `X`. - pub fn constant_folding(circuits: &mut [&mut ConstraintCircuit]) { - for circuit in circuits.iter_mut() { - let mut mutated = circuit.constant_fold_inner(); - while mutated { - mutated = circuit.constant_fold_inner(); - } - } - } - /// Return max degree after evaluating the circuit with an input of specified degree pub fn symbolic_degree_bound( &self, @@ -518,9 +430,9 @@ impl ConstraintCircuit { } } } - Input(input) => match input.is_base_table_row() { - true => max_base_degrees[input.base_row_index()], - false => max_ext_degrees[input.ext_row_index()], + Input(input) => match input.is_base_table_column() { + true => max_base_degrees[input.base_col_index()], + false => max_ext_degrees[input.ext_col_index()], }, XConstant(xfe) => { if xfe.is_zero() { @@ -634,44 +546,117 @@ impl ConstraintCircuit { } } - /// Replace all challenges with constants in subtree. - fn apply_challenges_to_one_root(&mut self, challenges: &Challenges) { - match &self.expression { - Challenge(challenge_id) => { - *self.expression.borrow_mut() = XConstant(challenges.get_challenge(*challenge_id)) - } - BinaryOperation(_, lhs, rhs) => { - lhs.as_ref() - .borrow_mut() - .apply_challenges_to_one_root(challenges); - rhs.as_ref() - .borrow_mut() - .apply_challenges_to_one_root(challenges); - } - _ => (), + /// Panics if two nodes evaluate to the same value + pub fn assert_all_evaluate_different( + constraints: &[Self], + challenges: &Challenges, + base_table: ArrayView2, + ext_table: ArrayView2, + ) { + let mut evaluated_values = HashMap::default(); + for constraint in constraints.iter() { + Self::evaluate_and_store_and_assert_unique( + constraint, + challenges, + base_table, + ext_table, + &mut evaluated_values, + ); } } - /// Simplify the circuit constraints by replacing the known challenges with roots. - pub fn apply_challenges(constraints: &mut [ConstraintCircuit], challenges: &Challenges) { - for circuit in constraints.iter_mut() { - circuit.apply_challenges_to_one_root(challenges); + /// Return own value and whether own value was seen before. Stores own value in hash map. + fn evaluate_and_store_and_assert_unique( + &self, + challenges: &Challenges, + base_table: ArrayView2, + ext_table: ArrayView2, + evaluated_values: &mut HashMap)>, + ) -> XFieldElement { + // assert_eq!( + // self.var_count, + // input.len(), + // "Input length match circuit's var count" + // ); + let value = match &self.expression { + XConstant(xfe) => { + if self.id == 107 || self.id == 454 { + println!("{}: XFE", self.id); + } + xfe.to_owned() + } + BConstant(bfe) => { + if self.id == 107 || self.id == 454 { + println!("{}: BFE", self.id); + } + bfe.lift() + } + Input(s) => { + s.evaluate(base_table, ext_table) + // if s.is_base_table_row() { + // base_input[s.base_row_index()].lift() + // } else { + // ext_input[s.ext_row_index()] + // } + } + Challenge(cid) => challenges.get_challenge(*cid), + BinaryOperation(binop, lhs, rhs) => { + let lhs = lhs.as_ref().borrow().evaluate_and_store_and_assert_unique( + challenges, + base_table, + ext_table, + evaluated_values, + ); + let rhs = rhs.as_ref().borrow().evaluate_and_store_and_assert_unique( + challenges, + base_table, + ext_table, + evaluated_values, + ); + match binop { + BinOp::Add => lhs + rhs, + BinOp::Sub => lhs - rhs, + BinOp::Mul => lhs * rhs, + } + } + }; + + let self_evaluated_is_unique = + evaluated_values.insert(value, (self.id.to_owned(), self.clone())); + if let Some((collided_circuit_id, collided_circuit)) = self_evaluated_is_unique { + let own_id = self.id.to_owned(); + if collided_circuit_id != self.id { + panic!( + "Circuit ID {collided_circuit_id} and circuit ID {own_id} are not unique. \ + Collission on:\n \ + {collided_circuit_id}: {collided_circuit}\n {own_id}: {self}. \ + Value was {value}", + ); + } } + value } - fn evaluate_inner( + pub fn evaluate( &self, base_table: ArrayView2, ext_table: ArrayView2, + challenges: &Challenges, ) -> XFieldElement { match self.clone().expression { XConstant(xfe) => xfe, BConstant(bfe) => bfe.lift(), Input(input) => input.evaluate(base_table, ext_table), - Challenge(challenge_id) => panic!("Challenge {challenge_id} not evaluated"), + Challenge(challenge_id) => challenges.get_challenge(challenge_id), BinaryOperation(binop, lhs, rhs) => { - let lhs_value = lhs.as_ref().borrow().evaluate_inner(base_table, ext_table); - let rhs_value = rhs.as_ref().borrow().evaluate_inner(base_table, ext_table); + let lhs_value = lhs + .as_ref() + .borrow() + .evaluate(base_table, ext_table, challenges); + let rhs_value = rhs + .as_ref() + .borrow() + .evaluate(base_table, ext_table, challenges); match binop { BinOp::Add => lhs_value + rhs_value, BinOp::Sub => lhs_value - rhs_value, @@ -680,26 +665,14 @@ impl ConstraintCircuit { } } } - - pub fn evaluate( - &self, - base_table: ArrayView2, - ext_table: ArrayView2, - challenges: &Challenges, - ) -> XFieldElement { - let mut self_to_evaluate = self.clone(); - self_to_evaluate.apply_challenges_to_one_root(challenges); - self_to_evaluate.evaluate_inner(base_table, ext_table) - } } -/// The inner type used in the [`ConstraintCircuitBuilder`] to build a circuit. Provides -/// convenience methods, for example by allowing to use `+` and `*` to add and multiply. +/// Constraint expressions, with context needed to ensure that two equal nodes are not added to +/// the multicircuit. #[derive(Clone)] pub struct ConstraintCircuitMonad { pub circuit: Rc>>, - pub all_nodes: Rc>>>, - pub id_counter_ref: Rc>, + pub builder: ConstraintCircuitBuilder, } impl Debug for ConstraintCircuitMonad { @@ -710,11 +683,11 @@ impl Debug for ConstraintCircuitMonad { .field("id", &self.circuit) .field( "all_nodes length: ", - &self.all_nodes.as_ref().borrow().len(), + &self.builder.all_nodes.as_ref().borrow().len(), ) .field( "id_counter_ref value: ", - &self.id_counter_ref.as_ref().borrow(), + &self.builder.id_counter.as_ref().borrow(), ) .finish() } @@ -745,24 +718,37 @@ fn binop( rhs: ConstraintCircuitMonad, ) -> ConstraintCircuitMonad { // Get ID for the new node - let new_index = lhs.id_counter_ref.as_ref().borrow().to_owned(); + let new_index = lhs.builder.id_counter.as_ref().borrow().to_owned(); + let lhs = Rc::new(RefCell::new(lhs)); + let rhs = Rc::new(RefCell::new(rhs)); let new_node = ConstraintCircuitMonad { circuit: Rc::new(RefCell::new(ConstraintCircuit { visited_counter: 0, - expression: BinaryOperation(binop, Rc::clone(&lhs.circuit), Rc::clone(&rhs.circuit)), + expression: BinaryOperation( + binop, + Rc::clone(&lhs.as_ref().borrow().circuit), + Rc::clone(&rhs.as_ref().borrow().circuit), + ), id: new_index, })), - id_counter_ref: Rc::clone(&lhs.id_counter_ref), - all_nodes: Rc::clone(&lhs.all_nodes), + builder: lhs.as_ref().borrow().builder.clone(), }; // check if node already exists - let contained = lhs.all_nodes.as_ref().borrow().contains(&new_node); + let contained = lhs + .as_ref() + .borrow() + .builder + .all_nodes + .as_ref() + .borrow() + .contains(&new_node); if contained { - let ret0 = &lhs.all_nodes.as_ref().borrow(); - let ret1 = &(*ret0.get(&new_node).as_ref().unwrap()).clone(); - return ret1.to_owned(); + let ret0 = &lhs.as_ref().borrow(); + let ret1 = ret0.builder.all_nodes.as_ref().borrow(); + let ret2 = &(*ret1.get(&new_node).as_ref().unwrap()).clone(); + return ret2.to_owned(); } // If the operator commutes, check if the inverse node has already been constructed. @@ -774,33 +760,47 @@ fn binop( expression: BinaryOperation( binop, // Switch rhs and lhs for symmetric operators to check membership in hash set - Rc::clone(&rhs.circuit), - Rc::clone(&lhs.circuit), + Rc::clone(&rhs.as_ref().borrow().circuit), + Rc::clone(&lhs.as_ref().borrow().circuit), ), id: new_index, })), - id_counter_ref: Rc::clone(&lhs.id_counter_ref), - all_nodes: Rc::clone(&lhs.all_nodes), + builder: lhs.as_ref().borrow().builder.clone(), }; // check if node already exists - let inverted_contained = lhs.all_nodes.as_ref().borrow().contains(&new_node_inverted); + let inverted_contained = lhs + .as_ref() + .borrow() + .builder + .all_nodes + .as_ref() + .borrow() + .contains(&new_node_inverted); if inverted_contained { - let ret0 = &lhs.all_nodes.as_ref().borrow(); - let ret1 = &(*ret0.get(&new_node_inverted).as_ref().unwrap()).clone(); - return ret1.to_owned(); + let ret0 = &lhs.as_ref().borrow(); + let ret1 = ret0.builder.all_nodes.as_ref().borrow(); + let ret2 = &(*ret1.get(&new_node_inverted).as_ref().unwrap()).clone(); + return ret2.to_owned(); } } // Increment counter index - *lhs.id_counter_ref.as_ref().borrow_mut() = new_index + 1; + *lhs.as_ref() + .borrow() + .builder + .id_counter + .as_ref() + .borrow_mut() = new_index + 1; // Store new node in HashSet - new_node + let inserted_value_was_new = new_node + .builder .all_nodes .as_ref() .borrow_mut() .insert(new_node.clone()); + assert!(inserted_value_was_new, "Binop-created value must be new"); new_node } @@ -843,6 +843,223 @@ impl ConstraintCircuitMonad { pub fn consume(self) -> ConstraintCircuit { self.circuit.try_borrow().unwrap().to_owned() } + + pub fn max_id(&self) -> usize { + let max_from_hash_map = self + .builder + .all_nodes + .as_ref() + .borrow() + .iter() + .map(|x| x.circuit.as_ref().borrow().id) + .max() + .unwrap(); + + let id_ref_value = *self.builder.id_counter.borrow(); + assert_eq!(id_ref_value - 1, max_from_hash_map); + max_from_hash_map + } + + fn replace_references(&self, old_id: usize, new: Rc>>) { + for node in self.builder.all_nodes.as_ref().borrow().clone().into_iter() { + if node.circuit.as_ref().borrow().id == old_id { + continue; + } + + if let BinaryOperation(_, ref mut lhs, ref mut rhs) = + node.circuit.as_ref().borrow_mut().expression + { + if lhs.as_ref().borrow().id == old_id { + *lhs = new.clone(); + } + if rhs.as_ref().borrow().id == old_id { + *rhs = new.clone(); + } + } + } + } + + fn find_equivalent_expression(&self) -> Option>>> { + if let BinaryOperation(binop, lhs, rhs) = &self.circuit.as_ref().borrow().expression { + // a + 0 = a ∧ a - 0 = a + if matches!(binop, BinOp::Add | BinOp::Sub) && rhs.borrow().is_zero() { + return Some(Rc::clone(lhs)); + } + + // 0 + a = a + if *binop == BinOp::Add && lhs.borrow().is_zero() { + return Some(Rc::clone(rhs)); + } + + if matches!(binop, BinOp::Mul) { + // a * 1 = a + if rhs.borrow().is_one() { + return Some(Rc::clone(lhs)); + } + + // 1 * a = a + if lhs.borrow().is_one() { + return Some(Rc::clone(rhs)); + } + + // 0 * a = 0 + if lhs.borrow().is_zero() { + return Some(Rc::clone(lhs)); + } + + // a * 0 = 0 + if rhs.borrow().is_zero() { + return Some(Rc::clone(rhs)); + } + } + + // if left and right hand sides are both constants + if let XConstant(lhs_xfe) = lhs.borrow().expression { + if let XConstant(rhs_xfe) = rhs.borrow().expression { + return match binop { + BinOp::Add => Some(Rc::new(RefCell::new( + self.builder + .make_leaf(XConstant(lhs_xfe + rhs_xfe)) + .consume(), + ))), + BinOp::Sub => Some(Rc::new(RefCell::new( + self.builder + .make_leaf(XConstant(lhs_xfe - rhs_xfe)) + .consume(), + ))), + BinOp::Mul => Some(Rc::new(RefCell::new( + self.builder + .make_leaf(XConstant(lhs_xfe * rhs_xfe)) + .consume(), + ))), + }; + } + + if let BConstant(rhs_bfe) = rhs.borrow().expression { + return match binop { + BinOp::Add => Some(Rc::new(RefCell::new( + self.builder + .make_leaf(XConstant(lhs_xfe + rhs_bfe.lift())) + .consume(), + ))), + BinOp::Sub => Some(Rc::new(RefCell::new( + self.builder + .make_leaf(XConstant(lhs_xfe - rhs_bfe.lift())) + .consume(), + ))), + BinOp::Mul => Some(Rc::new(RefCell::new( + self.builder + .make_leaf(XConstant(lhs_xfe * rhs_bfe)) + .consume(), + ))), + }; + } + } + + if let BConstant(lhs_bfe) = lhs.borrow().expression { + if let XConstant(rhs_xfe) = rhs.borrow().expression { + return match binop { + BinOp::Add => Some(Rc::new(RefCell::new( + self.builder + .make_leaf(XConstant(lhs_bfe.lift() + rhs_xfe)) + .consume(), + ))), + BinOp::Sub => Some(Rc::new(RefCell::new( + self.builder + .make_leaf(XConstant(lhs_bfe.lift() - rhs_xfe)) + .consume(), + ))), + BinOp::Mul => Some(Rc::new(RefCell::new( + self.builder + .make_leaf(XConstant(rhs_xfe * lhs_bfe)) + .consume(), + ))), + }; + } + + if let BConstant(rhs_bfe) = rhs.borrow().expression { + return match binop { + BinOp::Add => Some(Rc::new(RefCell::new( + self.builder + .make_leaf(BConstant(lhs_bfe + rhs_bfe)) + .consume(), + ))), + BinOp::Sub => Some(Rc::new(RefCell::new( + self.builder + .make_leaf(BConstant(lhs_bfe - rhs_bfe)) + .consume(), + ))), + BinOp::Mul => Some(Rc::new(RefCell::new( + self.builder + .make_leaf(BConstant(lhs_bfe * lhs_bfe)) + .consume(), + ))), + }; + } + } + return None; + } + None + } + + /// Apply constant folding to simplify the (sub)tree. + /// If the subtree is a leaf (terminal), no change. + /// If the subtree is a binary operation on: + /// + /// - one constant x one constant => fold + /// - one constant x one expr => can't + /// - one expr x one constant => can't + /// - one expr x one expr => can't + /// + /// This operation mutates self and returns true if a change was + /// applied anywhere in the tree. + fn constant_fold_inner(&mut self) -> (bool, Option>>>) { + let mut change_tracker = false; + let self_expr = self.circuit.as_ref().borrow().expression.clone(); + if let BinaryOperation(_, lhs, rhs) = &self_expr { + let mut lhs_as_monadic_value = ConstraintCircuitMonad { + circuit: lhs.clone(), + builder: self.builder.clone(), + }; + let (change_in_lhs, _) = lhs_as_monadic_value.constant_fold_inner(); + change_tracker |= change_in_lhs; + let mut rhs_as_monadic_value = ConstraintCircuitMonad { + circuit: rhs.clone(), + builder: self.builder.clone(), + }; + let (change_in_rhs, _) = rhs_as_monadic_value.constant_fold_inner(); + change_tracker |= change_in_rhs; + } + + let equivalent_circuit = self.find_equivalent_expression(); + change_tracker |= equivalent_circuit.is_some(); + + if equivalent_circuit.is_some() { + let equivalent_circuit = equivalent_circuit.as_ref().unwrap().clone(); + let id_of_node_to_be_deleted = self.circuit.borrow().id; + self.replace_references(id_of_node_to_be_deleted, equivalent_circuit); + self.builder.all_nodes.as_ref().borrow_mut().remove(self); + } + + (change_tracker, equivalent_circuit) + } + + /// Reduce size of multitree by simplifying constant expressions such as `1 * MPol(_,_)` + pub fn constant_folding(circuits: &mut [ConstraintCircuitMonad]) { + for circuit in circuits.iter_mut() { + let mut mutated = true; + while mutated { + let (mutated_inner, maybe_new_root) = circuit.constant_fold_inner(); + mutated = mutated_inner; + if let Some(new_root) = maybe_new_root { + *circuit = ConstraintCircuitMonad { + circuit: new_root, + builder: circuit.builder.clone(), + }; + } + } + } + } } #[derive(Debug, Clone)] @@ -867,6 +1084,15 @@ impl ConstraintCircuitBuilder { } } + pub fn get_node_by_id(&self, id: usize) -> Option> { + for node in self.all_nodes.as_ref().borrow().iter() { + if node.circuit.as_ref().borrow().id == id { + return Some(node.clone()); + } + } + None + } + /// Create constant leaf node. pub fn x_constant(&self, xfe: XFieldElement) -> ConstraintCircuitMonad { let expression = XConstant(xfe); @@ -891,7 +1117,14 @@ impl ConstraintCircuitBuilder { self.make_leaf(expression) } - fn make_leaf(&self, expression: CircuitExpression) -> ConstraintCircuitMonad { + fn make_leaf(&self, mut expression: CircuitExpression) -> ConstraintCircuitMonad { + // Don't generate an X field leaf if it can be expressed as a B field leaf + if let XConstant(xfe) = expression { + if let Some(bfe) = xfe.unlift() { + expression = BConstant(bfe); + } + } + let new_id = self.id_counter.as_ref().borrow().to_owned(); let new_node = ConstraintCircuitMonad { circuit: Rc::new(RefCell::new(ConstraintCircuit { @@ -899,8 +1132,7 @@ impl ConstraintCircuitBuilder { expression, id: new_id, })), - id_counter_ref: Rc::clone(&self.id_counter), - all_nodes: Rc::clone(&self.all_nodes), + builder: self.clone(), }; // Check if node already exists, return the existing one if it does @@ -927,18 +1159,22 @@ mod constraint_circuit_tests { use std::collections::hash_map::DefaultHasher; use std::hash::Hasher; - use itertools::Itertools; + use ndarray::Array2; + use rand::random; use rand::thread_rng; use rand::RngCore; use twenty_first::shared_math::other::random_elements; use crate::table::challenges::ChallengeId::U32Indeterminate; use crate::table::challenges::Challenges; + use crate::table::hash_table::ExtHashTable; use crate::table::jump_stack_table::ExtJumpStackTable; + use crate::table::master_table; use crate::table::op_stack_table::ExtOpStackTable; use crate::table::processor_table::ExtProcessorTable; use crate::table::program_table::ExtProgramTable; use crate::table::ram_table::ExtRamTable; + use crate::table::u32_table::ExtU32Table; use super::*; @@ -1135,64 +1371,103 @@ mod constraint_circuit_tests { let var_0_copy_0 = deep_copy(&var_0.circuit.as_ref().borrow()); let var_0_mul_one_0 = var_0_copy_0.clone() * one.clone(); assert_ne!(var_0_copy_0, var_0_mul_one_0); - let mut var_0_circuit_0 = var_0_copy_0.consume(); - let mut var_0_same_circuit_0 = var_0_mul_one_0.consume(); - ConstraintCircuit::constant_folding(&mut [&mut var_0_circuit_0, &mut var_0_same_circuit_0]); - assert_eq!(var_0_circuit_0, var_0_same_circuit_0); - assert_eq!(var_0_same_circuit_0, var_0_circuit_0); + let mut circuits = [var_0_copy_0, var_0_mul_one_0]; + ConstraintCircuitMonad::constant_folding(&mut circuits); + assert_eq!( + circuits[0], circuits[1], + "{} != {}", + circuits[0], circuits[1] + ); + assert_eq!( + circuits[1], circuits[0], + "{} != {}", + circuits[1], circuits[0] + ); // Verify that constant folding can handle a = 1 * a let var_0_copy_1 = deep_copy(&var_0.circuit.as_ref().borrow()); let var_0_one_mul_1 = one.clone() * var_0_copy_1.clone(); assert_ne!(var_0_copy_1, var_0_one_mul_1); - let mut var_0_circuit_1 = var_0_copy_1.consume(); - let mut var_0_same_circuit_1 = var_0_one_mul_1.consume(); - ConstraintCircuit::constant_folding(&mut [&mut var_0_circuit_1, &mut var_0_same_circuit_1]); - assert_eq!(var_0_circuit_1, var_0_same_circuit_1); - assert_eq!(var_0_same_circuit_1, var_0_circuit_1); + let mut circuits = [var_0_copy_1, var_0_one_mul_1]; + ConstraintCircuitMonad::constant_folding(&mut circuits); + assert_eq!( + circuits[0], circuits[1], + "{} != {}", + circuits[0], circuits[1] + ); + assert_eq!( + circuits[1], circuits[0], + "{} != {}", + circuits[1], circuits[0] + ); // Verify that constant folding can handle a = 1 * a * 1 let var_0_copy_2 = deep_copy(&var_0.circuit.as_ref().borrow()); let var_0_one_mul_2 = one.clone() * var_0_copy_2.clone() * one; assert_ne!(var_0_copy_2, var_0_one_mul_2); - let mut var_0_circuit_2 = var_0_copy_2.consume(); - let mut var_0_same_circuit_2 = var_0_one_mul_2.consume(); - ConstraintCircuit::constant_folding(&mut [&mut var_0_circuit_2, &mut var_0_same_circuit_2]); - assert_eq!(var_0_circuit_2, var_0_same_circuit_2); - assert_eq!(var_0_same_circuit_2, var_0_circuit_2); + let mut circuits = [var_0_copy_2, var_0_one_mul_2]; + ConstraintCircuitMonad::constant_folding(&mut circuits); + assert_eq!( + circuits[0], circuits[1], + "{} != {}", + circuits[0], circuits[1] + ); + assert_eq!( + circuits[1], circuits[0], + "{} != {}", + circuits[1], circuits[0] + ); // Verify that constant folding handles a + 0 = a let var_0_copy_3 = deep_copy(&var_0.circuit.as_ref().borrow()); let var_0_plus_zero_3 = var_0_copy_3.clone() + zero.clone(); assert_ne!(var_0_copy_3, var_0_plus_zero_3); - let mut var_0_circuit_3 = var_0_copy_3.consume(); - let mut var_0_same_circuit_3 = var_0_plus_zero_3.consume(); - ConstraintCircuit::constant_folding(&mut [&mut var_0_circuit_3, &mut var_0_same_circuit_3]); - assert_eq!(var_0_circuit_3, var_0_same_circuit_3); - assert_eq!(var_0_same_circuit_3, var_0_circuit_3); + let mut circuits = [var_0_copy_3, var_0_plus_zero_3]; + ConstraintCircuitMonad::constant_folding(&mut circuits); + assert_eq!( + circuits[0], circuits[1], + "{} != {}", + circuits[0], circuits[1] + ); + assert_eq!( + circuits[1], circuits[0], + "{} != {}", + circuits[1], circuits[0] + ); // Verify that constant folding handles a + (a * 0) = a let var_0_copy_4 = deep_copy(&var_0.circuit.as_ref().borrow()); let var_0_plus_zero_4 = var_0_copy_4.clone() + var_0_copy_4.clone() * zero.clone(); assert_ne!(var_0_copy_4, var_0_plus_zero_4); - let mut var_0_circuit_4 = var_0_copy_4.consume(); - let mut var_0_same_circuit_4 = var_0_plus_zero_4.consume(); - ConstraintCircuit::constant_folding(&mut [&mut var_0_circuit_4, &mut var_0_same_circuit_4]); - assert_eq!(var_0_circuit_4, var_0_same_circuit_4); - assert_eq!(var_0_same_circuit_4, var_0_circuit_4); + let mut circuits = [var_0_copy_4, var_0_plus_zero_4]; + ConstraintCircuitMonad::constant_folding(&mut circuits); + assert_eq!( + circuits[0], circuits[1], + "{} != {}", + circuits[0], circuits[1] + ); + assert_eq!( + circuits[1], circuits[0], + "{} != {}", + circuits[1], circuits[0] + ); // Verify that constant folding does not equate `0 - a` with `a` let var_0_copy_5 = deep_copy(&var_0.circuit.as_ref().borrow()); let zero_minus_var_0 = zero - var_0_copy_5.clone(); assert_ne!(var_0_copy_5, zero_minus_var_0); - let mut var_0_circuit_5 = var_0_copy_5.consume(); - let mut var_0_not_same_circuit_5 = zero_minus_var_0.consume(); - ConstraintCircuit::constant_folding(&mut [ - &mut var_0_circuit_5, - &mut var_0_not_same_circuit_5, - ]); - assert_ne!(var_0_circuit_5, var_0_not_same_circuit_5); - assert_ne!(var_0_not_same_circuit_5, var_0_circuit_5); + let mut circuits = [var_0_copy_5, zero_minus_var_0]; + ConstraintCircuitMonad::constant_folding(&mut circuits); + assert_ne!( + circuits[0], circuits[1], + "{} == {}", + circuits[0], circuits[1] + ); + assert_ne!( + circuits[1], circuits[0], + "{} == {}", + circuits[1], circuits[0] + ); } #[test] @@ -1206,184 +1481,284 @@ mod constraint_circuit_tests { let copy_0 = deep_copy(&circuit.circuit.as_ref().borrow()); let copy_0_alt = copy_0.clone() * one.clone(); assert_ne!(copy_0, copy_0_alt); - let mut circuit_0 = copy_0.consume(); - let mut same_circuit_0 = copy_0_alt.consume(); - ConstraintCircuit::constant_folding(&mut [&mut circuit_0, &mut same_circuit_0]); - assert_eq!(circuit_0, same_circuit_0); - assert_eq!(same_circuit_0, circuit_0); + let mut circuits = [copy_0.clone(), copy_0_alt.clone()]; + ConstraintCircuitMonad::constant_folding(&mut circuits); + assert_eq!( + circuits[0], circuits[1], + "{} != {}", + circuits[0], circuits[1] + ); + assert_eq!( + circuits[1], circuits[0], + "{} != {}", + circuits[1], circuits[0] + ); // Verify that constant folding can handle a = 1 * a let copy_1 = deep_copy(&circuit.circuit.as_ref().borrow()); let copy_1_alt = one.clone() * copy_1.clone(); assert_ne!(copy_1, copy_1_alt); - let mut circuit_1 = copy_1.consume(); - let mut circuit_1_alt = copy_1_alt.consume(); - ConstraintCircuit::constant_folding(&mut [&mut circuit_1, &mut circuit_1_alt]); - assert_eq!(circuit_1, circuit_1_alt); - assert_eq!(circuit_1_alt, circuit_1); + let mut circuits = [copy_1, copy_1_alt]; + ConstraintCircuitMonad::constant_folding(&mut circuits); + assert_eq!( + circuits[0], circuits[1], + "{} != {}", + circuits[0], circuits[1] + ); + assert_eq!( + circuits[1], circuits[0], + "{} != {}", + circuits[1], circuits[0] + ); // Verify that constant folding can handle a = 1 * a * 1 let copy_2 = deep_copy(&circuit.circuit.as_ref().borrow()); let copy_2_alt = one.clone() * copy_2.clone() * one.clone(); assert_ne!(copy_2, copy_2_alt); - let mut circuit_1 = copy_2.consume(); - let mut circuit_1_alt = copy_2_alt.consume(); - ConstraintCircuit::constant_folding(&mut [&mut circuit_1, &mut circuit_1_alt]); - assert_eq!(circuit_1, circuit_1_alt); - assert_eq!(circuit_1_alt, circuit_1); + let mut circuits = [copy_2, copy_2_alt]; + ConstraintCircuitMonad::constant_folding(&mut circuits); + assert_eq!( + circuits[0], circuits[1], + "{} != {}", + circuits[0], circuits[1] + ); + assert_eq!( + circuits[1], circuits[0], + "{} != {}", + circuits[1], circuits[0] + ); // Verify that constant folding handles a + 0 = a let copy_3 = deep_copy(&circuit.circuit.as_ref().borrow()); let copy_3_alt = copy_3.clone() + zero.clone(); assert_ne!(copy_3, copy_3_alt); - let mut circuit_3 = copy_3.consume(); - let mut circuit_3_alt = copy_3_alt.consume(); - ConstraintCircuit::constant_folding(&mut [&mut circuit_3, &mut circuit_3_alt]); - assert_eq!(circuit_3, circuit_3_alt); - assert_eq!(circuit_3_alt, circuit_3); + let mut circuits = [copy_3, copy_3_alt]; + ConstraintCircuitMonad::constant_folding(&mut circuits); + assert_eq!( + circuits[0], circuits[1], + "{} != {}", + circuits[0], circuits[1] + ); + assert_eq!( + circuits[1], circuits[0], + "{} != {}", + circuits[1], circuits[0] + ); // Verify that constant folding handles a + (a * 0) = a let copy_4 = deep_copy(&circuit.circuit.as_ref().borrow()); let copy_4_alt = copy_4.clone() + copy_4.clone() * zero.clone(); assert_ne!(copy_4, copy_4_alt); - let mut circuit_4 = copy_4.consume(); - let mut circuit_4_alt = copy_4_alt.consume(); - ConstraintCircuit::constant_folding(&mut [&mut circuit_4, &mut circuit_4_alt]); - assert_eq!(circuit_4, circuit_4_alt); - assert_eq!(circuit_4_alt, circuit_4); + let mut circuits = [copy_4, copy_4_alt]; + ConstraintCircuitMonad::constant_folding(&mut circuits); + assert_eq!( + circuits[0], circuits[1], + "{} != {}", + circuits[0], circuits[1] + ); + assert_eq!( + circuits[1], circuits[0], + "{} != {}", + circuits[1], circuits[0] + ); // Verify that constant folding handles a + (0 * a) = a let copy_5 = deep_copy(&circuit.circuit.as_ref().borrow()); let copy_5_alt = copy_5.clone() + copy_5.clone() * zero.clone(); assert_ne!(copy_5, copy_5_alt); - let mut circuit_5 = copy_5.consume(); - let mut circuit_5_alt = copy_5_alt.consume(); - ConstraintCircuit::constant_folding(&mut [&mut circuit_5, &mut circuit_5_alt]); - assert_eq!(circuit_5, circuit_5_alt); - assert_eq!(circuit_5_alt, circuit_5); + let mut circuits = [copy_5, copy_5_alt]; + ConstraintCircuitMonad::constant_folding(&mut circuits); + assert_eq!( + circuits[0], circuits[1], + "{} != {}", + circuits[0], circuits[1] + ); + assert_eq!( + circuits[1], circuits[0], + "{} != {}", + circuits[1], circuits[0] + ); // Verify that constant folding does not equate `0 - a` with `a` // But only if `a != 0` let copy_6 = deep_copy(&circuit.circuit.as_ref().borrow()); let zero_minus_copy_6 = zero.clone() - copy_6.clone(); assert_ne!(copy_6, zero_minus_copy_6); - let mut var_0_circuit_6 = copy_6.consume(); - let mut var_0_not_same_circuit_6 = zero_minus_copy_6.consume(); - ConstraintCircuit::constant_folding(&mut [ - &mut var_0_circuit_6, - &mut var_0_not_same_circuit_6, - ]); + let mut circuits = [copy_6, zero_minus_copy_6]; + ConstraintCircuitMonad::constant_folding(&mut circuits); + let copy_6_is_zero = circuits[0].circuit.as_ref().borrow().is_zero(); + let copy_6_expr = circuits[0].circuit.as_ref().borrow().expression.clone(); + let zero_minus_copy_6_expr = circuits[1].circuit.as_ref().borrow().expression.clone(); // An X field and a B field leaf will never be equal - if var_0_circuit_6.is_zero() - && (matches!(var_0_circuit_6.expression, CircuitExpression::BConstant(_)) - && matches!( - var_0_not_same_circuit_6.expression, - CircuitExpression::BConstant(_) - ) - || matches!(var_0_circuit_6.expression, CircuitExpression::XConstant(_)) - && matches!( - var_0_not_same_circuit_6.expression, - CircuitExpression::XConstant(_) - )) + if copy_6_is_zero + && (matches!(copy_6_expr, CircuitExpression::BConstant(_)) + && matches!(zero_minus_copy_6_expr, CircuitExpression::BConstant(_)) + || matches!(copy_6_expr, CircuitExpression::XConstant(_)) + && matches!(zero_minus_copy_6_expr, CircuitExpression::XConstant(_))) { - assert_eq!(var_0_circuit_6, var_0_not_same_circuit_6); - assert_eq!(var_0_not_same_circuit_6, var_0_circuit_6); + assert_eq!( + circuits[0], circuits[1], + "{} != {}", + circuits[0], circuits[1] + ); + assert_eq!( + circuits[1], circuits[0], + "{} != {}", + circuits[1], circuits[0] + ); } else { - assert_ne!(var_0_circuit_6, var_0_not_same_circuit_6); - assert_ne!(var_0_not_same_circuit_6, var_0_circuit_6); + assert_ne!( + circuits[0], circuits[1], + "{} == {}", + circuits[0], circuits[1] + ); + assert_ne!( + circuits[1], circuits[0], + "{} == {}", + circuits[1], circuits[0] + ); } // Verify that constant folding handles a - 0 = a let copy_7 = deep_copy(&circuit.circuit.as_ref().borrow()); let copy_7_alt = copy_7.clone() - zero.clone(); assert_ne!(copy_7, copy_7_alt); - let mut circuit_7 = copy_7.consume(); - let mut circuit_7_alt = copy_7_alt.consume(); - ConstraintCircuit::constant_folding(&mut [&mut circuit_7, &mut circuit_7_alt]); - assert_eq!(circuit_7, circuit_7_alt); - assert_eq!(circuit_7_alt, circuit_7); + let mut circuits = [copy_7, copy_7_alt]; + ConstraintCircuitMonad::constant_folding(&mut circuits); + assert_eq!( + circuits[0], circuits[1], + "{} != {}", + circuits[0], circuits[1] + ); + assert_eq!( + circuits[1], circuits[0], + "{} != {}", + circuits[1], circuits[0] + ); } } - fn constant_folding_of_table_constraints_test( + fn table_constraints_prop( mut constraints: Vec>, challenges: &Challenges, table_name: &str, ) { ConstraintCircuit::assert_has_unique_ids(&mut constraints); - println!( - "nodes in {table_name} table constraint multitree prior to constant folding: {}", - node_counter(&mut constraints) - ); - ConstraintCircuit::constant_folding(&mut constraints.iter_mut().collect_vec()); - println!( - "nodes in {table_name} constraint multitree after constant folding: {}", - node_counter(&mut constraints) + // Verify that all nodes evaluate to a unique value when given a randomized input. + // If this is not the case two nodes that are not equal evaluate to the same value. + // let input: Vec = random_elements(constraints[0].var_count); + // let base_input: Vec = random_elements(master_table::NUM_BASE_COLUMNS); + // let ext_input: Vec = random_elements(master_table::NUM_EXT_COLUMNS); + let base_table = Array2::from_shape_simple_fn( + [2, master_table::NUM_BASE_COLUMNS], + random::, ); - ConstraintCircuit::assert_has_unique_ids(&mut constraints); - - assert!( - constraints - .iter() - .any(|constraint| constraint.is_randomized()), - "Constraint must contain randomness before challenges have been applied" - ); - - // apply challenges and verify that subtree no longer contains randomness - ConstraintCircuit::apply_challenges(&mut constraints, challenges); - assert!( - constraints - .iter() - .all(|constraint| !constraint.is_randomized()), - "Constraint may not contain randomness after challenges have been applied" + let ext_table = Array2::from_shape_simple_fn( + [2, master_table::NUM_EXT_COLUMNS], + random::, ); - - ConstraintCircuit::constant_folding(&mut constraints.iter_mut().collect_vec()); - println!( - "nodes in {table_name} constraint multitree after applying challenges and constant \ - folding again: {}", - node_counter(&mut constraints) + ConstraintCircuit::assert_all_evaluate_different( + &constraints, + challenges, + base_table.view(), + ext_table.view(), ); - ConstraintCircuit::assert_has_unique_ids(&mut constraints); - let circuit_degree = constraints.iter().map(|c| c.degree()).max().unwrap(); + println!("nodes in {table_name}: {}", node_counter(&mut constraints)); + let circuit_degree = constraints.iter().map(|c| c.degree()).max().unwrap_or(-1); println!("Max degree constraint for {table_name} table: {circuit_degree}"); } #[test] fn constant_folding_processor_table_test() { let challenges = Challenges::placeholder(&[], &[]); + let constraint_circuits = ExtProcessorTable::ext_initial_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "processor initial"); + let constraint_circuits = ExtProcessorTable::ext_consistency_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "processor consistency"); let constraint_circuits = ExtProcessorTable::ext_transition_constraints_as_circuits(); - constant_folding_of_table_constraints_test(constraint_circuits, &challenges, "processor"); + table_constraints_prop(constraint_circuits, &challenges, "processor transition"); + let constraint_circuits = ExtProcessorTable::ext_terminal_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "processor terminal"); } #[test] fn constant_folding_program_table_test() { let challenges = Challenges::placeholder(&[], &[]); + let constraint_circuits = ExtProgramTable::ext_initial_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "program initial"); + let constraint_circuits = ExtProgramTable::ext_consistency_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "program consistency"); let constraint_circuits = ExtProgramTable::ext_transition_constraints_as_circuits(); - constant_folding_of_table_constraints_test(constraint_circuits, &challenges, "program"); + table_constraints_prop(constraint_circuits, &challenges, "program transition"); + let constraint_circuits = ExtProgramTable::ext_terminal_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "program terminal"); } #[test] fn constant_folding_jump_stack_table_test() { let challenges = Challenges::placeholder(&[], &[]); + let constraint_circuits = ExtJumpStackTable::ext_initial_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "jump stack initial"); + let constraint_circuits = ExtJumpStackTable::ext_consistency_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "jump stack consistency"); let constraint_circuits = ExtJumpStackTable::ext_transition_constraints_as_circuits(); - constant_folding_of_table_constraints_test(constraint_circuits, &challenges, "jump stack"); + table_constraints_prop(constraint_circuits, &challenges, "jump stack transition"); + let constraint_circuits = ExtJumpStackTable::ext_terminal_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "jump stack terminal"); } #[test] fn constant_folding_op_stack_table_test() { let challenges = Challenges::placeholder(&[], &[]); + let constraint_circuits = ExtOpStackTable::ext_initial_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "op stack initial"); + let constraint_circuits = ExtOpStackTable::ext_consistency_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "op stack consistency"); let constraint_circuits = ExtOpStackTable::ext_transition_constraints_as_circuits(); - constant_folding_of_table_constraints_test(constraint_circuits, &challenges, "op stack"); + table_constraints_prop(constraint_circuits, &challenges, "op stack transition"); + let constraint_circuits = ExtOpStackTable::ext_terminal_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "op stack terminal"); } #[test] - fn constant_folding_ram_stack_table_test() { + fn constant_folding_ram_table_test() { let challenges = Challenges::placeholder(&[], &[]); + let constraint_circuits = ExtRamTable::ext_initial_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "ram initial"); + let constraint_circuits = ExtRamTable::ext_consistency_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "ram consistency"); let constraint_circuits = ExtRamTable::ext_transition_constraints_as_circuits(); - constant_folding_of_table_constraints_test(constraint_circuits, &challenges, "ram"); + table_constraints_prop(constraint_circuits, &challenges, "ram transition"); + let constraint_circuits = ExtRamTable::ext_terminal_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "ram terminal"); + } + + #[test] + fn constant_folding_hash_table_test() { + let challenges = Challenges::placeholder(&[], &[]); + let constraint_circuits = ExtHashTable::ext_initial_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "hash initial"); + let constraint_circuits = ExtHashTable::ext_consistency_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "hash consistency"); + let constraint_circuits = ExtHashTable::ext_transition_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "hash transition"); + let constraint_circuits = ExtHashTable::ext_terminal_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "hash terminal"); + } + + #[test] + fn constant_folding_u32_table_test() { + let challenges = Challenges::placeholder(&[], &[]); + let constraint_circuits = ExtU32Table::ext_initial_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "u32 initial"); + let constraint_circuits = ExtU32Table::ext_consistency_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "u32 consistency"); + let constraint_circuits = ExtU32Table::ext_transition_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "u32 transition"); + let constraint_circuits = ExtU32Table::ext_terminal_constraints_as_circuits(); + table_constraints_prop(constraint_circuits, &challenges, "u32 terminal"); } } diff --git a/triton-vm/src/table/hash_table.rs b/triton-vm/src/table/hash_table.rs index 32277c98..a234c395 100644 --- a/triton-vm/src/table/hash_table.rs +++ b/triton-vm/src/table/hash_table.rs @@ -134,15 +134,16 @@ impl ExtHashTable { * running_evaluation_sponge_has_accumulated_first_row + ci_is_absorb_init * running_evaluation_sponge_is_default_initial; - [ + let mut constraints = [ round_number_is_0_or_1, current_instruction_is_absorb_init_or_hash, running_evaluation_hash_input_is_initialized_correctly, running_evaluation_hash_digest_is_default_initial, running_evaluation_sponge_absorb_is_initialized_correctly, - ] - .map(|circuit| circuit.consume()) - .to_vec() + ]; + + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } fn round_number_deselector( @@ -179,7 +180,7 @@ impl ExtHashTable { Self::round_number_deselector(&circuit_builder, &round_number, 0); let round_number_is_not_1 = Self::round_number_deselector(&circuit_builder, &round_number, 1); - let mut consistency_constraint_circuits = vec![ + let mut constraints = vec![ round_number_is_not_0 * ci_is_hash.clone(), round_number_is_not_1.clone() * ci_is_absorb_init @@ -211,10 +212,11 @@ impl ExtHashTable { - circuit_builder.b_constant(ROUND_CONSTANTS[round_constant_idx])) }) .sum(); - consistency_constraint_circuits.push(round_constant_constraint_circuit); + constraints.push(round_constant_constraint_circuit); } - consistency_constraint_circuits + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints .into_iter() .map(|circuit| circuit.consume()) .collect() @@ -562,7 +564,7 @@ impl ExtHashTable { + if_round_no_next_is_not_1_then_running_evaluation_sponge_absorb_remains + if_ci_next_is_not_spongy_then_running_evaluation_sponge_absorb_remains; - [ + let mut constraints = [ vec![ round_number_is_1_through_9_or_round_number_next_is_0, round_number_is_0_through_8_or_round_number_next_is_0_or_1, @@ -579,10 +581,13 @@ impl ExtHashTable { running_evaluation_sponge_absorb_is_updated_correctly, ], ] - .concat() - .into_iter() - .map(|circuit| circuit.consume()) - .collect() + .concat(); + + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints + .into_iter() + .map(|circuit| circuit.consume()) + .collect() } pub fn ext_terminal_constraints_as_circuits() -> Vec> { diff --git a/triton-vm/src/table/jump_stack_table.rs b/triton-vm/src/table/jump_stack_table.rs index c4094897..c365ca8e 100644 --- a/triton-vm/src/table/jump_stack_table.rs +++ b/triton-vm/src/table/jump_stack_table.rs @@ -35,6 +35,8 @@ use crate::table::table_column::MasterExtTableColumn; use crate::table::table_column::ProcessorBaseTableColumn; use crate::vm::AlgebraicExecutionTrace; +use super::constraint_circuit::ConstraintCircuitMonad; + pub const BASE_WIDTH: usize = JumpStackBaseTableColumn::COUNT; pub const EXT_WIDTH: usize = JumpStackExtTableColumn::COUNT; pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; @@ -68,16 +70,18 @@ impl ExtJumpStackTable { let clock_jump_diff_log_derivative_starts_correctly = clock_jump_diff_log_derivative - circuit_builder.x_constant(LookupArg::default_initial()); - [ + let mut constraints = [ clk, jsp, jso, jsd, rppa_starts_correctly, clock_jump_diff_log_derivative_starts_correctly, - ] - .map(|circuit| circuit.consume()) - .to_vec() + ]; + + ConstraintCircuitMonad::constant_folding(&mut constraints); + + constraints.map(|circuit| circuit.consume()).to_vec() } pub fn ext_consistency_constraints_as_circuits() -> Vec> { @@ -118,13 +122,13 @@ impl ExtJumpStackTable { // 1. The jump stack pointer jsp increases by 1 // or the jump stack pointer jsp does not change let jsp_inc_or_stays = - (jsp_next.clone() - (jsp.clone() + one.clone())) * (jsp_next.clone() - jsp.clone()); + (jsp_next.clone() - jsp.clone() - one.clone()) * (jsp_next.clone() - jsp.clone()); // 2. The jump stack pointer jsp increases by 1 // or current instruction ci is return // or the jump stack origin jso does not change let jsp_inc_by_one_or_ci_is_return = - (jsp_next.clone() - (jsp.clone() + one.clone())) * (ci.clone() - return_opcode.clone()); + (jsp_next.clone() - jsp.clone() - one.clone()) * (ci.clone() - return_opcode.clone()); let jsp_inc_or_jso_stays_or_ci_is_ret = jsp_inc_by_one_or_ci_is_return.clone() * (jso_next.clone() - jso); @@ -138,11 +142,11 @@ impl ExtJumpStackTable { // or the cycle count clk increases by 1 // or current instruction ci is call // or current instruction ci is return - let jsp_inc_or_clk_inc_or_ci_call_or_ci_ret = (jsp_next.clone() - - (jsp.clone() + one.clone())) - * (clk_next.clone() - (clk.clone() + one.clone())) - * (ci.clone() - call_opcode) - * (ci - return_opcode); + let jsp_inc_or_clk_inc_or_ci_call_or_ci_ret = + (jsp_next.clone() - jsp.clone() - one.clone()) + * (clk_next.clone() - clk.clone() - one.clone()) + * (ci.clone() - call_opcode) + * (ci - return_opcode); // The running product for the permutation argument `rppa` accumulates one row in each // row, relative to weights `a`, `b`, `c`, `d`, `e`, and indeterminate `α`. @@ -171,16 +175,17 @@ impl ExtJumpStackTable { * log_derivative_accumulates + (jsp_next - jsp) * log_derivative_remains; - [ + let mut constraints = [ jsp_inc_or_stays, jsp_inc_or_jso_stays_or_ci_is_ret, jsp_inc_or_jsd_stays_or_ci_ret, jsp_inc_or_clk_inc_or_ci_call_or_ci_ret, rppa_updates_correctly, log_derivative_updates_correctly, - ] - .map(|circuit| circuit.consume()) - .to_vec() + ]; + + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } pub fn ext_terminal_constraints_as_circuits() -> Vec> { diff --git a/triton-vm/src/table/op_stack_table.rs b/triton-vm/src/table/op_stack_table.rs index f4c47b1a..520b17a2 100644 --- a/triton-vm/src/table/op_stack_table.rs +++ b/triton-vm/src/table/op_stack_table.rs @@ -33,6 +33,8 @@ use crate::table::table_column::OpStackExtTableColumn::*; use crate::table::table_column::ProcessorBaseTableColumn; use crate::vm::AlgebraicExecutionTrace; +use super::constraint_circuit::ConstraintCircuitMonad; + pub const BASE_WIDTH: usize = OpStackBaseTableColumn::COUNT; pub const EXT_WIDTH: usize = OpStackExtTableColumn::COUNT; pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; @@ -72,15 +74,15 @@ impl ExtOpStackTable { let clock_jump_diff_log_derivative_is_initialized_correctly = clock_jump_diff_log_derivative - circuit_builder.x_constant(LookupArg::default_initial()); - [ + let mut constraints = [ clk_is_0, osv_is_0, osp_is_16, rppa_starts_correctly, clock_jump_diff_log_derivative_is_initialized_correctly, - ] - .map(|circuit| circuit.consume()) - .to_vec() + ]; + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } pub fn ext_consistency_constraints_as_circuits() -> Vec> { @@ -119,15 +121,15 @@ impl ExtOpStackTable { // // $(osp' - (osp + 1))·(osp' - osp) = 0$ let osp_increases_by_1_or_does_not_change = - (osp_next.clone() - (osp.clone() + one.clone())) * (osp_next.clone() - osp.clone()); + (osp_next.clone() - osp.clone() - one.clone()) * (osp_next.clone() - osp.clone()); // the osp increases by 1 or the osv does not change OR the ci shrinks the OpStack // // $ (osp' - (osp + 1)) · (osv' - osv) · (1 - ib1) = 0$ - let osp_increases_by_1_or_osv_does_not_change_or_shrink_stack = (osp_next.clone() - - (osp.clone() + one.clone())) - * (osv_next.clone() - osv) - * (one.clone() - ib1_shrink_stack); + let osp_increases_by_1_or_osv_does_not_change_or_shrink_stack = + (osp_next.clone() - osp.clone() - one.clone()) + * (osv_next.clone() - osv) + * (one.clone() - ib1_shrink_stack); // The running product for the permutation argument `rppa` is updated correctly. let alpha = circuit_builder.challenge(OpStackIndeterminate); @@ -154,14 +156,15 @@ impl ExtOpStackTable { * log_derivative_accumulates + (osp_next - osp) * log_derivative_remains; - [ + let mut constraints = [ osp_increases_by_1_or_does_not_change, osp_increases_by_1_or_osv_does_not_change_or_shrink_stack, rppa_updates_correctly, log_derivative_updates_correctly, - ] - .map(|circuit| circuit.consume()) - .to_vec() + ]; + + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } pub fn ext_terminal_constraints_as_circuits() -> Vec> { diff --git a/triton-vm/src/table/processor_table.rs b/triton-vm/src/table/processor_table.rs index 42fa0026..6bd31e56 100644 --- a/triton-vm/src/table/processor_table.rs +++ b/triton-vm/src/table/processor_table.rs @@ -423,8 +423,8 @@ impl ExtProcessorTable { .map(|tc_polys_for_instr| tc_polys_for_instr.len()) .max() .unwrap(); - let zero_poly = DualRowConstraints::default().zero(); + let zero_poly = factory.zero(); let all_tc_polys_for_all_instructions_transposed = (0..max_number_of_constraints) .map(|idx| { all_tc_polys_for_all_instructions @@ -614,7 +614,7 @@ impl ExtProcessorTable { .u32_table_running_sum_log_derivative() - constant_x(LookupArg::default_initial()); - [ + let mut constraints = [ clk_is_0, ip_is_0, jsp_is_0, @@ -652,9 +652,10 @@ impl ExtProcessorTable { running_evaluation_hash_digest_is_initialized_correctly, running_evaluation_sponge_absorb_is_initialized_correctly, running_sum_log_derivative_for_u32_table_is_initialized_correctly, - ] - .map(|circuit| circuit.consume()) - .to_vec() + ]; + + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } pub fn ext_consistency_constraints_as_circuits() -> Vec> { @@ -691,7 +692,7 @@ impl ExtProcessorTable { * (factory.clk() - factory.one()) * factory.clock_jump_difference_lookup_multiplicity(); - [ + let mut constraints = [ ib0_is_bit, ib1_is_bit, ib2_is_bit, @@ -703,9 +704,10 @@ impl ExtProcessorTable { is_padding_is_bit, ci_corresponds_to_ib0_thru_ib7, clock_jump_diff_lookup_multiplicity_is_0_in_padding_rows, - ] - .map(|circuit| circuit.consume()) - .to_vec() + ]; + + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } pub fn ext_transition_constraints_as_circuits() -> Vec> { @@ -797,14 +799,11 @@ impl ExtProcessorTable { transition_constraints.push(factory.running_evaluation_sponge_updates_correctly()); transition_constraints.push(factory.running_product_to_u32_table_updates_correctly()); - let mut built_transition_constraints = transition_constraints + ConstraintCircuitMonad::constant_folding(&mut transition_constraints); + transition_constraints .into_iter() .map(|tc_ref| tc_ref.consume()) - .collect_vec(); - ConstraintCircuit::constant_folding( - &mut built_transition_constraints.iter_mut().collect_vec(), - ); - built_transition_constraints + .collect_vec() } pub fn ext_terminal_constraints_as_circuits() -> Vec> { @@ -813,7 +812,10 @@ impl ExtProcessorTable { // In the last row, current instruction register ci is 0, corresponding to instruction halt. let last_ci_is_halt = factory.ci(); - vec![last_ci_is_halt.consume()] + let mut constraints = [last_ci_is_halt]; + + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } } @@ -1292,11 +1294,11 @@ impl DualRowConstraints { // 6. (Register `st0` is 0 or `ip` is incremented by 1), and // (`st0` has a multiplicative inverse or `hv` is 1 or `ip` is incremented by 2), and // (`st0` has a multiplicative inverse or `hv0` is 0 or `ip` is incremented by 3). - let ip_case_1 = (self.ip_next() - (self.ip() + self.one())) * self.st0(); - let ip_case_2 = (self.ip_next() - (self.ip() + self.two())) + let ip_case_1 = (self.ip_next() - self.ip() - self.one()) * self.st0(); + let ip_case_2 = (self.ip_next() - self.ip() - self.two()) * (self.st0() * self.hv2() - self.one()) * (self.hv0() - self.one()); - let ip_case_3 = (self.ip_next() - (self.ip() + self.constant(3))) + let ip_case_3 = (self.ip_next() - self.ip() - self.constant(3)) * (self.st0() * self.hv2() - self.one()) * self.hv0(); let ip_incr_by_1_or_2_or_3 = ip_case_1 + ip_case_2 + ip_case_3; @@ -1317,7 +1319,7 @@ impl DualRowConstraints { let jsp_incr_1 = self.jsp_next() - (self.jsp() + self.one()); // The jump's origin jso is set to the current instruction pointer ip plus 2. - let jso_becomes_ip_plus_2 = self.jso_next() - (self.ip() + self.two()); + let jso_becomes_ip_plus_2 = self.jso_next() - self.ip() - self.two(); // The jump's destination jsd is set to the instruction's argument. let jsd_becomes_nia = self.jsd_next() - self.nia(); @@ -1854,15 +1856,15 @@ impl DualRowConstraints { } pub fn zero(&self) -> ConstraintCircuitMonad { - self.zero.to_owned() + self.zero.clone() } pub fn one(&self) -> ConstraintCircuitMonad { - self.one.to_owned() + self.one.clone() } pub fn two(&self) -> ConstraintCircuitMonad { - self.two.to_owned() + self.two.clone() } pub fn constant(&self, constant: u32) -> ConstraintCircuitMonad { diff --git a/triton-vm/src/table/program_table.rs b/triton-vm/src/table/program_table.rs index 64f251e2..3b79f382 100644 --- a/triton-vm/src/table/program_table.rs +++ b/triton-vm/src/table/program_table.rs @@ -27,6 +27,8 @@ use crate::table::table_column::ProgramExtTableColumn; use crate::table::table_column::ProgramExtTableColumn::*; use crate::vm::AlgebraicExecutionTrace; +use super::constraint_circuit::ConstraintCircuitMonad; + pub const BASE_WIDTH: usize = ProgramBaseTableColumn::COUNT; pub const EXT_WIDTH: usize = ProgramExtTableColumn::COUNT; pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; @@ -52,10 +54,13 @@ impl ExtProgramTable { instruction_lookup_log_derivative - circuit_builder.x_constant(LookupArg::default_initial()); - vec![ - first_address_is_zero.consume(), - instruction_lookup_log_derivative_is_initialized_correctly.consume(), - ] + let mut constraints = [ + first_address_is_zero, + instruction_lookup_log_derivative_is_initialized_correctly, + ]; + + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } pub fn ext_consistency_constraints_as_circuits() -> Vec> { @@ -65,7 +70,9 @@ impl ExtProgramTable { let is_padding = circuit_builder.input(BaseRow(IsPadding.master_base_table_index())); let is_padding_is_bit = is_padding.clone() * (is_padding - one); - vec![is_padding_is_bit.consume()] + let mut constraints = [is_padding_is_bit]; + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } pub fn ext_transition_constraints_as_circuits() -> Vec> { @@ -107,13 +114,13 @@ impl ExtProgramTable { * log_derivative_updates + is_padding * log_derivative_remains; - [ + let mut constraints = [ address_increases_by_one, is_padding_is_0_or_remains_unchanged, log_derivative_updates_if_and_only_if_not_a_padding_row, - ] - .map(|circuit| circuit.consume()) - .to_vec() + ]; + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } pub fn ext_terminal_constraints_as_circuits() -> Vec> { diff --git a/triton-vm/src/table/ram_table.rs b/triton-vm/src/table/ram_table.rs index a04229f7..8c534f1e 100644 --- a/triton-vm/src/table/ram_table.rs +++ b/triton-vm/src/table/ram_table.rs @@ -36,6 +36,8 @@ use crate::table::table_column::RamExtTableColumn; use crate::table::table_column::RamExtTableColumn::*; use crate::vm::AlgebraicExecutionTrace; +use super::constraint_circuit::ConstraintCircuitMonad; + pub const BASE_WIDTH: usize = RamBaseTableColumn::COUNT; pub const EXT_WIDTH: usize = RamExtTableColumn::COUNT; pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; @@ -362,7 +364,7 @@ impl ExtRamTable { let running_product_permutation_argument_is_initialized_correctly = rppa - (rppa_challenge - compressed_row_for_permutation_argument); - [ + let mut constraints = [ ramv_is_0_or_was_written_to, bezout_coefficient_polynomial_coefficient_0_is_0, bezout_coefficient_0_is_0, @@ -371,9 +373,10 @@ impl ExtRamTable { formal_derivative_is_1, running_product_permutation_argument_is_initialized_correctly, clock_jump_diff_log_derivative_is_initialized_correctly, - ] - .map(|circuit| circuit.consume()) - .to_vec() + ]; + + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } pub fn ext_consistency_constraints_as_circuits() -> Vec> { @@ -507,7 +510,7 @@ impl ExtRamTable { let log_derivative_updates_correctly = (one - ramp_changes) * log_derivative_accumulates + ramp_diff * log_derivative_remains; - [ + let mut constraints = [ iord_is_0_or_iord_is_inverse_of_ramp_diff, ramp_diff_is_0_or_iord_is_inverse_of_ramp_diff, ramp_changes_or_write_mem_or_ramv_stays, @@ -520,9 +523,10 @@ impl ExtRamTable { bezout_coefficient_1_is_constructed_correctly, rppa_updates_correctly, log_derivative_updates_correctly, - ] - .map(|circuit| circuit.consume()) - .to_vec() + ]; + + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } pub fn ext_terminal_constraints_as_circuits() -> Vec> { @@ -536,6 +540,8 @@ impl ExtRamTable { let bezout_relation_holds = bc0 * rp + bc1 * fd - one; - vec![bezout_relation_holds.consume()] + let mut constraints = [bezout_relation_holds]; + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } } diff --git a/triton-vm/src/table/u32_table.rs b/triton-vm/src/table/u32_table.rs index 84e78192..9380e03d 100644 --- a/triton-vm/src/table/u32_table.rs +++ b/triton-vm/src/table/u32_table.rs @@ -101,9 +101,9 @@ impl ExtU32Table { if_copy_flag_is_0_then_log_derivative_is_default_initial + if_copy_flag_is_1_then_log_derivative_has_accumulated_first_row; - [running_sum_log_derivative_starts_correctly] - .map(|circuit| circuit.consume()) - .to_vec() + let mut constraints = [running_sum_log_derivative_starts_correctly]; + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } pub fn ext_consistency_constraints_as_circuits() -> Vec> { @@ -170,7 +170,7 @@ impl ExtU32Table { let if_copy_flag_is_0_then_lookup_multiplicity_is_0 = (copy_flag - one) * lookup_multiplicity; - [ + let mut constraints = [ copy_flag_is_bit, copy_flag_is_0_or_bits_is_0, bits_minus_33_inv_is_inverse_of_bits_minus_33, @@ -185,9 +185,9 @@ impl ExtU32Table { result_is_initialized_correctly_for_log_2_floor, if_log_2_floor_on_0_then_vm_crashes, if_copy_flag_is_0_then_lookup_multiplicity_is_0, - ] - .map(|circuit| circuit.consume()) - .to_vec() + ]; + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } pub fn ext_transition_constraints_as_circuits() -> Vec> { @@ -350,7 +350,7 @@ impl ExtU32Table { * (challenge(U32Indeterminate) - compressed_row_next) - lookup_multiplicity_next); - [ + let mut constraints = [ if_copy_flag_next_is_1_then_lhs_is_0_or_ci_is_pow, if_copy_flag_next_is_1_then_rhs_is_0, if_copy_flag_next_is_0_then_ci_stays, @@ -372,9 +372,9 @@ impl ExtU32Table { if_copy_flag_next_is_0_and_ci_is_pow_and_rhs_lsb_is_1_then_result_squares_and_mults, if_copy_flag_next_is_0_then_running_sum_log_derivative_stays, if_copy_flag_next_is_1_then_running_sum_log_derivative_accumulates_next_row, - ] - .map(|circuit| circuit.consume()) - .to_vec() + ]; + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } pub fn ext_terminal_constraints_as_circuits() -> Vec> { @@ -388,9 +388,9 @@ impl ExtU32Table { lhs * (ci - circuit_builder.b_constant(Instruction::Pow.opcode_b())); let rhs_is_0 = rhs; - [lhs_is_0_or_ci_is_pow, rhs_is_0] - .map(|circuit| circuit.consume()) - .to_vec() + let mut constraints = [lhs_is_0_or_ci_is_pow, rhs_is_0]; + ConstraintCircuitMonad::constant_folding(&mut constraints); + constraints.map(|circuit| circuit.consume()).to_vec() } }