Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Various fixes for defunctionalization & brillig gen #1973

Merged
merged 6 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@ fn main(x: u32) {
assert(entry_point(x) == 2);
swap_entry_point(x, x + 1);
assert(deep_entry_point(x) == 4);
multiple_values_entry_point(x);
}

unconstrained fn returns_multiple_values(x : u32) -> (u32, u32, u32, u32) {
(x + 1, x + 2, x + 3, x + 4)
}

unconstrained fn multiple_values_entry_point(x: u32) {
let (a, b, c, d) = returns_multiple_values(x);
assert(a == x + 1);
assert(b == x + 2);
assert(c == x + 3);
assert(d == x + 4);
}

unconstrained fn inner(x : u32) -> u32 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ struct MyStruct {

fn main(x: u32) {
assert(wrapper(increment, x) == x + 1);
assert(wrapper(increment_acir, x) == x + 1);
assert(wrapper(decrement, x) == x - 1);
assert(wrapper_with_struct(MyStruct { operation: increment }, x) == x + 1);
assert(wrapper_with_struct(MyStruct { operation: decrement }, x) == x - 1);
// https://github.com/noir-lang/noir/issues/1975
assert(increment(x) == x + 1);
}

unconstrained fn wrapper(func: fn (u32) -> u32, param: u32) -> u32 {
Expand All @@ -26,3 +29,8 @@ unconstrained fn wrapper_with_struct(my_struct: MyStruct, param: u32) -> u32 {
func(param)
}



fn increment_acir(x: u32) -> u32 {
x + 1
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// The features being tested is basic looping on brillig
fn main(sum: u32){
assert(loop(4) == sum);
assert(plain_loop() == sum);
}

unconstrained fn loop(x: u32) -> u32 {
Expand All @@ -12,3 +13,11 @@ unconstrained fn loop(x: u32) -> u32 {
}
sum
}

unconstrained fn plain_loop() -> u32 {
let mut sum = 0;
for i in 0..4 {
sum = sum + i;
}
sum
}
144 changes: 90 additions & 54 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::brillig::brillig_gen::brillig_slice_ops::{
convert_array_or_vector_to_vector, slice_push_back_operation,
};
use crate::brillig::brillig_ir::{BrilligBinaryOp, BrilligContext};
use crate::brillig::brillig_ir::{
BrilligBinaryOp, BrilligContext, BRILLIG_INTEGER_ARITHMETIC_BIT_SIZE,
};
use crate::ssa_refactor::ir::function::FunctionId;
use crate::ssa_refactor::ir::instruction::Intrinsic;
use crate::ssa_refactor::ir::{
Expand Down Expand Up @@ -285,45 +287,7 @@
}
}
Value::Function(func_id) => {
let argument_registers: Vec<RegisterIndex> = arguments
.iter()
.flat_map(|arg| {
let arg = self.convert_ssa_value(*arg, dfg);
self.function_context.extract_registers(arg)
})
.collect();
let result_ids = dfg.instruction_results(instruction_id);

// Create label for the function that will be called
let label_of_function_to_call =
FunctionContext::function_id_to_function_label(*func_id);

let saved_registers =
self.brillig_context.pre_call_save_registers_prep_args(&argument_registers);

// Call instruction, which will interpret above registers 0..num args
self.brillig_context.add_external_call_instruction(label_of_function_to_call);

// Important: resolve after pre_call_save_registers_prep_args
// This ensures we don't save the results to registers unnecessarily.
let result_registers: Vec<RegisterIndex> = result_ids
.iter()
.flat_map(|arg| {
let arg = self.function_context.create_variable(
self.brillig_context,
*arg,
dfg,
);
self.function_context.extract_registers(arg)
})
.collect();

assert!(
!saved_registers.iter().any(|x| result_registers.contains(x)),
"should not save registers used as function results"
);
self.brillig_context
.post_call_prep_returns_load_registers(&result_registers, &saved_registers);
self.convert_ssa_function_call(*func_id, arguments, dfg, instruction_id);
}
Value::Intrinsic(Intrinsic::BlackBox(bb_func)) => {
let function_arguments =
Expand Down Expand Up @@ -430,6 +394,59 @@
};
}

fn convert_ssa_function_call(
&mut self,
func_id: FunctionId,
arguments: &[ValueId],
dfg: &DataFlowGraph,
instruction_id: InstructionId,
) {
// Convert the arguments to registers casting those to the types of the receiving function
let argument_registers: Vec<RegisterIndex> = arguments
.iter()
.flat_map(|argument_id| {
let variable_to_pass = self.convert_ssa_value(*argument_id, dfg);
self.function_context.extract_registers(variable_to_pass)
})
.collect();

let result_ids = dfg.instruction_results(instruction_id);

// Create label for the function that will be called
let label_of_function_to_call = FunctionContext::function_id_to_function_label(func_id);

let saved_registers =
self.brillig_context.pre_call_save_registers_prep_args(&argument_registers);

// Call instruction, which will interpret above registers 0..num args
self.brillig_context.add_external_call_instruction(label_of_function_to_call);

// Important: resolve after pre_call_save_registers_prep_args
// This ensures we don't save the results to registers unnecessarily.

// Allocate the registers for the variables where we are assigning the returns
let variables_assigned_to = vecmap(result_ids, |result_id| {
self.function_context.create_variable(self.brillig_context, *result_id, dfg)
});

// Collect the registers that should have been returned
let returned_registers: Vec<RegisterIndex> = variables_assigned_to
.iter()
.flat_map(|returned_variable| {
self.function_context.extract_registers(*returned_variable)
})
.collect();

assert!(
!saved_registers.iter().any(|x| returned_registers.contains(x)),
"should not save registers used as function results"
);

// puts the returns into the returned_registers and restores saved_registers
self.brillig_context
.post_call_prep_returns_load_registers(&returned_registers, &saved_registers);
}

/// Array set operation in SSA returns a new array or slice that is a copy of the parameter array or slice
/// With a specific value changed.
fn convert_ssa_array_set(
Expand Down Expand Up @@ -642,7 +659,7 @@
);
}
}
_ => unimplemented!("ICE: Value {:?} not storeable in memory", value),

