Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ssa): array sort #754

Merged
merged 10 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions crates/nargo/tests/test_data/array_len/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,12 @@ fn main(len3: [u8; 3], len4: [Field; 4]) {

// std::array::len returns a comptime value
constrain len4[std::array::len(len3)] == 4;

// test for std::array::sort
let mut unsorted = len3;
unsorted[0] = len3[1];
unsorted[1] = len3[0];
constrain unsorted[0] > unsorted[1];
let sorted = std::array::sort(unsorted);
constrain sorted[0] < sorted[1];
}
3 changes: 3 additions & 0 deletions crates/noirc_evaluator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ iter-extended.workspace = true
thiserror.workspace = true
num-bigint = "0.4"
num-traits = "0.2.8"

[dev-dependencies]
rand="0.8.5"
199 changes: 199 additions & 0 deletions crates/noirc_evaluator/src/ssa/acir_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,48 @@ impl Acir {
};
evaluator.opcodes.push(AcirOpcode::BlackBoxFuncCall(call_gate));
}
Opcode::Sort => {
let mut in_expr = Vec::new();
let array_id = Memory::deref(ctx, args[0]).unwrap();
let array = &ctx.mem[array_id];
let num_bits = array.element_type.bits();
for i in 0..array.len {
let address = array.adr + i;
if self.memory_map.contains_key(&address) {
if let Some(wit) = self.memory_map[&address].witness {
in_expr.push(from_witness(wit))
} else {
in_expr.push(self.memory_map[&address].expression.clone());
}
} else {
in_expr.push(from_witness(array.values[i as usize].witness.unwrap()));
}
}
outputs = self.prepare_outputs(instruction_id, array.len, ctx, evaluator);
let out_expr: Vec<Expression> = outputs.iter().map(|w| from_witness(*w)).collect();
for i in 0..(out_expr.len() - 1) {
bound_constraint_with_offset(
&out_expr[i],
&out_expr[i + 1],
&Expression::zero(),
num_bits,
evaluator,
);
}
let bits = evaluate_permutation(&in_expr, &out_expr, evaluator);
let inputs = in_expr.iter().map(|a| vec![a.clone()]).collect();
evaluator.opcodes.push(AcirOpcode::Directive(Directive::PermutationSort {
inputs,
tuple: 1,
bits,
sort_by: vec![0],
}));
if let node::ObjectType::Pointer(a) = res_type {
self.map_array(a, &outputs, ctx);
} else {
unreachable!();
}
}
}

if outputs.len() == 1 {
Expand Down Expand Up @@ -1460,3 +1502,160 @@ pub fn from_witness(witness: Witness) -> Expression {
q_c: FieldElement::zero(),
}
}

// Generate gates which ensure that out_expr is a permutation of in_expr
// Returns the control bits of the sorting network used to generate the constrains
pub fn evaluate_permutation(
in_expr: &Vec<Expression>,
out_expr: &Vec<Expression>,
evaluator: &mut Evaluator,
) -> Vec<Witness> {
let (w, b) = permutation_layer(in_expr, evaluator);
// we contrain the network output to out_expr
for (b, o) in b.iter().zip(out_expr) {
evaluator.opcodes.push(AcirOpcode::Arithmetic(subtract(b, FieldElement::one(), o)));
}
w
}

