Skip to content

Commit

Permalink
feat: Allow slices to brillig entry points (#4713)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Allows passing slices to brillig entry points, since ACIR knows the
initial length of the slice and is fixed at compile time.

## Summary\*



## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
sirasistant authored Apr 4, 2024
1 parent 68f9eeb commit 62423d5
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 37 deletions.
19 changes: 3 additions & 16 deletions compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl FunctionContext {
function_id.to_string()
}

fn ssa_type_to_parameter(typ: &Type) -> BrilligParameter {
pub(crate) fn ssa_type_to_parameter(typ: &Type) -> BrilligParameter {
match typ {
Type::Numeric(_) | Type::Reference(_) => {
BrilligParameter::SingleAddr(get_bit_size_from_ssa_type(typ))
Expand All @@ -81,26 +81,13 @@ impl FunctionContext {
}),
*size,
),
Type::Slice(item_type) => {
BrilligParameter::Slice(vecmap(item_type.iter(), |item_typ| {
FunctionContext::ssa_type_to_parameter(item_typ)
}))
Type::Slice(_) => {
panic!("ICE: Slice parameters cannot be derived from type information")
}
_ => unimplemented!("Unsupported function parameter/return type {typ:?}"),
}
}

/// Collects the parameters of a given function
pub(crate) fn parameters(func: &Function) -> Vec<BrilligParameter> {
func.parameters()
.iter()
.map(|&value_id| {
let typ = func.dfg.type_of_value(value_id);
FunctionContext::ssa_type_to_parameter(&typ)
})
.collect()
}

/// Collects the return values of a given function
pub(crate) fn return_values(func: &Function) -> Vec<BrilligParameter> {
func.returns()
Expand Down
5 changes: 3 additions & 2 deletions compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ use std::collections::{BTreeMap, HashMap};

use crate::ssa::ir::dfg::CallStack;

/// Represents a parameter or a return value of a function.
/// Represents a parameter or a return value of an entry point function.
#[derive(Debug, Clone)]
pub(crate) enum BrilligParameter {
/// A single address parameter or return value. Holds the bit size of the parameter.
SingleAddr(u32),
/// An array parameter or return value. Holds the type of an array item and its size.
Array(Vec<BrilligParameter>, usize),
/// A slice parameter or return value. Holds the type of a slice item.
Slice(Vec<BrilligParameter>),
/// Only known-length slices can be passed to brillig entry points, so the size is available as well.
Slice(Vec<BrilligParameter>, usize),
}

/// The result of compiling and linking brillig artifacts.
Expand Down
60 changes: 45 additions & 15 deletions compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
artifact::{BrilligArtifact, BrilligParameter},
brillig_variable::{BrilligArray, BrilligVariable, SingleAddrVariable},
brillig_variable::{BrilligArray, BrilligVariable, BrilligVector, SingleAddrVariable},
debug_show::DebugShow,
registers::BrilligRegistersContext,
BrilligBinaryOp, BrilligContext, ReservedRegisters,
Expand Down Expand Up @@ -83,24 +83,56 @@ impl BrilligContext {
current_calldata_pointer += flattened_size;
var
}
BrilligParameter::Slice(_) => unimplemented!("Unsupported slices as parameter"),
BrilligParameter::Slice(_, _) => {
let pointer_to_the_array_in_calldata =
self.make_usize_constant_instruction(current_calldata_pointer.into());

let flattened_size = BrilligContext::flattened_size(argument);
let size_register = self.make_usize_constant_instruction(flattened_size.into());
let rc_register = self.make_usize_constant_instruction(1_usize.into());

let var = BrilligVariable::BrilligVector(BrilligVector {
pointer: pointer_to_the_array_in_calldata.address,
size: size_register.address,
rc: rc_register.address,
});

current_calldata_pointer += flattened_size;
var
}
})
.collect();

// Deflatten arrays
for (argument_variable, argument) in argument_variables.iter_mut().zip(arguments) {
if let (
BrilligVariable::BrilligArray(array),
BrilligParameter::Array(item_type, item_count),
) = (argument_variable, argument)
{
if BrilligContext::has_nested_arrays(item_type) {
match (argument_variable, argument) {
(
BrilligVariable::BrilligArray(array),
BrilligParameter::Array(item_type, item_count),
) => {
let deflattened_address =
self.deflatten_array(item_type, array.size, array.pointer);
self.mov_instruction(array.pointer, deflattened_address);
array.size = item_type.len() * item_count;
self.deallocate_register(deflattened_address);
}
(
BrilligVariable::BrilligVector(vector),
BrilligParameter::Slice(item_type, item_count),
) => {
let flattened_size = BrilligContext::flattened_size(argument);

let deflattened_address =
self.deflatten_array(item_type, flattened_size, vector.pointer);
self.mov_instruction(vector.pointer, deflattened_address);
self.usize_const_instruction(
vector.size,
(item_type.len() * item_count).into(),
);

self.deallocate_register(deflattened_address);
}
_ => {}
}
}
}
Expand All @@ -112,10 +144,10 @@ impl BrilligContext {
fn flat_bit_sizes(param: &BrilligParameter) -> Box<dyn Iterator<Item = u32> + '_> {
match param {
BrilligParameter::SingleAddr(bit_size) => Box::new(std::iter::once(*bit_size)),
BrilligParameter::Array(item_types, item_count) => Box::new(
BrilligParameter::Array(item_types, item_count)
| BrilligParameter::Slice(item_types, item_count) => Box::new(
(0..*item_count).flat_map(move |_| item_types.iter().flat_map(flat_bit_sizes)),
),
BrilligParameter::Slice(..) => unimplemented!("Unsupported slices as parameter"),
}
}

Expand All @@ -134,13 +166,11 @@ impl BrilligContext {
fn flattened_size(param: &BrilligParameter) -> usize {
match param {
BrilligParameter::SingleAddr(_) => 1,
BrilligParameter::Array(item_types, item_count) => {
BrilligParameter::Array(item_types, item_count)
| BrilligParameter::Slice(item_types, item_count) => {
let item_size: usize = item_types.iter().map(BrilligContext::flattened_size).sum();
item_count * item_size
}
BrilligParameter::Slice(_) => {
unreachable!("ICE: Slices cannot be passed as entry point arguments")
}
}
}

Expand Down Expand Up @@ -457,8 +487,8 @@ mod tests {
use acvm::FieldElement;

use crate::brillig::brillig_ir::{
artifact::BrilligParameter,
brillig_variable::BrilligArray,
entry_point::BrilligParameter,
tests::{create_and_run_vm, create_context, create_entry_point_bytecode},
};

Expand Down
46 changes: 42 additions & 4 deletions compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use super::{
},
ssa_gen::Ssa,
};
use crate::brillig::brillig_ir::artifact::GeneratedBrillig;
use crate::brillig::brillig_ir::artifact::{BrilligParameter, GeneratedBrillig};
use crate::brillig::brillig_ir::BrilligContext;
use crate::brillig::{brillig_gen::brillig_fn::FunctionContext as BrilligFunctionContext, Brillig};
use crate::errors::{InternalError, InternalWarning, RuntimeError, SsaReport};
Expand Down Expand Up @@ -297,12 +297,14 @@ impl Context {
let typ = dfg.type_of_value(*param_id);
self.create_value_from_type(&typ, &mut |this, _| Ok(this.acir_context.add_variable()))
})?;
let arguments = self.gen_brillig_parameters(dfg[main_func.entry_block()].parameters(), dfg);

let witness_inputs = self.acir_context.extract_witness(&inputs);

let outputs: Vec<AcirType> =
vecmap(main_func.returns(), |result_id| dfg.type_of_value(*result_id).into());

let code = self.gen_brillig_for(main_func, brillig)?;
let code = self.gen_brillig_for(main_func, arguments, brillig)?;

// We specifically do not attempt execution of the brillig code being generated as this can result in it being
// replaced with constraints on witnesses to the program outputs.
Expand Down Expand Up @@ -594,8 +596,9 @@ impl Context {
}

let inputs = vecmap(arguments, |arg| self.convert_value(*arg, dfg));
let arguments = self.gen_brillig_parameters(arguments, dfg);

let code = self.gen_brillig_for(func, brillig)?;
let code = self.gen_brillig_for(func, arguments, brillig)?;

let outputs: Vec<AcirType> = vecmap(result_ids, |result_id| {
dfg.type_of_value(*result_id).into()
Expand Down Expand Up @@ -673,14 +676,49 @@ impl Context {
Ok(())
}

fn gen_brillig_parameters(
&self,
values: &[ValueId],
dfg: &DataFlowGraph,
) -> Vec<BrilligParameter> {
values
.iter()
.map(|&value_id| {
let typ = dfg.type_of_value(value_id);
if let Type::Slice(item_types) = typ {
let len = match self
.ssa_values
.get(&value_id)
.expect("ICE: Unknown slice input to brillig")
{
AcirValue::DynamicArray(AcirDynamicArray { len, .. }) => *len,
AcirValue::Array(array) => array.len(),
_ => unreachable!("ICE: Slice value is not an array"),
};

BrilligParameter::Slice(
item_types
.iter()
.map(BrilligFunctionContext::ssa_type_to_parameter)
.collect(),
len / item_types.len(),
)
} else {
BrilligFunctionContext::ssa_type_to_parameter(&typ)
}
})
.collect()
}

fn gen_brillig_for(
&self,
func: &Function,
arguments: Vec<BrilligParameter>,
brillig: &Brillig,
) -> Result<GeneratedBrillig, InternalError> {
// Create the entry point artifact
let mut entry_point = BrilligContext::new_entry_point_artifact(
BrilligFunctionContext::parameters(func),
arguments,
BrilligFunctionContext::return_values(func),
BrilligFunctionContext::function_id_to_function_label(func.id()),
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "brillig_slice_input"
type = "bin"
authors = [""]

[dependencies]
40 changes: 40 additions & 0 deletions test_programs/execution_success/brillig_slice_input/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
struct Point {
x: Field,
y: Field,
}

unconstrained fn sum_slice(slice: [[Point; 2]]) -> Field {
let mut sum = 0;
for i in 0..slice.len() {
for j in 0..slice[i].len() {
sum += slice[i][j].x + slice[i][j].y;
}
}
sum
}

fn main() {
let mut slice = &[];
slice = slice.push_back([
Point {
x: 13,
y: 14,
},
Point {
x: 20,
y: 8,
}
]);
slice = slice.push_back([
Point {
x: 15,
y: 5,
},
Point {
x: 12,
y: 13,
}
]);
let brillig_sum = sum_slice(slice);
assert_eq(brillig_sum, 100);
}

0 comments on commit 62423d5

Please sign in to comment.