Skip to content

Commit

Permalink
chore(ssa refactor): enable_side_effects instruction (#1547)
Browse files Browse the repository at this point in the history
* chore(ssa refactor): enable_side_effects instruction

* chore(ssa refactor): fix and document enable_side_effects insertions

* chore(ssa refactor): rm comments

* fix(ssa refactor): redundant EnableSideEffects
  • Loading branch information
joss-aztec authored Jun 12, 2023
1 parent 08ca847 commit 2918155
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ impl AcirContext {
lhs: AcirVar,
rhs: AcirVar,
bit_size: u32,
predicate: Option<AcirVar>,
) -> Result<AcirVar, AcirGenError> {
let lhs_data = &self.vars[lhs];
let rhs_data = &self.vars[rhs];
Expand All @@ -523,8 +524,13 @@ impl AcirContext {

// TODO: check what happens when we do (a as u8) >= (b as u32)
// TODO: The frontend should shout in this case

let predicate = predicate.map(|acir_var| {
let predicate_data = &self.vars[acir_var];
predicate_data.to_expression().into_owned()
});
let is_greater_than_eq =
self.acir_ir.more_than_eq_comparison(&lhs_expr, &rhs_expr, bit_size)?;
self.acir_ir.more_than_eq_comparison(&lhs_expr, &rhs_expr, bit_size, predicate)?;

Ok(self.add_data(AcirVarData::Witness(is_greater_than_eq)))
}
Expand All @@ -536,10 +542,11 @@ impl AcirContext {
lhs: AcirVar,
rhs: AcirVar,
bit_size: u32,
predicate: Option<AcirVar>,
) -> Result<AcirVar, AcirGenError> {
// Flip the result of calling more than equal method to
// compute less than.
let comparison = self.more_than_eq_var(lhs, rhs, bit_size)?;
let comparison = self.more_than_eq_var(lhs, rhs, bit_size, predicate)?;

let one = self.add_constant(FieldElement::one());
self.sub_var(one, comparison) // comparison_negated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ impl GeneratedAcir {
a: &Expression,
b: &Expression,
max_bits: u32,
predicate: Option<Expression>,
) -> Result<Witness, AcirGenError> {
// Ensure that 2^{max_bits + 1} is less than the field size
//
Expand Down Expand Up @@ -667,7 +668,7 @@ impl GeneratedAcir {
b: Expression::from_field(two_max_bits),
q: q_witness,
r: r_witness,
predicate: None,
predicate,
})));