// Generates gates for a sorting network
// returns witness corresponding to the network configuration and the expressions corresponding to the network output
// in_expr: inputs of the sorting network
pub fn permutation_layer(
in_expr: &Vec<Expression>,
evaluator: &mut Evaluator,
) -> (Vec<Witness>, Vec<Expression>) {
let n = in_expr.len();
if n == 1 {
return (Vec::new(), in_expr.clone());
}
let n1 = n / 2;
let mut conf = Vec::new();
// witness for the input switches
for _ in 0..n1 {
conf.push(evaluator.add_witness_to_cs());
}
// compute expressions after the input switches
// If inputs are a1,a2, and the switch value is c, then we compute expresions b1,b2 where
// b1 = a1+q, b2 = a2-q, q = c(a2-a1)
let mut in_sub1 = Vec::new();
let mut in_sub2 = Vec::new();
for i in 0..n1 {
//q = c*(a2-a1);
let intermediate = mul_with_witness(
evaluator,
&from_witness(conf[i]),
&subtract(&in_expr[2 * i + 1], FieldElement::one(), &in_expr[2 * i]),
);
//b1=a1+q
in_sub1.push(add(&intermediate, FieldElement::one(), &in_expr[2 * i]));
//b2=a2-q
in_sub2.push(subtract(&in_expr[2 * i + 1], FieldElement::one(), &intermediate));
}
if n % 2 == 1 {
in_sub2.push(in_expr.last().unwrap().clone());
}
let mut out_expr = Vec::new();
// compute results for the sub networks
let (w1, b1) = permutation_layer(&in_sub1, evaluator);
let (w2, b2) = permutation_layer(&in_sub2, evaluator);
// apply the output swithces
for i in 0..(n - 1) / 2 {
let c = evaluator.add_witness_to_cs();
conf.push(c);
let intermediate = mul_with_witness(
evaluator,
&from_witness(c),
&subtract(&b2[i], FieldElement::one(), &b1[i]),
);
out_expr.push(add(&intermediate, FieldElement::one(), &b1[i]));
out_expr.push(subtract(&b2[i], FieldElement::one(), &intermediate));
}
if n % 2 == 0 {
out_expr.push(b1.last().unwrap().clone());
}
out_expr.push(b2.last().unwrap().clone());
conf.extend(w1);
conf.extend(w2);
(conf, out_expr)
}

#[cfg(test)]
mod test {
use std::collections::BTreeMap;

use acvm::{
acir::{circuit::opcodes::BlackBoxFuncCall, native_types::Witness},
FieldElement, OpcodeResolutionError, PartialWitnessGenerator,
};

use crate::{ssa::acir_gen::evaluate_permutation, Evaluator};
use rand::prelude::*;

use super::from_witness;

struct MockBackend {}
impl PartialWitnessGenerator for MockBackend {
fn solve_black_box_function_call(
_initial_witness: &mut BTreeMap<Witness, FieldElement>,
_func_call: &BlackBoxFuncCall,
) -> Result<(), OpcodeResolutionError> {
unreachable!();
}
}

// Check that a random network constrains its output to be a permutation of any random input
#[test]
fn test_permutation() {
let mut rng = rand::thread_rng();
for n in 2..50 {
let mut eval = Evaluator {
current_witness_index: 0,
public_inputs: Vec::new(),
opcodes: Vec::new(),
};

//we generate random inputs
let mut input = Vec::new();
let mut a_val = Vec::new();
let mut b_wit = Vec::new();
let mut solved_witness: BTreeMap<Witness, FieldElement> = BTreeMap::new();
for i in 0..n {
let w = eval.add_witness_to_cs();
input.push(from_witness(w));
a_val.push(FieldElement::from(rng.next_u32() as i128));
solved_witness.insert(w, a_val[i]);
}

let mut output = Vec::new();
for _i in 0..n {
let w = eval.add_witness_to_cs();
b_wit.push(w);
output.push(from_witness(w));
}
//generate constraints for the inputs
let w = evaluate_permutation(&input, &output, &mut eval);

//we generate random network
let mut c = Vec::new();
for _i in 0..w.len() {
c.push(rng.next_u32() % 2 != 0);
}
// intialise bits
for i in 0..w.len() {
solved_witness.insert(w[i], FieldElement::from(c[i] as i128));
}
// compute the network output by solving the constraints
let backend = MockBackend {};
backend
.solve(&mut solved_witness, eval.opcodes.clone())
.expect("Could not solve permutation constraints");
let mut b_val = Vec::new();
for i in 0..output.len() {
b_val.push(solved_witness[&b_wit[i]]);
}
// ensure the outputs are a permutation of the inputs
assert_eq!(a_val.sort(), b_val.sort());
}
}
}
16 changes: 13 additions & 3 deletions crates/noirc_evaluator/src/ssa/builtin.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::ssa::node::ObjectType;
use crate::ssa::{
context::SsaContext,
node::{NodeId, ObjectType},
};
use acvm::{acir::BlackBoxFunc, FieldElement};
use num_bigint::BigUint;
use num_traits::{One, Zero};
Expand All @@ -8,6 +11,7 @@ pub enum Opcode {
LowLevel(BlackBoxFunc),
ToBits,
ToRadix,
Sort,
}

