From ffad9ebd40d85a0fb07dffaa8b68a18901c9b425 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Mon, 7 Aug 2023 00:32:40 +0200 Subject: [PATCH] slightly simplify constraint code generation --- constraint-evaluation-generator/src/main.rs | 644 ++++++++++---------- 1 file changed, 317 insertions(+), 327 deletions(-) diff --git a/constraint-evaluation-generator/src/main.rs b/constraint-evaluation-generator/src/main.rs index 8fe681423..880baabed 100644 --- a/constraint-evaluation-generator/src/main.rs +++ b/constraint-evaluation-generator/src/main.rs @@ -1,5 +1,4 @@ use std::collections::HashSet; -use std::process::Command; use itertools::Itertools; use proc_macro2::TokenStream; @@ -30,9 +29,58 @@ use triton_vm::table::program_table::ExtProgramTable; use triton_vm::table::ram_table::ExtRamTable; use triton_vm::table::u32_table::ExtU32Table; +struct AllSubstitutions { + base: Substitutions, + ext: Substitutions, +} + +struct Substitutions { + init: Vec>, + cons: Vec>, + tran: Vec>, + term: Vec>, +} + +impl Substitutions { + fn len(&self) -> usize { + self.init.len() + self.cons.len() + self.tran.len() + self.term.len() + } +} + +struct Constraints { + init: Vec>, + cons: Vec>, + tran: Vec>, + term: Vec>, +} + fn main() { + let mut constraints = all_constraints(); + let substitutions = lower_to_target_degree_through_substitutions(&mut constraints); + let degree_lowering_table_code = generate_degree_lowering_table_code(&substitutions); + + let constraints = + combine_existing_and_substitution_induced_constraints(constraints, substitutions); + let constraint_code = generate_constraint_code(constraints); + + write_code_to_file(degree_lowering_table_code, "degree_lowering_table"); + write_code_to_file(constraint_code, "constraints"); +} + +fn all_constraints() -> Constraints { + let mut constraints = Constraints { + init: all_initial_constraints(), + cons: all_consistency_constraints(), + tran: all_transition_constraints(), + term: all_terminal_constraints(), + }; + constant_fold_all_constraints(&mut constraints); + constraints +} + +fn all_initial_constraints() -> Vec> { let circuit_builder = ConstraintCircuitBuilder::new(); - let mut initial_constraints = vec![ + vec![ ExtProgramTable::initial_constraints(&circuit_builder), ExtProcessorTable::initial_constraints(&circuit_builder), ExtOpStackTable::initial_constraints(&circuit_builder), @@ -44,10 +92,12 @@ fn main() { ExtU32Table::initial_constraints(&circuit_builder), GrandCrossTableArg::initial_constraints(&circuit_builder), ] - .concat(); + .concat() +} +fn all_consistency_constraints() -> Vec> { let circuit_builder = ConstraintCircuitBuilder::new(); - let mut consistency_constraints = vec![ + vec![ ExtProgramTable::consistency_constraints(&circuit_builder), ExtProcessorTable::consistency_constraints(&circuit_builder), ExtOpStackTable::consistency_constraints(&circuit_builder), @@ -59,10 +109,12 @@ fn main() { ExtU32Table::consistency_constraints(&circuit_builder), GrandCrossTableArg::consistency_constraints(&circuit_builder), ] - .concat(); + .concat() +} +fn all_transition_constraints() -> Vec> { let circuit_builder = ConstraintCircuitBuilder::new(); - let mut transition_constraints = vec![ + vec![ ExtProgramTable::transition_constraints(&circuit_builder), ExtProcessorTable::transition_constraints(&circuit_builder), ExtOpStackTable::transition_constraints(&circuit_builder), @@ -74,10 +126,12 @@ fn main() { ExtU32Table::transition_constraints(&circuit_builder), GrandCrossTableArg::transition_constraints(&circuit_builder), ] - .concat(); + .concat() +} +fn all_terminal_constraints() -> Vec> { let circuit_builder = ConstraintCircuitBuilder::new(); - let mut terminal_constraints = vec![ + vec![ ExtProgramTable::terminal_constraints(&circuit_builder), ExtProcessorTable::terminal_constraints(&circuit_builder), ExtOpStackTable::terminal_constraints(&circuit_builder), @@ -89,19 +143,25 @@ fn main() { ExtU32Table::terminal_constraints(&circuit_builder), GrandCrossTableArg::terminal_constraints(&circuit_builder), ] - .concat(); + .concat() +} - ConstraintCircuitMonad::constant_folding(&mut initial_constraints); - ConstraintCircuitMonad::constant_folding(&mut consistency_constraints); - ConstraintCircuitMonad::constant_folding(&mut transition_constraints); - ConstraintCircuitMonad::constant_folding(&mut terminal_constraints); +fn constant_fold_all_constraints(constraints: &mut Constraints) { + ConstraintCircuitMonad::constant_folding(&mut constraints.init); + ConstraintCircuitMonad::constant_folding(&mut constraints.cons); + ConstraintCircuitMonad::constant_folding(&mut constraints.tran); + ConstraintCircuitMonad::constant_folding(&mut constraints.term); +} +fn lower_to_target_degree_through_substitutions( + all_constraints: &mut Constraints, +) -> AllSubstitutions { // Subtract the degree lowering table's width from the total number of columns to guarantee // the same number of columns even for repeated runs of the constraint evaluation generator. let mut num_base_cols = master_table::NUM_BASE_COLUMNS - degree_lowering_table::BASE_WIDTH; let mut num_ext_cols = master_table::NUM_EXT_COLUMNS - degree_lowering_table::EXT_WIDTH; let (init_base_substitutions, init_ext_substitutions) = ConstraintCircuitMonad::lower_to_degree( - &mut initial_constraints, + &mut all_constraints.init, master_table::AIR_TARGET_DEGREE, num_base_cols, num_ext_cols, @@ -110,7 +170,7 @@ fn main() { num_ext_cols += init_ext_substitutions.len(); let (cons_base_substitutions, cons_ext_substitutions) = ConstraintCircuitMonad::lower_to_degree( - &mut consistency_constraints, + &mut all_constraints.cons, master_table::AIR_TARGET_DEGREE, num_base_cols, num_ext_cols, @@ -119,7 +179,7 @@ fn main() { num_ext_cols += cons_ext_substitutions.len(); let (tran_base_substitutions, tran_ext_substitutions) = ConstraintCircuitMonad::lower_to_degree( - &mut transition_constraints, + &mut all_constraints.tran, master_table::AIR_TARGET_DEGREE, num_base_cols, num_ext_cols, @@ -128,111 +188,79 @@ fn main() { num_ext_cols += tran_ext_substitutions.len(); let (term_base_substitutions, term_ext_substitutions) = ConstraintCircuitMonad::lower_to_degree( - &mut terminal_constraints, + &mut all_constraints.term, master_table::AIR_TARGET_DEGREE, num_base_cols, num_ext_cols, ); - let table_code = generate_degree_lowering_table_code( - &init_base_substitutions, - &cons_base_substitutions, - &tran_base_substitutions, - &term_base_substitutions, - &init_ext_substitutions, - &cons_ext_substitutions, - &tran_ext_substitutions, - &term_ext_substitutions, - ); - - let initial_constraints = vec![ - initial_constraints, - init_base_substitutions, - init_ext_substitutions, - ] - .concat(); - let consistency_constraints = vec![ - consistency_constraints, - cons_base_substitutions, - cons_ext_substitutions, - ] - .concat(); - let transition_constraints = vec![ - transition_constraints, - tran_base_substitutions, - tran_ext_substitutions, - ] - .concat(); - let terminal_constraints = vec![ - terminal_constraints, - term_base_substitutions, - term_ext_substitutions, - ] - .concat(); - - let mut initial_constraints = consume(initial_constraints); - let mut consistency_constraints = consume(consistency_constraints); - let mut transition_constraints = consume(transition_constraints); - let mut terminal_constraints = consume(terminal_constraints); - - let constraint_code = generate_constraint_code( - &mut initial_constraints, - &mut consistency_constraints, - &mut transition_constraints, - &mut terminal_constraints, - ); - - let table_syntax_tree = syn::parse2(table_code).unwrap(); - let table_code = prettyplease::unparse(&table_syntax_tree); - match std::fs::write("triton-vm/src/table/degree_lowering_table.rs", table_code) { - Ok(_) => (), - Err(err) => panic!("Writing to disk has failed: {err}"), - } - - let constraint_syntax_tree = syn::parse2(constraint_code).unwrap(); - let constraint_code = prettyplease::unparse(&constraint_syntax_tree); - match std::fs::write("triton-vm/src/table/constraints.rs", constraint_code) { - Ok(_) => (), - Err(err) => panic!("Writing to disk has failed: {err}"), + AllSubstitutions { + base: Substitutions { + init: init_base_substitutions, + cons: cons_base_substitutions, + tran: tran_base_substitutions, + term: term_base_substitutions, + }, + ext: Substitutions { + init: init_ext_substitutions, + cons: cons_ext_substitutions, + tran: tran_ext_substitutions, + term: term_ext_substitutions, + }, } +} - match Command::new("cargo") - .arg("clippy") - .arg("--workspace") - .arg("--all-targets") - .output() - { - Ok(_) => (), - Err(err) => panic!("cargo clippy failed: {err}"), +fn combine_existing_and_substitution_induced_constraints( + constraints: Constraints, + substitutions: AllSubstitutions, +) -> Constraints { + let init = [ + constraints.init, + substitutions.base.init, + substitutions.ext.init, + ]; + let cons = [ + constraints.cons, + substitutions.base.cons, + substitutions.ext.cons, + ]; + let tran = [ + constraints.tran, + substitutions.base.tran, + substitutions.ext.tran, + ]; + let term = [ + constraints.term, + substitutions.base.term, + substitutions.ext.term, + ]; + Constraints { + init: init.concat(), + cons: cons.concat(), + tran: tran.concat(), + term: term.concat(), } } -/// Consumes every `ConstraintCircuitMonad`, returning their corresponding `ConstraintCircuit`s. -fn consume( - constraints: Vec>, -) -> Vec> { - constraints.into_iter().map(|c| c.consume()).collect() -} +fn generate_constraint_code(constraints: Constraints) -> TokenStream { + let num_init_constraints = constraints.init.len(); + let num_cons_constraints = constraints.cons.len(); + let num_tran_constraints = constraints.tran.len(); + let num_term_constraints = constraints.term.len(); -fn generate_constraint_code( - init_constraint_circuits: &mut [ConstraintCircuit], - cons_constraint_circuits: &mut [ConstraintCircuit], - tran_constraint_circuits: &mut [ConstraintCircuit], - term_constraint_circuits: &mut [ConstraintCircuit], -) -> TokenStream { - let num_init_constraints = init_constraint_circuits.len(); - let num_cons_constraints = cons_constraint_circuits.len(); - let num_tran_constraints = tran_constraint_circuits.len(); - let num_term_constraints = term_constraint_circuits.len(); + let mut init_constraint_circuits = consume(constraints.init); + let mut cons_constraint_circuits = consume(constraints.cons); + let mut tran_constraint_circuits = consume(constraints.tran); + let mut term_constraint_circuits = consume(constraints.term); let (init_constraint_degrees, init_constraints_bfe, init_constraints_xfe) = - tokenize_circuits(init_constraint_circuits); + tokenize_circuits(&mut init_constraint_circuits); let (cons_constraint_degrees, cons_constraints_bfe, cons_constraints_xfe) = - tokenize_circuits(cons_constraint_circuits); + tokenize_circuits(&mut cons_constraint_circuits); let (tran_constraint_degrees, tran_constraints_bfe, tran_constraints_xfe) = - tokenize_circuits(tran_constraint_circuits); + tokenize_circuits(&mut tran_constraint_circuits); let (term_constraint_degrees, term_constraints_bfe, term_constraints_xfe) = - tokenize_circuits(term_constraint_circuits); + tokenize_circuits(&mut term_constraint_circuits); quote!( use ndarray::ArrayView1; @@ -391,6 +419,13 @@ fn generate_constraint_code( ) } +/// Consumes every `ConstraintCircuitMonad`, returning their corresponding `ConstraintCircuit`s. +fn consume( + constraints: Vec>, +) -> Vec> { + constraints.into_iter().map(|c| c.consume()).collect() +} + /// Given a slice of constraint circuits, return a tuple of [`TokenStream`]s corresponding to code /// evaluating these constraints as well as their degrees. In particular: /// 1. The first stream contains code that, when evaluated, produces the constraints' degrees, @@ -617,41 +652,11 @@ fn evaluate_single_node( quote!((#evaluated_lhs) #binop (#evaluated_rhs)) } -/// Given a substitution rule, i.e., a `ConstraintCircuit` of the form `x - expr`, generate code -/// that evaluates `expr`. -fn substitution_rule_to_code(circuit: ConstraintCircuit) -> TokenStream { - let BinaryOperation(BinOp::Sub, new_var, expr) = circuit.expression else { - panic!("Substitution rule must be a subtraction."); - }; - let Input(_) = new_var.as_ref().borrow().expression else { - panic!("Substitution rule must be a simple substitution."); - }; - - let expr = expr.as_ref().borrow().to_owned(); - evaluate_single_node(usize::MAX, &expr, &HashSet::new()) -} - /// Given all substitution rules, generate the code that evaluates them in order. /// This includes generating the columns that are to be filled using the substitution rules. -#[allow(clippy::too_many_arguments)] -fn generate_degree_lowering_table_code( - init_base_substitutions: &[ConstraintCircuitMonad], - cons_base_substitutions: &[ConstraintCircuitMonad], - tran_base_substitutions: &[ConstraintCircuitMonad], - term_base_substitutions: &[ConstraintCircuitMonad], - init_ext_substitutions: &[ConstraintCircuitMonad], - cons_ext_substitutions: &[ConstraintCircuitMonad], - tran_ext_substitutions: &[ConstraintCircuitMonad], - term_ext_substitutions: &[ConstraintCircuitMonad], -) -> TokenStream { - let num_new_base_cols = init_base_substitutions.len() - + cons_base_substitutions.len() - + tran_base_substitutions.len() - + term_base_substitutions.len(); - let num_new_ext_cols = init_ext_substitutions.len() - + cons_ext_substitutions.len() - + tran_ext_substitutions.len() - + term_ext_substitutions.len(); +fn generate_degree_lowering_table_code(substitutions: &AllSubstitutions) -> TokenStream { + let num_new_base_cols = substitutions.base.len(); + let num_new_ext_cols = substitutions.ext.len(); // A zero-variant enum cannot be annotated with `repr(usize)`. let base_repr_usize = match num_new_base_cols == 0 { @@ -678,18 +683,8 @@ fn generate_degree_lowering_table_code( .map(|ident| quote!(#ident)) .collect_vec(); - let fill_base_columns_code = generate_fill_base_columns_code( - init_base_substitutions, - cons_base_substitutions, - tran_base_substitutions, - term_base_substitutions, - ); - let fill_ext_columns_code = generate_fill_ext_columns_code( - init_ext_substitutions, - cons_ext_substitutions, - tran_ext_substitutions, - term_ext_substitutions, - ); + let fill_base_columns_code = generate_fill_base_columns_code(&substitutions.base); + let fill_ext_columns_code = generate_fill_ext_columns_code(&substitutions.ext); quote!( //! The degree lowering table contains the introduced variables that allow @@ -742,96 +737,35 @@ fn generate_degree_lowering_table_code( ) } -fn generate_fill_base_columns_code( - init_substitutions: &[ConstraintCircuitMonad], - cons_substitutions: &[ConstraintCircuitMonad], - tran_substitutions: &[ConstraintCircuitMonad], - term_substitutions: &[ConstraintCircuitMonad], -) -> TokenStream { - let derived_section_start = master_table::NUM_BASE_COLUMNS - degree_lowering_table::BASE_WIDTH; - - let num_init_substitutions = init_substitutions.len(); - let num_cons_substitutions = cons_substitutions.len(); - let num_tran_substitutions = tran_substitutions.len(); - let num_term_substitutions = term_substitutions.len(); - - let init_col_indices = (0..num_init_substitutions) - .map(|i| i + derived_section_start) - .collect_vec(); - let cons_col_indices = (0..num_cons_substitutions) - .map(|i| i + derived_section_start + num_init_substitutions) - .collect_vec(); - let tran_col_indices = (0..num_tran_substitutions) - .map(|i| i + derived_section_start + num_init_substitutions + num_cons_substitutions) - .collect_vec(); - let term_col_indices = (0..num_term_substitutions) - .map(|i| { - i + derived_section_start - + num_init_substitutions - + num_cons_substitutions - + num_tran_substitutions - }) - .collect_vec(); - - let init_substitutions = init_substitutions - .iter() - .map(|c| substitution_rule_to_code(c.circuit.as_ref().borrow().to_owned())) - .collect_vec(); - let cons_substitutions = cons_substitutions - .iter() - .map(|c| substitution_rule_to_code(c.circuit.as_ref().borrow().to_owned())) - .collect_vec(); - let tran_substitutions = tran_substitutions - .iter() - .map(|c| substitution_rule_to_code(c.circuit.as_ref().borrow().to_owned())) - .collect_vec(); - let term_substitutions = term_substitutions - .iter() - .map(|c| substitution_rule_to_code(c.circuit.as_ref().borrow().to_owned())) - .collect_vec(); - - let single_row_substitutions = |indices: Vec, substitutions: Vec| { - assert_eq!(indices.len(), substitutions.len()); - if indices.is_empty() { - return quote!(); - } - quote!( - master_base_table.rows_mut().into_iter().for_each(|mut row| { - #( - let (base_row, mut det_col) = - row.multi_slice_mut((s![..#indices],s![#indices..#indices + 1])); - det_col[0] = #substitutions; - )* - }); - ) - }; - let dual_row_substitutions = |indices: Vec, substitutions: Vec| { - assert_eq!(indices.len(), substitutions.len()); - if indices.is_empty() { - return quote!(); - } - quote!( - for curr_row_idx in 0..master_base_table.nrows() - 1 { - let next_row_idx = curr_row_idx + 1; - let (mut curr_base_row, next_base_row) = master_base_table.multi_slice_mut(( - s![curr_row_idx..curr_row_idx + 1, ..], - s![next_row_idx..next_row_idx + 1, ..], - )); - let mut curr_base_row = curr_base_row.row_mut(0); - let next_base_row = next_base_row.row(0); - #( - let (current_base_row, mut det_col) = - curr_base_row.multi_slice_mut((s![..#indices], s![#indices..#indices + 1])); - det_col[0] = #substitutions; - )* - } - ) - }; - - let init_substitutions = single_row_substitutions(init_col_indices, init_substitutions); - let cons_substitutions = single_row_substitutions(cons_col_indices, cons_substitutions); - let tran_substitutions = dual_row_substitutions(tran_col_indices, tran_substitutions); - let term_substitutions = single_row_substitutions(term_col_indices, term_substitutions); +fn generate_fill_base_columns_code(substitutions: &Substitutions) -> TokenStream { + let derived_section_init_start = + master_table::NUM_BASE_COLUMNS - degree_lowering_table::BASE_WIDTH; + let derived_section_cons_start = derived_section_init_start + substitutions.init.len(); + let derived_section_tran_start = derived_section_cons_start + substitutions.cons.len(); + let derived_section_term_start = derived_section_tran_start + substitutions.tran.len(); + + let init_col_indices = (0..substitutions.init.len()) + .map(|i| i + derived_section_init_start) + .collect(); + let cons_col_indices = (0..substitutions.cons.len()) + .map(|i| i + derived_section_cons_start) + .collect(); + let tran_col_indices = (0..substitutions.tran.len()) + .map(|i| i + derived_section_tran_start) + .collect(); + let term_col_indices = (0..substitutions.term.len()) + .map(|i| i + derived_section_term_start) + .collect(); + + let init_substitutions = several_substitution_rules_to_code(&substitutions.init); + let cons_substitutions = several_substitution_rules_to_code(&substitutions.cons); + let tran_substitutions = several_substitution_rules_to_code(&substitutions.tran); + let term_substitutions = several_substitution_rules_to_code(&substitutions.term); + + let init_substitutions = base_single_row_substitutions(init_col_indices, init_substitutions); + let cons_substitutions = base_single_row_substitutions(cons_col_indices, cons_substitutions); + let tran_substitutions = base_dual_row_substitutions(tran_col_indices, tran_substitutions); + let term_substitutions = base_single_row_substitutions(term_col_indices, term_substitutions); quote!( #[allow(unused_variables)] @@ -845,100 +779,35 @@ fn generate_fill_base_columns_code( ) } -fn generate_fill_ext_columns_code( - init_substitutions: &[ConstraintCircuitMonad], - cons_substitutions: &[ConstraintCircuitMonad], - tran_substitutions: &[ConstraintCircuitMonad], - term_substitutions: &[ConstraintCircuitMonad], -) -> TokenStream { - let derived_section_start = master_table::NUM_EXT_COLUMNS - degree_lowering_table::EXT_WIDTH; +fn generate_fill_ext_columns_code(substitutions: &Substitutions) -> TokenStream { + let derived_section_init_start = + master_table::NUM_EXT_COLUMNS - degree_lowering_table::EXT_WIDTH; + let derived_section_cons_start = derived_section_init_start + substitutions.init.len(); + let derived_section_tran_start = derived_section_cons_start + substitutions.cons.len(); + let derived_section_term_start = derived_section_tran_start + substitutions.tran.len(); - let num_init_substitutions = init_substitutions.len(); - let num_cons_substitutions = cons_substitutions.len(); - let num_tran_substitutions = tran_substitutions.len(); - let num_term_substitutions = term_substitutions.len(); - - let init_col_indices = (0..num_init_substitutions) - .map(|i| i + derived_section_start) + let init_col_indices = (0..substitutions.init.len()) + .map(|i| i + derived_section_init_start) .collect_vec(); - let cons_col_indices = (0..num_cons_substitutions) - .map(|i| i + derived_section_start + num_init_substitutions) + let cons_col_indices = (0..substitutions.cons.len()) + .map(|i| i + derived_section_cons_start) .collect_vec(); - let tran_col_indices = (0..num_tran_substitutions) - .map(|i| i + derived_section_start + num_init_substitutions + num_cons_substitutions) + let tran_col_indices = (0..substitutions.tran.len()) + .map(|i| i + derived_section_tran_start) .collect_vec(); - let term_col_indices = (0..num_term_substitutions) - .map(|i| { - i + derived_section_start - + num_init_substitutions - + num_cons_substitutions - + num_tran_substitutions - }) + let term_col_indices = (0..substitutions.term.len()) + .map(|i| i + derived_section_term_start) .collect_vec(); - let init_substitutions = init_substitutions - .iter() - .map(|c| substitution_rule_to_code(c.circuit.as_ref().borrow().to_owned())) - .collect_vec(); - let cons_substitutions = cons_substitutions - .iter() - .map(|c| substitution_rule_to_code(c.circuit.as_ref().borrow().to_owned())) - .collect_vec(); - let tran_substitutions = tran_substitutions - .iter() - .map(|c| substitution_rule_to_code(c.circuit.as_ref().borrow().to_owned())) - .collect_vec(); - let term_substitutions = term_substitutions - .iter() - .map(|c| substitution_rule_to_code(c.circuit.as_ref().borrow().to_owned())) - .collect_vec(); + let init_substitutions = several_substitution_rules_to_code(&substitutions.init); + let cons_substitutions = several_substitution_rules_to_code(&substitutions.cons); + let tran_substitutions = several_substitution_rules_to_code(&substitutions.tran); + let term_substitutions = several_substitution_rules_to_code(&substitutions.term); - let single_row_substitutions = |indices: Vec, substitutions: Vec| { - assert_eq!(indices.len(), substitutions.len()); - if indices.is_empty() { - return quote!(); - } - quote!( - for row_idx in 0..master_base_table.nrows() - 1 { - let base_row = master_base_table.row(row_idx); - let mut extension_row = master_ext_table.row_mut(row_idx); - #( - let (ext_row, mut det_col) = - extension_row.multi_slice_mut((s![..#indices],s![#indices..#indices + 1])); - det_col[0] = #substitutions; - )* - } - ) - }; - let dual_row_substitutions = |indices: Vec, substitutions: Vec| { - assert_eq!(indices.len(), substitutions.len()); - if indices.is_empty() { - return quote!(); - } - quote!( - for curr_row_idx in 0..master_base_table.nrows() - 1 { - let next_row_idx = curr_row_idx + 1; - let current_base_row = master_base_table.row(curr_row_idx); - let next_base_row = master_base_table.row(next_row_idx); - let (mut curr_ext_row, next_ext_row) = master_ext_table.multi_slice_mut(( - s![curr_row_idx..curr_row_idx + 1, ..], - s![next_row_idx..next_row_idx + 1, ..], - )); - let mut curr_ext_row = curr_ext_row.row_mut(0); - let next_ext_row = next_ext_row.row(0); - #( - let (current_ext_row, mut det_col) = - curr_ext_row.multi_slice_mut((s![..#indices], s![#indices..#indices + 1])); - det_col[0] = #substitutions; - )* - } - ) - }; - - let init_substitutions = single_row_substitutions(init_col_indices, init_substitutions); - let cons_substitutions = single_row_substitutions(cons_col_indices, cons_substitutions); - let tran_substitutions = dual_row_substitutions(tran_col_indices, tran_substitutions); - let term_substitutions = single_row_substitutions(term_col_indices, term_substitutions); + let init_substitutions = ext_single_row_substitutions(init_col_indices, init_substitutions); + let cons_substitutions = ext_single_row_substitutions(cons_col_indices, cons_substitutions); + let tran_substitutions = ext_dual_row_substitutions(tran_col_indices, tran_substitutions); + let term_substitutions = ext_single_row_substitutions(term_col_indices, term_substitutions); quote!( #[allow(unused_variables)] @@ -957,3 +826,124 @@ fn generate_fill_ext_columns_code( } ) } + +fn several_substitution_rules_to_code( + substitution_rules: &[ConstraintCircuitMonad], +) -> Vec { + substitution_rules + .iter() + .map(|c| substitution_rule_to_code(c.circuit.as_ref().borrow().to_owned())) + .collect() +} + +/// Given a substitution rule, i.e., a `ConstraintCircuit` of the form `x - expr`, generate code +/// that evaluates `expr`. +fn substitution_rule_to_code(circuit: ConstraintCircuit) -> TokenStream { + let BinaryOperation(BinOp::Sub, new_var, expr) = circuit.expression else { + panic!("Substitution rule must be a subtraction."); + }; + let Input(_) = new_var.as_ref().borrow().expression else { + panic!("Substitution rule must be a simple substitution."); + }; + + let expr = expr.as_ref().borrow().to_owned(); + evaluate_single_node(usize::MAX, &expr, &HashSet::new()) +} + +fn base_single_row_substitutions( + indices: Vec, + substitutions: Vec, +) -> TokenStream { + assert_eq!(indices.len(), substitutions.len()); + if indices.is_empty() { + return quote!(); + } + quote!( + master_base_table.rows_mut().into_iter().for_each(|mut row| { + #( + let (base_row, mut det_col) = + row.multi_slice_mut((s![..#indices],s![#indices..#indices + 1])); + det_col[0] = #substitutions; + )* + }); + ) +} + +fn base_dual_row_substitutions( + indices: Vec, + substitutions: Vec, +) -> TokenStream { + assert_eq!(indices.len(), substitutions.len()); + if indices.is_empty() { + return quote!(); + } + quote!( + for curr_row_idx in 0..master_base_table.nrows() - 1 { + let next_row_idx = curr_row_idx + 1; + let (mut curr_base_row, next_base_row) = master_base_table.multi_slice_mut(( + s![curr_row_idx..curr_row_idx + 1, ..], + s![next_row_idx..next_row_idx + 1, ..], + )); + let mut curr_base_row = curr_base_row.row_mut(0); + let next_base_row = next_base_row.row(0); + #( + let (current_base_row, mut det_col) = + curr_base_row.multi_slice_mut((s![..#indices], s![#indices..#indices + 1])); + det_col[0] = #substitutions; + )* + } + ) +} + +fn ext_single_row_substitutions( + indices: Vec, + substitutions: Vec, +) -> TokenStream { + assert_eq!(indices.len(), substitutions.len()); + if indices.is_empty() { + return quote!(); + } + quote!( + for row_idx in 0..master_base_table.nrows() - 1 { + let base_row = master_base_table.row(row_idx); + let mut extension_row = master_ext_table.row_mut(row_idx); + #( + let (ext_row, mut det_col) = + extension_row.multi_slice_mut((s![..#indices],s![#indices..#indices + 1])); + det_col[0] = #substitutions; + )* + } + ) +} + +fn ext_dual_row_substitutions(indices: Vec, substitutions: Vec) -> TokenStream { + assert_eq!(indices.len(), substitutions.len()); + if indices.is_empty() { + return quote!(); + } + quote!( + for curr_row_idx in 0..master_base_table.nrows() - 1 { + let next_row_idx = curr_row_idx + 1; + let current_base_row = master_base_table.row(curr_row_idx); + let next_base_row = master_base_table.row(next_row_idx); + let (mut curr_ext_row, next_ext_row) = master_ext_table.multi_slice_mut(( + s![curr_row_idx..curr_row_idx + 1, ..], + s![next_row_idx..next_row_idx + 1, ..], + )); + let mut curr_ext_row = curr_ext_row.row_mut(0); + let next_ext_row = next_ext_row.row(0); + #( + let (current_ext_row, mut det_col) = + curr_ext_row.multi_slice_mut((s![..#indices], s![#indices..#indices + 1])); + det_col[0] = #substitutions; + )* + } + ) +} + +fn write_code_to_file(code: TokenStream, file_name: &str) { + let syntax_tree = syn::parse2(code).unwrap(); + let code = prettyplease::unparse(&syntax_tree); + let path = format!("triton-vm/src/table/{file_name}.rs"); + std::fs::write(path, code).unwrap(); +}