Skip to content

Commit

Permalink
feat: Added cast opcode and cast calldata (AztecProtocol#4423)
Browse files Browse the repository at this point in the history
  • Loading branch information
sirasistant authored Feb 9, 2024
1 parent 9d50e24 commit e58eda8
Show file tree
Hide file tree
Showing 14 changed files with 341 additions and 83 deletions.
68 changes: 68 additions & 0 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,16 @@ struct BrilligOpcode {
static BinaryIntOp bincodeDeserialize(std::vector<uint8_t>);
};

struct Cast {
Circuit::MemoryAddress destination;
Circuit::MemoryAddress source;
uint32_t bit_size;

friend bool operator==(const Cast&, const Cast&);
std::vector<uint8_t> bincodeSerialize() const;
static Cast bincodeDeserialize(std::vector<uint8_t>);
};

struct JumpIfNot {
Circuit::MemoryAddress condition;
uint64_t location;
Expand Down Expand Up @@ -612,6 +622,7 @@ struct BrilligOpcode {

std::variant<BinaryFieldOp,
BinaryIntOp,
Cast,
JumpIfNot,
JumpIf,
Jump,
Expand Down Expand Up @@ -5192,6 +5203,63 @@ Circuit::BrilligOpcode::BinaryIntOp serde::Deserializable<Circuit::BrilligOpcode

namespace Circuit {

inline bool operator==(const BrilligOpcode::Cast& lhs, const BrilligOpcode::Cast& rhs)
{
if (!(lhs.destination == rhs.destination)) {
return false;
}
if (!(lhs.source == rhs.source)) {
return false;
}
if (!(lhs.bit_size == rhs.bit_size)) {
return false;
}
return true;
}

inline std::vector<uint8_t> BrilligOpcode::Cast::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BrilligOpcode::Cast>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BrilligOpcode::Cast BrilligOpcode::Cast::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BrilligOpcode::Cast>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BrilligOpcode::Cast>::serialize(const Circuit::BrilligOpcode::Cast& obj,
Serializer& serializer)
{
serde::Serializable<decltype(obj.destination)>::serialize(obj.destination, serializer);
serde::Serializable<decltype(obj.source)>::serialize(obj.source, serializer);
serde::Serializable<decltype(obj.bit_size)>::serialize(obj.bit_size, serializer);
}

template <>
template <typename Deserializer>
Circuit::BrilligOpcode::Cast serde::Deserializable<Circuit::BrilligOpcode::Cast>::deserialize(
Deserializer& deserializer)
{
Circuit::BrilligOpcode::Cast obj;
obj.destination = serde::Deserializable<decltype(obj.destination)>::deserialize(deserializer);
obj.source = serde::Deserializable<decltype(obj.source)>::deserialize(deserializer);
obj.bit_size = serde::Deserializable<decltype(obj.bit_size)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BrilligOpcode::JumpIfNot& lhs, const BrilligOpcode::JumpIfNot& rhs)
{
if (!(lhs.condition == rhs.condition)) {
Expand Down
56 changes: 55 additions & 1 deletion noir/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,16 @@ namespace Circuit {
static BinaryIntOp bincodeDeserialize(std::vector<uint8_t>);
};

struct Cast {
Circuit::MemoryAddress destination;
Circuit::MemoryAddress source;
uint32_t bit_size;

friend bool operator==(const Cast&, const Cast&);
std::vector<uint8_t> bincodeSerialize() const;
static Cast bincodeDeserialize(std::vector<uint8_t>);
};

struct JumpIfNot {
Circuit::MemoryAddress condition;
uint64_t location;
Expand Down Expand Up @@ -590,7 +600,7 @@ namespace Circuit {
static Stop bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<BinaryFieldOp, BinaryIntOp, JumpIfNot, JumpIf, Jump, CalldataCopy, Call, Const, Return, ForeignCall, Mov, Load, Store, BlackBox, Trap, Stop> value;
std::variant<BinaryFieldOp, BinaryIntOp, Cast, JumpIfNot, JumpIf, Jump, CalldataCopy, Call, Const, Return, ForeignCall, Mov, Load, Store, BlackBox, Trap, Stop> value;

friend bool operator==(const BrilligOpcode&, const BrilligOpcode&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4311,6 +4321,50 @@ Circuit::BrilligOpcode::BinaryIntOp serde::Deserializable<Circuit::BrilligOpcode
return obj;
}

namespace Circuit {

inline bool operator==(const BrilligOpcode::Cast &lhs, const BrilligOpcode::Cast &rhs) {
if (!(lhs.destination == rhs.destination)) { return false; }
if (!(lhs.source == rhs.source)) { return false; }
if (!(lhs.bit_size == rhs.bit_size)) { return false; }
return true;
}

inline std::vector<uint8_t> BrilligOpcode::Cast::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BrilligOpcode::Cast>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BrilligOpcode::Cast BrilligOpcode::Cast::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BrilligOpcode::Cast>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BrilligOpcode::Cast>::serialize(const Circuit::BrilligOpcode::Cast &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.destination)>::serialize(obj.destination, serializer);
serde::Serializable<decltype(obj.source)>::serialize(obj.source, serializer);
serde::Serializable<decltype(obj.bit_size)>::serialize(obj.bit_size, serializer);
}

template <>
template <typename Deserializer>
Circuit::BrilligOpcode::Cast serde::Deserializable<Circuit::BrilligOpcode::Cast>::deserialize(Deserializer &deserializer) {
Circuit::BrilligOpcode::Cast obj;
obj.destination = serde::Deserializable<decltype(obj.destination)>::deserialize(deserializer);
obj.source = serde::Deserializable<decltype(obj.source)>::deserialize(deserializer);
obj.bit_size = serde::Deserializable<decltype(obj.bit_size)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BrilligOpcode::JumpIfNot &lhs, const BrilligOpcode::JumpIfNot &rhs) {
Expand Down
28 changes: 14 additions & 14 deletions noir/acvm-repo/acir/tests/test_program_serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,11 @@ fn simple_brillig_foreign_call() {
let bytes = Circuit::serialize_circuit(&circuit);

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 65, 10, 192, 32, 12, 4, 215, 148, 150, 246,
212, 175, 216, 31, 244, 51, 61, 244, 226, 65, 196, 247, 171, 24, 33, 136, 122, 209, 129,
144, 176, 132, 101, 247, 4, 160, 144, 217, 196, 45, 41, 218, 203, 91, 207, 241, 168, 117,
94, 90, 230, 37, 238, 144, 216, 27, 249, 11, 87, 156, 131, 239, 223, 248, 207, 186, 81,
235, 150, 67, 173, 221, 189, 95, 18, 34, 97, 64, 0, 116, 135, 40, 214, 136, 1, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 177, 10, 192, 32, 16, 67, 227, 21, 74, 233,
212, 79, 177, 127, 208, 159, 233, 224, 226, 32, 226, 247, 139, 168, 16, 68, 93, 244, 45,
119, 228, 142, 144, 92, 0, 20, 50, 7, 237, 76, 213, 190, 50, 245, 26, 175, 218, 231, 165,
57, 175, 148, 14, 137, 179, 147, 191, 114, 211, 221, 216, 240, 59, 63, 107, 221, 115, 104,
181, 103, 244, 43, 36, 10, 38, 68, 108, 25, 253, 238, 136, 1, 0, 0,
];

assert_eq!(bytes, expected_serialization)
Expand Down Expand Up @@ -305,15 +305,15 @@ fn complex_brillig_foreign_call() {
let bytes = Circuit::serialize_circuit(&circuit);

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 73, 14, 131, 48, 12, 28, 147, 166, 165, 167,
126, 161, 82, 251, 128, 180, 47, 224, 47, 85, 111, 32, 56, 242, 124, 130, 24, 68, 176, 2,
23, 130, 4, 35, 89, 206, 50, 137, 71, 182, 147, 28, 128, 96, 128, 241, 150, 113, 44, 156,
135, 24, 121, 5, 189, 219, 134, 143, 164, 187, 203, 237, 165, 49, 59, 129, 70, 179, 131,
198, 177, 31, 14, 90, 239, 148, 117, 73, 154, 63, 19, 121, 63, 23, 111, 214, 219, 149, 243,
27, 125, 206, 117, 208, 63, 85, 222, 161, 248, 32, 167, 72, 162, 245, 235, 44, 166, 94, 20,
21, 251, 30, 196, 253, 213, 85, 83, 254, 91, 163, 168, 90, 234, 43, 24, 191, 213, 190, 172,
156, 235, 17, 126, 59, 49, 142, 68, 120, 75, 220, 7, 166, 84, 90, 68, 72, 194, 139, 180,
136, 25, 58, 46, 103, 45, 188, 25, 5, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 75, 10, 132, 48, 12, 125, 177, 163, 35, 179,
154, 35, 8, 51, 7, 232, 204, 9, 188, 139, 184, 83, 116, 233, 241, 173, 152, 98, 12, 213,
141, 21, 244, 65, 232, 39, 175, 233, 35, 73, 155, 3, 32, 204, 48, 206, 18, 158, 19, 175,
37, 60, 175, 228, 209, 30, 195, 143, 226, 197, 178, 103, 105, 76, 110, 160, 209, 156, 160,
209, 247, 195, 69, 235, 29, 179, 46, 81, 243, 103, 2, 239, 231, 225, 44, 117, 150, 241,
250, 201, 99, 206, 251, 96, 95, 161, 242, 14, 193, 243, 40, 162, 105, 253, 219, 12, 75, 47,
146, 186, 251, 37, 116, 86, 93, 219, 55, 245, 96, 20, 85, 75, 253, 136, 249, 87, 249, 105,
231, 220, 4, 249, 237, 132, 56, 20, 224, 109, 113, 223, 88, 82, 153, 34, 64, 34, 14, 164,
69, 172, 48, 2, 23, 243, 6, 31, 25, 5, 0, 0,
];

assert_eq!(bytes, expected_serialization)
Expand Down
14 changes: 7 additions & 7 deletions noir/acvm-repo/acvm_js/test/shared/complex_foreign_call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ import { WitnessMap } from '@noir-lang/acvm_js';

// See `complex_brillig_foreign_call` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 73, 14, 131, 48, 12, 28, 147, 166, 165, 167, 126, 161, 82, 251, 128, 180,
47, 224, 47, 85, 111, 32, 56, 242, 124, 130, 24, 68, 176, 2, 23, 130, 4, 35, 89, 206, 50, 137, 71, 182, 147, 28, 128,
96, 128, 241, 150, 113, 44, 156, 135, 24, 121, 5, 189, 219, 134, 143, 164, 187, 203, 237, 165, 49, 59, 129, 70, 179,
131, 198, 177, 31, 14, 90, 239, 148, 117, 73, 154, 63, 19, 121, 63, 23, 111, 214, 219, 149, 243, 27, 125, 206, 117,
208, 63, 85, 222, 161, 248, 32, 167, 72, 162, 245, 235, 44, 166, 94, 20, 21, 251, 30, 196, 253, 213, 85, 83, 254, 91,
163, 168, 90, 234, 43, 24, 191, 213, 190, 172, 156, 235, 17, 126, 59, 49, 142, 68, 120, 75, 220, 7, 166, 84, 90, 68,
72, 194, 139, 180, 136, 25, 58, 46, 103, 45, 188, 25, 5, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 75, 10, 132, 48, 12, 125, 177, 163, 35, 179, 154, 35, 8, 51, 7, 232, 204,
9, 188, 139, 184, 83, 116, 233, 241, 173, 152, 98, 12, 213, 141, 21, 244, 65, 232, 39, 175, 233, 35, 73, 155, 3, 32,
204, 48, 206, 18, 158, 19, 175, 37, 60, 175, 228, 209, 30, 195, 143, 226, 197, 178, 103, 105, 76, 110, 160, 209, 156,
160, 209, 247, 195, 69, 235, 29, 179, 46, 81, 243, 103, 2, 239, 231, 225, 44, 117, 150, 241, 250, 201, 99, 206, 251,
96, 95, 161, 242, 14, 193, 243, 40, 162, 105, 253, 219, 12, 75, 47, 146, 186, 251, 37, 116, 86, 93, 219, 55, 245, 96,
20, 85, 75, 253, 136, 249, 87, 249, 105, 231, 220, 4, 249, 237, 132, 56, 20, 224, 109, 113, 223, 88, 82, 153, 34, 64,
34, 14, 164, 69, 172, 48, 2, 23, 243, 6, 31, 25, 5, 0, 0,
]);
export const initialWitnessMap: WitnessMap = new Map([
[1, '0x0000000000000000000000000000000000000000000000000000000000000001'],
Expand Down
8 changes: 4 additions & 4 deletions noir/acvm-repo/acvm_js/test/shared/foreign_call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ import { WitnessMap } from '@noir-lang/acvm_js';

// See `simple_brillig_foreign_call` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 65, 10, 192, 32, 12, 4, 215, 148, 150, 246, 212, 175, 216, 31, 244, 51,
61, 244, 226, 65, 196, 247, 171, 24, 33, 136, 122, 209, 129, 144, 176, 132, 101, 247, 4, 160, 144, 217, 196, 45, 41,
218, 203, 91, 207, 241, 168, 117, 94, 90, 230, 37, 238, 144, 216, 27, 249, 11, 87, 156, 131, 239, 223, 248, 207, 186,
81, 235, 150, 67, 173, 221, 189, 95, 18, 34, 97, 64, 0, 116, 135, 40, 214, 136, 1, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 177, 10, 192, 32, 16, 67, 227, 21, 74, 233, 212, 79, 177, 127, 208, 159,
233, 224, 226, 32, 226, 247, 139, 168, 16, 68, 93, 244, 45, 119, 228, 142, 144, 92, 0, 20, 50, 7, 237, 76, 213, 190,
50, 245, 26, 175, 218, 231, 165, 57, 175, 148, 14, 137, 179, 147, 191, 114, 211, 221, 216, 240, 59, 63, 107, 221, 115,
104, 181, 103, 244, 43, 36, 10, 38, 68, 108, 25, 253, 238, 136, 1, 0, 0,
]);
export const initialWitnessMap: WitnessMap = new Map([
[1, '0x0000000000000000000000000000000000000000000000000000000000000005'],
Expand Down
5 changes: 5 additions & 0 deletions noir/acvm-repo/brillig/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ pub enum BrilligOpcode {
lhs: MemoryAddress,
rhs: MemoryAddress,
},
Cast {
destination: MemoryAddress,
source: MemoryAddress,
bit_size: u32,
},
JumpIfNot {
condition: MemoryAddress,
location: Label,
Expand Down
47 changes: 47 additions & 0 deletions noir/acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> {
self.increment_program_counter()
}
}
Opcode::Cast { destination: destination_address, source: source_address, bit_size } => {
let source_value = self.memory.read(*source_address);
let casted_value = self.cast(*bit_size, source_value);
self.memory.write(*destination_address, casted_value);
self.increment_program_counter()
}
Opcode::Jump { location: destination } => self.set_program_counter(*destination),
Opcode::JumpIf { condition, location: destination } => {
// Check if condition is true
Expand Down Expand Up @@ -501,6 +507,13 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> {
.write(result, FieldElement::from_be_bytes_reduce(&result_value.to_bytes_be()).into());
Ok(())
}

/// Casts a value to a different bit size.
fn cast(&self, bit_size: u32, value: Value) -> Value {
let lhs_big = BigUint::from_bytes_be(&value.to_field().to_be_bytes());
let mask = BigUint::from(2_u32).pow(bit_size) - 1_u32;
FieldElement::from_be_bytes_reduce(&(lhs_big & mask).to_bytes_be()).into()
}
}

pub(crate) struct DummyBlackBoxSolver;
Expand Down Expand Up @@ -698,6 +711,40 @@ mod tests {
assert_eq!(output_value, Value::from(false));
}

#[test]
fn cast_opcode() {
let calldata = vec![Value::from((2_u128.pow(32)) - 1)];

let opcodes = &[
Opcode::CalldataCopy {
destination_address: MemoryAddress::from(0),
size: 1,
offset: 0,
},
Opcode::Cast {
destination: MemoryAddress::from(1),
source: MemoryAddress::from(0),
bit_size: 8,
},
Opcode::Stop { return_data_offset: 1, return_data_size: 1 },
];
let mut vm = VM::new(calldata, opcodes, vec![], &DummyBlackBoxSolver);

let status = vm.process_opcode();
assert_eq!(status, VMStatus::InProgress);

let status = vm.process_opcode();
assert_eq!(status, VMStatus::InProgress);

let status = vm.process_opcode();
assert_eq!(status, VMStatus::Finished { return_data_offset: 1, return_data_size: 1 });

let VM { memory, .. } = vm;

let casted_value = memory.read(MemoryAddress::from(1));
assert_eq!(casted_value, Value::from(2_u128.pow(8) - 1));
}

#[test]
fn mov_opcode() {
let calldata = vec![Value::from(1u128), Value::from(2u128), Value::from(3u128)];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use num_bigint::BigUint;

use super::brillig_black_box::convert_black_box_call;
use super::brillig_block_variables::BlockVariables;
use super::brillig_fn::FunctionContext;
use super::brillig_fn::{get_bit_size_from_ssa_type, FunctionContext};

/// Generate the compilation artifacts for compiling a function into brillig bytecode.
pub(crate) struct BrilligBlock<'block> {
Expand Down Expand Up @@ -87,16 +87,6 @@ impl<'block> BrilligBlock<'block> {
self.convert_ssa_terminator(terminator_instruction, dfg);
}

fn get_bit_size_from_ssa_type(typ: &Type) -> u32 {
match typ {
Type::Numeric(num_type) => match num_type {
NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size } => *bit_size,
NumericType::NativeField => FieldElement::max_num_bits(),
},
_ => unreachable!("ICE bitwise not on a non numeric type"),
}
}

/// Creates a unique global label for a block.
///
/// This uses the current functions's function ID and the block ID
Expand Down Expand Up @@ -324,7 +314,7 @@ impl<'block> BrilligBlock<'block> {
dfg.instruction_results(instruction_id)[0],
dfg,
);
let bit_size = Self::get_bit_size_from_ssa_type(&dfg.type_of_value(*value));
let bit_size = get_bit_size_from_ssa_type(&dfg.type_of_value(*value));
self.brillig_context.not_instruction(condition_register, bit_size, result_register);
}
Instruction::Call { func, arguments } => match &dfg[*func] {
Expand Down Expand Up @@ -547,7 +537,7 @@ impl<'block> BrilligBlock<'block> {
*bit_size,
);
}
Instruction::Cast(value, _) => {
Instruction::Cast(value, typ) => {
let result_ids = dfg.instruction_results(instruction_id);
let destination_register = self.variables.define_register_variable(
self.function_context,
Expand All @@ -556,7 +546,7 @@ impl<'block> BrilligBlock<'block> {
dfg,
);
let source_register = self.convert_ssa_register_value(*value, dfg);
self.convert_cast(destination_register, source_register);
self.convert_cast(destination_register, source_register, typ);
}
Instruction::ArrayGet { array, index } => {
let result_ids = dfg.instruction_results(instruction_id);
Expand Down Expand Up @@ -1136,11 +1126,11 @@ impl<'block> BrilligBlock<'block> {

/// Converts an SSA cast to a sequence of Brillig opcodes.
/// Casting is only necessary when shrinking the bit size of a numeric value.
fn convert_cast(&mut self, destination: MemoryAddress, source: MemoryAddress) {
fn convert_cast(&mut self, destination: MemoryAddress, source: MemoryAddress, typ: &Type) {
// We assume that `source` is a valid `target_type` as it's expected that a truncate instruction was emitted
// to ensure this is the case.

self.brillig_context.mov_instruction(destination, source);
self.brillig_context.cast_instruction(destination, source, get_bit_size_from_ssa_type(typ));
}

/// Converts the Binary instruction into a sequence of Brillig opcodes.
Expand Down Expand Up @@ -1186,7 +1176,7 @@ impl<'block> BrilligBlock<'block> {
self.brillig_context.const_instruction(
register_index,
(*constant).into(),
Self::get_bit_size_from_ssa_type(typ),
get_bit_size_from_ssa_type(typ),
);
new_variable
}
Expand Down
Loading

0 comments on commit e58eda8

Please sign in to comment.