Check warning on line 662 in crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (storeable)
}
}

Expand Down Expand Up @@ -697,11 +714,28 @@
let binary_type =
type_of_binary_operation(dfg[binary.lhs].get_type(), dfg[binary.rhs].get_type());

let left = self.convert_ssa_register_value(binary.lhs, dfg);
let right = self.convert_ssa_register_value(binary.rhs, dfg);
let mut left = self.convert_ssa_register_value(binary.lhs, dfg);
let mut right = self.convert_ssa_register_value(binary.rhs, dfg);

let brillig_binary_op =
convert_ssa_binary_op_to_brillig_binary_op(binary.operator, binary_type);
convert_ssa_binary_op_to_brillig_binary_op(binary.operator, &binary_type);

// Some binary operations with fields are issued by the compiler, such as loop comparisons, cast those to the bit size here
jfecher marked this conversation as resolved.
Show resolved Hide resolved
// TODO Remove after fixing https://github.com/noir-lang/noir/issues/1979
if let (
BrilligBinaryOp::Integer { bit_size, .. },
Type::Numeric(NumericType::NativeField),
) = (&brillig_binary_op, &binary_type)
{
let new_lhs = self.brillig_context.allocate_register();
let new_rhs = self.brillig_context.allocate_register();

self.brillig_context.cast_instruction(new_lhs, left, *bit_size);
self.brillig_context.cast_instruction(new_rhs, right, *bit_size);

left = new_lhs;
right = new_rhs;
}

self.brillig_context.binary_instruction(left, right, result_register, brillig_binary_op);
}
Expand Down Expand Up @@ -830,7 +864,7 @@
/// - Brillig Binary Field Op, if it is a field type
pub(crate) fn convert_ssa_binary_op_to_brillig_binary_op(
ssa_op: BinaryOp,
typ: Type,
typ: &Type,
) -> BrilligBinaryOp {
// First get the bit size and whether its a signed integer, if it is a numeric type
// if it is not,then we return None, indicating that
Expand All @@ -845,18 +879,20 @@
};

