diff --git a/Cargo.toml b/Cargo.toml index c3fc50f5..3fb30630 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ documentation = "https://triton-vm.org/spec/" anyhow = "1.0" arbitrary = { version = "1", features = ["derive"] } assert2 = "0.3" +blake3 = "1.5.1" colored = "2.1" clap = { version = "4", features = ["derive", "cargo", "wrap_help", "unicode", "string"] } criterion = { version = "0.5", features = ["html_reports"] } diff --git a/triton-vm/Cargo.toml b/triton-vm/Cargo.toml index 7d62b2bb..e3577e9b 100644 --- a/triton-vm/Cargo.toml +++ b/triton-vm/Cargo.toml @@ -19,6 +19,7 @@ readme.workspace = true [dependencies] arbitrary.workspace = true +blake3.workspace = true colored.workspace = true criterion.workspace = true get-size.workspace = true diff --git a/triton-vm/src/table/constraint_circuit.rs b/triton-vm/src/table/constraint_circuit.rs index bca7cbba..ecf3755e 100644 --- a/triton-vm/src/table/constraint_circuit.rs +++ b/triton-vm/src/table/constraint_circuit.rs @@ -31,6 +31,8 @@ use num_traits::One; use num_traits::Zero; use quote::quote; use quote::ToTokens; +use rand::thread_rng; +use rand::Rng; use twenty_first::prelude::*; use CircuitExpression::*; @@ -652,7 +654,7 @@ impl ConstraintCircuitMonad { self.circuit.borrow().to_owned() } - fn find_equivalent_expression(&self) -> Option>>> { + fn find_equivalent_reduced_expression(&self) -> Option>>> { let BinaryOperation(op, lhs, rhs) = &self.circuit.borrow().expression else { return None; }; @@ -702,7 +704,7 @@ impl ConstraintCircuitMonad { let (change_in_rhs, _) = rhs_as_monadic_value.constant_fold_inner(); change_tracker |= change_in_rhs; - let equivalent_circuit = self.find_equivalent_expression(); + let equivalent_circuit = self.find_equivalent_reduced_expression(); if let Some(ref circuit) = equivalent_circuit { change_tracker = true; let id_to_remove = self.circuit.borrow().id; @@ -727,6 +729,123 @@ impl ConstraintCircuitMonad { } } + /// Traverse the circuit and find all nodes that are equivalent. Note that + /// two nodes are equivalent if they compute the same value on all identical + /// inputs. Equivalence is different from identity, which is when two nodes + /// connect the same set of neighbors in the same way. (There may be two + /// different ways to compute the same result; they are equivalent but + /// unequal.) + /// + /// This function returns a list of lists of equivalent nodes such that + /// every inner list can be reduced to a single node without changing the + /// circuit's function. + /// + /// Equivalent nodes are detected probabilistically using the multivariate + /// Schwartz-Zippel lemma. The false positive probability is zero (we can be + /// certain that equivalent nodes will be found). The false negative + /// probability is bounded by max_degree / (2^64 - 2^32 + 1)^3. + pub fn find_equivalent_nodes(&self) -> Vec>>>> { + let mut values: HashMap = HashMap::new(); + let mut ids: HashMap> = HashMap::new(); + let mut nodes: HashMap>>> = HashMap::new(); + let seed: [u8; 32] = thread_rng().gen(); + Self::probe_random( + self.circuit.clone(), + &mut values, + &mut ids, + &mut nodes, + seed, + ); + + ids.values() + .filter(|l| l.len() >= 2) + .cloned() + .map(|l| l.iter().map(|i| nodes[i].clone()).collect_vec()) + .collect_vec() + } + + /// Populate the dictionaries such that they associate with every node in + /// the circuit its evaluation in a random point. The inputs are assigned + /// random values. Equivalent nodes are detected based on evaluating to the + /// same value using the Schwartz-Zippel lemma. + fn probe_random( + circuit: Rc>>, + values: &mut HashMap, + ids: &mut HashMap>, + nodes: &mut HashMap>>>, + master_seed: [u8; 32], + ) { + // the node was already touched; nothing to do + if values.contains_key(&circuit.borrow().id) { + return; + } + + // compute the node's value; recurse if necessary + let value = match &circuit.borrow().expression { + BConstant(bfe) => bfe.lift(), + XConstant(xfe) => *xfe, + Input(input) => { + let mut hasher = blake3::Hasher::new(); + hasher.update(&master_seed); + hasher.update(b"input"); + hasher.update(&usize::from(input.is_base_table_column()).to_ne_bytes()); + hasher.update(&input.column().to_ne_bytes()); + let mut output = [0u8; 24]; + hasher.finalize_xof().fill(&mut output); + let x0 = BFieldElement::from_ne_bytes(&output[0..8]); + let x1 = BFieldElement::from_ne_bytes(&output[8..16]); + let x2 = BFieldElement::from_ne_bytes(&output[16..24]); + + XFieldElement::new([x0, x1, x2]) + } + Challenge(challenge) => { + let mut hasher = blake3::Hasher::new(); + hasher.update(&master_seed); + hasher.update(b"challenge"); + hasher.update(&challenge.to_ne_bytes()); + let mut output = [0u8; 24]; + hasher.finalize_xof().fill(&mut output); + let x0 = BFieldElement::from_ne_bytes(&output[0..8]); + let x1 = BFieldElement::from_ne_bytes(&output[8..16]); + let x2 = BFieldElement::from_ne_bytes(&output[16..24]); + + XFieldElement::new([x0, x1, x2]) + } + BinaryOperation(operation, lhs, rhs) => { + // if lhs or rhs wasn't touched yet, recurse + if !values.contains_key(&lhs.borrow().id) { + Self::probe_random(lhs.clone(), values, ids, nodes, master_seed); + } + if !values.contains_key(&rhs.borrow().id) { + Self::probe_random(rhs.clone(), values, ids, nodes, master_seed); + } + + // lookup values + let lhs_value = *values.get(&lhs.borrow().id).unwrap(); + let rhs_value = *values.get(&rhs.borrow().id).unwrap(); + + // combine using appropriate operator + match operation { + BinOp::Add => lhs_value + rhs_value, + BinOp::Mul => lhs_value * rhs_value, + } + } + }; + + // value already exists; keep books + if let Some(peers) = ids.get_mut(&value) { + values.insert(circuit.borrow().id, value); + peers.push(circuit.borrow().id); + nodes.insert(circuit.borrow().id, circuit.clone()); + } + // value is new; keep books + else { + values.insert(circuit.borrow().id, value); + ids.insert(value, vec![circuit.borrow().id]); + nodes.insert(circuit.borrow().id, circuit.clone()); + } + } + /// Lowers the degree of a given multicircuit to the target degree. /// This is achieved by introducing additional variables and constraints. /// The appropriate substitutions are applied to the given multicircuit. @@ -2137,18 +2256,31 @@ mod tests { } #[test] - #[ignore = "requires a proper debugging session, or maybe additional optimizations"] - fn constraint_circuit_builder_reuses_existing_nodes_when_folding_constants() { + fn equivalent_nodes_are_detected_when_present() { let builder = ConstraintCircuitBuilder::new(); - let constant = |c| builder.b_constant(c); - let base_row = |r| builder.input(BaseRow(r)); - let c_0 = base_row(0); - let c_1 = base_row(0) - constant(0); - let mut constraints = [c_0, c_1]; + let x = |i| builder.input(BaseRow(i)); + let ch = |i: usize| builder.challenge(i); - ConstraintCircuitMonad::constant_folding(&mut constraints); - let mut constraints = constraints.iter().map(|c| c.consume()).collect_vec(); - ConstraintCircuit::assert_unique_ids(&mut constraints); + let x0 = x(0); + let x1 = x(1); + let y0 = x(2); + let y1 = x(3); + let ch0 = ch(0); + let ch1 = ch(1); + + let u0 = x0.clone() + x1.clone(); + let u1 = y0.clone() + y1.clone(); + let v = u0 * u1; + + let z0 = x0.clone() * y0.clone(); + let z2 = x1.clone() * y1.clone(); + + let z1 = x1.clone() * y0.clone() + x0.clone() * y1.clone(); + let w = v - z0 - z2; + assert!(w.find_equivalent_nodes().is_empty()); + + let o = ch0 * z1 - ch1 * w; + assert!(!o.find_equivalent_nodes().is_empty()); } }