From e58eda804cbdd8a7220013ac8befacbef243b856 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Rodr=C3=ADguez?= Date: Fri, 9 Feb 2024 09:47:01 +0100 Subject: [PATCH] feat: Added cast opcode and cast calldata (#4423) Resolves https://github.com/noir-lang/noir/issues/4160 --- .../dsl/acir_format/serde/acir.hpp | 68 +++++++++++++++++++ noir/acvm-repo/acir/codegen/acir.cpp | 56 ++++++++++++++- .../acir/tests/test_program_serialization.rs | 28 ++++---- .../test/shared/complex_foreign_call.ts | 14 ++-- .../acvm_js/test/shared/foreign_call.ts | 8 +-- noir/acvm-repo/brillig/src/opcodes.rs | 5 ++ noir/acvm-repo/brillig_vm/src/lib.rs | 47 +++++++++++++ .../src/brillig/brillig_gen/brillig_block.rs | 24 ++----- .../src/brillig/brillig_gen/brillig_fn.rs | 20 +++++- .../brillig/brillig_gen/brillig_slice_ops.rs | 57 +++++++++++----- .../noirc_evaluator/src/brillig/brillig_ir.rs | 11 +++ .../src/brillig/brillig_ir/artifact.rs | 5 +- .../src/brillig/brillig_ir/debug_show.rs | 16 +++++ .../src/brillig/brillig_ir/entry_point.rs | 65 +++++++++++++----- 14 files changed, 341 insertions(+), 83 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp index da1c0ddb68a..2c6bea75698 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp @@ -488,6 +488,16 @@ struct BrilligOpcode { static BinaryIntOp bincodeDeserialize(std::vector); }; + struct Cast { + Circuit::MemoryAddress destination; + Circuit::MemoryAddress source; + uint32_t bit_size; + + friend bool operator==(const Cast&, const Cast&); + std::vector bincodeSerialize() const; + static Cast bincodeDeserialize(std::vector); + }; + struct JumpIfNot { Circuit::MemoryAddress condition; uint64_t location; @@ -612,6 +622,7 @@ struct BrilligOpcode { std::variant BrilligOpcode::Cast::bincodeSerialize() const +{ + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); +} + +inline BrilligOpcode::Cast BrilligOpcode::Cast::bincodeDeserialize(std::vector input) +{ + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw_or_abort("Some input bytes were not read"); + } + return value; +} + +} // end of namespace Circuit + +template <> +template +void serde::Serializable::serialize(const Circuit::BrilligOpcode::Cast& obj, + Serializer& serializer) +{ + serde::Serializable::serialize(obj.destination, serializer); + serde::Serializable::serialize(obj.source, serializer); + serde::Serializable::serialize(obj.bit_size, serializer); +} + +template <> +template +Circuit::BrilligOpcode::Cast serde::Deserializable::deserialize( + Deserializer& deserializer) +{ + Circuit::BrilligOpcode::Cast obj; + obj.destination = serde::Deserializable::deserialize(deserializer); + obj.source = serde::Deserializable::deserialize(deserializer); + obj.bit_size = serde::Deserializable::deserialize(deserializer); + return obj; +} + +namespace Circuit { + inline bool operator==(const BrilligOpcode::JumpIfNot& lhs, const BrilligOpcode::JumpIfNot& rhs) { if (!(lhs.condition == rhs.condition)) { diff --git a/noir/acvm-repo/acir/codegen/acir.cpp b/noir/acvm-repo/acir/codegen/acir.cpp index b1c0d12272c..3ce63ecfa94 100644 --- a/noir/acvm-repo/acir/codegen/acir.cpp +++ b/noir/acvm-repo/acir/codegen/acir.cpp @@ -468,6 +468,16 @@ namespace Circuit { static BinaryIntOp bincodeDeserialize(std::vector); }; + struct Cast { + Circuit::MemoryAddress destination; + Circuit::MemoryAddress source; + uint32_t bit_size; + + friend bool operator==(const Cast&, const Cast&); + std::vector bincodeSerialize() const; + static Cast bincodeDeserialize(std::vector); + }; + struct JumpIfNot { Circuit::MemoryAddress condition; uint64_t location; @@ -590,7 +600,7 @@ namespace Circuit { static Stop bincodeDeserialize(std::vector); }; - std::variant value; + std::variant value; friend bool operator==(const BrilligOpcode&, const BrilligOpcode&); std::vector bincodeSerialize() const; @@ -4311,6 +4321,50 @@ Circuit::BrilligOpcode::BinaryIntOp serde::Deserializable BrilligOpcode::Cast::bincodeSerialize() const { + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); + } + + inline BrilligOpcode::Cast BrilligOpcode::Cast::bincodeDeserialize(std::vector input) { + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw serde::deserialization_error("Some input bytes were not read"); + } + return value; + } + +} // end of namespace Circuit + +template <> +template +void serde::Serializable::serialize(const Circuit::BrilligOpcode::Cast &obj, Serializer &serializer) { + serde::Serializable::serialize(obj.destination, serializer); + serde::Serializable::serialize(obj.source, serializer); + serde::Serializable::serialize(obj.bit_size, serializer); +} + +template <> +template +Circuit::BrilligOpcode::Cast serde::Deserializable::deserialize(Deserializer &deserializer) { + Circuit::BrilligOpcode::Cast obj; + obj.destination = serde::Deserializable::deserialize(deserializer); + obj.source = serde::Deserializable::deserialize(deserializer); + obj.bit_size = serde::Deserializable::deserialize(deserializer); + return obj; +} + namespace Circuit { inline bool operator==(const BrilligOpcode::JumpIfNot &lhs, const BrilligOpcode::JumpIfNot &rhs) { diff --git a/noir/acvm-repo/acir/tests/test_program_serialization.rs b/noir/acvm-repo/acir/tests/test_program_serialization.rs index 6a73522c822..2c8ad2b9986 100644 --- a/noir/acvm-repo/acir/tests/test_program_serialization.rs +++ b/noir/acvm-repo/acir/tests/test_program_serialization.rs @@ -206,11 +206,11 @@ fn simple_brillig_foreign_call() { let bytes = Circuit::serialize_circuit(&circuit); let expected_serialization: Vec = vec![ - 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 65, 10, 192, 32, 12, 4, 215, 148, 150, 246, - 212, 175, 216, 31, 244, 51, 61, 244, 226, 65, 196, 247, 171, 24, 33, 136, 122, 209, 129, - 144, 176, 132, 101, 247, 4, 160, 144, 217, 196, 45, 41, 218, 203, 91, 207, 241, 168, 117, - 94, 90, 230, 37, 238, 144, 216, 27, 249, 11, 87, 156, 131, 239, 223, 248, 207, 186, 81, - 235, 150, 67, 173, 221, 189, 95, 18, 34, 97, 64, 0, 116, 135, 40, 214, 136, 1, 0, 0, + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 177, 10, 192, 32, 16, 67, 227, 21, 74, 233, + 212, 79, 177, 127, 208, 159, 233, 224, 226, 32, 226, 247, 139, 168, 16, 68, 93, 244, 45, + 119, 228, 142, 144, 92, 0, 20, 50, 7, 237, 76, 213, 190, 50, 245, 26, 175, 218, 231, 165, + 57, 175, 148, 14, 137, 179, 147, 191, 114, 211, 221, 216, 240, 59, 63, 107, 221, 115, 104, + 181, 103, 244, 43, 36, 10, 38, 68, 108, 25, 253, 238, 136, 1, 0, 0, ]; assert_eq!(bytes, expected_serialization) @@ -305,15 +305,15 @@ fn complex_brillig_foreign_call() { let bytes = Circuit::serialize_circuit(&circuit); let expected_serialization: Vec = vec![ - 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 73, 14, 131, 48, 12, 28, 147, 166, 165, 167, - 126, 161, 82, 251, 128, 180, 47, 224, 47, 85, 111, 32, 56, 242, 124, 130, 24, 68, 176, 2, - 23, 130, 4, 35, 89, 206, 50, 137, 71, 182, 147, 28, 128, 96, 128, 241, 150, 113, 44, 156, - 135, 24, 121, 5, 189, 219, 134, 143, 164, 187, 203, 237, 165, 49, 59, 129, 70, 179, 131, - 198, 177, 31, 14, 90, 239, 148, 117, 73, 154, 63, 19, 121, 63, 23, 111, 214, 219, 149, 243, - 27, 125, 206, 117, 208, 63, 85, 222, 161, 248, 32, 167, 72, 162, 245, 235, 44, 166, 94, 20, - 21, 251, 30, 196, 253, 213, 85, 83, 254, 91, 163, 168, 90, 234, 43, 24, 191, 213, 190, 172, - 156, 235, 17, 126, 59, 49, 142, 68, 120, 75, 220, 7, 166, 84, 90, 68, 72, 194, 139, 180, - 136, 25, 58, 46, 103, 45, 188, 25, 5, 0, 0, + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 75, 10, 132, 48, 12, 125, 177, 163, 35, 179, + 154, 35, 8, 51, 7, 232, 204, 9, 188, 139, 184, 83, 116, 233, 241, 173, 152, 98, 12, 213, + 141, 21, 244, 65, 232, 39, 175, 233, 35, 73, 155, 3, 32, 204, 48, 206, 18, 158, 19, 175, + 37, 60, 175, 228, 209, 30, 195, 143, 226, 197, 178, 103, 105, 76, 110, 160, 209, 156, 160, + 209, 247, 195, 69, 235, 29, 179, 46, 81, 243, 103, 2, 239, 231, 225, 44, 117, 150, 241, + 250, 201, 99, 206, 251, 96, 95, 161, 242, 14, 193, 243, 40, 162, 105, 253, 219, 12, 75, 47, + 146, 186, 251, 37, 116, 86, 93, 219, 55, 245, 96, 20, 85, 75, 253, 136, 249, 87, 249, 105, + 231, 220, 4, 249, 237, 132, 56, 20, 224, 109, 113, 223, 88, 82, 153, 34, 64, 34, 14, 164, + 69, 172, 48, 2, 23, 243, 6, 31, 25, 5, 0, 0, ]; assert_eq!(bytes, expected_serialization) diff --git a/noir/acvm-repo/acvm_js/test/shared/complex_foreign_call.ts b/noir/acvm-repo/acvm_js/test/shared/complex_foreign_call.ts index 07c343c5ba0..27abd72305f 100644 --- a/noir/acvm-repo/acvm_js/test/shared/complex_foreign_call.ts +++ b/noir/acvm-repo/acvm_js/test/shared/complex_foreign_call.ts @@ -2,13 +2,13 @@ import { WitnessMap } from '@noir-lang/acvm_js'; // See `complex_brillig_foreign_call` integration test in `acir/tests/test_program_serialization.rs`. export const bytecode = Uint8Array.from([ - 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 73, 14, 131, 48, 12, 28, 147, 166, 165, 167, 126, 161, 82, 251, 128, 180, - 47, 224, 47, 85, 111, 32, 56, 242, 124, 130, 24, 68, 176, 2, 23, 130, 4, 35, 89, 206, 50, 137, 71, 182, 147, 28, 128, - 96, 128, 241, 150, 113, 44, 156, 135, 24, 121, 5, 189, 219, 134, 143, 164, 187, 203, 237, 165, 49, 59, 129, 70, 179, - 131, 198, 177, 31, 14, 90, 239, 148, 117, 73, 154, 63, 19, 121, 63, 23, 111, 214, 219, 149, 243, 27, 125, 206, 117, - 208, 63, 85, 222, 161, 248, 32, 167, 72, 162, 245, 235, 44, 166, 94, 20, 21, 251, 30, 196, 253, 213, 85, 83, 254, 91, - 163, 168, 90, 234, 43, 24, 191, 213, 190, 172, 156, 235, 17, 126, 59, 49, 142, 68, 120, 75, 220, 7, 166, 84, 90, 68, - 72, 194, 139, 180, 136, 25, 58, 46, 103, 45, 188, 25, 5, 0, 0, + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 75, 10, 132, 48, 12, 125, 177, 163, 35, 179, 154, 35, 8, 51, 7, 232, 204, + 9, 188, 139, 184, 83, 116, 233, 241, 173, 152, 98, 12, 213, 141, 21, 244, 65, 232, 39, 175, 233, 35, 73, 155, 3, 32, + 204, 48, 206, 18, 158, 19, 175, 37, 60, 175, 228, 209, 30, 195, 143, 226, 197, 178, 103, 105, 76, 110, 160, 209, 156, + 160, 209, 247, 195, 69, 235, 29, 179, 46, 81, 243, 103, 2, 239, 231, 225, 44, 117, 150, 241, 250, 201, 99, 206, 251, + 96, 95, 161, 242, 14, 193, 243, 40, 162, 105, 253, 219, 12, 75, 47, 146, 186, 251, 37, 116, 86, 93, 219, 55, 245, 96, + 20, 85, 75, 253, 136, 249, 87, 249, 105, 231, 220, 4, 249, 237, 132, 56, 20, 224, 109, 113, 223, 88, 82, 153, 34, 64, + 34, 14, 164, 69, 172, 48, 2, 23, 243, 6, 31, 25, 5, 0, 0, ]); export const initialWitnessMap: WitnessMap = new Map([ [1, '0x0000000000000000000000000000000000000000000000000000000000000001'], diff --git a/noir/acvm-repo/acvm_js/test/shared/foreign_call.ts b/noir/acvm-repo/acvm_js/test/shared/foreign_call.ts index cfa7c679b18..0be8937b57d 100644 --- a/noir/acvm-repo/acvm_js/test/shared/foreign_call.ts +++ b/noir/acvm-repo/acvm_js/test/shared/foreign_call.ts @@ -2,10 +2,10 @@ import { WitnessMap } from '@noir-lang/acvm_js'; // See `simple_brillig_foreign_call` integration test in `acir/tests/test_program_serialization.rs`. export const bytecode = Uint8Array.from([ - 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 65, 10, 192, 32, 12, 4, 215, 148, 150, 246, 212, 175, 216, 31, 244, 51, - 61, 244, 226, 65, 196, 247, 171, 24, 33, 136, 122, 209, 129, 144, 176, 132, 101, 247, 4, 160, 144, 217, 196, 45, 41, - 218, 203, 91, 207, 241, 168, 117, 94, 90, 230, 37, 238, 144, 216, 27, 249, 11, 87, 156, 131, 239, 223, 248, 207, 186, - 81, 235, 150, 67, 173, 221, 189, 95, 18, 34, 97, 64, 0, 116, 135, 40, 214, 136, 1, 0, 0, + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 177, 10, 192, 32, 16, 67, 227, 21, 74, 233, 212, 79, 177, 127, 208, 159, + 233, 224, 226, 32, 226, 247, 139, 168, 16, 68, 93, 244, 45, 119, 228, 142, 144, 92, 0, 20, 50, 7, 237, 76, 213, 190, + 50, 245, 26, 175, 218, 231, 165, 57, 175, 148, 14, 137, 179, 147, 191, 114, 211, 221, 216, 240, 59, 63, 107, 221, 115, + 104, 181, 103, 244, 43, 36, 10, 38, 68, 108, 25, 253, 238, 136, 1, 0, 0, ]); export const initialWitnessMap: WitnessMap = new Map([ [1, '0x0000000000000000000000000000000000000000000000000000000000000005'], diff --git a/noir/acvm-repo/brillig/src/opcodes.rs b/noir/acvm-repo/brillig/src/opcodes.rs index 06c8fdc04eb..51df1f90941 100644 --- a/noir/acvm-repo/brillig/src/opcodes.rs +++ b/noir/acvm-repo/brillig/src/opcodes.rs @@ -98,6 +98,11 @@ pub enum BrilligOpcode { lhs: MemoryAddress, rhs: MemoryAddress, }, + Cast { + destination: MemoryAddress, + source: MemoryAddress, + bit_size: u32, + }, JumpIfNot { condition: MemoryAddress, location: Label, diff --git a/noir/acvm-repo/brillig_vm/src/lib.rs b/noir/acvm-repo/brillig_vm/src/lib.rs index 4292a623cdb..081ecd33cb6 100644 --- a/noir/acvm-repo/brillig_vm/src/lib.rs +++ b/noir/acvm-repo/brillig_vm/src/lib.rs @@ -183,6 +183,12 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { self.increment_program_counter() } } + Opcode::Cast { destination: destination_address, source: source_address, bit_size } => { + let source_value = self.memory.read(*source_address); + let casted_value = self.cast(*bit_size, source_value); + self.memory.write(*destination_address, casted_value); + self.increment_program_counter() + } Opcode::Jump { location: destination } => self.set_program_counter(*destination), Opcode::JumpIf { condition, location: destination } => { // Check if condition is true @@ -501,6 +507,13 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { .write(result, FieldElement::from_be_bytes_reduce(&result_value.to_bytes_be()).into()); Ok(()) } + + /// Casts a value to a different bit size. + fn cast(&self, bit_size: u32, value: Value) -> Value { + let lhs_big = BigUint::from_bytes_be(&value.to_field().to_be_bytes()); + let mask = BigUint::from(2_u32).pow(bit_size) - 1_u32; + FieldElement::from_be_bytes_reduce(&(lhs_big & mask).to_bytes_be()).into() + } } pub(crate) struct DummyBlackBoxSolver; @@ -698,6 +711,40 @@ mod tests { assert_eq!(output_value, Value::from(false)); } + #[test] + fn cast_opcode() { + let calldata = vec![Value::from((2_u128.pow(32)) - 1)]; + + let opcodes = &[ + Opcode::CalldataCopy { + destination_address: MemoryAddress::from(0), + size: 1, + offset: 0, + }, + Opcode::Cast { + destination: MemoryAddress::from(1), + source: MemoryAddress::from(0), + bit_size: 8, + }, + Opcode::Stop { return_data_offset: 1, return_data_size: 1 }, + ]; + let mut vm = VM::new(calldata, opcodes, vec![], &DummyBlackBoxSolver); + + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::Finished { return_data_offset: 1, return_data_size: 1 }); + + let VM { memory, .. } = vm; + + let casted_value = memory.read(MemoryAddress::from(1)); + assert_eq!(casted_value, Value::from(2_u128.pow(8) - 1)); + } + #[test] fn mov_opcode() { let calldata = vec![Value::from(1u128), Value::from(2u128), Value::from(3u128)]; diff --git a/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index b19ee97bc4e..e0630655253 100644 --- a/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -24,7 +24,7 @@ use num_bigint::BigUint; use super::brillig_black_box::convert_black_box_call; use super::brillig_block_variables::BlockVariables; -use super::brillig_fn::FunctionContext; +use super::brillig_fn::{get_bit_size_from_ssa_type, FunctionContext}; /// Generate the compilation artifacts for compiling a function into brillig bytecode. pub(crate) struct BrilligBlock<'block> { @@ -87,16 +87,6 @@ impl<'block> BrilligBlock<'block> { self.convert_ssa_terminator(terminator_instruction, dfg); } - fn get_bit_size_from_ssa_type(typ: &Type) -> u32 { - match typ { - Type::Numeric(num_type) => match num_type { - NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size } => *bit_size, - NumericType::NativeField => FieldElement::max_num_bits(), - }, - _ => unreachable!("ICE bitwise not on a non numeric type"), - } - } - /// Creates a unique global label for a block. /// /// This uses the current functions's function ID and the block ID @@ -324,7 +314,7 @@ impl<'block> BrilligBlock<'block> { dfg.instruction_results(instruction_id)[0], dfg, ); - let bit_size = Self::get_bit_size_from_ssa_type(&dfg.type_of_value(*value)); + let bit_size = get_bit_size_from_ssa_type(&dfg.type_of_value(*value)); self.brillig_context.not_instruction(condition_register, bit_size, result_register); } Instruction::Call { func, arguments } => match &dfg[*func] { @@ -547,7 +537,7 @@ impl<'block> BrilligBlock<'block> { *bit_size, ); } - Instruction::Cast(value, _) => { + Instruction::Cast(value, typ) => { let result_ids = dfg.instruction_results(instruction_id); let destination_register = self.variables.define_register_variable( self.function_context, @@ -556,7 +546,7 @@ impl<'block> BrilligBlock<'block> { dfg, ); let source_register = self.convert_ssa_register_value(*value, dfg); - self.convert_cast(destination_register, source_register); + self.convert_cast(destination_register, source_register, typ); } Instruction::ArrayGet { array, index } => { let result_ids = dfg.instruction_results(instruction_id); @@ -1136,11 +1126,11 @@ impl<'block> BrilligBlock<'block> { /// Converts an SSA cast to a sequence of Brillig opcodes. /// Casting is only necessary when shrinking the bit size of a numeric value. - fn convert_cast(&mut self, destination: MemoryAddress, source: MemoryAddress) { + fn convert_cast(&mut self, destination: MemoryAddress, source: MemoryAddress, typ: &Type) { // We assume that `source` is a valid `target_type` as it's expected that a truncate instruction was emitted // to ensure this is the case. - self.brillig_context.mov_instruction(destination, source); + self.brillig_context.cast_instruction(destination, source, get_bit_size_from_ssa_type(typ)); } /// Converts the Binary instruction into a sequence of Brillig opcodes. @@ -1186,7 +1176,7 @@ impl<'block> BrilligBlock<'block> { self.brillig_context.const_instruction( register_index, (*constant).into(), - Self::get_bit_size_from_ssa_type(typ), + get_bit_size_from_ssa_type(typ), ); new_variable } diff --git a/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs b/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs index 026def4ef11..e96a756a9ee 100644 --- a/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs +++ b/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs @@ -1,16 +1,17 @@ +use acvm::FieldElement; use iter_extended::vecmap; use crate::{ brillig::brillig_ir::{ artifact::{BrilligParameter, Label}, brillig_variable::BrilligVariable, - BrilligContext, + BrilligContext, BRILLIG_MEMORY_ADDRESSING_BIT_SIZE, }, ssa::ir::{ basic_block::BasicBlockId, function::{Function, FunctionId}, post_order::PostOrder, - types::Type, + types::{NumericType, Type}, value::ValueId, }, }; @@ -72,7 +73,9 @@ impl FunctionContext { fn ssa_type_to_parameter(typ: &Type) -> BrilligParameter { match typ { - Type::Numeric(_) | Type::Reference(_) => BrilligParameter::Simple, + Type::Numeric(_) | Type::Reference(_) => { + BrilligParameter::Simple(get_bit_size_from_ssa_type(typ)) + } Type::Array(item_type, size) => BrilligParameter::Array( vecmap(item_type.iter(), |item_typ| { FunctionContext::ssa_type_to_parameter(item_typ) @@ -110,3 +113,14 @@ impl FunctionContext { .collect() } } + +pub(crate) fn get_bit_size_from_ssa_type(typ: &Type) -> u32 { + match typ { + Type::Numeric(num_type) => match num_type { + NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size } => *bit_size, + NumericType::NativeField => FieldElement::max_num_bits(), + }, + Type::Reference(_) => BRILLIG_MEMORY_ADDRESSING_BIT_SIZE, + _ => unreachable!("ICE bitwise not on a non numeric type"), + } +} diff --git a/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs b/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs index 3fed8ee91d9..933396be0cb 100644 --- a/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs +++ b/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs @@ -339,7 +339,7 @@ mod tests { use crate::brillig::brillig_ir::tests::{ create_and_run_vm, create_context, create_entry_point_bytecode, }; - use crate::brillig::brillig_ir::BrilligContext; + use crate::brillig::brillig_ir::{BrilligContext, BRILLIG_MEMORY_ADDRESSING_BIT_SIZE}; use crate::ssa::function_builder::FunctionBuilder; use crate::ssa::ir::function::RuntimeType; use crate::ssa::ir::map::Id; @@ -378,11 +378,16 @@ mod tests { expected_return: Vec, ) { let arguments = vec![ - BrilligParameter::Array(vec![BrilligParameter::Simple], array.len()), - BrilligParameter::Simple, + BrilligParameter::Array( + vec![BrilligParameter::Simple(BRILLIG_MEMORY_ADDRESSING_BIT_SIZE)], + array.len(), + ), + BrilligParameter::Simple(BRILLIG_MEMORY_ADDRESSING_BIT_SIZE), ]; - let returns = - vec![BrilligParameter::Array(vec![BrilligParameter::Simple], array.len() + 1)]; + let returns = vec![BrilligParameter::Array( + vec![BrilligParameter::Simple(BRILLIG_MEMORY_ADDRESSING_BIT_SIZE)], + array.len() + 1, + )]; let (_, mut function_context, mut context) = create_test_environment(); @@ -466,11 +471,16 @@ mod tests { expected_return_array: Vec, expected_return_item: Value, ) { - let arguments = - vec![BrilligParameter::Array(vec![BrilligParameter::Simple], array.len())]; + let arguments = vec![BrilligParameter::Array( + vec![BrilligParameter::Simple(BRILLIG_MEMORY_ADDRESSING_BIT_SIZE)], + array.len(), + )]; let returns = vec![ - BrilligParameter::Array(vec![BrilligParameter::Simple], array.len() - 1), - BrilligParameter::Simple, + BrilligParameter::Array( + vec![BrilligParameter::Simple(BRILLIG_MEMORY_ADDRESSING_BIT_SIZE)], + array.len() - 1, + ), + BrilligParameter::Simple(BRILLIG_MEMORY_ADDRESSING_BIT_SIZE), ]; let (_, mut function_context, mut context) = create_test_environment(); @@ -548,12 +558,17 @@ mod tests { expected_return: Vec, ) { let arguments = vec![ - BrilligParameter::Array(vec![BrilligParameter::Simple], array.len()), - BrilligParameter::Simple, - BrilligParameter::Simple, + BrilligParameter::Array( + vec![BrilligParameter::Simple(BRILLIG_MEMORY_ADDRESSING_BIT_SIZE)], + array.len(), + ), + BrilligParameter::Simple(BRILLIG_MEMORY_ADDRESSING_BIT_SIZE), + BrilligParameter::Simple(BRILLIG_MEMORY_ADDRESSING_BIT_SIZE), ]; - let returns = - vec![BrilligParameter::Array(vec![BrilligParameter::Simple], array.len() + 1)]; + let returns = vec![BrilligParameter::Array( + vec![BrilligParameter::Simple(BRILLIG_MEMORY_ADDRESSING_BIT_SIZE)], + array.len() + 1, + )]; let (_, mut function_context, mut context) = create_test_environment(); @@ -660,12 +675,18 @@ mod tests { expected_removed_item: Value, ) { let arguments = vec![ - BrilligParameter::Array(vec![BrilligParameter::Simple], array.len()), - BrilligParameter::Simple, + BrilligParameter::Array( + vec![BrilligParameter::Simple(BRILLIG_MEMORY_ADDRESSING_BIT_SIZE)], + array.len(), + ), + BrilligParameter::Simple(BRILLIG_MEMORY_ADDRESSING_BIT_SIZE), ]; let returns = vec![ - BrilligParameter::Array(vec![BrilligParameter::Simple], array.len() - 1), - BrilligParameter::Simple, + BrilligParameter::Array( + vec![BrilligParameter::Simple(BRILLIG_MEMORY_ADDRESSING_BIT_SIZE)], + array.len() - 1, + ), + BrilligParameter::Simple(BRILLIG_MEMORY_ADDRESSING_BIT_SIZE), ]; let (_, mut function_context, mut context) = create_test_environment(); diff --git a/noir/compiler/noirc_evaluator/src/brillig/brillig_ir.rs b/noir/compiler/noirc_evaluator/src/brillig/brillig_ir.rs index 035b9b01500..a0e5ca080bd 100644 --- a/noir/compiler/noirc_evaluator/src/brillig/brillig_ir.rs +++ b/noir/compiler/noirc_evaluator/src/brillig/brillig_ir.rs @@ -501,6 +501,17 @@ impl BrilligContext { self.push_opcode(BrilligOpcode::Mov { destination, source }); } + /// Cast truncates the value to the given bit size and converts the type of the value in memory to that bit size. + pub(crate) fn cast_instruction( + &mut self, + destination: MemoryAddress, + source: MemoryAddress, + bit_size: u32, + ) { + self.debug_show.cast_instruction(destination, source, bit_size); + self.push_opcode(BrilligOpcode::Cast { destination, source, bit_size }); + } + /// Processes a binary instruction according `operation`. /// /// This method will compute lhs rhs diff --git a/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs b/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs index 437774da157..4ef8c9d1dfc 100644 --- a/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs +++ b/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs @@ -6,8 +6,11 @@ use crate::ssa::ir::dfg::CallStack; /// Represents a parameter or a return value of a function. #[derive(Debug, Clone)] pub(crate) enum BrilligParameter { - Simple, + /// A simple parameter or return value. Holds the bit size of the parameter. + Simple(u32), + /// An array parameter or return value. Holds the type of an array item and its size. Array(Vec, usize), + /// A slice parameter or return value. Holds the type of a slice item. Slice(Vec), } diff --git a/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs b/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs index 280cd12e91d..6ee2e0c0b9f 100644 --- a/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs +++ b/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs @@ -172,6 +172,22 @@ impl DebugShow { debug_println!(self.enable_debug_trace, " MOV {}, {}", destination, source); } + /// Emits a `cast` instruction. + pub(crate) fn cast_instruction( + &self, + destination: MemoryAddress, + source: MemoryAddress, + bit_size: u32, + ) { + debug_println!( + self.enable_debug_trace, + " CAST {}, {} as u{}", + destination, + source, + bit_size + ); + } + /// Processes a binary instruction according `operation`. pub(crate) fn binary_instruction( &self, diff --git a/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs b/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs index fc4ac36d7fd..0eb4c8c31bd 100644 --- a/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs +++ b/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs @@ -5,7 +5,10 @@ use super::{ registers::BrilligRegistersContext, BrilligContext, ReservedRegisters, BRILLIG_MEMORY_ADDRESSING_BIT_SIZE, }; -use acvm::acir::brillig::{MemoryAddress, Opcode as BrilligOpcode}; +use acvm::{ + acir::brillig::{MemoryAddress, Opcode as BrilligOpcode}, + FieldElement, +}; pub(crate) const MAX_STACK_SIZE: usize = 1024; @@ -52,11 +55,7 @@ impl BrilligContext { }); // Copy calldata - self.push_opcode(BrilligOpcode::CalldataCopy { - destination_address: MemoryAddress(MAX_STACK_SIZE), - size: calldata_size, - offset: 0, - }); + self.copy_and_cast_calldata(arguments); // Allocate the variables for every argument: let mut current_calldata_pointer = MAX_STACK_SIZE; @@ -64,7 +63,7 @@ impl BrilligContext { let mut argument_variables: Vec<_> = arguments .iter() .map(|argument| match argument { - BrilligParameter::Simple => { + BrilligParameter::Simple(_) => { let simple_address = self.allocate_register(); let var = BrilligVariable::Simple(simple_address); self.mov_instruction(simple_address, MemoryAddress(current_calldata_pointer)); @@ -107,10 +106,40 @@ impl BrilligContext { } } + fn copy_and_cast_calldata(&mut self, arguments: &[BrilligParameter]) { + let calldata_size = BrilligContext::flattened_tuple_size(arguments); + self.push_opcode(BrilligOpcode::CalldataCopy { + destination_address: MemoryAddress(MAX_STACK_SIZE), + size: calldata_size, + offset: 0, + }); + + fn flat_bit_sizes(param: &BrilligParameter) -> Box + '_> { + match param { + BrilligParameter::Simple(bit_size) => Box::new(std::iter::once(*bit_size)), + BrilligParameter::Array(item_types, item_count) => Box::new( + (0..*item_count).flat_map(move |_| item_types.iter().flat_map(flat_bit_sizes)), + ), + BrilligParameter::Slice(..) => unimplemented!("Unsupported slices as parameter"), + } + } + + for (i, bit_size) in arguments.iter().flat_map(flat_bit_sizes).enumerate() { + // Calldatacopy tags everything with field type, so when downcast when necessary + if bit_size < FieldElement::max_num_bits() { + self.push_opcode(BrilligOpcode::Cast { + destination: MemoryAddress(MAX_STACK_SIZE + i), + source: MemoryAddress(MAX_STACK_SIZE + i), + bit_size, + }); + } + } + } + /// Computes the size of a parameter if it was flattened fn flattened_size(param: &BrilligParameter) -> usize { match param { - BrilligParameter::Simple => 1, + BrilligParameter::Simple(_) => 1, BrilligParameter::Array(item_types, item_count) => { let item_size: usize = item_types.iter().map(BrilligContext::flattened_size).sum(); item_count * item_size @@ -128,7 +157,7 @@ impl BrilligContext { /// Computes the size of a parameter if it was flattened fn has_nested_arrays(tuple: &[BrilligParameter]) -> bool { - tuple.iter().any(|param| !matches!(param, BrilligParameter::Simple)) + tuple.iter().any(|param| !matches!(param, BrilligParameter::Simple(_))) } /// Deflatten an array by recursively allocating nested arrays and copying the plain values. @@ -165,7 +194,7 @@ impl BrilligContext { self.make_usize_constant((target_item_base_index + subitem_index).into()); match subitem { - BrilligParameter::Simple => { + BrilligParameter::Simple(_) => { self.array_get( flattened_array_pointer, source_index, @@ -250,7 +279,7 @@ impl BrilligContext { let returned_variables: Vec<_> = return_parameters .iter() .map(|return_parameter| match return_parameter { - BrilligParameter::Simple => BrilligVariable::Simple(self.allocate_register()), + BrilligParameter::Simple(_) => BrilligVariable::Simple(self.allocate_register()), BrilligParameter::Array(item_types, item_count) => { BrilligVariable::BrilligArray(BrilligArray { pointer: self.allocate_register(), @@ -272,7 +301,7 @@ impl BrilligContext { for (return_param, returned_variable) in return_parameters.iter().zip(&returned_variables) { match return_param { - BrilligParameter::Simple => { + BrilligParameter::Simple(_) => { self.mov_instruction( MemoryAddress(return_data_index), returned_variable.extract_register(), @@ -330,7 +359,7 @@ impl BrilligContext { self.make_usize_constant((target_item_base_index + target_offset).into()); match subitem { - BrilligParameter::Simple => { + BrilligParameter::Simple(_) => { self.array_get( deflattened_array_pointer, source_index, @@ -439,12 +468,12 @@ mod tests { ]; let arguments = vec![BrilligParameter::Array( vec![ - BrilligParameter::Array(vec![BrilligParameter::Simple], 2), - BrilligParameter::Simple, + BrilligParameter::Array(vec![BrilligParameter::Simple(8)], 2), + BrilligParameter::Simple(8), ], 2, )]; - let returns = vec![BrilligParameter::Simple]; + let returns = vec![BrilligParameter::Simple(8)]; let mut context = create_context(); @@ -477,8 +506,8 @@ mod tests { ]; let array_param = BrilligParameter::Array( vec![ - BrilligParameter::Array(vec![BrilligParameter::Simple], 2), - BrilligParameter::Simple, + BrilligParameter::Array(vec![BrilligParameter::Simple(8)], 2), + BrilligParameter::Simple(8), ], 2, );