From a9bb83372149c49ce125ca4623b7b38ab92a55af Mon Sep 17 00:00:00 2001 From: guipublic Date: Mon, 6 Feb 2023 17:42:45 +0000 Subject: [PATCH 1/5] Constrains for sorting network --- .../tests/test_data/array_len/src/main.nr | 8 + crates/noirc_driver/src/lib.rs | 2 +- crates/noirc_evaluator/Cargo.toml | 4 + crates/noirc_evaluator/src/ssa/acir_gen.rs | 199 ++++++++++++++++++ crates/noirc_evaluator/src/ssa/builtin.rs | 16 +- crates/noirc_evaluator/src/ssa/function.rs | 3 +- .../noirc_frontend/src/hir/type_check/expr.rs | 10 + crates/noirc_frontend/src/node_interner.rs | 2 +- noir_stdlib/src/array.nr | 15 +- 9 files changed, 239 insertions(+), 20 deletions(-) diff --git a/crates/nargo/tests/test_data/array_len/src/main.nr b/crates/nargo/tests/test_data/array_len/src/main.nr index 45c474f9dca..2531bbc42fb 100644 --- a/crates/nargo/tests/test_data/array_len/src/main.nr +++ b/crates/nargo/tests/test_data/array_len/src/main.nr @@ -20,4 +20,12 @@ fn main(len3: [u8; 3], len4: [Field; 4]) { // std::array::len returns a comptime value constrain len4[std::array::len(len3)] == 4; + + // test for std::array::sort + let mut unsorted = len3; + unsorted[0] = len3[1]; + unsorted[1] = len3[0]; + constrain unsorted[0] > unsorted[1]; + let sorted = std::array::sort(unsorted); + constrain sorted[0] < sorted[1]; } diff --git a/crates/noirc_driver/src/lib.rs b/crates/noirc_driver/src/lib.rs index 0f97ab351ce..57b448e7d38 100644 --- a/crates/noirc_driver/src/lib.rs +++ b/crates/noirc_driver/src/lib.rs @@ -193,7 +193,7 @@ impl Driver { let circuit = match create_circuit( program, np_language.clone(), - acvm::default_is_blackbox_supported(np_language), + acvm::default_is_black_box_supported(np_language), show_ssa, ) { Ok(circuit) => circuit, diff --git a/crates/noirc_evaluator/Cargo.toml b/crates/noirc_evaluator/Cargo.toml index 4f1e10ea60b..75ca9a7fc66 100644 --- a/crates/noirc_evaluator/Cargo.toml +++ b/crates/noirc_evaluator/Cargo.toml @@ -18,3 +18,7 @@ lazy_static.workspace = true thiserror.workspace = true num-bigint.workspace = true num-traits.workspace = true + + +[dev-dependencies] +rand="0.8.5" \ No newline at end of file diff --git a/crates/noirc_evaluator/src/ssa/acir_gen.rs b/crates/noirc_evaluator/src/ssa/acir_gen.rs index 8858d048032..18c7252919c 100644 --- a/crates/noirc_evaluator/src/ssa/acir_gen.rs +++ b/crates/noirc_evaluator/src/ssa/acir_gen.rs @@ -596,6 +596,48 @@ impl Acir { }; evaluator.opcodes.push(AcirOpcode::BlackBoxFuncCall(call_gate)); } + Opcode::Sort => { + let mut in_expr = Vec::new(); + let array_id = Memory::deref(ctx, args[0]).unwrap(); + let array = &ctx.mem[array_id]; + let num_bits = array.element_type.bits(); + for i in 0..array.len { + let address = array.adr + i; + if self.memory_map.contains_key(&address) { + if let Some(wit) = self.memory_map[&address].witness { + in_expr.push(from_witness(wit)) + } else { + in_expr.push(self.memory_map[&address].expression.clone()); + } + } else { + in_expr.push(from_witness(array.values[i as usize].witness.unwrap())); + } + } + outputs = self.prepare_outputs(instruction_id, array.len, ctx, evaluator); + let out_expr: Vec = outputs.iter().map(|w| from_witness(*w)).collect(); + for i in 0..(out_expr.len() - 1) { + bound_constraint_with_offset( + &out_expr[i], + &out_expr[i + 1], + &Expression::zero(), + num_bits, + evaluator, + ); + } + let bits = evaluate_permutation(&in_expr, &out_expr, evaluator); + let inputs = in_expr.iter().map(|a| vec![a.clone()]).collect(); + evaluator.opcodes.push(AcirOpcode::Directive(Directive::PermutationSort { + inputs, + tuple: 1, + bits, + sort_by: vec![0], + })); + if let node::ObjectType::Pointer(a) = res_type { + self.map_array(a, &outputs, ctx); + } else { + unreachable!(); + } + } } if outputs.len() == 1 { @@ -1428,3 +1470,160 @@ pub fn from_witness(witness: Witness) -> Expression { q_c: FieldElement::zero(), } } + +// Generate gates which ensure that out_expr is a permutation of in_expr +// Returns the control bits of the sorting network used to generate the constrains +pub fn evaluate_permutation( + in_expr: &Vec, + out_expr: &Vec, + evaluator: &mut Evaluator, +) -> Vec { + let (w, b) = permutation_layer(in_expr, evaluator); + // we contrain the network output to out_expr + for (b, o) in b.iter().zip(out_expr) { + evaluator.opcodes.push(AcirOpcode::Arithmetic(subtract(b, FieldElement::one(), o))); + } + w +} + +// Generates gates for a sorting network +// returns witness corresponding to the network configuration and the expressions corresponding to the network output +// in_expr: inputs of the sorting network +pub fn permutation_layer( + in_expr: &Vec, + evaluator: &mut Evaluator, +) -> (Vec, Vec) { + let n = in_expr.len(); + if n == 1 { + return (Vec::new(), in_expr.clone()); + } + let n1 = n / 2; + let mut conf = Vec::new(); + // witness for the input switches + for _ in 0..n1 { + conf.push(evaluator.add_witness_to_cs()); + } + // compute expressions after the input switches + // If inputs are a1,a2, and the switch value is c, then we compute expresions b1,b2 where + // b1 = a1+q, b2 = a2-q, q = c(a2-a1) + let mut in_sub1 = Vec::new(); + let mut in_sub2 = Vec::new(); + for i in 0..n1 { + //q = c*(a2-a1); + let intermediate = mul_with_witness( + evaluator, + &from_witness(conf[i]), + &subtract(&in_expr[2 * i + 1], FieldElement::one(), &in_expr[2 * i]), + ); + //b1=a1+q + in_sub1.push(add(&intermediate, FieldElement::one(), &in_expr[2 * i])); + //b2=a2-q + in_sub2.push(subtract(&in_expr[2 * i + 1], FieldElement::one(), &intermediate)); + } + if n % 2 == 1 { + in_sub2.push(in_expr.last().unwrap().clone()); + } + let mut out_expr = Vec::new(); + // compute results for the sub networks + let (w1, b1) = permutation_layer(&in_sub1, evaluator); + let (w2, b2) = permutation_layer(&in_sub2, evaluator); + // apply the output swithces + for i in 0..(n - 1) / 2 { + let c = evaluator.add_witness_to_cs(); + conf.push(c); + let intermediate = mul_with_witness( + evaluator, + &from_witness(c), + &subtract(&b2[i], FieldElement::one(), &b1[i]), + ); + out_expr.push(add(&intermediate, FieldElement::one(), &b1[i])); + out_expr.push(subtract(&b2[i], FieldElement::one(), &intermediate)); + } + if n % 2 == 0 { + out_expr.push(b1.last().unwrap().clone()); + } + out_expr.push(b2.last().unwrap().clone()); + conf.extend(w1); + conf.extend(w2); + (conf, out_expr) +} + +#[cfg(test)] +mod test { + use std::collections::BTreeMap; + + use acvm::{ + acir::{circuit::opcodes::BlackBoxFuncCall, native_types::Witness}, + FieldElement, OpcodeResolutionError, PartialWitnessGenerator, + }; + + use crate::{ssa::acir_gen::evaluate_permutation, Evaluator}; + use rand::prelude::*; + + use super::from_witness; + + struct MockBackend {} + impl PartialWitnessGenerator for MockBackend { + fn solve_black_box_function_call( + _initial_witness: &mut BTreeMap, + _func_call: &BlackBoxFuncCall, + ) -> Result<(), OpcodeResolutionError> { + unreachable!(); + } + } + + // Check that a random network constrains its output to be a permutation of any random input + #[test] + fn test_permutation() { + let mut rng = rand::thread_rng(); + for n in 2..50 { + let mut eval = Evaluator { + current_witness_index: 0, + public_inputs: Vec::new(), + opcodes: Vec::new(), + }; + + //we generate random inputs + let mut input = Vec::new(); + let mut a_val = Vec::new(); + let mut b_wit = Vec::new(); + let mut solved_witness: BTreeMap = BTreeMap::new(); + for i in 0..n { + let w = eval.add_witness_to_cs(); + input.push(from_witness(w)); + a_val.push(FieldElement::from(rng.next_u32() as i128)); + solved_witness.insert(w, a_val[i]); + } + + let mut output = Vec::new(); + for _i in 0..n { + let w = eval.add_witness_to_cs(); + b_wit.push(w); + output.push(from_witness(w)); + } + //generate constraints for the inputs + let w = evaluate_permutation(&input, &output, &mut eval); + + //we generate random network + let mut c = Vec::new(); + for _i in 0..w.len() { + c.push(rng.next_u32() % 2 != 0); + } + // intialise bits + for i in 0..w.len() { + solved_witness.insert(w[i], FieldElement::from(c[i] as i128)); + } + // compute the network output by solving the constraints + let backend = MockBackend {}; + backend + .solve(&mut solved_witness, eval.opcodes.clone()) + .expect("Could not solve permutation constraints"); + let mut b_val = Vec::new(); + for i in 0..output.len() { + b_val.push(solved_witness[&b_wit[i]]); + } + // ensure the outputs are a permutation of the inputs + assert_eq!(a_val.sort(), b_val.sort()); + } + } +} diff --git a/crates/noirc_evaluator/src/ssa/builtin.rs b/crates/noirc_evaluator/src/ssa/builtin.rs index b6e4eebb772..4c8c45021b0 100644 --- a/crates/noirc_evaluator/src/ssa/builtin.rs +++ b/crates/noirc_evaluator/src/ssa/builtin.rs @@ -2,13 +2,17 @@ use acvm::{acir::BlackBoxFunc, FieldElement}; use num_bigint::BigUint; use num_traits::{One, Zero}; -use super::node::ObjectType; +use super::{ + context::SsaContext, + node::{NodeId, ObjectType}, +}; #[derive(Clone, Debug, Hash, Copy, PartialEq, Eq)] pub enum Opcode { LowLevel(BlackBoxFunc), ToBits, ToRadix, + Sort, } impl std::fmt::Display for Opcode { @@ -22,6 +26,7 @@ impl Opcode { match op_name { "to_bits" => Some(Opcode::ToBits), "to_radix" => Some(Opcode::ToRadix), + "arraysort" => Some(Opcode::Sort), _ => BlackBoxFunc::lookup(op_name).map(Opcode::LowLevel), } } @@ -31,6 +36,7 @@ impl Opcode { Opcode::LowLevel(op) => op.name(), Opcode::ToBits => "to_bits", Opcode::ToRadix => "to_radix", + Opcode::Sort => "arraysort", } } @@ -49,12 +55,12 @@ impl Opcode { _ => todo!("max value must be implemented for opcode {} ", op), } } - Opcode::ToBits | Opcode::ToRadix => BigUint::zero(), //pointers do not overflow + Opcode::ToBits | Opcode::ToRadix | Opcode::Sort => BigUint::zero(), //pointers do not overflow } } //Returns the number of elements and their type, of the output result corresponding to the OPCODE function. - pub fn get_result_type(&self) -> (u32, ObjectType) { + pub fn get_result_type(&self, args: &Vec, ctx: &SsaContext) -> (u32, ObjectType) { match self { Opcode::LowLevel(op) => { match op { @@ -74,6 +80,10 @@ impl Opcode { } Opcode::ToBits => (FieldElement::max_num_bits(), ObjectType::Boolean), Opcode::ToRadix => (FieldElement::max_num_bits(), ObjectType::NativeField), + Opcode::Sort => { + let a = super::mem::Memory::deref(ctx, args[0]).unwrap(); + (ctx.mem[a].len, ctx.mem[a].element_type) + } } } } diff --git a/crates/noirc_evaluator/src/ssa/function.rs b/crates/noirc_evaluator/src/ssa/function.rs index f14d2f41d12..50bec734c17 100644 --- a/crates/noirc_evaluator/src/ssa/function.rs +++ b/crates/noirc_evaluator/src/ssa/function.rs @@ -338,7 +338,7 @@ impl IRGenerator { op: builtin::Opcode, args: Vec, ) -> Result, RuntimeError> { - let (len, elem_type) = op.get_result_type(); + let (len, elem_type) = op.get_result_type(&args, &self.context); let result_type = if len > 1 { //We create an array that will contain the result and set the res_type to point to that array @@ -347,7 +347,6 @@ impl IRGenerator { } else { elem_type }; - //when the function returns an array, we use ins.res_type(array) //else we map ins.id to the returned witness let id = self.context.new_instruction(node::Operation::Intrinsic(op, args), result_type)?; diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 501c92f9d89..009ad040083 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -775,6 +775,16 @@ pub fn comparator_operand_type_rules( Ok(Bool(Comptime::No(Some(op.location.span)))) } + (TypeVariable(l_binding), TypeVariable(r_binding)) => { + if let TypeBinding::Bound(l_link) = &*l_binding.borrow() { + if let TypeBinding::Bound(r_link) = &*r_binding.borrow() { + return comparator_operand_type_rules(l_link, r_link, op, errors); + } + } + let l_typ = TypeVariable(l_binding.clone()); + let r_typ = TypeVariable(r_binding.clone()); + Err(format!("Unsupported types for comparison: {l_typ} and {r_typ}")) + } (lhs, rhs) => Err(format!("Unsupported types for comparison: {lhs} and {rhs}")), } } diff --git a/crates/noirc_frontend/src/node_interner.rs b/crates/noirc_frontend/src/node_interner.rs index 5b1c5aa2efb..ea334afbae1 100644 --- a/crates/noirc_frontend/src/node_interner.rs +++ b/crates/noirc_frontend/src/node_interner.rs @@ -553,7 +553,7 @@ impl NodeInterner { #[allow(deprecated)] pub fn foreign(&self, opcode: &str) -> bool { - let is_supported = acvm::default_is_blackbox_supported(self.language.clone()); + let is_supported = acvm::default_is_black_box_supported(self.language.clone()); let black_box_func = match acvm::acir::BlackBoxFunc::lookup(opcode) { Some(black_box_func) => black_box_func, None => return false, diff --git a/noir_stdlib/src/array.nr b/noir_stdlib/src/array.nr index 30b146cdab8..4a27320c91b 100644 --- a/noir_stdlib/src/array.nr +++ b/noir_stdlib/src/array.nr @@ -1,16 +1,5 @@ #[builtin(arraylen)] fn len(_input : [T]) -> comptime Field {} -// insertion sort - n.b. it is a quadratic sort -fn sort(mut a: [T]) -> [T] { - for i in 1..len(a) { - for j in 0..i { - if(a[i] < a[j]) { - let c = a[j]; - a[j] = a[i]; - a[i]= c; - } - }; - }; - a -} +#[builtin(arraysort)] + fn sort(_a: [T]) -> [T] {} From b5667432b06671df3e214c27fe9b079aeba3a0e7 Mon Sep 17 00:00:00 2001 From: guipublic Date: Mon, 6 Feb 2023 18:04:52 +0000 Subject: [PATCH 2/5] cargo fmt + (some) clippy --- crates/noirc_evaluator/src/ssa/builtin.rs | 5 ++++- crates/noirc_frontend/src/node_interner.rs | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa/builtin.rs b/crates/noirc_evaluator/src/ssa/builtin.rs index ce1e3dbd461..e9ffeed3f10 100644 --- a/crates/noirc_evaluator/src/ssa/builtin.rs +++ b/crates/noirc_evaluator/src/ssa/builtin.rs @@ -1,4 +1,7 @@ -use crate::ssa::{node::{NodeId, ObjectType}, context::SsaContext}; +use crate::ssa::{ + context::SsaContext, + node::{NodeId, ObjectType}, +}; use acvm::{acir::BlackBoxFunc, FieldElement}; use num_bigint::BigUint; use num_traits::{One, Zero}; diff --git a/crates/noirc_frontend/src/node_interner.rs b/crates/noirc_frontend/src/node_interner.rs index 16bfa93332c..6b05ebed6ec 100644 --- a/crates/noirc_frontend/src/node_interner.rs +++ b/crates/noirc_frontend/src/node_interner.rs @@ -554,7 +554,7 @@ impl NodeInterner { #[allow(deprecated)] pub fn foreign(&self, opcode: &str) -> bool { - let is_supported = acvm::default_is_black_box_supported(self.language.clone()); + let is_supported = acvm::default_is_blackbox_supported(self.language.clone()); let black_box_func = match acvm::acir::BlackBoxFunc::lookup(opcode) { Some(black_box_func) => black_box_func, None => return false, From f3ada0e9147d1c3cac60318de5eceac1fd670499 Mon Sep 17 00:00:00 2001 From: guipublic Date: Thu, 9 Feb 2023 14:45:00 +0000 Subject: [PATCH 3/5] Removing unreachable case --- crates/noirc_frontend/src/hir/type_check/expr.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index e077be3f007..ad0b4a3710a 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -818,16 +818,6 @@ pub fn comparator_operand_type_rules( Ok(Bool(CompTime::No(Some(op.location.span)))) } - (TypeVariable(l_binding), TypeVariable(r_binding)) => { - if let TypeBinding::Bound(l_link) = &*l_binding.borrow() { - if let TypeBinding::Bound(r_link) = &*r_binding.borrow() { - return comparator_operand_type_rules(l_link, r_link, op, errors); - } - } - let l_typ = TypeVariable(l_binding.clone()); - let r_typ = TypeVariable(r_binding.clone()); - Err(format!("Unsupported types for comparison: {l_typ} and {r_typ}")) - } (lhs, rhs) => Err(format!("Unsupported types for comparison: {lhs} and {rhs}")), } } From 4968f3b53eb53ecdac1673a9b00d4b0ec985d101 Mon Sep 17 00:00:00 2001 From: guipublic Date: Mon, 13 Feb 2023 13:35:31 +0000 Subject: [PATCH 4/5] rand in cargo.lock --- Cargo.lock | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.lock b/Cargo.lock index 01d054818fe..55c25dff231 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1840,6 +1840,7 @@ dependencies = [ "noirc_frontend", "num-bigint", "num-traits", + "rand 0.8.5", "thiserror", ] From 0e3b972dd226ad96be00b5340f0569ba631509b8 Mon Sep 17 00:00:00 2001 From: guipublic Date: Mon, 13 Feb 2023 15:29:22 +0000 Subject: [PATCH 5/5] add length to array sort --- noir_stdlib/src/array.nr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noir_stdlib/src/array.nr b/noir_stdlib/src/array.nr index 48d84197e2d..b6256793c25 100644 --- a/noir_stdlib/src/array.nr +++ b/noir_stdlib/src/array.nr @@ -2,4 +2,4 @@ fn len(_input : [T]) -> comptime Field {} #[builtin(arraysort)] - fn sort(_a: [T]) -> [T] {} +fn sort(_a: [T; N]) -> [T; N] {}