Skip to content

Commit

Permalink
Parallelize Filling of Degree-Lowering Table (#284)
Browse files Browse the repository at this point in the history
This PR modifies the auto-generated code for filling the degree-lowering
table, essentially upgrading it from sequential to parallel iteration.
To achieve this, the master base and extension tables are split into
left and right parts at the column index separating original columns
from degree-lowering columns. Then parallel iteration over the rows
allows filling in the degree-lowering rows left to right. No value in
any degree-lowering row depends on values to its right.

A complication arises when filling the degree-lowering columns for
*transition constraints* as the degree-lowering values depend on two
rows, current and next. The solution is already implied by the AIR
constraints, which ensures that all degree-lowering values live in the
current row. Therefore, by parallel iteration over *single* rows of the
degree-lowering part, and *overlapping row pairs* of the original part,
one can fill the former left to right. Note that the overlapping row
pairs do not interfere with parallel execution because these are
*immutable* references. Rust disallows multiple *mutable* references to
the same data, but in this case only the right half of the table is
mutable, and there we select one row per iteration.

Much of the kudos goes to @jan-ferdinand who came up with the blueprint
for the solution.

On my desktop machine, benchmarking `prove_halt` --

sequential: 

![image](https://github.com/TritonVM/triton-vm/assets/1583170/42bf1900-3b4f-4d52-be7d-060b8090a701)

parallel:

![image](https://github.com/TritonVM/triton-vm/assets/1583170/d6072176-fb03-4a25-99e1-369765976350)
  • Loading branch information
aszepieniec authored May 22, 2024
2 parents f7b13e7 + f617e60 commit 9c02c64
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 88 deletions.
193 changes: 105 additions & 88 deletions constraint-evaluation-generator/src/substitution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ impl AllSubstitutions {
//! To re-generate, execute:
//! `cargo run --bin constraint-evaluation-generator`

use ndarray::Array1;
use ndarray::s;
use ndarray::ArrayView2;
use ndarray::ArrayViewMut2;
use ndarray::Axis;
use ndarray::Zip;
use strum::Display;
use strum::EnumCount;
use strum::EnumIter;
Expand Down Expand Up @@ -118,32 +121,19 @@ impl Substitutions {
let derived_section_tran_start = derived_section_cons_start + self.cons.len();
let derived_section_term_start = derived_section_tran_start + self.tran.len();

let init_col_indices = (0..self.init.len())
.map(|i| i + derived_section_init_start)
.collect_vec();
let cons_col_indices = (0..self.cons.len())
.map(|i| i + derived_section_cons_start)
.collect_vec();
let tran_col_indices = (0..self.tran.len())
.map(|i| i + derived_section_tran_start)
.collect_vec();
let term_col_indices = (0..self.term.len())
.map(|i| i + derived_section_term_start)
.collect_vec();

let init_substitutions = Self::several_substitution_rules_to_code(&self.init);
let cons_substitutions = Self::several_substitution_rules_to_code(&self.cons);
let tran_substitutions = Self::several_substitution_rules_to_code(&self.tran);
let term_substitutions = Self::several_substitution_rules_to_code(&self.term);

let init_substitutions =
Self::base_single_row_substitutions(&init_col_indices, &init_substitutions);
Self::base_single_row_substitutions(derived_section_init_start, &init_substitutions);
let cons_substitutions =
Self::base_single_row_substitutions(&cons_col_indices, &cons_substitutions);
Self::base_single_row_substitutions(derived_section_cons_start, &cons_substitutions);
let tran_substitutions =
Self::base_dual_row_substitutions(&tran_col_indices, &tran_substitutions);
Self::base_dual_row_substitutions(derived_section_tran_start, &tran_substitutions);
let term_substitutions =
Self::base_single_row_substitutions(&term_col_indices, &term_substitutions);
Self::base_single_row_substitutions(derived_section_term_start, &term_substitutions);

quote!(
#[allow(unused_variables)]
Expand All @@ -163,32 +153,19 @@ impl Substitutions {
let derived_section_tran_start = derived_section_cons_start + self.cons.len();
let derived_section_term_start = derived_section_tran_start + self.tran.len();

let init_col_indices = (0..self.init.len())
.map(|i| i + derived_section_init_start)
.collect_vec();
let cons_col_indices = (0..self.cons.len())
.map(|i| i + derived_section_cons_start)
.collect_vec();
let tran_col_indices = (0..self.tran.len())
.map(|i| i + derived_section_tran_start)
.collect_vec();
let term_col_indices = (0..self.term.len())
.map(|i| i + derived_section_term_start)
.collect_vec();

let init_substitutions = Self::several_substitution_rules_to_code(&self.init);
let cons_substitutions = Self::several_substitution_rules_to_code(&self.cons);
let tran_substitutions = Self::several_substitution_rules_to_code(&self.tran);
let term_substitutions = Self::several_substitution_rules_to_code(&self.term);

let init_substitutions =
Self::ext_single_row_substitutions(&init_col_indices, &init_substitutions);
Self::ext_single_row_substitutions(derived_section_init_start, &init_substitutions);
let cons_substitutions =
Self::ext_single_row_substitutions(&cons_col_indices, &cons_substitutions);
Self::ext_single_row_substitutions(derived_section_cons_start, &cons_substitutions);
let tran_substitutions =
Self::ext_dual_row_substitutions(&tran_col_indices, &tran_substitutions);
Self::ext_dual_row_substitutions(derived_section_tran_start, &tran_substitutions);
let term_substitutions =
Self::ext_single_row_substitutions(&term_col_indices, &term_substitutions);
Self::ext_single_row_substitutions(derived_section_term_start, &term_substitutions);

quote!(
#[allow(unused_variables)]
Expand Down Expand Up @@ -241,93 +218,133 @@ impl Substitutions {
}

fn base_single_row_substitutions(
indices: &[usize],
section_start_index: usize,
substitutions: &[TokenStream],
) -> TokenStream {
assert_eq!(indices.len(), substitutions.len());
let num_substitutions = substitutions.len();
let indices = (0..num_substitutions).collect_vec();
if indices.is_empty() {
return quote!();
}
quote!(
master_base_table.rows_mut().into_iter().for_each(|mut row| {
#(
let (base_row, mut det_col) =
row.multi_slice_mut((s![..#indices],s![#indices..=#indices]));
det_col[0] = #substitutions;
)*
});
let (original_part, mut current_section) =
master_base_table.multi_slice_mut(
(
s![.., 0..#section_start_index],
s![.., #section_start_index..#section_start_index+#num_substitutions],
)
);
Zip::from(original_part.rows())
.and(current_section.rows_mut())
.par_for_each(|original_row, mut section_row| {
let mut base_row = original_row.to_owned();
#(
section_row[#indices] = #substitutions;
base_row.push(Axis(0), section_row.slice(s![#indices])).unwrap();
)*
});
)
}

fn base_dual_row_substitutions(
indices: &[usize],
section_start_index: usize,
substitutions: &[TokenStream],
) -> TokenStream {
assert_eq!(indices.len(), substitutions.len());
let num_substitutions = substitutions.len();
let indices = (0..substitutions.len()).collect_vec();
if indices.is_empty() {
return quote!();
}
quote!(
for curr_row_idx in 0..master_base_table.nrows() - 1 {
let next_row_idx = curr_row_idx + 1;
let (mut curr_base_row, next_base_row) = master_base_table.multi_slice_mut((
s![curr_row_idx..=curr_row_idx, ..],
s![next_row_idx..=next_row_idx, ..],
));
let mut curr_base_row = curr_base_row.row_mut(0);
let next_base_row = next_base_row.row(0);
#(
let (current_base_row, mut det_col) =
curr_base_row.multi_slice_mut((s![..#indices], s![#indices..=#indices]));
det_col[0] = #substitutions;
)*
}
let num_rows = master_base_table.nrows();
let (original_part, mut current_section) =
master_base_table.multi_slice_mut(
(
s![.., 0..#section_start_index],
s![.., #section_start_index..#section_start_index+#num_substitutions],
)
);
let row_indices = Array1::from_vec((0..num_rows - 1).collect::<Vec<_>>());
Zip::from(current_section.slice_mut(s![0..num_rows-1, ..]).rows_mut())
.and(row_indices.view())
.par_for_each( |mut section_row, &current_row_index| {
let next_row_index = current_row_index + 1;
let current_base_row_slice = original_part.slice(s![current_row_index..=current_row_index, ..]);
let next_base_row_slice = original_part.slice(s![next_row_index..=next_row_index, ..]);
let mut current_base_row = current_base_row_slice.row(0).to_owned();
let next_base_row = next_base_row_slice.row(0);
#(
section_row[#indices] = #substitutions;
current_base_row.push(Axis(0), section_row.slice(s![#indices])).unwrap();
)*
});
)
}

fn ext_single_row_substitutions(
indices: &[usize],
section_start_index: usize,
substitutions: &[TokenStream],
) -> TokenStream {
assert_eq!(indices.len(), substitutions.len());
let num_substitutions = substitutions.len();
let indices = (0..substitutions.len()).collect_vec();
if indices.is_empty() {
return quote!();
}
quote!(
for row_idx in 0..master_base_table.nrows() - 1 {
let base_row = master_base_table.row(row_idx);
let mut extension_row = master_ext_table.row_mut(row_idx);
#(
let (ext_row, mut det_col) =
extension_row.multi_slice_mut((s![..#indices],s![#indices..=#indices]));
det_col[0] = #substitutions;
)*
}
let (original_part, mut current_section) = master_ext_table.multi_slice_mut(
(
s![.., 0..#section_start_index],
s![.., #section_start_index..#section_start_index+#num_substitutions],
)
);
Zip::from(master_base_table.rows())
.and(original_part.rows())
.and(current_section.rows_mut())
.par_for_each(
|base_table_row, original_row, mut section_row| {
let mut extension_row = original_row.to_owned();
#(
let (original_row_extension_row, mut det_col) =
section_row.multi_slice_mut((s![..#indices],s![#indices..=#indices]));
det_col[0] = #substitutions;
extension_row.push(Axis(0), det_col.slice(s![0])).unwrap();
)*
}
);
)
}

fn ext_dual_row_substitutions(indices: &[usize], substitutions: &[TokenStream]) -> TokenStream {
assert_eq!(indices.len(), substitutions.len());
fn ext_dual_row_substitutions(
section_start_index: usize,
substitutions: &[TokenStream],
) -> TokenStream {
let num_substitutions = substitutions.len();
let indices = (0..substitutions.len()).collect_vec();
if indices.is_empty() {
return quote!();
}
quote!(
for curr_row_idx in 0..master_base_table.nrows() - 1 {
let next_row_idx = curr_row_idx + 1;
let current_base_row = master_base_table.row(curr_row_idx);
let next_base_row = master_base_table.row(next_row_idx);
let (mut curr_ext_row, next_ext_row) = master_ext_table.multi_slice_mut((
s![curr_row_idx..=curr_row_idx, ..],
s![next_row_idx..=next_row_idx, ..],
));
let mut curr_ext_row = curr_ext_row.row_mut(0);
let next_ext_row = next_ext_row.row(0);
#(
let (current_ext_row, mut det_col) =
curr_ext_row.multi_slice_mut((s![..#indices], s![#indices..=#indices]));
det_col[0] = #substitutions;
)*
}
let num_rows = master_base_table.nrows();
let (original_part, mut current_section) = master_ext_table.multi_slice_mut(
(
s![.., 0..#section_start_index],
s![.., #section_start_index..#section_start_index+#num_substitutions],
)
);
let row_indices = Array1::from_vec((0..num_rows - 1).collect::<Vec<_>>());
Zip::from(current_section.slice_mut(s![0..num_rows-1, ..]).rows_mut())
.and(row_indices.view())
.par_for_each(|mut section_row, &current_row_index| {
let next_row_index = current_row_index + 1;
let current_base_row = master_base_table.row(current_row_index);
let next_base_row = master_base_table.row(next_row_index);
let mut current_ext_row = original_part.row(current_row_index).to_owned();
let next_ext_row = original_part.row(next_row_index);
#(
section_row[#indices]= #substitutions;
current_ext_row.push(Axis(0), section_row.slice(s![#indices])).unwrap();
)*
});
)
}
}
4 changes: 4 additions & 0 deletions triton-vm/src/table/master_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -875,15 +875,19 @@ impl MasterBaseTable {
u32_table,
];

profiler!(start "pad original tables");
Self::all_pad_functions()
.into_par_iter()
.zip_eq(base_tables.into_par_iter())
.zip_eq(table_lengths.into_par_iter())
.for_each(|((pad, base_table), table_length)| {
pad(base_table, table_length);
});
profiler!(stop "pad original tables");

profiler!(start "fill degree-lowering table");
DegreeLoweringTable::fill_derived_base_columns(self.trace_table_mut());
profiler!(stop "fill degree-lowering table");
}

fn all_pad_functions() -> [PadFunction; NUM_TABLES_WITHOUT_DEGREE_LOWERING] {
Expand Down

0 comments on commit 9c02c64

Please sign in to comment.