Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(ssa refactor): enable_side_effects instruction #1547

Merged
merged 7 commits into from
Jun 12, 2023
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,12 @@ impl AcirContext {
}
/// Returns an `AcirVar` which will be `1` if lhs >= rhs
/// and `0` otherwise.
fn more_than_eq_var(&mut self, lhs: AcirVar, rhs: AcirVar) -> Result<AcirVar, AcirGenError> {
fn more_than_eq_var(
&mut self,
lhs: AcirVar,
rhs: AcirVar,
predicate: Option<AcirVar>,
) -> Result<AcirVar, AcirGenError> {
let lhs_data = &self.data[&lhs];
let rhs_data = &self.data[&rhs];

Expand All @@ -498,8 +503,17 @@ impl AcirContext {
// TODO: The frontend should shout in this case
assert_eq!(lhs_type, rhs_type, "types in a more than eq comparison should be the same");

let is_greater_than_eq =
self.acir_ir.more_than_eq_comparison(&lhs_expr, &rhs_expr, lhs_type.bit_size())?;
let predicate = predicate.map(|acir_var| {
let predicate_data = &self.data[&acir_var];
predicate_data.to_expression().into_owned()
});

let is_greater_than_eq = self.acir_ir.more_than_eq_comparison(
&lhs_expr,
&rhs_expr,
lhs_type.bit_size(),
predicate,
)?;

Ok(self.add_data(AcirVarData::Witness(is_greater_than_eq)))
}
Expand All @@ -510,10 +524,11 @@ impl AcirContext {
&mut self,
lhs: AcirVar,
rhs: AcirVar,
predicate: Option<AcirVar>,
) -> Result<AcirVar, AcirGenError> {
// Flip the result of calling more than equal method to
// compute less than.
let comparison = self.more_than_eq_var(lhs, rhs)?;
let comparison = self.more_than_eq_var(lhs, rhs, predicate)?;

let one = self.add_constant(FieldElement::one());
let comparison_negated = self.sub_var(one, comparison);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ impl GeneratedAcir {
a: &Expression,
b: &Expression,
max_bits: u32,
predicate: Option<Expression>,
) -> Result<Witness, AcirGenError> {
// Ensure that 2^{max_bits + 1} is less than the field size
//
Expand All @@ -447,7 +448,7 @@ impl GeneratedAcir {
b: Expression::from_field(two_max_bits),
q: q_witness,
r: r_witness,
predicate: None,
predicate,
})));

// Add constraint to ensure `r` is correctly bounded
Expand Down
10 changes: 9 additions & 1 deletion crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ struct Context {
/// of such instructions are stored, in effect capturing any further values that refer to
/// addresses.
ssa_value_to_array_address: HashMap<ValueId, (ArrayId, usize)>,
/// The `AcirVar` that describes the condition belonging to the most recently invoked
/// `SideEffectsEnabled` instruction.
current_side_effects_enabled_var: Option<AcirVar>,
/// Manages and builds the `AcirVar`s to which the converted SSA values refer.
acir_context: AcirContext,
}
Expand Down Expand Up @@ -281,6 +284,11 @@ impl Context {

(vec![result_ids[0]], vec![result_acir_var])
}
Instruction::EnableSideEffects { condition } => {
let acir_var = self.convert_ssa_value(*condition, dfg);
self.current_side_effects_enabled_var = Some(acir_var);
(Vec::new(), Vec::new())
}
};

// Map the results of the instructions to Acir variables
Expand Down Expand Up @@ -357,7 +365,7 @@ impl Context {
BinaryOp::Eq => self.acir_context.eq_var(lhs, rhs),
BinaryOp::Lt => self
.acir_context
.less_than_var(lhs, rhs)
.less_than_var(lhs, rhs, self.current_side_effects_enabled_var)
.expect("add Result types to all methods so errors bubble up"),
BinaryOp::Shl => self.acir_context.shift_left_var(lhs, rhs, binary_type.into()),
BinaryOp::Shr => self.acir_context.shift_right_var(lhs, rhs, binary_type.into()),
Expand Down
17 changes: 16 additions & 1 deletion crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ pub(crate) enum Instruction {

/// Writes a value to memory.
Store { address: ValueId, value: ValueId },

/// Provides a context for all instructions that follow up until the next
/// `EnableSideEffects` is encountered, for stating a condition that determines whether
/// such instructions are allowed to have side-effects.
///
/// This instruction is only emitted after the cfg flattening pass, and is used to annotate
/// instruction regions with an condition that corresponds to their position in the CFG's
/// if-branching structure.
EnableSideEffects { condition: ValueId },
}

impl Instruction {
Expand All @@ -122,7 +131,9 @@ impl Instruction {
Instruction::Not(value) | Instruction::Truncate { value, .. } => {
InstructionResultType::Operand(*value)
}
Instruction::Constrain(_) | Instruction::Store { .. } => InstructionResultType::None,
Instruction::Constrain(_)
| Instruction::Store { .. }
| Instruction::EnableSideEffects { .. } => InstructionResultType::None,
Instruction::Load { .. } | Instruction::Call { .. } => InstructionResultType::Unknown,
}
}
Expand Down Expand Up @@ -160,6 +171,9 @@ impl Instruction {
Instruction::Store { address, value } => {
Instruction::Store { address: f(*address), value: f(*value) }
}
Instruction::EnableSideEffects { condition } => {
Instruction::EnableSideEffects { condition: f(*condition) }
}
}
}

