Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(acir)!: Add predicate to call opcode #5616

Merged
merged 20 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,7 @@ struct Opcode {
uint32_t id;
std::vector<Program::Witness> inputs;
std::vector<Program::Witness> outputs;
std::optional<Program::Expression> predicate;

friend bool operator==(const Call&, const Call&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -7491,6 +7492,9 @@ inline bool operator==(const Opcode::Call& lhs, const Opcode::Call& rhs)
if (!(lhs.outputs == rhs.outputs)) {
return false;
}
if (!(lhs.predicate == rhs.predicate)) {
return false;
}
return true;
}

Expand Down Expand Up @@ -7520,6 +7524,7 @@ void serde::Serializable<Program::Opcode::Call>::serialize(const Program::Opcode
serde::Serializable<decltype(obj.id)>::serialize(obj.id, serializer);
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
serde::Serializable<decltype(obj.predicate)>::serialize(obj.predicate, serializer);
}

template <>
Expand All @@ -7530,6 +7535,7 @@ Program::Opcode::Call serde::Deserializable<Program::Opcode::Call>::deserialize(
obj.id = serde::Deserializable<decltype(obj.id)>::deserialize(deserializer);
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
obj.predicate = serde::Deserializable<decltype(obj.predicate)>::deserialize(deserializer);
return obj;
}

Expand Down
4 changes: 4 additions & 0 deletions noir/noir-repo/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,7 @@ namespace Program {
uint32_t id;
std::vector<Program::Witness> inputs;
std::vector<Program::Witness> outputs;
std::optional<Program::Expression> predicate;

friend bool operator==(const Call&, const Call&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -6173,6 +6174,7 @@ namespace Program {
if (!(lhs.id == rhs.id)) { return false; }
if (!(lhs.inputs == rhs.inputs)) { return false; }
if (!(lhs.outputs == rhs.outputs)) { return false; }
if (!(lhs.predicate == rhs.predicate)) { return false; }
return true;
}

Expand All @@ -6199,6 +6201,7 @@ void serde::Serializable<Program::Opcode::Call>::serialize(const Program::Opcode
serde::Serializable<decltype(obj.id)>::serialize(obj.id, serializer);
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
serde::Serializable<decltype(obj.predicate)>::serialize(obj.predicate, serializer);
}

template <>
Expand All @@ -6208,6 +6211,7 @@ Program::Opcode::Call serde::Deserializable<Program::Opcode::Call>::deserialize(
obj.id = serde::Deserializable<decltype(obj.id)>::deserialize(deserializer);
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
obj.predicate = serde::Deserializable<decltype(obj.predicate)>::deserialize(deserializer);
return obj;
}

Expand Down
7 changes: 6 additions & 1 deletion noir/noir-repo/acvm-repo/acir/src/circuit/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ pub enum Opcode {
inputs: Vec<Witness>,
/// Outputs of the function call
outputs: Vec<Witness>,
/// Predicate of the circuit execution - indicates if it should be skipped
predicate: Option<Expression>,
},
}

Expand Down Expand Up @@ -97,8 +99,11 @@ impl std::fmt::Display for Opcode {
write!(f, "INIT ")?;
write!(f, "(id: {}, len: {}) ", block_id.0, init.len())
}
Opcode::Call { id, inputs, outputs } => {
Opcode::Call { id, inputs, outputs, predicate } => {
write!(f, "CALL func {}: ", id)?;
if let Some(pred) = predicate {
writeln!(f, "PREDICATE = {pred}")?;
}
write!(f, "inputs: {:?}, ", inputs)?;
write!(f, "outputs: {:?}", outputs)
}
Expand Down
43 changes: 28 additions & 15 deletions noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,18 @@ fn nested_acir_call_circuit() {
// assert(x == y);
// x
// }
let nested_call =
Opcode::Call { id: 1, inputs: vec![Witness(0), Witness(1)], outputs: vec![Witness(2)] };
let nested_call_two =
Opcode::Call { id: 1, inputs: vec![Witness(0), Witness(1)], outputs: vec![Witness(3)] };
let nested_call = Opcode::Call {
id: 1,
inputs: vec![Witness(0), Witness(1)],
outputs: vec![Witness(2)],
predicate: Some(Expression::one()),
};
let nested_call_two = Opcode::Call {
id: 1,
inputs: vec![Witness(0), Witness(1)],
outputs: vec![Witness(3)],
predicate: Some(Expression::one()),
};

let assert_nested_call_results = Opcode::AssertZero(Expression {
mul_terms: Vec::new(),
Expand All @@ -410,8 +418,12 @@ fn nested_acir_call_circuit() {
],
q_c: FieldElement::one() + FieldElement::one(),
});
let call =
Opcode::Call { id: 2, inputs: vec![Witness(2), Witness(1)], outputs: vec![Witness(3)] };
let call = Opcode::Call {
id: 2,
inputs: vec![Witness(2), Witness(1)],
outputs: vec![Witness(3)],
predicate: Some(Expression::one()),
};

let nested_call = Circuit {
current_witness_index: 3,
Expand Down Expand Up @@ -443,15 +455,16 @@ fn nested_acir_call_circuit() {
let bytes = Program::serialize_program(&program);

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 205, 146, 97, 10, 195, 32, 12, 133, 163, 66, 207, 147,
24, 109, 227, 191, 93, 101, 50, 123, 255, 35, 172, 99, 25, 83, 17, 250, 99, 14, 250, 224,
97, 144, 16, 146, 143, 231, 224, 45, 167, 126, 105, 57, 108, 14, 91, 248, 202, 168, 65,
255, 207, 122, 28, 180, 250, 244, 221, 244, 197, 223, 68, 182, 154, 197, 184, 134, 80, 54,
95, 136, 233, 142, 62, 101, 137, 24, 98, 94, 133, 132, 162, 196, 135, 23, 230, 34, 65, 182,
148, 211, 134, 137, 2, 23, 218, 99, 226, 93, 135, 185, 121, 123, 33, 84, 12, 234, 218, 192,
64, 174, 3, 248, 47, 88, 48, 17, 150, 157, 183, 151, 95, 244, 86, 91, 221, 61, 10, 81, 31,
178, 190, 110, 194, 102, 96, 76, 251, 202, 80, 13, 204, 77, 224, 25, 176, 70, 79, 197, 128,
18, 64, 3, 4, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 83, 65, 14, 3, 33, 8, 68, 77, 246, 61, 32, 186,
139, 183, 126, 165, 166, 238, 255, 159, 208, 109, 74, 82, 229, 178, 135, 186, 77, 58, 201,
4, 195, 97, 128, 1, 3, 188, 17, 148, 47, 44, 7, 221, 65, 15, 31, 56, 37, 104, 222, 129,
193, 77, 35, 126, 7, 58, 43, 30, 174, 44, 222, 107, 250, 201, 218, 190, 211, 98, 92, 83,
106, 91, 108, 196, 116, 199, 88, 170, 100, 76, 185, 174, 66, 66, 89, 242, 35, 10, 115, 147,
36, 91, 169, 101, 195, 66, 137, 27, 237, 185, 240, 174, 98, 97, 94, 95, 8, 198, 80, 103,
226, 128, 0, 227, 102, 174, 50, 11, 38, 154, 229, 231, 245, 21, 23, 157, 213, 119, 115,
255, 244, 58, 237, 183, 176, 239, 0, 38, 233, 254, 108, 91, 14, 230, 158, 246, 153, 97, 3,
158, 188, 79, 135, 232, 14, 5, 0, 0,
];

assert_eq!(bytes, expected_serialization);
}
15 changes: 7 additions & 8 deletions noir/noir-repo/acvm-repo/acvm/src/pwg/memory_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use acir::{
FieldElement,
};

use super::{arithmetic::ExpressionSolver, get_value, insert_value, witness_to_value};
use super::{
arithmetic::ExpressionSolver, get_value, insert_value, is_predicate_false, witness_to_value,
};
use super::{ErrorLocation, OpcodeResolutionError};

type MemoryIndex = u32;
Expand Down Expand Up @@ -80,11 +82,8 @@ impl MemoryOpSolver {
// `operation == 0` implies a read operation. (`operation == 1` implies write operation).
let is_read_operation = operation.is_zero();

// If the predicate is `None`, then we simply return the value 1
let pred_value = match predicate {
Some(pred) => get_value(pred, initial_witness),
None => Ok(FieldElement::one()),
}?;
// Fetch whether or not the predicate is false (e.g. equal to zero)
let skip_operation = is_predicate_false(initial_witness, predicate)?;

if is_read_operation {
// `value_read = arr[memory_index]`
Expand All @@ -97,7 +96,7 @@ impl MemoryOpSolver {

// A zero predicate indicates that we should skip the read operation
// and zero out the operation's output.
let value_in_array = if pred_value.is_zero() {
let value_in_array = if skip_operation {
FieldElement::zero()
} else {
self.read_memory_index(memory_index)?
Expand All @@ -111,7 +110,7 @@ impl MemoryOpSolver {
let value_write = value;

// A zero predicate indicates that we should skip the write operation.
if pred_value.is_zero() {
if skip_operation {
// We only want to write to already initialized memory.
// Do nothing if the predicate is zero.
Ok(())
Expand Down
28 changes: 26 additions & 2 deletions noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
};

let witness = &mut self.witness_map;
if BrilligSolver::<B>::should_skip(witness, brillig)? {
if is_predicate_false(witness, &brillig.predicate)? {
return BrilligSolver::<B>::zero_out_brillig_outputs(witness, brillig).map(|_| None);
}

Expand Down Expand Up @@ -448,7 +448,9 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
}

pub fn solve_call_opcode(&mut self) -> Result<Option<AcirCallWaitInfo>, OpcodeResolutionError> {
let Opcode::Call { id, inputs, outputs } = &self.opcodes[self.instruction_pointer] else {
let Opcode::Call { id, inputs, outputs, predicate } =
&self.opcodes[self.instruction_pointer]
else {
unreachable!("Not executing a Call opcode");
};
if *id == 0 {
Expand All @@ -459,6 +461,14 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
});
}

if is_predicate_false(&self.witness_map, predicate)? {
// Zero out the outputs if we have a false predicate
for output in outputs {
insert_value(output, FieldElement::zero(), &mut self.witness_map)?;
}
return Ok(None);
}

if self.acir_call_counter >= self.acir_call_results.len() {
let mut initial_witness = WitnessMap::default();
for (i, input_witness) in inputs.iter().enumerate() {
Expand Down Expand Up @@ -556,6 +566,20 @@ fn any_witness_from_expression(expr: &Expression) -> Option<Witness> {
}
}

/// Returns `true` if the predicate is zero
/// A predicate is used to indicate whether we should skip a certain operation.
/// If we have a zero predicate it means the operation should be skipped.
pub(crate) fn is_predicate_false(
witness: &WitnessMap,
predicate: &Option<Expression>,
) -> Result<bool, OpcodeResolutionError> {
match predicate {
Some(pred) => get_value(pred, witness).map(|pred_value| pred_value.is_zero()),
// If the predicate is `None`, then we treat it as an unconditional `true`
None => Ok(false),
}
}

#[derive(Debug, Clone, PartialEq)]
pub struct AcirCallWaitInfo {
/// Index in the list of ACIR function's that should be called
Expand Down
14 changes: 7 additions & 7 deletions noir/noir-repo/acvm-repo/acvm_js/test/shared/nested_acir_call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ import { WitnessMap, StackItem, WitnessStack } from '@noir-lang/acvm_js';

// See `nested_acir_call_circuit` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 205, 146, 97, 10, 195, 32, 12, 133, 163, 66, 207, 147, 24, 109, 227, 191, 93, 101,
50, 123, 255, 35, 172, 99, 25, 83, 17, 250, 99, 14, 250, 224, 97, 144, 16, 146, 143, 231, 224, 45, 167, 126, 105, 57,
108, 14, 91, 248, 202, 168, 65, 255, 207, 122, 28, 180, 250, 244, 221, 244, 197, 223, 68, 182, 154, 197, 184, 134, 80,
54, 95, 136, 233, 142, 62, 101, 137, 24, 98, 94, 133, 132, 162, 196, 135, 23, 230, 34, 65, 182, 148, 211, 134, 137, 2,
23, 218, 99, 226, 93, 135, 185, 121, 123, 33, 84, 12, 234, 218, 192, 64, 174, 3, 248, 47, 88, 48, 17, 150, 157, 183,
151, 95, 244, 86, 91, 221, 61, 10, 81, 31, 178, 190, 110, 194, 102, 96, 76, 251, 202, 80, 13, 204, 77, 224, 25, 176,
70, 79, 197, 128, 18, 64, 3, 4, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 83, 65, 14, 3, 33, 8, 68, 77, 246, 61, 32, 186, 139, 183, 126, 165, 166, 238,
255, 159, 208, 109, 74, 82, 229, 178, 135, 186, 77, 58, 201, 4, 195, 97, 128, 1, 3, 188, 17, 148, 47, 44, 7, 221, 65,
15, 31, 56, 37, 104, 222, 129, 193, 77, 35, 126, 7, 58, 43, 30, 174, 44, 222, 107, 250, 201, 218, 190, 211, 98, 92,
83, 106, 91, 108, 196, 116, 199, 88, 170, 100, 76, 185, 174, 66, 66, 89, 242, 35, 10, 115, 147, 36, 91, 169, 101, 195,
66, 137, 27, 237, 185, 240, 174, 98, 97, 94, 95, 8, 198, 80, 103, 226, 128, 0, 227, 102, 174, 50, 11, 38, 154, 229,
231, 245, 21, 23, 157, 213, 119, 115, 255, 244, 58, 237, 183, 176, 239, 0, 38, 233, 254, 108, 91, 14, 230, 158, 246,
153, 97, 3, 158, 188, 79, 135, 232, 14, 5, 0, 0,
]);

export const initialWitnessMap: WitnessMap = new Map([
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1763,6 +1763,7 @@ impl AcirContext {
id: u32,
inputs: Vec<AcirValue>,
output_count: usize,
predicate: AcirVar,
) -> Result<Vec<AcirVar>, RuntimeError> {
let inputs = self.prepare_inputs_for_black_box_func_call(inputs)?;
let inputs = inputs
Expand All @@ -1778,7 +1779,8 @@ impl AcirContext {
let results =
vecmap(&outputs, |witness_index| self.add_data(AcirVarData::Witness(*witness_index)));

self.acir_ir.push_opcode(Opcode::Call { id, inputs, outputs });
let predicate = Some(self.var_to_expression(predicate)?);
self.acir_ir.push_opcode(Opcode::Call { id, inputs, outputs, predicate });
Ok(results)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ impl Context {
*acir_program_id,
inputs,
output_count,
self.current_side_effects_enabled_var,
)?;
let output_values =
self.convert_vars_to_values(output_vars, dfg, result_ids);
Expand Down Expand Up @@ -2713,7 +2714,7 @@ mod test {
expected_outputs: Vec<Witness>,
) {
match opcode {
Opcode::Call { id, inputs, outputs } => {
Opcode::Call { id, inputs, outputs, .. } => {
assert_eq!(
*id, expected_id,
"Main was expected to call {expected_id} but got {}",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# TODO(https://github.com/noir-lang/noir/issues/4707): Change these inputs to fail the assertion in `fn return_value`
# and change `enable` to false. For now we need the inputs to pass as we do not handle predicates with ACIR calls
x = "5"
x = "10"
y = "10"
enable = true
enable = false
Loading