Skip to content

Commit

Permalink
feat: Detect equivalent nodes in constraint circuit
Browse files Browse the repository at this point in the history
  • Loading branch information
aszepieniec authored and jan-ferdinand committed May 2, 2024
1 parent 79ceb10 commit 17c5b61
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 12 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ documentation = "https://triton-vm.org/spec/"
anyhow = "1.0"
arbitrary = { version = "1", features = ["derive"] }
assert2 = "0.3"
blake3 = "1.5.1"
colored = "2.1"
clap = { version = "4", features = ["derive", "cargo", "wrap_help", "unicode", "string"] }
criterion = { version = "0.5", features = ["html_reports"] }
Expand Down
1 change: 1 addition & 0 deletions triton-vm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ readme.workspace = true

[dependencies]
arbitrary.workspace = true
blake3.workspace = true
colored.workspace = true
criterion.workspace = true
get-size.workspace = true
Expand Down
156 changes: 144 additions & 12 deletions triton-vm/src/table/constraint_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ use num_traits::One;
use num_traits::Zero;
use quote::quote;
use quote::ToTokens;
use rand::thread_rng;
use rand::Rng;
use twenty_first::prelude::*;

use CircuitExpression::*;
Expand Down Expand Up @@ -652,7 +654,7 @@ impl<II: InputIndicator> ConstraintCircuitMonad<II> {
self.circuit.borrow().to_owned()
}

