Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize Filling of Degree-Lowering Table #284

Merged
merged 12 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 105 additions & 88 deletions constraint-evaluation-generator/src/substitution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ impl AllSubstitutions {
//! To re-generate, execute:
//! `cargo run --bin constraint-evaluation-generator`

use ndarray::Array1;
use ndarray::s;
use ndarray::ArrayView2;
use ndarray::ArrayViewMut2;
use ndarray::Axis;
use ndarray::Zip;
use strum::Display;
use strum::EnumCount;
use strum::EnumIter;
Expand Down Expand Up @@ -118,32 +121,19 @@ impl Substitutions {
let derived_section_tran_start = derived_section_cons_start + self.cons.len();
let derived_section_term_start = derived_section_tran_start + self.tran.len();

let init_col_indices = (0..self.init.len())
.map(|i| i + derived_section_init_start)
.collect_vec();
let cons_col_indices = (0..self.cons.len())
.map(|i| i + derived_section_cons_start)
.collect_vec();
let tran_col_indices = (0..self.tran.len())
.map(|i| i + derived_section_tran_start)
.collect_vec();
let term_col_indices = (0..self.term.len())
.map(|i| i + derived_section_term_start)
.collect_vec();

let init_substitutions = Self::several_substitution_rules_to_code(&self.init);
let cons_substitutions = Self::several_substitution_rules_to_code(&self.cons);
let tran_substitutions = Self::several_substitution_rules_to_code(&self.tran);
let term_substitutions = Self::several_substitution_rules_to_code(&self.term);

let init_substitutions =
Self::base_single_row_substitutions(&init_col_indices, &init_substitutions);
Self::base_single_row_substitutions(derived_section_init_start, &init_substitutions);
let cons_substitutions =
Self::base_single_row_substitutions(&cons_col_indices, &cons_substitutions);
Self::base_single_row_substitutions(derived_section_cons_start, &cons_substitutions);
let tran_substitutions =
Self::base_dual_row_substitutions(&tran_col_indices, &tran_substitutions);
Self::base_dual_row_substitutions(derived_section_tran_start, &tran_substitutions);
let term_substitutions =
Self::base_single_row_substitutions(&term_col_indices, &term_substitutions);
Self::base_single_row_substitutions(derived_section_term_start, &term_substitutions);

quote!(
#[allow(unused_variables)]
Expand All @@ -163,32 +153,19 @@ impl Substitutions {
let derived_section_tran_start = derived_section_cons_start + self.cons.len();
let derived_section_term_start = derived_section_tran_start + self.tran.len();

let init_col_indices = (0..self.init.len())
.map(|i| i + derived_section_init_start)
.collect_vec();
let cons_col_indices = (0..self.cons.len())
.map(|i| i + derived_section_cons_start)
.collect_vec();
let tran_col_indices = (0..self.tran.len())
.map(|i| i + derived_section_tran_start)
.collect_vec();
let term_col_indices = (0..self.term.len())
.map(|i| i + derived_section_term_start)
.collect_vec();

let init_substitutions = Self::several_substitution_rules_to_code(&self.init);
let cons_substitutions = Self::several_substitution_rules_to_code(&self.cons);
let tran_substitutions = Self::several_substitution_rules_to_code(&self.tran);
let term_substitutions = Self::several_substitution_rules_to_code(&self.term);

let init_substitutions =
Self::ext_single_row_substitutions(&init_col_indices, &init_substitutions);
Self::ext_single_row_substitutions(derived_section_init_start, &init_substitutions);
let cons_substitutions =
Self::ext_single_row_substitutions(&cons_col_indices, &cons_substitutions);
Self::ext_single_row_substitutions(derived_section_cons_start, &cons_substitutions);
let tran_substitutions =
Self::ext_dual_row_substitutions(&tran_col_indices, &tran_substitutions);
Self::ext_dual_row_substitutions(derived_section_tran_start, &tran_substitutions);
let term_substitutions =
Self::ext_single_row_substitutions(&term_col_indices, &term_substitutions);
Self::ext_single_row_substitutions(derived_section_term_start, &term_substitutions);

quote!(
#[allow(unused_variables)]
Expand Down Expand Up @@ -241,93 +218,133 @@ impl Substitutions {
}

fn base_single_row_substitutions(
indices: &[usize],
section_start_index: usize,
substitutions: &[TokenStream],
) -> TokenStream {
assert_eq!(indices.len(), substitutions.len());
let num_substitutions = substitutions.len();
let indices = (0..num_substitutions).collect_vec();
aszepieniec marked this conversation as resolved.
Show resolved Hide resolved
if indices.is_empty() {
return quote!();
}
quote!(
master_base_table.rows_mut().into_iter().for_each(|mut row| {
#(
let (base_row, mut det_col) =
row.multi_slice_mut((s![..#indices],s![#indices..=#indices]));
det_col[0] = #substitutions;
)*
});
let (original_part, mut current_section) =
master_base_table.multi_slice_mut(
(
s![.., 0..#section_start_index],
s![.., #section_start_index..#section_start_index+#num_substitutions],
)
);
Zip::from(original_part.rows())
.and(current_section.rows_mut())
.par_for_each(|original_row, mut section_row| {
let mut base_row = original_row.to_owned();
#(
section_row[#indices] = #substitutions;
base_row.push(Axis(0), section_row.slice(s![#indices])).unwrap();
)*
});
)
}

fn base_dual_row_substitutions(
indices: &[usize],
section_start_index: usize,
substitutions: &[TokenStream],
) -> TokenStream {
assert_eq!(indices.len(), substitutions.len());
let num_substitutions = substitutions.len();
let indices = (0..substitutions.len()).collect_vec();
if indices.is_empty() {
return quote!();
}
quote!(
for curr_row_idx in 0..master_base_table.nrows() - 1 {
let next_row_idx = curr_row_idx + 1;
let (mut curr_base_row, next_base_row) = master_base_table.multi_slice_mut((
s![curr_row_idx..=curr_row_idx, ..],
s![next_row_idx..=next_row_idx, ..],
));
let mut curr_base_row = curr_base_row.row_mut(0);
let next_base_row = next_base_row.row(0);
#(
let (current_base_row, mut det_col) =
curr_base_row.multi_slice_mut((s![..#indices], s![#indices..=#indices]));
det_col[0] = #substitutions;
)*
}
let num_rows = master_base_table.nrows();
let (original_part, mut current_section) =
master_base_table.multi_slice_mut(
(
s![.., 0..#section_start_index],
s![.., #section_start_index..#section_start_index+#num_substitutions],
)
);
let row_indices = Array1::from_vec((0..num_rows - 1).collect::<Vec<_>>());
Zip::from(current_section.slice_mut(s![0..num_rows-1, ..]).rows_mut())
.and(row_indices.view())
.par_for_each( |mut section_row, &current_row_index| {
let next_row_index = current_row_index + 1;
let current_base_row_slice = original_part.slice(s![current_row_index..=current_row_index, ..]);
let next_base_row_slice = original_part.slice(s![next_row_index..=next_row_index, ..]);
let mut current_base_row = current_base_row_slice.row(0).to_owned();
let next_base_row = next_base_row_slice.row(0);
#(
section_row[#indices] = #substitutions;
current_base_row.push(Axis(0), section_row.slice(s![#indices])).unwrap();
)*
});
)
}

fn ext_single_row_substitutions(
indices: &[usize],
section_start_index: usize,
substitutions: &[TokenStream],
) -> TokenStream {
assert_eq!(indices.len(), substitutions.len());
let num_substitutions = substitutions.len();
let indices = (0..substitutions.len()).collect_vec();
if indices.is_empty() {
return quote!();
}
quote!(
for row_idx in 0..master_base_table.nrows() - 1 {
let base_row = master_base_table.row(row_idx);
let mut extension_row = master_ext_table.row_mut(row_idx);
#(
let (ext_row, mut det_col) =
extension_row.multi_slice_mut((s![..#indices],s![#indices..=#indices]));
det_col[0] = #substitutions;
)*
}
let (original_part, mut current_section) = master_ext_table.multi_slice_mut(
(
s![.., 0..#section_start_index],
s![.., #section_start_index..#section_start_index+#num_substitutions],
)
);
Zip::from(master_base_table.rows())
.and(original_part.rows())
.and(current_section.rows_mut())
.par_for_each(
|base_table_row, original_row, mut section_row| {
let mut extension_row = original_row.to_owned();
#(
let (original_row_extension_row, mut det_col) =
section_row.multi_slice_mut((s![..#indices],s![#indices..=#indices]));
det_col[0] = #substitutions;
extension_row.push(Axis(0), det_col.slice(s![0])).unwrap();
)*
}
);
)
}

fn ext_dual_row_substitutions(indices: &[usize], substitutions: &[TokenStream]) -> TokenStream {
assert_eq!(indices.len(), substitutions.len());
fn ext_dual_row_substitutions(
section_start_index: usize,
substitutions: &[TokenStream],
) -> TokenStream {
let num_substitutions = substitutions.len();
let indices = (0..substitutions.len()).collect_vec();
if indices.is_empty() {
return quote!();
}
quote!(
for curr_row_idx in 0..master_base_table.nrows() - 1 {
jan-ferdinand marked this conversation as resolved.
Show resolved Hide resolved
let next_row_idx = curr_row_idx + 1;
let current_base_row = master_base_table.row(curr_row_idx);
let next_base_row = master_base_table.row(next_row_idx);
let (mut curr_ext_row, next_ext_row) = master_ext_table.multi_slice_mut((
s![curr_row_idx..=curr_row_idx, ..],
s![next_row_idx..=next_row_idx, ..],
));
let mut curr_ext_row = curr_ext_row.row_mut(0);
let next_ext_row = next_ext_row.row(0);
#(
let (current_ext_row, mut det_col) =
curr_ext_row.multi_slice_mut((s![..#indices], s![#indices..=#indices]));
det_col[0] = #substitutions;
)*
}
let num_rows = master_base_table.nrows();
let (original_part, mut current_section) = master_ext_table.multi_slice_mut(
(
s![.., 0..#section_start_index],
s![.., #section_start_index..#section_start_index+#num_substitutions],
)
);
let row_indices = Array1::from_vec((0..num_rows - 1).collect::<Vec<_>>());
Zip::from(current_section.slice_mut(s![0..num_rows-1, ..]).rows_mut())
.and(row_indices.view())
.par_for_each(|mut section_row, &current_row_index| {
let next_row_index = current_row_index + 1;
let current_base_row = master_base_table.row(current_row_index);
let next_base_row = master_base_table.row(next_row_index);
let mut current_ext_row = original_part.row(current_row_index).to_owned();
let next_ext_row = original_part.row(next_row_index);
#(
section_row[#indices]= #substitutions;
current_ext_row.push(Axis(0), section_row.slice(s![#indices])).unwrap();
)*
});
)
}
}
4 changes: 4 additions & 0 deletions triton-vm/src/table/master_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,15 +844,19 @@ impl MasterBaseTable {
u32_table,
];

profiler!(start "pad original tables");
Self::all_pad_functions()
.into_par_iter()
.zip_eq(base_tables.into_par_iter())
.zip_eq(table_lengths.into_par_iter())
.for_each(|((pad, base_table), table_length)| {
pad(base_table, table_length);
});
profiler!(stop "pad original tables");

profiler!(start "fill degree-lowering table");
DegreeLoweringTable::fill_derived_base_columns(self.trace_table_mut());
profiler!(stop "fill degree-lowering table");
}

fn all_pad_functions() -> [PadFunction; NUM_TABLES_WITHOUT_DEGREE_LOWERING] {
Expand Down