fn binary_op_to_field_op(op: BinaryOp) -> BrilligBinaryOp {
let operation = match op {
BinaryOp::Add => BinaryFieldOp::Add,
BinaryOp::Sub => BinaryFieldOp::Sub,
BinaryOp::Mul => BinaryFieldOp::Mul,
BinaryOp::Div => BinaryFieldOp::Div,
BinaryOp::Eq => BinaryFieldOp::Equals,
match op {
BinaryOp::Add => BrilligBinaryOp::Field { op: BinaryFieldOp::Add },
BinaryOp::Sub => BrilligBinaryOp::Field { op: BinaryFieldOp::Sub },
BinaryOp::Mul => BrilligBinaryOp::Field { op: BinaryFieldOp::Mul },
BinaryOp::Div => BrilligBinaryOp::Field { op: BinaryFieldOp::Div },
BinaryOp::Eq => BrilligBinaryOp::Field { op: BinaryFieldOp::Equals },
BinaryOp::Lt => BrilligBinaryOp::Integer {
op: BinaryIntOp::LessThan,
bit_size: BRILLIG_INTEGER_ARITHMETIC_BIT_SIZE,
},
jfecher marked this conversation as resolved.
Show resolved Hide resolved
_ => unreachable!(
"Field type cannot be used with {op}. This should have been caught by the frontend"
),
};

BrilligBinaryOp::Field { op: operation }
}
}

fn binary_op_to_int_op(op: BinaryOp, bit_size: u32, is_signed: bool) -> BrilligBinaryOp {
Expand Down Expand Up @@ -888,7 +924,7 @@

// If bit size is available then it is a binary integer operation
match bit_size_signedness {
Some((bit_size, is_signed)) => binary_op_to_int_op(ssa_op, bit_size, is_signed),
Some((bit_size, is_signed)) => binary_op_to_int_op(ssa_op, *bit_size, is_signed),
None => binary_op_to_field_op(ssa_op),
}
}
10 changes: 7 additions & 3 deletions crates/noirc_evaluator/src/brillig/brillig_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ impl BrilligContext {
.collect();
for (new_source, destination) in new_sources.iter().zip(destinations.iter()) {
self.mov_instruction(*destination, *new_source);
self.deallocate_register(*new_source);
}
}

Expand Down Expand Up @@ -821,9 +822,12 @@ impl BrilligContext {
) {
// Allocate our result registers and write into them
// We assume the return values of our call are held in 0..num results register indices
for (i, result_register) in result_registers.iter().enumerate() {
self.mov_instruction(*result_register, self.register(i));
}
let (sources, destinations) = result_registers
.iter()
.enumerate()
.map(|(i, result_register)| (self.register(i), *result_register))
.unzip();
self.mov_registers_to_registers_instruction(sources, destinations);

// Restore all the same registers we have, in exact reverse order.
// Note that we have allocated some registers above, which we will not be handling here,
Expand Down
10 changes: 9 additions & 1 deletion crates/noirc_evaluator/src/ssa_refactor/ir/function.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::collections::HashSet;

use iter_extended::vecmap;

use super::basic_block::BasicBlockId;
use super::dfg::DataFlowGraph;
use super::instruction::TerminatorInstruction;
Expand Down Expand Up @@ -116,6 +118,12 @@ impl Function {
}
blocks
}

pub(crate) fn signature(&self) -> Signature {
let params = vecmap(self.parameters(), |param| self.dfg.type_of_value(*param));
let returns = vecmap(self.returns(), |ret| self.dfg.type_of_value(*ret));
Signature { params, returns }
}
}

impl std::fmt::Display for RuntimeType {
Expand All @@ -133,7 +141,7 @@ impl std::fmt::Display for RuntimeType {
/// within Call instructions.
pub(crate) type FunctionId = Id<Function>;

#[derive(Debug, Default, Clone)]
#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
pub(crate) struct Signature {
pub(crate) params: Vec<Type>,
pub(crate) returns: Vec<Type>,
Expand Down
Loading