impl std::fmt::Display for Opcode {
Expand All @@ -21,6 +25,7 @@ impl Opcode {
match op_name {
"to_le_bits" => Some(Opcode::ToBits),
"to_radix" => Some(Opcode::ToRadix),
"arraysort" => Some(Opcode::Sort),
_ => BlackBoxFunc::lookup(op_name).map(Opcode::LowLevel),
}
}
Expand All @@ -30,6 +35,7 @@ impl Opcode {
Opcode::LowLevel(op) => op.name(),
Opcode::ToBits => "to_le_bits",
Opcode::ToRadix => "to_radix",
Opcode::Sort => "arraysort",
}
}

Expand All @@ -48,12 +54,12 @@ impl Opcode {
_ => todo!("max value must be implemented for opcode {} ", op),
}
}
Opcode::ToBits | Opcode::ToRadix => BigUint::zero(), //pointers do not overflow
Opcode::ToBits | Opcode::ToRadix | Opcode::Sort => BigUint::zero(), //pointers do not overflow
}
}

//Returns the number of elements and their type, of the output result corresponding to the OPCODE function.
pub fn get_result_type(&self) -> (u32, ObjectType) {
pub fn get_result_type(&self, args: &Vec<NodeId>, ctx: &SsaContext) -> (u32, ObjectType) {
match self {
Opcode::LowLevel(op) => {
match op {
Expand All @@ -73,6 +79,10 @@ impl Opcode {
}
Opcode::ToBits => (FieldElement::max_num_bits(), ObjectType::Boolean),
Opcode::ToRadix => (FieldElement::max_num_bits(), ObjectType::NativeField),
Opcode::Sort => {
let a = super::mem::Memory::deref(ctx, args[0]).unwrap();
(ctx.mem[a].len, ctx.mem[a].element_type)
}
}
}
}
3 changes: 1 addition & 2 deletions crates/noirc_evaluator/src/ssa/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ impl IRGenerator {
op: builtin::Opcode,
args: Vec<NodeId>,
) -> Result<Vec<NodeId>, RuntimeError> {
let (len, elem_type) = op.get_result_type();
let (len, elem_type) = op.get_result_type(&args, &self.context);

let result_type = if len > 1 {
//We create an array that will contain the result and set the res_type to point to that array
Expand All @@ -342,7 +342,6 @@ impl IRGenerator {
} else {
elem_type
};

//when the function returns an array, we use ins.res_type(array)
//else we map ins.id to the returned witness
let id = self.context.new_instruction(node::Operation::Intrinsic(op, args), result_type)?;
Expand Down
10 changes: 10 additions & 0 deletions crates/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,16 @@ pub fn comparator_operand_type_rules(

Ok(Bool(CompTime::No(Some(op.location.span))))
}
(TypeVariable(l_binding), TypeVariable(r_binding)) => {
if let TypeBinding::Bound(l_link) = &*l_binding.borrow() {
if let TypeBinding::Bound(r_link) = &*r_binding.borrow() {
return comparator_operand_type_rules(l_link, r_link, op, errors);
jfecher marked this conversation as resolved.
Show resolved Hide resolved
}
}
let l_typ = TypeVariable(l_binding.clone());
let r_typ = TypeVariable(r_binding.clone());
Err(format!("Unsupported types for comparison: {l_typ} and {r_typ}"))
}
(lhs, rhs) => Err(format!("Unsupported types for comparison: {lhs} and {rhs}")),
}
}
15 changes: 2 additions & 13 deletions noir_stdlib/src/array.nr
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
#[builtin(array_len)]
fn len<T>(_input : [T]) -> comptime Field {}

// insertion sort - n.b. it is a quadratic sort
fn sort<T>(mut a: [T]) -> [T] {
for i in 1..len(a) {
for j in 0..i {
if(a[i] < a[j]) {
let c = a[j];
a[j] = a[i];
a[i]= c;
}
};
};
a
}
#[builtin(arraysort)]
fn sort<T>(_a: [T]) -> [T] {}
guipublic marked this conversation as resolved.
Show resolved Hide resolved
guipublic marked this conversation as resolved.
Show resolved Hide resolved