Expand Down Expand Up @@ -207,6 +221,7 @@ impl Instruction {
Instruction::Allocate { .. } => None,
Instruction::Load { .. } => None,
Instruction::Store { .. } => None,
Instruction::EnableSideEffects { .. } => None,
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,8 @@ pub(crate) fn display_instruction(
Instruction::Store { address, value } => {
writeln!(f, "store {} at {}", show(*value), show(*address))
}
Instruction::EnableSideEffects { condition } => {
writeln!(f, "enable_side_effects {}", show(*condition))
}
}
}
162 changes: 112 additions & 50 deletions crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,26 @@
//! while merging branches. These extra instructions can be cleaned up by a later dead instruction
//! elimination (DIE) pass.
//!
//! Though CFG information is lost during this pass, some key information is retained in the form
//! of `EnableSideEffect` instructions. Each time the flattening pass enters and exits a branch of
//! a jmpif, an instruction is inserted to capture a condition that is analogous to the activeness
//! of the program point. For example:
//!
//! b0(v0: u1):
//! jmpif v0, then: b1, else: b2
//! b1():
//! v1 = call f0
//! jmp b3(v1)
//! ... blocks b2 & b3 ...
//!
//! Would brace the call instruction as such:
//! enable_side_effects v0
//! v1 = call f0
//! enable_side_effects u1 1
//!
//! (Note: we restore to "true" to indicate that this program point is not nested within any
//! other branches.)
//!
//! When we are flattening a block that was reached via a jmpif with a non-constant condition c,
//! the following transformations of certain instructions within the block are expected:
//!
Expand Down Expand Up @@ -359,6 +379,20 @@ impl<'f> Context<'f> {
self.function.dfg.insert_instruction_and_results(instruction, block, ctrl_typevars)
}

/// Checks the branch condition on the top of the stack and uses it to build and insert an
/// `EnableSideEffects` instruction into the entry block.
///
/// If the stack is empty, a "true" u1 constant is taken to be the active condition. This is
/// necessary for re-enabling side-effects when re-emerging to a branch depth of 0.
fn insert_current_side_effects_enabled(&mut self) {
let condition = match self.conditions.last() {
Some((_, cond)) => *cond,
None => self.function.dfg.make_constant(FieldElement::one(), Type::unsigned(1)),
};
let enable_side_effects = Instruction::EnableSideEffects { condition };
self.insert_instruction_with_typevars(enable_side_effects, None);
}

