diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp index 9ffda938c7e..35db8753546 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp @@ -1137,6 +1137,7 @@ struct Opcode { uint32_t id; std::vector inputs; std::vector outputs; + std::optional predicate; friend bool operator==(const Call&, const Call&); std::vector bincodeSerialize() const; @@ -7491,6 +7492,9 @@ inline bool operator==(const Opcode::Call& lhs, const Opcode::Call& rhs) if (!(lhs.outputs == rhs.outputs)) { return false; } + if (!(lhs.predicate == rhs.predicate)) { + return false; + } return true; } @@ -7520,6 +7524,7 @@ void serde::Serializable::serialize(const Program::Opcode serde::Serializable::serialize(obj.id, serializer); serde::Serializable::serialize(obj.inputs, serializer); serde::Serializable::serialize(obj.outputs, serializer); + serde::Serializable::serialize(obj.predicate, serializer); } template <> @@ -7530,6 +7535,7 @@ Program::Opcode::Call serde::Deserializable::deserialize( obj.id = serde::Deserializable::deserialize(deserializer); obj.inputs = serde::Deserializable::deserialize(deserializer); obj.outputs = serde::Deserializable::deserialize(deserializer); + obj.predicate = serde::Deserializable::deserialize(deserializer); return obj; } diff --git a/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp b/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp index d7ef849ab75..e4203b579b0 100644 --- a/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp +++ b/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp @@ -1074,6 +1074,7 @@ namespace Program { uint32_t id; std::vector inputs; std::vector outputs; + std::optional predicate; friend bool operator==(const Call&, const Call&); std::vector bincodeSerialize() const; @@ -6173,6 +6174,7 @@ namespace Program { if (!(lhs.id == rhs.id)) { return false; } if (!(lhs.inputs == rhs.inputs)) { return false; } if (!(lhs.outputs == rhs.outputs)) { return false; } + if (!(lhs.predicate == rhs.predicate)) { return false; } return true; } @@ -6199,6 +6201,7 @@ void serde::Serializable::serialize(const Program::Opcode serde::Serializable::serialize(obj.id, serializer); serde::Serializable::serialize(obj.inputs, serializer); serde::Serializable::serialize(obj.outputs, serializer); + serde::Serializable::serialize(obj.predicate, serializer); } template <> @@ -6208,6 +6211,7 @@ Program::Opcode::Call serde::Deserializable::deserialize( obj.id = serde::Deserializable::deserialize(deserializer); obj.inputs = serde::Deserializable::deserialize(deserializer); obj.outputs = serde::Deserializable::deserialize(deserializer); + obj.predicate = serde::Deserializable::deserialize(deserializer); return obj; } diff --git a/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes.rs b/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes.rs index 68d28b287e6..d8204132b3e 100644 --- a/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes.rs +++ b/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes.rs @@ -39,6 +39,8 @@ pub enum Opcode { inputs: Vec, /// Outputs of the function call outputs: Vec, + /// Predicate of the circuit execution - indicates if it should be skipped + predicate: Option, }, } @@ -97,8 +99,11 @@ impl std::fmt::Display for Opcode { write!(f, "INIT ")?; write!(f, "(id: {}, len: {}) ", block_id.0, init.len()) } - Opcode::Call { id, inputs, outputs } => { + Opcode::Call { id, inputs, outputs, predicate } => { write!(f, "CALL func {}: ", id)?; + if let Some(pred) = predicate { + writeln!(f, "PREDICATE = {pred}")?; + } write!(f, "inputs: {:?}, ", inputs)?; write!(f, "outputs: {:?}", outputs) } diff --git a/noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs b/noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs index a5b683c15e1..2feb2737d28 100644 --- a/noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs +++ b/noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs @@ -380,10 +380,18 @@ fn nested_acir_call_circuit() { // assert(x == y); // x // } - let nested_call = - Opcode::Call { id: 1, inputs: vec![Witness(0), Witness(1)], outputs: vec![Witness(2)] }; - let nested_call_two = - Opcode::Call { id: 1, inputs: vec![Witness(0), Witness(1)], outputs: vec![Witness(3)] }; + let nested_call = Opcode::Call { + id: 1, + inputs: vec![Witness(0), Witness(1)], + outputs: vec![Witness(2)], + predicate: Some(Expression::one()), + }; + let nested_call_two = Opcode::Call { + id: 1, + inputs: vec![Witness(0), Witness(1)], + outputs: vec![Witness(3)], + predicate: Some(Expression::one()), + }; let assert_nested_call_results = Opcode::AssertZero(Expression { mul_terms: Vec::new(), @@ -410,8 +418,12 @@ fn nested_acir_call_circuit() { ], q_c: FieldElement::one() + FieldElement::one(), }); - let call = - Opcode::Call { id: 2, inputs: vec![Witness(2), Witness(1)], outputs: vec![Witness(3)] }; + let call = Opcode::Call { + id: 2, + inputs: vec![Witness(2), Witness(1)], + outputs: vec![Witness(3)], + predicate: Some(Expression::one()), + }; let nested_call = Circuit { current_witness_index: 3, @@ -443,15 +455,16 @@ fn nested_acir_call_circuit() { let bytes = Program::serialize_program(&program); let expected_serialization: Vec = vec![ - 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 205, 146, 97, 10, 195, 32, 12, 133, 163, 66, 207, 147, - 24, 109, 227, 191, 93, 101, 50, 123, 255, 35, 172, 99, 25, 83, 17, 250, 99, 14, 250, 224, - 97, 144, 16, 146, 143, 231, 224, 45, 167, 126, 105, 57, 108, 14, 91, 248, 202, 168, 65, - 255, 207, 122, 28, 180, 250, 244, 221, 244, 197, 223, 68, 182, 154, 197, 184, 134, 80, 54, - 95, 136, 233, 142, 62, 101, 137, 24, 98, 94, 133, 132, 162, 196, 135, 23, 230, 34, 65, 182, - 148, 211, 134, 137, 2, 23, 218, 99, 226, 93, 135, 185, 121, 123, 33, 84, 12, 234, 218, 192, - 64, 174, 3, 248, 47, 88, 48, 17, 150, 157, 183, 151, 95, 244, 86, 91, 221, 61, 10, 81, 31, - 178, 190, 110, 194, 102, 96, 76, 251, 202, 80, 13, 204, 77, 224, 25, 176, 70, 79, 197, 128, - 18, 64, 3, 4, 0, 0, + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 83, 65, 14, 3, 33, 8, 68, 77, 246, 61, 32, 186, + 139, 183, 126, 165, 166, 238, 255, 159, 208, 109, 74, 82, 229, 178, 135, 186, 77, 58, 201, + 4, 195, 97, 128, 1, 3, 188, 17, 148, 47, 44, 7, 221, 65, 15, 31, 56, 37, 104, 222, 129, + 193, 77, 35, 126, 7, 58, 43, 30, 174, 44, 222, 107, 250, 201, 218, 190, 211, 98, 92, 83, + 106, 91, 108, 196, 116, 199, 88, 170, 100, 76, 185, 174, 66, 66, 89, 242, 35, 10, 115, 147, + 36, 91, 169, 101, 195, 66, 137, 27, 237, 185, 240, 174, 98, 97, 94, 95, 8, 198, 80, 103, + 226, 128, 0, 227, 102, 174, 50, 11, 38, 154, 229, 231, 245, 21, 23, 157, 213, 119, 115, + 255, 244, 58, 237, 183, 176, 239, 0, 38, 233, 254, 108, 91, 14, 230, 158, 246, 153, 97, 3, + 158, 188, 79, 135, 232, 14, 5, 0, 0, ]; + assert_eq!(bytes, expected_serialization); } diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/memory_op.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/memory_op.rs index e51797707a7..672c13e11c2 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/memory_op.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/memory_op.rs @@ -6,7 +6,9 @@ use acir::{ FieldElement, }; -use super::{arithmetic::ExpressionSolver, get_value, insert_value, witness_to_value}; +use super::{ + arithmetic::ExpressionSolver, get_value, insert_value, is_predicate_false, witness_to_value, +}; use super::{ErrorLocation, OpcodeResolutionError}; type MemoryIndex = u32; @@ -80,11 +82,8 @@ impl MemoryOpSolver { // `operation == 0` implies a read operation. (`operation == 1` implies write operation). let is_read_operation = operation.is_zero(); - // If the predicate is `None`, then we simply return the value 1 - let pred_value = match predicate { - Some(pred) => get_value(pred, initial_witness), - None => Ok(FieldElement::one()), - }?; + // Fetch whether or not the predicate is false (e.g. equal to zero) + let skip_operation = is_predicate_false(initial_witness, predicate)?; if is_read_operation { // `value_read = arr[memory_index]` @@ -97,7 +96,7 @@ impl MemoryOpSolver { // A zero predicate indicates that we should skip the read operation // and zero out the operation's output. - let value_in_array = if pred_value.is_zero() { + let value_in_array = if skip_operation { FieldElement::zero() } else { self.read_memory_index(memory_index)? @@ -111,7 +110,7 @@ impl MemoryOpSolver { let value_write = value; // A zero predicate indicates that we should skip the write operation. - if pred_value.is_zero() { + if skip_operation { // We only want to write to already initialized memory. // Do nothing if the predicate is zero. Ok(()) diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs index 3cedcfc0399..bb98eda2689 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs @@ -377,7 +377,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { }; let witness = &mut self.witness_map; - if BrilligSolver::::should_skip(witness, brillig)? { + if is_predicate_false(witness, &brillig.predicate)? { return BrilligSolver::::zero_out_brillig_outputs(witness, brillig).map(|_| None); } @@ -448,7 +448,9 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { } pub fn solve_call_opcode(&mut self) -> Result, OpcodeResolutionError> { - let Opcode::Call { id, inputs, outputs } = &self.opcodes[self.instruction_pointer] else { + let Opcode::Call { id, inputs, outputs, predicate } = + &self.opcodes[self.instruction_pointer] + else { unreachable!("Not executing a Call opcode"); }; if *id == 0 { @@ -459,6 +461,14 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { }); } + if is_predicate_false(&self.witness_map, predicate)? { + // Zero out the outputs if we have a false predicate + for output in outputs { + insert_value(output, FieldElement::zero(), &mut self.witness_map)?; + } + return Ok(None); + } + if self.acir_call_counter >= self.acir_call_results.len() { let mut initial_witness = WitnessMap::default(); for (i, input_witness) in inputs.iter().enumerate() { @@ -556,6 +566,20 @@ fn any_witness_from_expression(expr: &Expression) -> Option { } } +/// Returns `true` if the predicate is zero +/// A predicate is used to indicate whether we should skip a certain operation. +/// If we have a zero predicate it means the operation should be skipped. +pub(crate) fn is_predicate_false( + witness: &WitnessMap, + predicate: &Option, +) -> Result { + match predicate { + Some(pred) => get_value(pred, witness).map(|pred_value| pred_value.is_zero()), + // If the predicate is `None`, then we treat it as an unconditional `true` + None => Ok(false), + } +} + #[derive(Debug, Clone, PartialEq)] pub struct AcirCallWaitInfo { /// Index in the list of ACIR function's that should be called diff --git a/noir/noir-repo/acvm-repo/acvm_js/test/shared/nested_acir_call.ts b/noir/noir-repo/acvm-repo/acvm_js/test/shared/nested_acir_call.ts index ce91282a681..9f9b94b7a44 100644 --- a/noir/noir-repo/acvm-repo/acvm_js/test/shared/nested_acir_call.ts +++ b/noir/noir-repo/acvm-repo/acvm_js/test/shared/nested_acir_call.ts @@ -2,13 +2,13 @@ import { WitnessMap, StackItem, WitnessStack } from '@noir-lang/acvm_js'; // See `nested_acir_call_circuit` integration test in `acir/tests/test_program_serialization.rs`. export const bytecode = Uint8Array.from([ - 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 205, 146, 97, 10, 195, 32, 12, 133, 163, 66, 207, 147, 24, 109, 227, 191, 93, 101, - 50, 123, 255, 35, 172, 99, 25, 83, 17, 250, 99, 14, 250, 224, 97, 144, 16, 146, 143, 231, 224, 45, 167, 126, 105, 57, - 108, 14, 91, 248, 202, 168, 65, 255, 207, 122, 28, 180, 250, 244, 221, 244, 197, 223, 68, 182, 154, 197, 184, 134, 80, - 54, 95, 136, 233, 142, 62, 101, 137, 24, 98, 94, 133, 132, 162, 196, 135, 23, 230, 34, 65, 182, 148, 211, 134, 137, 2, - 23, 218, 99, 226, 93, 135, 185, 121, 123, 33, 84, 12, 234, 218, 192, 64, 174, 3, 248, 47, 88, 48, 17, 150, 157, 183, - 151, 95, 244, 86, 91, 221, 61, 10, 81, 31, 178, 190, 110, 194, 102, 96, 76, 251, 202, 80, 13, 204, 77, 224, 25, 176, - 70, 79, 197, 128, 18, 64, 3, 4, 0, 0, + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 83, 65, 14, 3, 33, 8, 68, 77, 246, 61, 32, 186, 139, 183, 126, 165, 166, 238, + 255, 159, 208, 109, 74, 82, 229, 178, 135, 186, 77, 58, 201, 4, 195, 97, 128, 1, 3, 188, 17, 148, 47, 44, 7, 221, 65, + 15, 31, 56, 37, 104, 222, 129, 193, 77, 35, 126, 7, 58, 43, 30, 174, 44, 222, 107, 250, 201, 218, 190, 211, 98, 92, + 83, 106, 91, 108, 196, 116, 199, 88, 170, 100, 76, 185, 174, 66, 66, 89, 242, 35, 10, 115, 147, 36, 91, 169, 101, 195, + 66, 137, 27, 237, 185, 240, 174, 98, 97, 94, 95, 8, 198, 80, 103, 226, 128, 0, 227, 102, 174, 50, 11, 38, 154, 229, + 231, 245, 21, 23, 157, 213, 119, 115, 255, 244, 58, 237, 183, 176, 239, 0, 38, 233, 254, 108, 91, 14, 230, 158, 246, + 153, 97, 3, 158, 188, 79, 135, 232, 14, 5, 0, 0, ]); export const initialWitnessMap: WitnessMap = new Map([ diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs index bcd62e3b062..53d9e2530cc 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs @@ -1763,6 +1763,7 @@ impl AcirContext { id: u32, inputs: Vec, output_count: usize, + predicate: AcirVar, ) -> Result, RuntimeError> { let inputs = self.prepare_inputs_for_black_box_func_call(inputs)?; let inputs = inputs @@ -1778,7 +1779,8 @@ impl AcirContext { let results = vecmap(&outputs, |witness_index| self.add_data(AcirVarData::Witness(*witness_index))); - self.acir_ir.push_opcode(Opcode::Call { id, inputs, outputs }); + let predicate = Some(self.var_to_expression(predicate)?); + self.acir_ir.push_opcode(Opcode::Call { id, inputs, outputs, predicate }); Ok(results) } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs index d4760c29eda..fc81ae1f901 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs @@ -576,6 +576,7 @@ impl Context { *acir_program_id, inputs, output_count, + self.current_side_effects_enabled_var, )?; let output_values = self.convert_vars_to_values(output_vars, dfg, result_ids); @@ -2713,7 +2714,7 @@ mod test { expected_outputs: Vec, ) { match opcode { - Opcode::Call { id, inputs, outputs } => { + Opcode::Call { id, inputs, outputs, .. } => { assert_eq!( *id, expected_id, "Main was expected to call {expected_id} but got {}", diff --git a/noir/noir-repo/test_programs/execution_success/fold_call_witness_condition/Prover.toml b/noir/noir-repo/test_programs/execution_success/fold_call_witness_condition/Prover.toml index 8481ce25648..a4d6339b661 100644 --- a/noir/noir-repo/test_programs/execution_success/fold_call_witness_condition/Prover.toml +++ b/noir/noir-repo/test_programs/execution_success/fold_call_witness_condition/Prover.toml @@ -1,5 +1,3 @@ -# TODO(https://github.com/noir-lang/noir/issues/4707): Change these inputs to fail the assertion in `fn return_value` -# and change `enable` to false. For now we need the inputs to pass as we do not handle predicates with ACIR calls -x = "5" +x = "10" y = "10" -enable = true \ No newline at end of file +enable = false