// Add constraint to ensure `r` is correctly bounded
Expand Down
15 changes: 14 additions & 1 deletion crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ struct Context {
/// already exists for this Value, we return the `AcirVar`.
ssa_values: HashMap<Id<Value>, AcirValue>,

/// The `AcirVar` that describes the condition belonging to the most recently invoked
/// `SideEffectsEnabled` instruction.
current_side_effects_enabled_var: Option<AcirVar>,

/// Manages and builds the `AcirVar`s to which the converted SSA values refer.
acir_context: AcirContext,
}
Expand Down Expand Up @@ -217,6 +221,10 @@ impl Context {
.expect("add Result types to all methods so errors bubble up");
self.define_result_var(dfg, instruction_id, result_acir_var);
}
Instruction::EnableSideEffects { condition } => {
let acir_var = self.convert_numeric_value(*condition, dfg);
self.current_side_effects_enabled_var = Some(acir_var);
}
Instruction::ArrayGet { array, index } => {
self.handle_array_operation(instruction_id, *array, *index, None, dfg);
}
Expand Down Expand Up @@ -405,7 +413,12 @@ impl Context {
// Note: that this produces unnecessary constraints when
// this Eq instruction is being used for a constrain statement
BinaryOp::Eq => self.acir_context.eq_var(lhs, rhs),
BinaryOp::Lt => self.acir_context.less_than_var(lhs, rhs, bit_count),
BinaryOp::Lt => self.acir_context.less_than_var(
lhs,
rhs,
bit_count,
self.current_side_effects_enabled_var,
),
BinaryOp::Shl => self.acir_context.shift_left_var(lhs, rhs, binary_type),
BinaryOp::Shr => self.acir_context.shift_right_var(lhs, rhs, binary_type),
BinaryOp::Xor => self.acir_context.xor_var(lhs, rhs, binary_type),
Expand Down
17 changes: 16 additions & 1 deletion crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ pub(crate) enum Instruction {
/// Writes a value to memory.
Store { address: ValueId, value: ValueId },

/// Provides a context for all instructions that follow up until the next
/// `EnableSideEffects` is encountered, for stating a condition that determines whether
/// such instructions are allowed to have side-effects.
///
/// This instruction is only emitted after the cfg flattening pass, and is used to annotate
/// instruction regions with an condition that corresponds to their position in the CFG's
/// if-branching structure.
EnableSideEffects { condition: ValueId },

/// Retrieve a value from an array at the given index
ArrayGet { array: ValueId, index: ValueId },

Expand All @@ -127,7 +136,9 @@ impl Instruction {
InstructionResultType::Operand(*value)
}
Instruction::ArraySet { array, .. } => InstructionResultType::Operand(*array),
Instruction::Constrain(_) | Instruction::Store { .. } => InstructionResultType::None,
Instruction::Constrain(_)
| Instruction::Store { .. }
| Instruction::EnableSideEffects { .. } => InstructionResultType::None,
Instruction::Load { .. } | Instruction::ArrayGet { .. } | Instruction::Call { .. } => {
InstructionResultType::Unknown
}
Expand Down Expand Up @@ -167,6 +178,9 @@ impl Instruction {
Instruction::Store { address, value } => {
Instruction::Store { address: f(*address), value: f(*value) }
}
Instruction::EnableSideEffects { condition } => {
Instruction::EnableSideEffects { condition: f(*condition) }
}
Instruction::ArrayGet { array, index } => {
Instruction::ArrayGet { array: f(*array), index: f(*index) }
}
Expand Down Expand Up @@ -256,6 +270,7 @@ impl Instruction {
Instruction::Allocate { .. } => None,
Instruction::Load { .. } => None,
Instruction::Store { .. } => None,
Instruction::EnableSideEffects { .. } => None,
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ pub(crate) fn display_instruction(
Instruction::Store { address, value } => {
writeln!(f, "store {} at {}", show(*value), show(*address))
}
Instruction::EnableSideEffects { condition } => {
writeln!(f, "enable_side_effects {}", show(*condition))
}
Instruction::ArrayGet { array, index } => {
writeln!(f, "array_get {}, index {}", show(*array), show(*index))
}
Expand Down
155 changes: 105 additions & 50 deletions crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,26 @@
//! while merging branches. These extra instructions can be cleaned up by a later dead instruction
//! elimination (DIE) pass.
//!
//! Though CFG information is lost during this pass, some key information is retained in the form
//! of `EnableSideEffect` instructions. Each time the flattening pass enters and exits a branch of
//! a jmpif, an instruction is inserted to capture a condition that is analogous to the activeness
//! of the program point. For example:
//!
//! b0(v0: u1):
//! jmpif v0, then: b1, else: b2
//! b1():
//! v1 = call f0
//! jmp b3(v1)
//! ... blocks b2 & b3 ...
//!
//! Would brace the call instruction as such:
//! enable_side_effects v0
//! v1 = call f0
//! enable_side_effects u1 1
//!
//! (Note: we restore to "true" to indicate that this program point is not nested within any
//! other branches.)
//!
//! When we are flattening a block that was reached via a jmpif with a non-constant condition c,
//! the following transformations of certain instructions within the block are expected:
//!
Expand Down Expand Up @@ -297,6 +317,8 @@ impl<'f> Context<'f> {
let else_branch =
self.inline_branch(block, else_block, old_condition, else_condition, zero);

self.insert_current_side_effects_enabled();

// While there is a condition on the stack we don't compile outside the condition
// until it is popped. This ensures we inline the full then and else branches
// before continuing from the end of the conditional here where they can be merged properly.
Expand Down Expand Up @@ -368,6 +390,20 @@ impl<'f> Context<'f> {
self.function.dfg.insert_instruction_and_results(instruction, block, ctrl_typevars)
}

/// Checks the branch condition on the top of the stack and uses it to build and insert an
/// `EnableSideEffects` instruction into the entry block.
///
/// If the stack is empty, a "true" u1 constant is taken to be the active condition. This is
/// necessary for re-enabling side-effects when re-emerging to a branch depth of 0.
fn insert_current_side_effects_enabled(&mut self) {
let condition = match self.conditions.last() {
Some((_, cond)) => *cond,
None => self.function.dfg.make_constant(FieldElement::one(), Type::unsigned(1)),
};
let enable_side_effects = Instruction::EnableSideEffects { condition };
self.insert_instruction_with_typevars(enable_side_effects, None);
}

/// Merge two values a and b from separate basic blocks to a single value. This
/// function would return the result of `if c { a } else { b }` as `c*a + (!c)*b`.
fn merge_values(
Expand Down Expand Up @@ -406,6 +442,7 @@ impl<'f> Context<'f> {
condition_value: FieldElement,
) -> Branch {
self.push_condition(jmpif_block, new_condition);
self.insert_current_side_effects_enabled();
let old_stores = std::mem::take(&mut self.store_values);

// Remember the old condition value is now known to be true/false within this branch
Expand Down Expand Up @@ -653,11 +690,13 @@ mod test {
// Expected output:
// fn main f0 {
// b0(v0: u1):
// v4 = not v0
// v5 = mul v0, Field 3
// v7 = not v0
// v8 = mul v7, Field 4
// v9 = add v5, v8
// enable_side_effects v0
// v5 = not v0
// enable_side_effects v5
// enable_side_effects u1 1
// v7 = mul v0, Field 3
// v8 = mul v5, Field 4
// v9 = add v7, v8
// return v9
// }
let ssa = ssa.flatten_cfg();
Expand Down Expand Up @@ -696,13 +735,17 @@ mod test {
let ssa = builder.finish();
assert_eq!(ssa.main().reachable_blocks().len(), 3);

// Expected output (sans useless extra 'not' instruction):
// Expected output:
// fn main f0 {
// b0(v0: u1, v1: u1):
// v2 = mul v1, v0
// v3 = eq v2, v0
// constrain v3
// return v1
// enable_side_effects v0
// v3 = mul v1, v0
// v4 = eq v3, v0
// constrain v4
// v5 = not v0
// enable_side_effects v5
// enable_side_effects u1 1
// return
// }
let ssa = ssa.flatten_cfg();
assert_eq!(ssa.main().reachable_blocks().len(), 1);
Expand Down Expand Up @@ -743,14 +786,16 @@ mod test {
// Expected output:
// fn main f0 {
// b0(v0: u1, v1: reference):
// enable_side_effects v0
// v4 = load v1
// store Field 5 at v1
// v5 = not v0
// enable_side_effects v5
// enable_side_effects u1 1
// v7 = mul v0, Field 5
// v8 = not v0
// v9 = mul v8, v4
// v10 = add v7, v9
// store v10 at v1
// v8 = mul v5, v4
// v9 = add v7, v8
// store v9 at v1
// return
// }
let ssa = ssa.flatten_cfg();
Expand Down Expand Up @@ -817,21 +862,24 @@ mod test {
// Expected output:
// fn main f0 {
// b0(v0: u1, v1: reference):
// v8 = add v1, Field 1
// v9 = load v8
// store Field 5 at v8
// v10 = not v0
// v12 = add v1, Field 1
// v13 = load v12
// store Field 6 at v12
// v14 = mul v0, Field 5
// v15 = mul v10, v9
// v16 = add v14, v15
// store v16 at v8
// v17 = mul v0, v13
// v18 = mul v10, Field 6
// v19 = add v17, v18
// store v19 at v12
// enable_side_effects v0
// v7 = add v1, Field 1
// v8 = load v7
// store Field 5 at v7
// v9 = not v0
// enable_side_effects v9
// v11 = add v1, Field 1
// v12 = load v11
// store Field 6 at v11
// enable_side_effects Field 1
// v13 = mul v0, Field 5
// v14 = mul v9, v8
// v15 = add v13, v14
// store v15 at v7
// v16 = mul v0, v12
// v17 = mul v9, Field 6
// v18 = add v16, v17
// store v18 at v11
// return
// }
let ssa = ssa.flatten_cfg();
Expand Down Expand Up @@ -1023,31 +1071,38 @@ mod test {
// b0(v0: u1, v1: u1):
// call println(Field 0, Field 0)
// call println(Field 1, Field 1)
// enable_side_effects v0
// call println(Field 2, Field 2)
// call println(Field 4, Field 2) ; block 4 does not store a value
// v45 = and v0, v1
// call println(Field 4, Field 2)
// v29 = and v0, v1
// enable_side_effects v29
// call println(Field 5, Field 5)
// v49 = not v1
// v50 = and v0, v49
// v32 = not v1
// v33 = and v0, v32
// enable_side_effects v33
// call println(Field 6, Field 6)
// v54 = mul v1, Field 5
// v55 = mul v49, Field 2
// v56 = add v54, v55
// v57 = mul v1, Field 5
// v58 = mul v49, Field 6
// v59 = add v57, v58
// call println(Field 7, v59) ; v59 = 5 and 6 merged
// v61 = not v0
// enable_side_effects v0
// v36 = mul v1, Field 5
// v37 = mul v32, Field 2
// v38 = add v36, v37
// v39 = mul v1, Field 5
// v40 = mul v32, Field 6
// v41 = add v39, v40
// call println(Field 7, v42)
// v43 = not v0
// enable_side_effects v43
// store Field 3 at v2
// call println(Field 3, Field 3)
// call println(Field 8, Field 3) ; block 8 does not store a value
// v66 = mul v0, v59
// v67 = mul v61, Field 1
// v68 = add v66, v67 ; This was from an unused store.
// v69 = mul v0, v59
// v70 = mul v61, Field 3
// v71 = add v69, v70
// call println(Field 9, v71) ; v71 = 3, 5, and 6 merged
// return v71
// call println(Field 8, Field 3)
// enable_side_effects Field 1
// v47 = mul v0, v41
// v48 = mul v43, Field 1
// v49 = add v47, v48
// v50 = mul v0, v44
// v51 = mul v43, Field 3
// v52 = add v50, v51
// call println(Field 9, v53)
// return v54
// }

let main = ssa.main();
Expand Down

0 comments on commit 2918155

Please sign in to comment.