/// Merge two values a and b from separate basic blocks to a single value. This
/// function would return the result of `if c { a } else { b }` as `c*a + (!c)*b`.
fn merge_values(
Expand Down Expand Up @@ -397,6 +431,7 @@ impl<'f> Context<'f> {
condition_value: FieldElement,
) -> Branch {
self.push_condition(jmpif_block, new_condition);
self.insert_current_side_effects_enabled(); // Instruction to annotate condition was pushed
let old_stores = std::mem::take(&mut self.store_values);

// Remember the old condition value is now known to be true/false within this branch
Expand All @@ -406,6 +441,7 @@ impl<'f> Context<'f> {
let final_block = self.inline_block(destination, &[]);

self.conditions.pop();
self.insert_current_side_effects_enabled(); // Instruction to annotate condition was popped
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
let stores_in_branch = std::mem::replace(&mut self.store_values, old_stores);

Branch { condition: new_condition, last_block: final_block, store_values: stores_in_branch }
Expand Down Expand Up @@ -643,11 +679,14 @@ mod test {
// Expected output:
// fn main f0 {
// b0(v0: u1):
// v4 = not v0
// v5 = mul v0, Field 3
// v7 = not v0
// v8 = mul v7, Field 4
// v9 = add v5, v8
// enable_side_effects v0
// enable_side_effects u1 1
// v5 = not v0
// enable_side_effects v5
// enable_side_effects u1 1
// v7 = mul v0, Field 3
// v8 = mul v5, Field 4
// v9 = add v7, v8
// return v9
// }
let ssa = ssa.flatten_cfg();
Expand Down Expand Up @@ -686,15 +725,21 @@ mod test {
let ssa = builder.finish();
assert_eq!(ssa.main().reachable_blocks().len(), 3);

// Expected output (sans useless extra 'not' instruction):
// Expected output:
// fn main f0 {
// b0(v0: u1, v1: u1):
// v2 = mul v1, v0
// v3 = eq v2, v0
// constrain v3
// return v1
// enable_side_effects v0
// v3 = mul v1, v0
// v4 = eq v3, v0
// constrain v4
// enable_side_effects u1 1
// v5 = not v0
// enable_side_effects v5
// enable_side_effects u1 1
// return
// }
let ssa = ssa.flatten_cfg();
println!("{ssa}");
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
assert_eq!(ssa.main().reachable_blocks().len(), 1);
}

Expand Down Expand Up @@ -733,17 +778,21 @@ mod test {
// Expected output:
// fn main f0 {
// b0(v0: u1, v1: reference):
// enable_side_effects v0
// v4 = load v1
// store Field 5 at v1
// enable_side_effects u1 1
// v5 = not v0
// enable_side_effects v5
// enable_side_effects u1 1
// v7 = mul v0, Field 5
// v8 = not v0
// v9 = mul v8, v4
// v10 = add v7, v9
// store v10 at v1
// v8 = mul v5, v4
// v9 = add v7, v8
// store v9 at v1
// return
// }
let ssa = ssa.flatten_cfg();
println!("{ssa}");
let main = ssa.main();
assert_eq!(main.reachable_blocks().len(), 1);

Expand Down Expand Up @@ -807,21 +856,25 @@ mod test {
// Expected output:
// fn main f0 {
// b0(v0: u1, v1: reference):
// v8 = add v1, Field 1
// v9 = load v8
// store Field 5 at v8
// v10 = not v0
// v12 = add v1, Field 1
// v13 = load v12
// store Field 6 at v12
// v14 = mul v0, Field 5
// v15 = mul v10, v9
// v16 = add v14, v15
// store v16 at v8
// v17 = mul v0, v13
// v18 = mul v10, Field 6
// v19 = add v17, v18
// store v19 at v12
// enable_side_effects v0
// v7 = add v1, Field 1
// v8 = load v7
// store Field 5 at v7
// enable_side_effects Field 1
// v9 = not v0
// enable_side_effects v9
// v11 = add v1, Field 1
// v12 = load v11
// store Field 6 at v11
// enable_side_effects Field 1
// v13 = mul v0, Field 5
// v14 = mul v9, v8
// v15 = add v13, v14
// store v15 at v7
// v16 = mul v0, v12
// v17 = mul v9, Field 6
// v18 = add v16, v17
// store v18 at v11
// return
// }
let ssa = ssa.flatten_cfg();
Expand Down Expand Up @@ -1017,31 +1070,40 @@ mod test {
// b0(v0: u1, v1: u1):
// call println(Field 0, Field 0)
// call println(Field 1, Field 1)
// enable_side_effects v0
// call println(Field 2, Field 2)
// call println(Field 4, Field 2) ; block 4 does not store a value
// v45 = and v0, v1
// call println(Field 4, Field 2)
// v29 = and v0, v1
// enable_side_effects v29
// call println(Field 5, Field 5)
// v49 = not v1
// v50 = and v0, v49
// enable_side_effects v0
// v32 = not v1
// v33 = and v0, v32
// enable_side_effects v33
// call println(Field 6, Field 6)
// v54 = mul v1, Field 5
// v55 = mul v49, Field 2
// v56 = add v54, v55
// v57 = mul v1, Field 5
// v58 = mul v49, Field 6
// v59 = add v57, v58
// call println(Field 7, v59) ; v59 = 5 and 6 merged
// v61 = not v0
// enable_side_effects v0
// v36 = mul v1, Field 5
// v37 = mul v32, Field 2
// v38 = add v36, v37
// v39 = mul v1, Field 5
// v40 = mul v32, Field 6
// v41 = add v39, v40
// call println(Field 7, v42)
// enable_side_effects Field 1
// v43 = not v0
// enable_side_effects v43
// store Field 3 at v2
// call println(Field 3, Field 3)
// call println(Field 8, Field 3) ; block 8 does not store a value
// v66 = mul v0, v59
// v67 = mul v61, Field 1
// v68 = add v66, v67 ; This was from an unused store.
// v69 = mul v0, v59
// v70 = mul v61, Field 3
// v71 = add v69, v70
// call println(Field 9, v71) ; v71 = 3, 5, and 6 merged
// return v71
// call println(Field 8, Field 3)
// enable_side_effects Field 1
// v47 = mul v0, v41
// v48 = mul v43, Field 1
// v49 = add v47, v48
// v50 = mul v0, v44
// v51 = mul v43, Field 3
// v52 = add v50, v51
// call println(Field 9, v53)
// return v54
// }

let main = ssa.main();
Expand Down