Skip to content

Commit

Permalink
feat(acir)!: Add predicate to call opcode (#5616)
Browse files Browse the repository at this point in the history
Resolves this issue noir-lang/noir#4707.

We simply add a `predicate` onto the Call opcode and handle it how we do
other predicates such as for Brillig opcodes and memory ops. I also made
a general utility method for checking a predicate in the VM as it now
happens for three different opcodes.

---------

Co-authored-by: AztecBot <[email protected]>
Co-authored-by: ludamad <[email protected]>
Co-authored-by: sirasistant <[email protected]>
Co-authored-by: Tom French <[email protected]>
  • Loading branch information
5 people authored Apr 8, 2024
1 parent a974ec8 commit e8cec0a
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,7 @@ struct Opcode {
uint32_t id;
std::vector<Program::Witness> inputs;
std::vector<Program::Witness> outputs;
std::optional<Program::Expression> predicate;

friend bool operator==(const Call&, const Call&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -7491,6 +7492,9 @@ inline bool operator==(const Opcode::Call& lhs, const Opcode::Call& rhs)
if (!(lhs.outputs == rhs.outputs)) {
return false;
}
if (!(lhs.predicate == rhs.predicate)) {
return false;
}
return true;
}

Expand Down Expand Up @@ -7520,6 +7524,7 @@ void serde::Serializable<Program::Opcode::Call>::serialize(const Program::Opcode
serde::Serializable<decltype(obj.id)>::serialize(obj.id, serializer);
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
serde::Serializable<decltype(obj.predicate)>::serialize(obj.predicate, serializer);
}

template <>
Expand All @@ -7530,6 +7535,7 @@ Program::Opcode::Call serde::Deserializable<Program::Opcode::Call>::deserialize(
obj.id = serde::Deserializable<decltype(obj.id)>::deserialize(deserializer);
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
obj.predicate = serde::Deserializable<decltype(obj.predicate)>::deserialize(deserializer);
return obj;
}

Expand Down
4 changes: 4 additions & 0 deletions noir/noir-repo/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,7 @@ namespace Program {
uint32_t id;
std::vector<Program::Witness> inputs;
std::vector<Program::Witness> outputs;
std::optional<Program::Expression> predicate;

friend bool operator==(const Call&, const Call&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -6173,6 +6174,7 @@ namespace Program {
if (!(lhs.id == rhs.id)) { return false; }
if (!(lhs.inputs == rhs.inputs)) { return false; }
if (!(lhs.outputs == rhs.outputs)) { return false; }
if (!(lhs.predicate == rhs.predicate)) { return false; }
return true;
}

Expand All @@ -6199,6 +6201,7 @@ void serde::Serializable<Program::Opcode::Call>::serialize(const Program::Opcode
serde::Serializable<decltype(obj.id)>::serialize(obj.id, serializer);
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
serde::Serializable<decltype(obj.predicate)>::serialize(obj.predicate, serializer);
}

template <>
Expand All @@ -6208,6 +6211,7 @@ Program::Opcode::Call serde::Deserializable<Program::Opcode::Call>::deserialize(
obj.id = serde::Deserializable<decltype(obj.id)>::deserialize(deserializer);
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
obj.predicate = serde::Deserializable<decltype(obj.predicate)>::deserialize(deserializer);
return obj;
}

Expand Down
7 changes: 6 additions & 1 deletion noir/noir-repo/acvm-repo/acir/src/circuit/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ pub enum Opcode {
inputs: Vec<Witness>,
/// Outputs of the function call
outputs: Vec<Witness>,
/// Predicate of the circuit execution - indicates if it should be skipped
predicate: Option<Expression>,
},
}

Expand Down Expand Up @@ -97,8 +99,11 @@ impl std::fmt::Display for Opcode {
write!(f, "INIT ")?;
write!(f, "(id: {}, len: {}) ", block_id.0, init.len())
}
Opcode::Call { id, inputs, outputs } => {
Opcode::Call { id, inputs, outputs, predicate } => {
write!(f, "CALL func {}: ", id)?;
if let Some(pred) = predicate {
writeln!(f, "PREDICATE = {pred}")?;
}
write!(f, "inputs: {:?}, ", inputs)?;
write!(f, "outputs: {:?}", outputs)
}
Expand Down
43 changes: 28 additions & 15 deletions noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,18 @@ fn nested_acir_call_circuit() {
// assert(x == y);
// x
// }
let nested_call =
Opcode::Call { id: 1, inputs: vec![Witness(0), Witness(1)], outputs: vec![Witness(2)] };
let nested_call_two =
Opcode::Call { id: 1, inputs: vec![Witness(0), Witness(1)], outputs: vec![Witness(3)] };
let nested_call = Opcode::Call {
id: 1,
inputs: vec![Witness(0), Witness(1)],
outputs: vec![Witness(2)],
predicate: Some(Expression::one()),
};
let nested_call_two = Opcode::Call {
id: 1,
inputs: vec![Witness(0), Witness(1)],
outputs: vec![Witness(3)],
predicate: Some(Expression::one()),
};

let assert_nested_call_results = Opcode::AssertZero(Expression {
mul_terms: Vec::new(),
Expand All @@ -410,8 +418,12 @@ fn nested_acir_call_circuit() {
],
q_c: FieldElement::one() + FieldElement::one(),
});
let call =
Opcode::Call { id: 2, inputs: vec![Witness(2), Witness(1)], outputs: vec![Witness(3)] };
let call = Opcode::Call {
id: 2,
inputs: vec![Witness(2), Witness(1)],
outputs: vec![Witness(3)],
predicate: Some(Expression::one()),
};

let nested_call = Circuit {
current_witness_index: 3,
Expand Down Expand Up @@ -443,15 +455,16 @@ fn nested_acir_call_circuit() {
let bytes = Program::serialize_program(&program);

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 205, 146, 97, 10, 195, 32, 12, 133, 163, 66, 207, 147,
24, 109, 227, 191, 93, 101, 50, 123, 255, 35, 172, 99, 25, 83, 17, 250, 99, 14, 250, 224,
97, 144, 16, 146, 143, 231, 224, 45, 167, 126, 105, 57, 108, 14, 91, 248, 202, 168, 65,
255, 207, 122, 28, 180, 250, 244, 221, 244, 197, 223, 68, 182, 154, 197, 184, 134, 80, 54,
95, 136, 233, 142, 62, 101, 137, 24, 98, 94, 133, 132, 162, 196, 135, 23, 230, 34, 65, 182,
148, 211, 134, 137, 2, 23, 218, 99, 226, 93, 135, 185, 121, 123, 33, 84, 12, 234, 218, 192,
64, 174, 3, 248, 47, 88, 48, 17, 150, 157, 183, 151, 95, 244, 86, 91, 221, 61, 10, 81, 31,
178, 190, 110, 194, 102, 96, 76, 251, 202, 80, 13, 204, 77, 224, 25, 176, 70, 79, 197, 128,
18, 64, 3, 4, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 83, 65, 14, 3, 33, 8, 68, 77, 246, 61, 32, 186,
139, 183, 126, 165, 166, 238, 255, 159, 208, 109, 74, 82, 229, 178, 135, 186, 77, 58, 201,
4, 195, 97, 128, 1, 3, 188, 17, 148, 47, 44, 7, 221, 65, 15, 31, 56, 37, 104, 222, 129,
193, 77, 35, 126, 7, 58, 43, 30, 174, 44, 222, 107, 250, 201, 218, 190, 211, 98, 92, 83,
106, 91, 108, 196, 116, 199, 88, 170, 100, 76, 185, 174, 66, 66, 89, 242, 35, 10, 115, 147,
36, 91, 169, 101, 195, 66, 137, 27, 237, 185, 240, 174, 98, 97, 94, 95, 8, 198, 80, 103,
226, 128, 0, 227, 102, 174, 50, 11, 38, 154, 229, 231, 245, 21, 23, 157, 213, 119, 115,
255, 244, 58, 237, 183, 176, 239, 0, 38, 233, 254, 108, 91, 14, 230, 158, 246, 153, 97, 3,
158, 188, 79, 135, 232, 14, 5, 0, 0,
];

assert_eq!(bytes, expected_serialization);
}
15 changes: 7 additions & 8 deletions noir/noir-repo/acvm-repo/acvm/src/pwg/memory_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use acir::{
FieldElement,
};

use super::{arithmetic::ExpressionSolver, get_value, insert_value, witness_to_value};
use super::{
arithmetic::ExpressionSolver, get_value, insert_value, is_predicate_false, witness_to_value,
};
use super::{ErrorLocation, OpcodeResolutionError};

type MemoryIndex = u32;
Expand Down Expand Up @@ -80,11 +82,8 @@ impl MemoryOpSolver {
// `operation == 0` implies a read operation. (`operation == 1` implies write operation).
let is_read_operation = operation.is_zero();

// If the predicate is `None`, then we simply return the value 1
let pred_value = match predicate {
Some(pred) => get_value(pred, initial_witness),
None => Ok(FieldElement::one()),
}?;
// Fetch whether or not the predicate is false (e.g. equal to zero)
let skip_operation = is_predicate_false(initial_witness, predicate)?;

if is_read_operation {
// `value_read = arr[memory_index]`
Expand All @@ -97,7 +96,7 @@ impl MemoryOpSolver {

// A zero predicate indicates that we should skip the read operation
// and zero out the operation's output.
let value_in_array = if pred_value.is_zero() {
let value_in_array = if skip_operation {
FieldElement::zero()
} else {
self.read_memory_index(memory_index)?
Expand All @@ -111,7 +110,7 @@ impl MemoryOpSolver {
let value_write = value;

// A zero predicate indicates that we should skip the write operation.
if pred_value.is_zero() {
if skip_operation {
// We only want to write to already initialized memory.
// Do nothing if the predicate is zero.
Ok(())
Expand Down
28 changes: 26 additions & 2 deletions noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
};

let witness = &mut self.witness_map;
if BrilligSolver::<B>::should_skip(witness, brillig)? {
if is_predicate_false(witness, &brillig.predicate)? {
return BrilligSolver::<B>::zero_out_brillig_outputs(witness, brillig).map(|_| None);
}

Expand Down Expand Up @@ -448,7 +448,9 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
}

pub fn solve_call_opcode(&mut self) -> Result<Option<AcirCallWaitInfo>, OpcodeResolutionError> {
let Opcode::Call { id, inputs, outputs } = &self.opcodes[self.instruction_pointer] else {
let Opcode::Call { id, inputs, outputs, predicate } =
&self.opcodes[self.instruction_pointer]
else {
unreachable!("Not executing a Call opcode");
};
if *id == 0 {
Expand All @@ -459,6 +461,14 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
});
}

if is_predicate_false(&self.witness_map, predicate)? {
// Zero out the outputs if we have a false predicate
for output in outputs {
insert_value(output, FieldElement::zero(), &mut self.witness_map)?;
}
return Ok(None);
}

if self.acir_call_counter >= self.acir_call_results.len() {
let mut initial_witness = WitnessMap::default();
for (i, input_witness) in inputs.iter().enumerate() {
Expand Down Expand Up @@ -556,6 +566,20 @@ fn any_witness_from_expression(expr: &Expression) -> Option<Witness> {
}
}

/// Returns `true` if the predicate is zero
/// A predicate is used to indicate whether we should skip a certain operation.
/// If we have a zero predicate it means the operation should be skipped.
pub(crate) fn is_predicate_false(
witness: &WitnessMap,
predicate: &Option<Expression>,
) -> Result<bool, OpcodeResolutionError> {
match predicate {
Some(pred) => get_value(pred, witness).map(|pred_value| pred_value.is_zero()),
// If the predicate is `None`, then we treat it as an unconditional `true`
None => Ok(false),
}
}

#[derive(Debug, Clone, PartialEq)]
pub struct AcirCallWaitInfo {
/// Index in the list of ACIR function's that should be called
Expand Down
14 changes: 7 additions & 7 deletions noir/noir-repo/acvm-repo/acvm_js/test/shared/nested_acir_call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ import { WitnessMap, StackItem, WitnessStack } from '@noir-lang/acvm_js';

// See `nested_acir_call_circuit` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 205, 146, 97, 10, 195, 32, 12, 133, 163, 66, 207, 147, 24, 109, 227, 191, 93, 101,
50, 123, 255, 35, 172, 99, 25, 83, 17, 250, 99, 14, 250, 224, 97, 144, 16, 146, 143, 231, 224, 45, 167, 126, 105, 57,
108, 14, 91, 248, 202, 168, 65, 255, 207, 122, 28, 180, 250, 244, 221, 244, 197, 223, 68, 182, 154, 197, 184, 134, 80,
54, 95, 136, 233, 142, 62, 101, 137, 24, 98, 94, 133, 132, 162, 196, 135, 23, 230, 34, 65, 182, 148, 211, 134, 137, 2,
23, 218, 99, 226, 93, 135, 185, 121, 123, 33, 84, 12, 234, 218, 192, 64, 174, 3, 248, 47, 88, 48, 17, 150, 157, 183,
151, 95, 244, 86, 91, 221, 61, 10, 81, 31, 178, 190, 110, 194, 102, 96, 76, 251, 202, 80, 13, 204, 77, 224, 25, 176,
70, 79, 197, 128, 18, 64, 3, 4, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 83, 65, 14, 3, 33, 8, 68, 77, 246, 61, 32, 186, 139, 183, 126, 165, 166, 238,
255, 159, 208, 109, 74, 82, 229, 178, 135, 186, 77, 58, 201, 4, 195, 97, 128, 1, 3, 188, 17, 148, 47, 44, 7, 221, 65,
15, 31, 56, 37, 104, 222, 129, 193, 77, 35, 126, 7, 58, 43, 30, 174, 44, 222, 107, 250, 201, 218, 190, 211, 98, 92,
83, 106, 91, 108, 196, 116, 199, 88, 170, 100, 76, 185, 174, 66, 66, 89, 242, 35, 10, 115, 147, 36, 91, 169, 101, 195,
66, 137, 27, 237, 185, 240, 174, 98, 97, 94, 95, 8, 198, 80, 103, 226, 128, 0, 227, 102, 174, 50, 11, 38, 154, 229,
231, 245, 21, 23, 157, 213, 119, 115, 255, 244, 58, 237, 183, 176, 239, 0, 38, 233, 254, 108, 91, 14, 230, 158, 246,
153, 97, 3, 158, 188, 79, 135, 232, 14, 5, 0, 0,
]);

export const initialWitnessMap: WitnessMap = new Map([
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1763,6 +1763,7 @@ impl AcirContext {
id: u32,
inputs: Vec<AcirValue>,
output_count: usize,
predicate: AcirVar,
) -> Result<Vec<AcirVar>, RuntimeError> {
let inputs = self.prepare_inputs_for_black_box_func_call(inputs)?;
let inputs = inputs
Expand All @@ -1778,7 +1779,8 @@ impl AcirContext {
let results =
vecmap(&outputs, |witness_index| self.add_data(AcirVarData::Witness(*witness_index)));

self.acir_ir.push_opcode(Opcode::Call { id, inputs, outputs });
let predicate = Some(self.var_to_expression(predicate)?);
self.acir_ir.push_opcode(Opcode::Call { id, inputs, outputs, predicate });
Ok(results)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ impl Context {
*acir_program_id,
inputs,
output_count,
self.current_side_effects_enabled_var,
)?;
let output_values =
self.convert_vars_to_values(output_vars, dfg, result_ids);
Expand Down Expand Up @@ -2713,7 +2714,7 @@ mod test {
expected_outputs: Vec<Witness>,
) {
match opcode {
Opcode::Call { id, inputs, outputs } => {
Opcode::Call { id, inputs, outputs, .. } => {
assert_eq!(
*id, expected_id,
"Main was expected to call {expected_id} but got {}",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# TODO(https://github.com/noir-lang/noir/issues/4707): Change these inputs to fail the assertion in `fn return_value`
# and change `enable` to false. For now we need the inputs to pass as we do not handle predicates with ACIR calls
x = "5"
x = "10"
y = "10"
enable = true
enable = false

0 comments on commit e8cec0a

Please sign in to comment.