Skip to content

Commit

Permalink
refactor!: remove BinOp::Sub
Browse files Browse the repository at this point in the history
BREAKING CHANGE: Substitute all subtractions of the form `a - b` by
`a + (-1·b)`.
  • Loading branch information
jan-ferdinand committed Mar 3, 2024
1 parent c3e9053 commit 675acc6
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 72 deletions.
2 changes: 0 additions & 2 deletions constraint-evaluation-generator/src/codegen/tasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,8 @@ impl TasmBackend {
}

fn tokenize_bin_op(binop: BinOp) -> TokenStream {
let minus_one = Self::load_base_field_constant(-BFieldElement::new(1));
match binop {
BinOp::Add => quote!(AnInstruction::XxAdd,),
BinOp::Sub => quote!(#minus_one AnInstruction::XxMul, AnInstruction::XxAdd,),
BinOp::Mul => quote!(AnInstruction::XxMul,),
}
}
Expand Down
9 changes: 7 additions & 2 deletions constraint-evaluation-generator/src/substitution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,18 @@ impl Substitutions {
fn substitution_rule_to_code<II: InputIndicator>(
circuit: ConstraintCircuit<II>,
) -> TokenStream {
let CircuitExpression::BinaryOperation(BinOp::Sub, new_var, expr) = circuit.expression
let CircuitExpression::BinaryOperation(BinOp::Add, new_var, expr) = circuit.expression
else {
panic!("Substitution rule must be a subtraction.");
panic!("Substitution rule must be a subtraction, i.e., addition of `x` and `-expr`.");
};
let CircuitExpression::Input(_) = new_var.borrow().expression else {
panic!("Substitution rule must be a simple substitution.");
};
let expr = expr.borrow();
let CircuitExpression::BinaryOperation(BinOp::Mul, neg_one, expr) = &expr.expression else {
panic!("Substitution rule must be a subtraction.");
};
assert!(neg_one.borrow().is_neg_one());

let expr = expr.borrow();
RustBackend::default().evaluate_single_node(&expr)
Expand Down
157 changes: 89 additions & 68 deletions triton-vm/src/table/constraint_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,13 @@ use crate::table::challenges::Challenges;
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum BinOp {
Add,
Sub,
Mul,
}

impl Display for BinOp {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
BinOp::Add => write!(f, "+"),
BinOp::Sub => write!(f, "-"),
BinOp::Mul => write!(f, "*"),
}
}
Expand All @@ -52,7 +50,6 @@ impl ToTokens for BinOp {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
match self {
BinOp::Add => tokens.extend(quote!(+)),
BinOp::Sub => tokens.extend(quote!(-)),
BinOp::Mul => tokens.extend(quote!(*)),
}
}
Expand All @@ -61,11 +58,10 @@ impl ToTokens for BinOp {
impl BinOp {
pub fn operation<L, R, O>(&self, lhs: L, rhs: R) -> O
where
L: Add<R, Output = O> + Sub<R, Output = O> + Mul<R, Output = O>,
L: Add<R, Output = O> + Mul<R, Output = O>,
{
match self {
BinOp::Add => lhs + rhs,
BinOp::Sub => lhs - rhs,
BinOp::Mul => lhs * rhs,
}
}
Expand Down Expand Up @@ -367,6 +363,14 @@ impl<II: InputIndicator> Display for ConstraintCircuit<II> {
}

impl<II: InputIndicator> ConstraintCircuit<II> {
fn new(id: usize, expression: CircuitExpression<II>) -> Self {
Self {
id,
ref_count: 0,
expression,
}
}

/// Reset the reference counters for the entire subtree
fn reset_ref_count_for_tree(&mut self) {
self.ref_count = 0;
Expand Down Expand Up @@ -429,7 +433,7 @@ impl<II: InputIndicator> ConstraintCircuit<II> {
false => degree_lhs + degree_rhs,
};
match binop {
BinOp::Add | BinOp::Sub => degree_additive,
BinOp::Add => degree_additive,
BinOp::Mul => degree_multiplicative,
}
}
Expand All @@ -456,8 +460,8 @@ impl<II: InputIndicator> ConstraintCircuit<II> {
!matches!(&self.expression, BinaryOperation(_, _, _))
}

/// Return true if this node represents a constant value of zero, does not catch composite
/// expressions that will always evaluate to zero.
/// Is the node the constant 0?
/// Does not catch composite expressions that will always evaluate to zero, like `0·a`.
pub fn is_zero(&self) -> bool {
match self.expression {
BConstant(bfe) => bfe.is_zero(),
Expand All @@ -466,8 +470,8 @@ impl<II: InputIndicator> ConstraintCircuit<II> {
}
}

/// Return true if this node represents a constant value of one, does not catch composite
/// expressions that will always evaluate to one.
/// Is the node the constant 1?
/// Does not catch composite expressions that will always evaluate to one, like `1·1`.
pub fn is_one(&self) -> bool {
match self.expression {
BConstant(bfe) => bfe.is_one(),
Expand All @@ -476,6 +480,14 @@ impl<II: InputIndicator> ConstraintCircuit<II> {
}
}

pub fn is_neg_one(&self) -> bool {
match self.expression {
BConstant(bfe) => (-bfe).is_one(),
XConstant(xfe) => (-xfe).is_one(),
_ => false,
}
}

/// Recursively check whether this node is composed of only BFieldElements, i.e., only uses
/// 1. inputs from base rows,
/// 2. constants from the B-field, and
Expand Down Expand Up @@ -587,46 +599,34 @@ fn binop<II: InputIndicator>(
lhs: ConstraintCircuitMonad<II>,
rhs: ConstraintCircuitMonad<II>,
) -> ConstraintCircuitMonad<II> {
let id = lhs.builder.id_counter.borrow().to_owned();
let expression = BinaryOperation(binop, lhs.circuit.clone(), rhs.circuit.clone());
let circuit = ConstraintCircuit {
id,
ref_count: 0,
expression,
};
let circuit = Rc::new(RefCell::new(circuit));
let new_node = lhs.new_monad_same_context(circuit);

let mut all_nodes = lhs.builder.all_nodes.borrow_mut();
if let Some(same_node) = all_nodes.get(&new_node) {
return same_node.to_owned();
}

// If the operator commutes, check if the switched node has already been constructed.
// If it has, return it instead. Do not allow a new one to be built.
if matches!(binop, BinOp::Add | BinOp::Mul) {
let expression_switched = BinaryOperation(binop, rhs.circuit, lhs.circuit);
let circuit_switched = ConstraintCircuit {
id,
ref_count: 0,
expression: expression_switched,
};
let circuit_switched = Rc::new(RefCell::new(circuit_switched));
let new_node_switched = ConstraintCircuitMonad {
circuit: circuit_switched,
builder: lhs.builder.clone(),
};
if let Some(same_node) = all_nodes.get(&new_node_switched) {
return same_node.to_owned();
}
// all `BinOp`s are commutative – try both orders of the operands
let new_node = binop_new_node(binop, &rhs, &lhs);
if let Some(node) = lhs.builder.all_nodes.borrow().get(&new_node) {
return node.to_owned();
}

let new_node = binop_new_node(binop, &lhs, &rhs);
if let Some(node) = lhs.builder.all_nodes.borrow().get(&new_node) {
return node.to_owned();
}

*lhs.builder.id_counter.borrow_mut() += 1;
let was_inserted = all_nodes.insert(new_node.clone());
let was_inserted = lhs.builder.all_nodes.borrow_mut().insert(new_node.clone());
assert!(was_inserted, "Binop-created value must be new");
new_node
}

fn binop_new_node<II: InputIndicator>(
binop: BinOp,
lhs: &ConstraintCircuitMonad<II>,
rhs: &ConstraintCircuitMonad<II>,
) -> ConstraintCircuitMonad<II> {
let id = lhs.builder.id_counter.borrow().to_owned();
let expression = BinaryOperation(binop, lhs.circuit.clone(), rhs.circuit.clone());
let circuit = ConstraintCircuit::new(id, expression);
lhs.builder.new_monad(circuit)
}

impl<II: InputIndicator> Add for ConstraintCircuitMonad<II> {
type Output = ConstraintCircuitMonad<II>;

Expand All @@ -639,7 +639,7 @@ impl<II: InputIndicator> Sub for ConstraintCircuitMonad<II> {
type Output = ConstraintCircuitMonad<II>;

fn sub(self, rhs: Self) -> Self::Output {
binop(BinOp::Sub, self, rhs)
binop(BinOp::Add, self, -rhs)
}
}

Expand All @@ -651,6 +651,14 @@ impl<II: InputIndicator> Mul for ConstraintCircuitMonad<II> {
}
}

impl<II: InputIndicator> Neg for ConstraintCircuitMonad<II> {
type Output = ConstraintCircuitMonad<II>;

fn neg(self) -> Self::Output {
binop(BinOp::Mul, self.builder.minus_one(), self)
}
}

/// This will panic if the iterator is empty because the neutral element needs a unique ID, and
/// we have no way of getting that here.
impl<II: InputIndicator> Sum for ConstraintCircuitMonad<II> {
Expand Down Expand Up @@ -692,7 +700,7 @@ impl<II: InputIndicator> ConstraintCircuitMonad<II> {
};

match (op, lhs, rhs) {
(BinOp::Add | BinOp::Sub, l, r) if r.borrow().is_zero() => return Some(l.clone()),
(BinOp::Add, l, r) if r.borrow().is_zero() => return Some(l.clone()),
(BinOp::Add, l, r) if l.borrow().is_zero() => return Some(r.clone()),
(BinOp::Mul, l, r) if r.borrow().is_one() => return Some(l.clone()),
(BinOp::Mul, l, r) if l.borrow().is_one() => return Some(r.clone()),
Expand Down Expand Up @@ -960,6 +968,14 @@ impl<II: InputIndicator> ConstraintCircuitBuilder<II> {
}
}

fn new_monad(&self, circuit: ConstraintCircuit<II>) -> ConstraintCircuitMonad<II> {
let circuit = Rc::new(RefCell::new(circuit));
ConstraintCircuitMonad {
circuit,
builder: self.clone(),
}
}

pub fn get_node_by_id(&self, id: usize) -> Option<ConstraintCircuitMonad<II>> {
self.all_nodes
.borrow()
Expand All @@ -968,6 +984,21 @@ impl<II: InputIndicator> ConstraintCircuitBuilder<II> {
.cloned()
}

/// The unique monad representing the constant value 0.
pub fn zero(&self) -> ConstraintCircuitMonad<II> {
self.b_constant(BFieldElement::zero())
}

/// The unique monad representing the constant value 1.
pub fn one(&self) -> ConstraintCircuitMonad<II> {
self.b_constant(BFieldElement::one())
}

/// The unique monad representing the constant value -1.
pub fn minus_one(&self) -> ConstraintCircuitMonad<II> {
self.b_constant(BFieldElement::one().neg())
}

/// Create constant leaf node.
pub fn x_constant(&self, xfe: XFieldElement) -> ConstraintCircuitMonad<II> {
self.make_leaf(XConstant(xfe))
Expand Down Expand Up @@ -997,25 +1028,16 @@ impl<II: InputIndicator> ConstraintCircuitBuilder<II> {
}

let id = self.id_counter.borrow().to_owned();
let circuit = ConstraintCircuit {
id,
ref_count: 0,
expression,
};
let circuit = Rc::new(RefCell::new(circuit));
let new_node = ConstraintCircuitMonad {
circuit,
builder: self.clone(),
};
let circuit = ConstraintCircuit::new(id, expression);
let new_node = self.new_monad(circuit);

let mut all_nodes = self.all_nodes.borrow_mut();
if let Some(same_node) = all_nodes.get(&new_node) {
same_node.to_owned()
} else {
*self.id_counter.borrow_mut() += 1;
all_nodes.insert(new_node.clone());
new_node
if let Some(same_node) = self.all_nodes.borrow().get(&new_node) {
return same_node.to_owned();
}

*self.id_counter.borrow_mut() += 1;
self.all_nodes.borrow_mut().insert(new_node.clone());
new_node
}

/// Substitute all nodes with ID `old_id` with the given `new` node.
Expand Down Expand Up @@ -1062,7 +1084,6 @@ mod tests {
use crate::table::hash_table::ExtHashTable;
use crate::table::jump_stack_table::ExtJumpStackTable;
use crate::table::lookup_table::ExtLookupTable;
use crate::table::master_table;
use crate::table::master_table::*;
use crate::table::op_stack_table::ExtOpStackTable;
use crate::table::processor_table::ExtProcessorTable;
Expand Down Expand Up @@ -1418,8 +1439,8 @@ mod tests {
let challenges = Challenges::new(challenges, &dummy_claim);

let num_rows = 2;
let base_shape = [num_rows, master_table::NUM_BASE_COLUMNS];
let ext_shape = [num_rows, master_table::NUM_EXT_COLUMNS];
let base_shape = [num_rows, NUM_BASE_COLUMNS];
let ext_shape = [num_rows, NUM_EXT_COLUMNS];
let base_rows = Array2::from_shape_simple_fn(base_shape, || rng.gen::<BFieldElement>());
let ext_rows = Array2::from_shape_simple_fn(ext_shape, || rng.gen::<XFieldElement>());
let base_rows = base_rows.view();
Expand Down Expand Up @@ -2046,7 +2067,7 @@ mod tests {
] {
for (i, constraint) in constraints.iter().enumerate() {
let expression = constraint.circuit.borrow().expression.clone();
let BinaryOperation(BinOp::Sub, lhs, rhs) = expression else {
let BinaryOperation(BinOp::Add, lhs, rhs) = expression else {
panic!("New {constraint_type} constraint {i} must be a subtraction.");
};
let Input(input_indicator) = lhs.borrow().expression.clone() else {
Expand All @@ -2067,8 +2088,8 @@ mod tests {
let num_rows = 2;
let num_new_base_constraints = new_base_constraints.len();
let num_new_ext_constraints = new_ext_constraints.len();
let num_base_cols = master_table::NUM_BASE_COLUMNS + num_new_base_constraints;
let num_ext_cols = master_table::NUM_EXT_COLUMNS + num_new_ext_constraints;
let num_base_cols = NUM_BASE_COLUMNS + num_new_base_constraints;
let num_ext_cols = NUM_EXT_COLUMNS + num_new_ext_constraints;
let base_shape = [num_rows, num_base_cols];
let ext_shape = [num_rows, num_ext_cols];
let base_rows = Array2::from_shape_simple_fn(base_shape, || rng.gen::<BFieldElement>());
Expand Down

0 comments on commit 675acc6

Please sign in to comment.