fn find_equivalent_expression(&self) -> Option<Rc<RefCell<ConstraintCircuit<II>>>> {
fn find_equivalent_reduced_expression(&self) -> Option<Rc<RefCell<ConstraintCircuit<II>>>> {
let BinaryOperation(op, lhs, rhs) = &self.circuit.borrow().expression else {
return None;
};
Expand Down Expand Up @@ -702,7 +704,7 @@ impl<II: InputIndicator> ConstraintCircuitMonad<II> {
let (change_in_rhs, _) = rhs_as_monadic_value.constant_fold_inner();
change_tracker |= change_in_rhs;

let equivalent_circuit = self.find_equivalent_expression();
let equivalent_circuit = self.find_equivalent_reduced_expression();
if let Some(ref circuit) = equivalent_circuit {
change_tracker = true;
let id_to_remove = self.circuit.borrow().id;
Expand All @@ -727,6 +729,123 @@ impl<II: InputIndicator> ConstraintCircuitMonad<II> {
}
}

/// Traverse the circuit and find all nodes that are equivalent. Note that
/// two nodes are equivalent if they compute the same value on all identical
/// inputs. Equivalence is different from identity, which is when two nodes
/// connect the same set of neighbors in the same way. (There may be two
/// different ways to compute the same result; they are equivalent but
/// unequal.)
///
/// This function returns a list of lists of equivalent nodes such that
/// every inner list can be reduced to a single node without changing the
/// circuit's function.
///
/// Equivalent nodes are detected probabilistically using the multivariate
/// Schwartz-Zippel lemma. The false positive probability is zero (we can be
/// certain that equivalent nodes will be found). The false negative
/// probability is bounded by max_degree / (2^64 - 2^32 + 1)^3.
pub fn find_equivalent_nodes(&self) -> Vec<Vec<Rc<RefCell<ConstraintCircuit<II>>>>> {
let mut values: HashMap<usize, XFieldElement> = HashMap::new();
let mut ids: HashMap<XFieldElement, Vec<usize>> = HashMap::new();
let mut nodes: HashMap<usize, Rc<RefCell<ConstraintCircuit<II>>>> = HashMap::new();
let seed: [u8; 32] = thread_rng().gen();
Self::probe_random(
self.circuit.clone(),
&mut values,
&mut ids,
&mut nodes,
seed,
);

ids.values()
.filter(|l| l.len() >= 2)
.cloned()
.map(|l| l.iter().map(|i| nodes[i].clone()).collect_vec())
.collect_vec()
}

/// Populate the dictionaries such that they associate with every node in
/// the circuit its evaluation in a random point. The inputs are assigned
/// random values. Equivalent nodes are detected based on evaluating to the
/// same value using the Schwartz-Zippel lemma.
fn probe_random(
circuit: Rc<RefCell<ConstraintCircuit<II>>>,
values: &mut HashMap<usize, XFieldElement>,
ids: &mut HashMap<XFieldElement, Vec<usize>>,
nodes: &mut HashMap<usize, Rc<RefCell<ConstraintCircuit<II>>>>,
master_seed: [u8; 32],
) {
// the node was already touched; nothing to do
if values.contains_key(&circuit.borrow().id) {
return;
}

// compute the node's value; recurse if necessary
let value = match &circuit.borrow().expression {
BConstant(bfe) => bfe.lift(),
XConstant(xfe) => *xfe,
Input(input) => {
let mut hasher = blake3::Hasher::new();
hasher.update(&master_seed);
hasher.update(b"input");
hasher.update(&usize::from(input.is_base_table_column()).to_ne_bytes());
hasher.update(&input.column().to_ne_bytes());
let mut output = [0u8; 24];
hasher.finalize_xof().fill(&mut output);
let x0 = BFieldElement::from_ne_bytes(&output[0..8]);
let x1 = BFieldElement::from_ne_bytes(&output[8..16]);
let x2 = BFieldElement::from_ne_bytes(&output[16..24]);

XFieldElement::new([x0, x1, x2])
}
Challenge(challenge) => {
let mut hasher = blake3::Hasher::new();
hasher.update(&master_seed);
hasher.update(b"challenge");
hasher.update(&challenge.to_ne_bytes());
let mut output = [0u8; 24];
hasher.finalize_xof().fill(&mut output);
let x0 = BFieldElement::from_ne_bytes(&output[0..8]);
let x1 = BFieldElement::from_ne_bytes(&output[8..16]);
let x2 = BFieldElement::from_ne_bytes(&output[16..24]);

XFieldElement::new([x0, x1, x2])
}
BinaryOperation(operation, lhs, rhs) => {
// if lhs or rhs wasn't touched yet, recurse
if !values.contains_key(&lhs.borrow().id) {
Self::probe_random(lhs.clone(), values, ids, nodes, master_seed);
}
if !values.contains_key(&rhs.borrow().id) {
Self::probe_random(rhs.clone(), values, ids, nodes, master_seed);
}

// lookup values
let lhs_value = *values.get(&lhs.borrow().id).unwrap();
let rhs_value = *values.get(&rhs.borrow().id).unwrap();

// combine using appropriate operator
match operation {
BinOp::Add => lhs_value + rhs_value,
BinOp::Mul => lhs_value * rhs_value,
}
}
};

// value already exists; keep books
if let Some(peers) = ids.get_mut(&value) {
values.insert(circuit.borrow().id, value);
peers.push(circuit.borrow().id);
nodes.insert(circuit.borrow().id, circuit.clone());
}
// value is new; keep books
else {
values.insert(circuit.borrow().id, value);
ids.insert(value, vec![circuit.borrow().id]);
nodes.insert(circuit.borrow().id, circuit.clone());
}
}

/// Lowers the degree of a given multicircuit to the target degree.
/// This is achieved by introducing additional variables and constraints.
/// The appropriate substitutions are applied to the given multicircuit.
Expand Down Expand Up @@ -2137,18 +2256,31 @@ mod tests {
}

#[test]
#[ignore = "requires a proper debugging session, or maybe additional optimizations"]
fn constraint_circuit_builder_reuses_existing_nodes_when_folding_constants() {
fn equivalent_nodes_are_detected_when_present() {
let builder = ConstraintCircuitBuilder::new();
let constant = |c| builder.b_constant(c);
let base_row = |r| builder.input(BaseRow(r));

let c_0 = base_row(0);
let c_1 = base_row(0) - constant(0);
let mut constraints = [c_0, c_1];
let x = |i| builder.input(BaseRow(i));
let ch = |i: usize| builder.challenge(i);

ConstraintCircuitMonad::constant_folding(&mut constraints);
let mut constraints = constraints.iter().map(|c| c.consume()).collect_vec();
ConstraintCircuit::assert_unique_ids(&mut constraints);
let x0 = x(0);
let x1 = x(1);
let y0 = x(2);
let y1 = x(3);
let ch0 = ch(0);
let ch1 = ch(1);

let u0 = x0.clone() + x1.clone();
let u1 = y0.clone() + y1.clone();
let v = u0 * u1;

let z0 = x0.clone() * y0.clone();
let z2 = x1.clone() * y1.clone();

let z1 = x1.clone() * y0.clone() + x0.clone() * y1.clone();
let w = v - z0 - z2;
assert!(w.find_equivalent_nodes().is_empty());

let o = ch0 * z1 - ch1 * w;
assert!(!o.find_equivalent_nodes().is_empty());
}
}

0 comments on commit 17c5b61

Please sign in to comment.