diff --git a/.noir-sync-commit b/.noir-sync-commit index 5fe0fbedd16..e7c73939ac6 100644 --- a/.noir-sync-commit +++ b/.noir-sync-commit @@ -1 +1 @@ -95d4d133d1eb5e0eb44cd928d8183d890e970a13 +b541e793e20fa3c991e0328ec2ff7926bdcdfd45 diff --git a/avm-transpiler/src/transpile.rs b/avm-transpiler/src/transpile.rs index 314dc77414b..11007c9497e 100644 --- a/avm-transpiler/src/transpile.rs +++ b/avm-transpiler/src/transpile.rs @@ -968,6 +968,34 @@ fn handle_black_box_function(avm_instrs: &mut Vec, operation: &B ..Default::default() }); } + BlackBoxOp::ToRadix { + input, + radix, + output, + } => { + let num_limbs = output.size; + let input_offset = input.0; + let output_offset = output.pointer.0; + assert!(radix <= &256u32, "Radix must be less than or equal to 256"); + + avm_instrs.push(AvmInstruction { + opcode: AvmOpcode::TORADIXLE, + indirect: Some(FIRST_OPERAND_INDIRECT), + tag: None, + operands: vec![ + AvmOperand::U32 { + value: input_offset as u32, + }, + AvmOperand::U32 { + value: output_offset as u32, + }, + AvmOperand::U32 { value: *radix }, + AvmOperand::U32 { + value: num_limbs as u32, + }, + ], + }) + } _ => panic!("Transpiler doesn't know how to process {:?}", operation), } } diff --git a/barretenberg/acir_tests/flows/all_cmds.sh b/barretenberg/acir_tests/flows/all_cmds.sh index a65159351ed..c912613302c 100755 --- a/barretenberg/acir_tests/flows/all_cmds.sh +++ b/barretenberg/acir_tests/flows/all_cmds.sh @@ -2,7 +2,7 @@ set -eu VFLAG=${VERBOSE:+-v} -BFLAG="-b ./target/acir.gz" +BFLAG="-b ./target/program.json" FLAGS="-c $CRS_PATH $VFLAG" # Test we can perform the proof/verify flow. @@ -19,4 +19,4 @@ $BIN contract -k vk $BFLAG -o - | grep "Verification Key Hash" > /dev/null OUTPUT=$($BIN proof_as_fields -k vk -p proof -o - | jq .) [ -n "$OUTPUT" ] || exit 1 OUTPUT=$($BIN vk_as_fields -k vk -o - | jq .) -[ -n "$OUTPUT" ] || exit 1 \ No newline at end of file +[ -n "$OUTPUT" ] || exit 1 diff --git a/barretenberg/acir_tests/flows/prove_and_verify.sh b/barretenberg/acir_tests/flows/prove_and_verify.sh index 091a6d57946..4d905538991 100755 --- a/barretenberg/acir_tests/flows/prove_and_verify.sh +++ b/barretenberg/acir_tests/flows/prove_and_verify.sh @@ -5,4 +5,4 @@ VFLAG=${VERBOSE:+-v} # This is the fastest flow, because it only generates pk/vk once, gate count once, etc. # It may not catch all class of bugs. -$BIN prove_and_verify $VFLAG -c $CRS_PATH -b ./target/acir.gz \ No newline at end of file +$BIN prove_and_verify $VFLAG -c $CRS_PATH -b ./target/program.json diff --git a/barretenberg/acir_tests/flows/prove_and_verify_goblin.sh b/barretenberg/acir_tests/flows/prove_and_verify_goblin.sh index 68a003d685b..23340df80a1 100755 --- a/barretenberg/acir_tests/flows/prove_and_verify_goblin.sh +++ b/barretenberg/acir_tests/flows/prove_and_verify_goblin.sh @@ -3,7 +3,7 @@ set -eu VFLAG=${VERBOSE:+-v} -$BIN prove_and_verify_goblin $VFLAG -c $CRS_PATH -b ./target/acir.gz +$BIN prove_and_verify_goblin $VFLAG -c $CRS_PATH -b ./target/program.json # This command can be used to run all of the tests in sequence with the debugger -# lldb-16 -o run -b -- $BIN prove_and_verify_goblin $VFLAG -c $CRS_PATH -b ./target/acir.gz \ No newline at end of file +# lldb-16 -o run -b -- $BIN prove_and_verify_goblin $VFLAG -c $CRS_PATH -b ./target/program.json diff --git a/barretenberg/acir_tests/flows/prove_and_verify_goblin_ultra_honk.sh b/barretenberg/acir_tests/flows/prove_and_verify_goblin_ultra_honk.sh index a8a72924898..2d79f15c212 100755 --- a/barretenberg/acir_tests/flows/prove_and_verify_goblin_ultra_honk.sh +++ b/barretenberg/acir_tests/flows/prove_and_verify_goblin_ultra_honk.sh @@ -3,4 +3,4 @@ set -eu VFLAG=${VERBOSE:+-v} -$BIN prove_and_verify_goblin_ultra_honk $VFLAG -c $CRS_PATH -b ./target/acir.gz \ No newline at end of file +$BIN prove_and_verify_goblin_ultra_honk $VFLAG -c $CRS_PATH -b ./target/program.json diff --git a/barretenberg/acir_tests/flows/prove_and_verify_ultra_honk.sh b/barretenberg/acir_tests/flows/prove_and_verify_ultra_honk.sh index 7b6f0384796..16f2fd7f398 100755 --- a/barretenberg/acir_tests/flows/prove_and_verify_ultra_honk.sh +++ b/barretenberg/acir_tests/flows/prove_and_verify_ultra_honk.sh @@ -3,4 +3,4 @@ set -eu VFLAG=${VERBOSE:+-v} -$BIN prove_and_verify_ultra_honk $VFLAG -c $CRS_PATH -b ./target/acir.gz \ No newline at end of file +$BIN prove_and_verify_ultra_honk $VFLAG -c $CRS_PATH -b ./target/program.json diff --git a/barretenberg/acir_tests/flows/prove_and_verify_ultra_honk_program.sh b/barretenberg/acir_tests/flows/prove_and_verify_ultra_honk_program.sh index 53c9d08e1a6..65a6e400226 100755 --- a/barretenberg/acir_tests/flows/prove_and_verify_ultra_honk_program.sh +++ b/barretenberg/acir_tests/flows/prove_and_verify_ultra_honk_program.sh @@ -3,4 +3,4 @@ set -eu VFLAG=${VERBOSE:+-v} -$BIN prove_and_verify_ultra_honk_program $VFLAG -c $CRS_PATH -b ./target/acir.gz \ No newline at end of file +$BIN prove_and_verify_ultra_honk_program $VFLAG -c $CRS_PATH -b ./target/program.json diff --git a/barretenberg/acir_tests/flows/prove_then_verify.sh b/barretenberg/acir_tests/flows/prove_then_verify.sh index 9c35b981a1a..08d8ea21057 100755 --- a/barretenberg/acir_tests/flows/prove_then_verify.sh +++ b/barretenberg/acir_tests/flows/prove_then_verify.sh @@ -2,7 +2,7 @@ set -eu VFLAG=${VERBOSE:+-v} -BFLAG="-b ./target/acir.gz" +BFLAG="-b ./target/program.json" FLAGS="-c $CRS_PATH $VFLAG" # Test we can perform the proof/verify flow. diff --git a/barretenberg/acir_tests/flows/prove_then_verify_goblin_ultra_honk.sh b/barretenberg/acir_tests/flows/prove_then_verify_goblin_ultra_honk.sh index 9586e6841eb..fa33cefe5d8 100755 --- a/barretenberg/acir_tests/flows/prove_then_verify_goblin_ultra_honk.sh +++ b/barretenberg/acir_tests/flows/prove_then_verify_goblin_ultra_honk.sh @@ -2,7 +2,7 @@ set -eu VFLAG=${VERBOSE:+-v} -BFLAG="-b ./target/acir.gz" +BFLAG="-b ./target/program.json" FLAGS="-c $CRS_PATH $VFLAG" # Test we can perform the proof/verify flow. diff --git a/barretenberg/acir_tests/flows/prove_then_verify_ultra_honk.sh b/barretenberg/acir_tests/flows/prove_then_verify_ultra_honk.sh index bfb20c27cf0..fd559e256c6 100755 --- a/barretenberg/acir_tests/flows/prove_then_verify_ultra_honk.sh +++ b/barretenberg/acir_tests/flows/prove_then_verify_ultra_honk.sh @@ -2,7 +2,7 @@ set -eu VFLAG=${VERBOSE:+-v} -BFLAG="-b ./target/acir.gz" +BFLAG="-b ./target/program.json" FLAGS="-c $CRS_PATH $VFLAG" # Test we can perform the proof/verify flow. diff --git a/barretenberg/acir_tests/flows/sol.sh b/barretenberg/acir_tests/flows/sol.sh index d95c9039eea..66f7833fb70 100755 --- a/barretenberg/acir_tests/flows/sol.sh +++ b/barretenberg/acir_tests/flows/sol.sh @@ -8,7 +8,7 @@ export PROOF_AS_FIELDS="$(pwd)/proof_fields.json" $BIN prove -o proof $BIN write_vk -o vk $BIN proof_as_fields -k vk -c $CRS_PATH -p $PROOF -$BIN contract -k vk -c $CRS_PATH -b ./target/acir.gz -o Key.sol +$BIN contract -k vk -c $CRS_PATH -b ./target/program.json -o Key.sol # Export the paths to the environment variables for the js test runner export KEY_PATH="$(pwd)/Key.sol" @@ -20,4 +20,4 @@ export BASE_PATH=$(realpath "../../../sol/src/ultra/BaseUltraVerifier.sol") # index.js will start an anvil, on a random port # Deploy the verifier then send a test transaction export TEST_NAME=$(basename $(pwd)) -node ../../sol-test/src/index.js \ No newline at end of file +node ../../sol-test/src/index.js diff --git a/barretenberg/acir_tests/flows/write_contract.sh b/barretenberg/acir_tests/flows/write_contract.sh index 4f483395c58..52669901035 100755 --- a/barretenberg/acir_tests/flows/write_contract.sh +++ b/barretenberg/acir_tests/flows/write_contract.sh @@ -4,4 +4,4 @@ set -eu export TEST_NAME=$(basename $(pwd)) $BIN write_vk -o vk -$BIN contract -k vk -c $CRS_PATH -b ./target/acir.gz -o $TEST_NAME.sol +$BIN contract -k vk -c $CRS_PATH -b ./target/program.json -o $TEST_NAME.sol diff --git a/barretenberg/acir_tests/gen_inner_proof_inputs.sh b/barretenberg/acir_tests/gen_inner_proof_inputs.sh index ade57bcea4f..ea3a2ced8e0 100755 --- a/barretenberg/acir_tests/gen_inner_proof_inputs.sh +++ b/barretenberg/acir_tests/gen_inner_proof_inputs.sh @@ -36,7 +36,7 @@ $BIN vk_as_fields $VFLAG -c $CRS_PATH echo "Generate proof to file..." [ -d "$PROOF_DIR" ] || mkdir $PWD/proofs [ -e "$PROOF_PATH" ] || touch $PROOF_PATH -$BIN prove $VFLAG -c $CRS_PATH -b ./target/acir.gz -o "./proofs/$PROOF_NAME" $RFLAG +$BIN prove $VFLAG -c $CRS_PATH -b ./target/program.json -o "./proofs/$PROOF_NAME" $RFLAG echo "Write proof as fields for recursion..." $BIN proof_as_fields $VFLAG -c $CRS_PATH -p "./proofs/$PROOF_NAME" diff --git a/barretenberg/acir_tests/headless-test/src/index.ts b/barretenberg/acir_tests/headless-test/src/index.ts index 1902e510e87..bb906a6db2d 100644 --- a/barretenberg/acir_tests/headless-test/src/index.ts +++ b/barretenberg/acir_tests/headless-test/src/index.ts @@ -38,9 +38,17 @@ function formatAndPrintLog(message: string): void { } const readBytecodeFile = (path: string): Uint8Array => { - const data = fs.readFileSync(path); - const buffer = gunzipSync(data); - return buffer; + const extension = path.substring(path.lastIndexOf('.') + 1); + + if (extension == 'json') { + const encodedCircuit = JSON.parse(fs.readFileSync(path, 'utf8')); + const decompressed = gunzipSync(Uint8Array.from(atob(encodedCircuit.bytecode), c => c.charCodeAt(0))); + return decompressed; + } + + const encodedCircuit = fs.readFileSync(path); + const decompressed = gunzipSync(encodedCircuit); + return decompressed; }; const readWitnessFile = (path: string): Uint8Array => { diff --git a/barretenberg/acir_tests/run_acir_tests.sh b/barretenberg/acir_tests/run_acir_tests.sh index 88189a43438..0e360a89551 100755 --- a/barretenberg/acir_tests/run_acir_tests.sh +++ b/barretenberg/acir_tests/run_acir_tests.sh @@ -73,7 +73,7 @@ else continue fi - if [[ ! -f ./$TEST_NAME/target/acir.gz || ! -f ./$TEST_NAME/target/witness.gz ]]; then + if [[ ! -f ./$TEST_NAME/target/program.json || ! -f ./$TEST_NAME/target/witness.gz ]]; then echo -e "\033[33mSKIPPED\033[0m (uncompiled)" continue fi diff --git a/barretenberg/cpp/src/barretenberg/bb/get_bytecode.hpp b/barretenberg/cpp/src/barretenberg/bb/get_bytecode.hpp index 2c7af46cfca..84d06a74953 100644 --- a/barretenberg/cpp/src/barretenberg/bb/get_bytecode.hpp +++ b/barretenberg/cpp/src/barretenberg/bb/get_bytecode.hpp @@ -1,11 +1,20 @@ #pragma once #include "exec_pipe.hpp" +#include /** * We can assume for now we're running on a unix like system and use the following to extract the bytecode. */ inline std::vector get_bytecode(const std::string& bytecodePath) { + std::filesystem::path filePath = bytecodePath; + if (filePath.extension() == ".json") { + // Try reading json files as if they are a Nargo build artifact + std::string command = "jq -r '.bytecode' \"" + bytecodePath + "\" | base64 -d | gunzip -c"; + return exec_pipe(command); + } + + // For other extensions, assume file is a raw ACIR program std::string command = "gunzip -c \"" + bytecodePath + "\""; return exec_pipe(command); -} \ No newline at end of file +} diff --git a/barretenberg/cpp/src/barretenberg/bb/main.cpp b/barretenberg/cpp/src/barretenberg/bb/main.cpp index 674e4e67e92..9b1df668943 100644 --- a/barretenberg/cpp/src/barretenberg/bb/main.cpp +++ b/barretenberg/cpp/src/barretenberg/bb/main.cpp @@ -754,7 +754,7 @@ int main(int argc, char* argv[]) std::string command = args[0]; - std::string bytecode_path = get_option(args, "-b", "./target/acir.gz"); + std::string bytecode_path = get_option(args, "-b", "./target/program.json"); std::string witness_path = get_option(args, "-w", "./target/witness.gz"); std::string proof_path = get_option(args, "-p", "./proofs/proof"); std::string vk_path = get_option(args, "-k", "./target/vk"); diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp index 9fb0e2b3a35..683e4c62407 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp @@ -686,7 +686,6 @@ struct BlackBoxOp { Program::HeapVector inputs; Program::HeapArray iv; Program::HeapArray key; - Program::MemoryAddress length; Program::HeapVector outputs; friend bool operator==(const AES128Encrypt&, const AES128Encrypt&); @@ -896,6 +895,16 @@ struct BlackBoxOp { static Sha256Compression bincodeDeserialize(std::vector); }; + struct ToRadix { + Program::MemoryAddress input; + uint32_t radix; + Program::HeapArray output; + + friend bool operator==(const ToRadix&, const ToRadix&); + std::vector bincodeSerialize() const; + static ToRadix bincodeDeserialize(std::vector); + }; + std::variant + Sha256Compression, + ToRadix> value; friend bool operator==(const BlackBoxOp&, const BlackBoxOp&); @@ -3939,9 +3949,6 @@ inline bool operator==(const BlackBoxOp::AES128Encrypt& lhs, const BlackBoxOp::A if (!(lhs.key == rhs.key)) { return false; } - if (!(lhs.length == rhs.length)) { - return false; - } if (!(lhs.outputs == rhs.outputs)) { return false; } @@ -5141,6 +5148,63 @@ Program::BlackBoxOp::Sha256Compression serde::Deserializable BlackBoxOp::ToRadix::bincodeSerialize() const +{ + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); +} + +inline BlackBoxOp::ToRadix BlackBoxOp::ToRadix::bincodeDeserialize(std::vector input) +{ + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw_or_abort("Some input bytes were not read"); + } + return value; +} + +} // end of namespace Program + +template <> +template +void serde::Serializable::serialize(const Program::BlackBoxOp::ToRadix& obj, + Serializer& serializer) +{ + serde::Serializable::serialize(obj.input, serializer); + serde::Serializable::serialize(obj.radix, serializer); + serde::Serializable::serialize(obj.output, serializer); +} + +template <> +template +Program::BlackBoxOp::ToRadix serde::Deserializable::deserialize( + Deserializer& deserializer) +{ + Program::BlackBoxOp::ToRadix obj; + obj.input = serde::Deserializable::deserialize(deserializer); + obj.radix = serde::Deserializable::deserialize(deserializer); + obj.output = serde::Deserializable::deserialize(deserializer); + return obj; +} + +namespace Program { + inline bool operator==(const BlockId& lhs, const BlockId& rhs) { if (!(lhs.value == rhs.value)) { diff --git a/barretenberg/ts/src/main.ts b/barretenberg/ts/src/main.ts index 2060523db76..2ac32e23544 100755 --- a/barretenberg/ts/src/main.ts +++ b/barretenberg/ts/src/main.ts @@ -21,6 +21,14 @@ const MAX_CIRCUIT_SIZE = 2 ** 19; const threads = +process.env.HARDWARE_CONCURRENCY! || undefined; function getBytecode(bytecodePath: string) { + const extension = bytecodePath.substring(bytecodePath.lastIndexOf('.') + 1); + + if (extension == 'json') { + const encodedCircuit = JSON.parse(readFileSync(bytecodePath, 'utf8')); + const decompressed = gunzipSync(Buffer.from(encodedCircuit.bytecode, 'base64')); + return decompressed; + } + const encodedCircuit = readFileSync(bytecodePath); const decompressed = gunzipSync(encodedCircuit); return decompressed; diff --git a/docs/docs/migration_notes.md b/docs/docs/migration_notes.md index 902eae67a4d..7f990738590 100644 --- a/docs/docs/migration_notes.md +++ b/docs/docs/migration_notes.md @@ -12,9 +12,22 @@ Aztec is in full-speed development. Literally every version breaks compatibility The type signature for `SharedMutable` changed from `SharedMutable` to `SharedMutable`. The behavior is the same as before, except the delay can now be changed after deployment by calling `schedule_delay_change`. +### [Aztec.nr] get_public_key oracle replaced with get_ivpk_m + +When implementing changes according to a [new key scheme](https://yp-aztec.netlify.app/docs/addresses-and-keys/keys) we had to change oracles. +What used to be called encryption public key is now master incoming viewing public key. + +```diff +- use dep::aztec::oracles::get_public_key::get_public_key; ++ use dep::aztec::keys::getters::get_ivpk_m; + +- let encryption_pub_key = get_public_key(self.owner); ++ let ivpk_m = get_ivpk_m(context, self.owner); +``` + ## 0.38.0 -### [Aztec.nr] Emmiting encrypted logs +### [Aztec.nr] Emitting encrypted logs The `emit_encrypted_log` function is now a context method. diff --git a/docs/docs/protocol-specs/transactions/tx-object.md b/docs/docs/protocol-specs/transactions/tx-object.md index d60f2a0b184..9bec884dcfd 100644 --- a/docs/docs/protocol-specs/transactions/tx-object.md +++ b/docs/docs/protocol-specs/transactions/tx-object.md @@ -40,7 +40,6 @@ Output of the last iteration of the private kernel circuit. Includes _accumulate | Field | Type | Description | |-------|------|-------------| -| aggregationObject | AggregationObject | Aggregated proof of all the previous kernel iterations. | | newNoteHashes | Field[] | The new note hashes made in this transaction. | | newNullifiers | Field[] | The new nullifiers made in this transaction. | | nullifiedNoteHashes | Field[] | The note hashes which are nullified by a nullifier in the above list. | diff --git a/noir-projects/aztec-nr/address-note/src/address_note.nr b/noir-projects/aztec-nr/address-note/src/address_note.nr index d784a0fb7e7..002c56de934 100644 --- a/noir-projects/aztec-nr/address-note/src/address_note.nr +++ b/noir-projects/aztec-nr/address-note/src/address_note.nr @@ -1,7 +1,8 @@ use dep::aztec::{ + keys::getters::get_ivpk_m, protocol_types::{address::AztecAddress, traits::Empty, constants::GENERATOR_INDEX__NOTE_NULLIFIER}, note::{note_header::NoteHeader, note_interface::NoteInterface, utils::compute_note_hash_for_consumption}, - oracle::{unsafe_rand::unsafe_rand, nullifier_key::get_app_nullifier_secret_key, get_public_key::get_public_key}, + oracle::{unsafe_rand::unsafe_rand, nullifier_key::get_app_nullifier_secret_key}, context::PrivateContext, hash::poseidon2_hash }; @@ -40,13 +41,13 @@ impl NoteInterface for AddressNote { // Broadcasts the note as an encrypted log on L1. fn broadcast(self, context: &mut PrivateContext, slot: Field) { - let encryption_pub_key = get_public_key(self.owner); + let ivpk_m = get_ivpk_m(context, self.owner); // docs:start:encrypted context.emit_encrypted_log( (*context).this_address(), slot, Self::get_note_type_id(), - encryption_pub_key, + ivpk_m, self.serialize_content(), ); // docs:end:encrypted diff --git a/noir-projects/aztec-nr/aztec/src/context/private_context.nr b/noir-projects/aztec-nr/aztec/src/context/private_context.nr index f6fe853c123..f18b33f0f11 100644 --- a/noir-projects/aztec-nr/aztec/src/context/private_context.nr +++ b/noir-projects/aztec-nr/aztec/src/context/private_context.nr @@ -286,7 +286,7 @@ impl PrivateContext { contract_address: AztecAddress, storage_slot: Field, note_type_id: Field, - encryption_pub_key: GrumpkinPoint, + ivpk_m: GrumpkinPoint, preimage: [Field; N] ) where [Field; N]: LensForEncryptedLog { // TODO(1139): perform encryption in the circuit @@ -296,7 +296,7 @@ impl PrivateContext { contract_address, storage_slot, note_type_id, - encryption_pub_key, + ivpk_m, preimage, counter ); diff --git a/noir-projects/aztec-nr/aztec/src/encrypted_logs/body.nr b/noir-projects/aztec-nr/aztec/src/encrypted_logs/body.nr index 4393d9da16c..9f490c768e0 100644 --- a/noir-projects/aztec-nr/aztec/src/encrypted_logs/body.nr +++ b/noir-projects/aztec-nr/aztec/src/encrypted_logs/body.nr @@ -67,7 +67,7 @@ mod test { use crate::{ note::{note_header::NoteHeader, note_interface::NoteInterface, utils::compute_note_hash_for_consumption}, - oracle::{unsafe_rand::unsafe_rand, nullifier_key::get_app_nullifier_secret_key, get_public_key::get_public_key}, + oracle::{unsafe_rand::unsafe_rand, nullifier_key::get_app_nullifier_secret_key}, context::PrivateContext, hash::poseidon2_hash }; diff --git a/noir-projects/aztec-nr/aztec/src/keys/getters.nr b/noir-projects/aztec-nr/aztec/src/keys/getters.nr index b6fc2759fb7..0e531da1028 100644 --- a/noir-projects/aztec-nr/aztec/src/keys/getters.nr +++ b/noir-projects/aztec-nr/aztec/src/keys/getters.nr @@ -1,4 +1,7 @@ -use dep::protocol_types::{address::AztecAddress, constants::CANONICAL_KEY_REGISTRY_ADDRESS, grumpkin_point::GrumpkinPoint}; +use dep::protocol_types::{ + address::{AztecAddress, PublicKeysHash}, constants::CANONICAL_KEY_REGISTRY_ADDRESS, + grumpkin_point::GrumpkinPoint +}; use crate::{ context::PrivateContext, oracle::keys::get_public_keys_and_partial_address, state_vars::{ @@ -80,20 +83,15 @@ fn fetch_key_from_registry( fn fetch_and_constrain_keys(address: AztecAddress) -> [GrumpkinPoint; 4] { let (public_keys, partial_address) = get_public_keys_and_partial_address(address); - let nullifier_pub_key = public_keys[0]; - let incoming_pub_key = public_keys[1]; - let outgoing_pub_key = public_keys[2]; - let tagging_pub_key = public_keys[3]; + let npk_m = public_keys[0]; + let ivpk_m = public_keys[1]; + let ovpk_m = public_keys[2]; + let tpk_m = public_keys[3]; - let computed_address = AztecAddress::compute_from_public_keys_and_partial_address( - nullifier_pub_key, - incoming_pub_key, - outgoing_pub_key, - tagging_pub_key, - partial_address - ); + let public_keys_hash = PublicKeysHash::compute(npk_m, ivpk_m, ovpk_m, tpk_m); + let computed_address = AztecAddress::compute(public_keys_hash, partial_address); assert(computed_address.eq(address)); - [nullifier_pub_key, incoming_pub_key, outgoing_pub_key, tagging_pub_key] + [npk_m, ivpk_m, ovpk_m, tpk_m] } diff --git a/noir-projects/aztec-nr/aztec/src/oracle.nr b/noir-projects/aztec-nr/aztec/src/oracle.nr index 64c449d61f8..59031fcdb9f 100644 --- a/noir-projects/aztec-nr/aztec/src/oracle.nr +++ b/noir-projects/aztec-nr/aztec/src/oracle.nr @@ -10,7 +10,6 @@ mod get_l1_to_l2_membership_witness; mod get_nullifier_membership_witness; mod get_public_data_witness; mod get_membership_witness; -mod get_public_key; mod keys; mod nullifier_key; mod get_sibling_path; diff --git a/noir-projects/aztec-nr/aztec/src/oracle/get_public_key.nr b/noir-projects/aztec-nr/aztec/src/oracle/get_public_key.nr deleted file mode 100644 index a509e8c1b54..00000000000 --- a/noir-projects/aztec-nr/aztec/src/oracle/get_public_key.nr +++ /dev/null @@ -1,20 +0,0 @@ -use dep::protocol_types::{address::{AztecAddress, PartialAddress, PublicKeysHash}, grumpkin_point::GrumpkinPoint}; - -#[oracle(getPublicKeyAndPartialAddress)] -fn get_public_key_and_partial_address_oracle(_address: AztecAddress) -> [Field; 3] {} - -unconstrained fn get_public_key_and_partial_address_internal(address: AztecAddress) -> [Field; 3] { - get_public_key_and_partial_address_oracle(address) -} - -pub fn get_public_key(address: AztecAddress) -> GrumpkinPoint { - let result = get_public_key_and_partial_address_internal(address); - let pub_key = GrumpkinPoint::new(result[0], result[1]); - let partial_address = PartialAddress::from_field(result[2]); - - // TODO(#5830): disabling the following constraint until we update the oracle according to the new key scheme - // let calculated_address = AztecAddress::compute(PublicKeysHash::compute(pub_key), partial_address); - // assert(calculated_address.eq(address)); - - pub_key -} diff --git a/noir-projects/aztec-nr/aztec/src/oracle/keys.nr b/noir-projects/aztec-nr/aztec/src/oracle/keys.nr index a985e385e81..173ad34aad2 100644 --- a/noir-projects/aztec-nr/aztec/src/oracle/keys.nr +++ b/noir-projects/aztec-nr/aztec/src/oracle/keys.nr @@ -1,7 +1,5 @@ use dep::protocol_types::{address::{AztecAddress, PartialAddress}, grumpkin_point::GrumpkinPoint}; -use crate::hash::poseidon2_hash; - #[oracle(getPublicKeysAndPartialAddress)] fn get_public_keys_and_partial_address_oracle(_address: AztecAddress) -> [Field; 9] {} diff --git a/noir-projects/aztec-nr/aztec/src/oracle/logs.nr b/noir-projects/aztec-nr/aztec/src/oracle/logs.nr index d692329a82f..a1d933915ee 100644 --- a/noir-projects/aztec-nr/aztec/src/oracle/logs.nr +++ b/noir-projects/aztec-nr/aztec/src/oracle/logs.nr @@ -17,7 +17,7 @@ unconstrained pub fn emit_encrypted_log( contract_address: AztecAddress, storage_slot: Field, note_type_id: Field, - encryption_pub_key: GrumpkinPoint, + ivpk_m: GrumpkinPoint, preimage: [Field; N], counter: u32 ) -> [Field; M] { @@ -25,7 +25,7 @@ unconstrained pub fn emit_encrypted_log( contract_address, storage_slot, note_type_id, - encryption_pub_key, + ivpk_m, preimage, counter ) diff --git a/noir-projects/aztec-nr/value-note/src/utils.nr b/noir-projects/aztec-nr/value-note/src/utils.nr index 5cb4b75b6c7..8d88fc6ef0a 100644 --- a/noir-projects/aztec-nr/value-note/src/utils.nr +++ b/noir-projects/aztec-nr/value-note/src/utils.nr @@ -1,6 +1,5 @@ use dep::aztec::prelude::{AztecAddress, PrivateContext, PrivateSet, NoteGetterOptions}; use dep::aztec::note::note_getter_options::SortOrder; -use dep::aztec::oracle::get_public_key::get_public_key; use crate::{filter::filter_notes_min_sum, value_note::{ValueNote, VALUE_NOTE_LEN}}; // Sort the note values (0th field) in descending order. diff --git a/noir-projects/aztec-nr/value-note/src/value_note.nr b/noir-projects/aztec-nr/value-note/src/value_note.nr index 019ea4bf543..ac790864aa8 100644 --- a/noir-projects/aztec-nr/value-note/src/value_note.nr +++ b/noir-projects/aztec-nr/value-note/src/value_note.nr @@ -1,7 +1,8 @@ use dep::aztec::{ + keys::getters::get_ivpk_m, protocol_types::{address::AztecAddress, traits::{Deserialize, Serialize}, constants::GENERATOR_INDEX__NOTE_NULLIFIER}, note::{note_header::NoteHeader, note_interface::NoteInterface, utils::compute_note_hash_for_consumption}, - oracle::{unsafe_rand::unsafe_rand, nullifier_key::get_app_nullifier_secret_key, get_public_key::get_public_key}, + oracle::{unsafe_rand::unsafe_rand, nullifier_key::get_app_nullifier_secret_key}, hash::poseidon2_hash, context::PrivateContext }; @@ -43,12 +44,12 @@ impl NoteInterface for ValueNote { // Broadcasts the note as an encrypted log on L1. fn broadcast(self, context: &mut PrivateContext, slot: Field) { - let encryption_pub_key = get_public_key(self.owner); + let ivpk_m = get_ivpk_m(context, self.owner); context.emit_encrypted_log( (*context).this_address(), slot, Self::get_note_type_id(), - encryption_pub_key, + ivpk_m, self.serialize_content(), ); } diff --git a/noir-projects/noir-contracts/contracts/app_subscription_contract/src/main.nr b/noir-projects/noir-contracts/contracts/app_subscription_contract/src/main.nr index e0532007937..74cc3d1861b 100644 --- a/noir-projects/noir-contracts/contracts/app_subscription_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/app_subscription_contract/src/main.nr @@ -2,24 +2,19 @@ mod subscription_note; mod dapp_payload; contract AppSubscription { - use dep::std; - use crate::dapp_payload::DAppPayload; - - use dep::aztec::prelude::{ + use crate::{dapp_payload::DAppPayload, subscription_note::{SubscriptionNote, SUBSCRIPTION_NOTE_LEN}}; + use dep::{ + aztec::{ + prelude::{ AztecAddress, FunctionSelector, PrivateContext, NoteHeader, Map, PrivateMutable, PublicMutable, SharedImmutable + }, + protocol_types::traits::is_empty + }, + authwit::{account::AccountActions, auth_witness::get_auth_witness, auth::assert_current_call_valid_authwit}, + gas_token::GasToken, token::Token }; - use dep::aztec::protocol_types::traits::is_empty; - - use dep::aztec::{context::Context, oracle::get_public_key::get_public_key}; - use dep::authwit::{account::AccountActions, auth_witness::get_auth_witness, auth::assert_current_call_valid_authwit}; - - use crate::subscription_note::{SubscriptionNote, SUBSCRIPTION_NOTE_LEN}; - - use dep::gas_token::GasToken; - use dep::token::Token; - #[aztec(storage)] struct Storage { // The following is only needed in private but we use ShareImmutable here instead of PrivateImmutable because diff --git a/noir-projects/noir-contracts/contracts/app_subscription_contract/src/subscription_note.nr b/noir-projects/noir-contracts/contracts/app_subscription_contract/src/subscription_note.nr index c2543a14707..665393b166f 100644 --- a/noir-projects/noir-contracts/contracts/app_subscription_contract/src/subscription_note.nr +++ b/noir-projects/noir-contracts/contracts/app_subscription_contract/src/subscription_note.nr @@ -1,8 +1,8 @@ use dep::aztec::prelude::{AztecAddress, PrivateContext, NoteHeader, NoteInterface}; use dep::aztec::{ - protocol_types::constants::GENERATOR_INDEX__NOTE_NULLIFIER, + keys::getters::get_ivpk_m, protocol_types::constants::GENERATOR_INDEX__NOTE_NULLIFIER, note::utils::compute_note_hash_for_consumption, hash::poseidon2_hash, - oracle::{nullifier_key::get_app_nullifier_secret_key, get_public_key::get_public_key} + oracle::{nullifier_key::get_app_nullifier_secret_key} }; global SUBSCRIPTION_NOTE_LEN: Field = 3; @@ -39,12 +39,12 @@ impl NoteInterface for SubscriptionNote { // Broadcasts the note as an encrypted log on L1. fn broadcast(self, context: &mut PrivateContext, slot: Field) { - let encryption_pub_key = get_public_key(self.owner); + let ivpk_m = get_ivpk_m(context, self.owner); context.emit_encrypted_log( (*context).this_address(), slot, Self::get_note_type_id(), - encryption_pub_key, + ivpk_m, self.serialize_content(), ); } diff --git a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr index e71861ffbef..a2f148e4e23 100644 --- a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr @@ -340,4 +340,10 @@ contract AvmTest { fn send_l2_to_l1_msg(recipient: EthAddress, content: Field) { context.message_portal(recipient, content) } + + #[aztec(public-vm)] + fn to_radix_le(input: Field) -> [u8; 10] { + let result: [u8] = input.to_le_radix(/*base=*/ 2, /*limbs=*/ 10); + result.as_array() + } } diff --git a/noir-projects/noir-contracts/contracts/docs_example_contract/src/types/card_note.nr b/noir-projects/noir-contracts/contracts/docs_example_contract/src/types/card_note.nr index 3f952146c2b..aa2fea463d4 100644 --- a/noir-projects/noir-contracts/contracts/docs_example_contract/src/types/card_note.nr +++ b/noir-projects/noir-contracts/contracts/docs_example_contract/src/types/card_note.nr @@ -1,8 +1,8 @@ use dep::aztec::prelude::{AztecAddress, NoteInterface, NoteHeader, PrivateContext}; use dep::aztec::{ - note::{utils::compute_note_hash_for_consumption}, - oracle::{nullifier_key::get_app_nullifier_secret_key, get_public_key::get_public_key}, - hash::poseidon2_hash, protocol_types::{traits::Empty, constants::GENERATOR_INDEX__NOTE_NULLIFIER} + keys::getters::get_ivpk_m, note::{utils::compute_note_hash_for_consumption}, + oracle::nullifier_key::get_app_nullifier_secret_key, hash::poseidon2_hash, + protocol_types::{traits::Empty, constants::GENERATOR_INDEX__NOTE_NULLIFIER} }; // Shows how to create a custom note @@ -47,12 +47,12 @@ impl NoteInterface for CardNote { // Broadcasts the note as an encrypted log on L1. fn broadcast(self, context: &mut PrivateContext, slot: Field) { - let encryption_pub_key = get_public_key(self.owner); + let ivpk_m = get_ivpk_m(context, self.owner); context.emit_encrypted_log( (*context).this_address(), slot, Self::get_note_type_id(), - encryption_pub_key, + ivpk_m, self.serialize_content(), ); } diff --git a/noir-projects/noir-contracts/contracts/ecdsa_account_contract/src/ecdsa_public_key_note.nr b/noir-projects/noir-contracts/contracts/ecdsa_account_contract/src/ecdsa_public_key_note.nr index 20fd400e967..eaadbbc60ac 100644 --- a/noir-projects/noir-contracts/contracts/ecdsa_account_contract/src/ecdsa_public_key_note.nr +++ b/noir-projects/noir-contracts/contracts/ecdsa_account_contract/src/ecdsa_public_key_note.nr @@ -1,9 +1,9 @@ use dep::aztec::prelude::{AztecAddress, FunctionSelector, NoteHeader, NoteInterface, NoteGetterOptions, PrivateContext}; use dep::aztec::{ - note::utils::compute_note_hash_for_consumption, - oracle::{nullifier_key::get_app_nullifier_secret_key, get_public_key::get_public_key}, - hash::poseidon2_hash, protocol_types::constants::GENERATOR_INDEX__NOTE_NULLIFIER + keys::getters::get_ivpk_m, note::utils::compute_note_hash_for_consumption, + oracle::nullifier_key::get_app_nullifier_secret_key, hash::poseidon2_hash, + protocol_types::constants::GENERATOR_INDEX__NOTE_NULLIFIER }; global ECDSA_PUBLIC_KEY_NOTE_LEN: Field = 5; @@ -85,12 +85,12 @@ impl NoteInterface for EcdsaPublicKeyNote { // Broadcasts the note as an encrypted log on L1. fn broadcast(self, context: &mut PrivateContext, slot: Field) { - let encryption_pub_key = get_public_key(self.owner); + let ivpk_m = get_ivpk_m(context, self.owner); context.emit_encrypted_log( (*context).this_address(), slot, Self::get_note_type_id(), - encryption_pub_key, + ivpk_m, self.serialize_content(), ); } diff --git a/noir-projects/noir-contracts/contracts/ecdsa_account_contract/src/main.nr b/noir-projects/noir-contracts/contracts/ecdsa_account_contract/src/main.nr index 25f96128b9d..d992251604d 100644 --- a/noir-projects/noir-contracts/contracts/ecdsa_account_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/ecdsa_account_contract/src/main.nr @@ -8,7 +8,7 @@ contract EcdsaAccount { use dep::aztec::protocol_types::abis::call_context::CallContext; use dep::std; - use dep::aztec::{context::{PublicContext, Context}, oracle::get_public_key::get_public_key}; + use dep::aztec::context::Context; use dep::authwit::{ entrypoint::{app::AppPayload, fee::FeePayload}, account::AccountActions, auth_witness::get_auth_witness diff --git a/noir-projects/noir-contracts/contracts/escrow_contract/src/main.nr b/noir-projects/noir-contracts/contracts/escrow_contract/src/main.nr index c1ec425486b..bc17776d83d 100644 --- a/noir-projects/noir-contracts/contracts/escrow_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/escrow_contract/src/main.nr @@ -2,7 +2,7 @@ contract Escrow { use dep::aztec::prelude::{AztecAddress, EthAddress, FunctionSelector, NoteHeader, PrivateContext, PrivateImmutable}; - use dep::aztec::{context::{PublicContext, Context}, oracle::get_public_key::get_public_key}; + use dep::aztec::context::{PublicContext, Context}; use dep::address_note::address_note::AddressNote; diff --git a/noir-projects/noir-contracts/contracts/key_registry_contract/src/main.nr b/noir-projects/noir-contracts/contracts/key_registry_contract/src/main.nr index b985c829d26..ca63a68aba3 100644 --- a/noir-projects/noir-contracts/contracts/key_registry_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/key_registry_contract/src/main.nr @@ -3,7 +3,7 @@ contract KeyRegistry { use dep::aztec::{ state_vars::{SharedMutable, Map}, - protocol_types::{grumpkin_point::GrumpkinPoint, address::{AztecAddress, PartialAddress}} + protocol_types::{grumpkin_point::GrumpkinPoint, address::{AztecAddress, PartialAddress, PublicKeysHash}} }; global KEY_ROTATION_DELAY = 5; @@ -27,11 +27,7 @@ contract KeyRegistry { } #[aztec(public)] - fn rotate_nullifier_public_key( - address: AztecAddress, - new_nullifier_public_key: GrumpkinPoint, - nonce: Field - ) { + fn rotate_npk_m(address: AztecAddress, new_npk_m: GrumpkinPoint, nonce: Field) { // TODO: (#6137) if (!address.eq(context.msg_sender())) { assert_current_call_valid_authwit_public(&mut context, address); @@ -41,26 +37,21 @@ contract KeyRegistry { let npk_m_x_registry = storage.npk_m_x_registry.at(address); let npk_m_y_registry = storage.npk_m_y_registry.at(address); - npk_m_x_registry.schedule_value_change(new_nullifier_public_key.x); - npk_m_y_registry.schedule_value_change(new_nullifier_public_key.y); + npk_m_x_registry.schedule_value_change(new_npk_m.x); + npk_m_y_registry.schedule_value_change(new_npk_m.y); } #[aztec(public)] fn register( address: AztecAddress, partial_address: PartialAddress, - nullifier_public_key: GrumpkinPoint, - incoming_public_key: GrumpkinPoint, - outgoing_public_key: GrumpkinPoint, - tagging_public_key: GrumpkinPoint + npk_m: GrumpkinPoint, + ivpk_m: GrumpkinPoint, + ovpk_m: GrumpkinPoint, + tpk_m: GrumpkinPoint ) { - let computed_address = AztecAddress::compute_from_public_keys_and_partial_address( - nullifier_public_key, - incoming_public_key, - outgoing_public_key, - tagging_public_key, - partial_address - ); + let public_keys_hash = PublicKeysHash::compute(npk_m, ivpk_m, ovpk_m, tpk_m); + let computed_address = AztecAddress::compute(public_keys_hash, partial_address); assert(computed_address.eq(address), "Computed address does not match supplied address"); @@ -73,14 +64,14 @@ contract KeyRegistry { // let tpk_m_x_registry = storage.tpk_m_x_registry.at(address); // let tpk_m_y_registry = storage.tpk_m_y_registry.at(address); - npk_m_x_registry.schedule_value_change(nullifier_public_key.x); - npk_m_y_registry.schedule_value_change(nullifier_public_key.y); - ivpk_m_x_registry.schedule_value_change(incoming_public_key.x); - ivpk_m_y_registry.schedule_value_change(incoming_public_key.y); + npk_m_x_registry.schedule_value_change(npk_m.x); + npk_m_y_registry.schedule_value_change(npk_m.y); + ivpk_m_x_registry.schedule_value_change(ivpk_m.x); + ivpk_m_y_registry.schedule_value_change(ivpk_m.y); // Commented out as we hit the max enqueued public calls limit when not done so - // ovpk_m_x_registry.schedule_value_change(outgoing_public_key.x); - // ovpk_m_y_registry.schedule_value_change(outgoing_public_key.y); - // tpk_m_x_registry.schedule_value_change(tagging_public_key.x); - // tpk_m_y_registry.schedule_value_change(tagging_public_key.y); + // ovpk_m_x_registry.schedule_value_change(ovpk_m.x); + // ovpk_m_y_registry.schedule_value_change(ovpk_m.y); + // tpk_m_x_registry.schedule_value_change(tpk_m.x); + // tpk_m_y_registry.schedule_value_change(tpk_m.y); } } diff --git a/noir-projects/noir-contracts/contracts/schnorr_account_contract/src/main.nr b/noir-projects/noir-contracts/contracts/schnorr_account_contract/src/main.nr index d42ee2119d6..39cc3384b3b 100644 --- a/noir-projects/noir-contracts/contracts/schnorr_account_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/schnorr_account_contract/src/main.nr @@ -7,7 +7,7 @@ contract SchnorrAccount { use dep::aztec::prelude::{AztecAddress, FunctionSelector, NoteHeader, PrivateContext, PrivateImmutable}; use dep::aztec::state_vars::{Map, PublicMutable}; - use dep::aztec::{context::Context, oracle::get_public_key::get_public_key}; + use dep::aztec::context::Context; use dep::authwit::{ entrypoint::{app::AppPayload, fee::FeePayload}, account::AccountActions, auth_witness::get_auth_witness diff --git a/noir-projects/noir-contracts/contracts/schnorr_account_contract/src/public_key_note.nr b/noir-projects/noir-contracts/contracts/schnorr_account_contract/src/public_key_note.nr index 95fbe422f78..74812ec7465 100644 --- a/noir-projects/noir-contracts/contracts/schnorr_account_contract/src/public_key_note.nr +++ b/noir-projects/noir-contracts/contracts/schnorr_account_contract/src/public_key_note.nr @@ -1,7 +1,7 @@ use dep::aztec::prelude::{AztecAddress, NoteHeader, NoteInterface, PrivateContext}; use dep::aztec::{ - note::utils::compute_note_hash_for_consumption, hash::poseidon2_hash, - oracle::{nullifier_key::get_app_nullifier_secret_key, get_public_key::get_public_key}, + keys::getters::get_ivpk_m, note::utils::compute_note_hash_for_consumption, hash::poseidon2_hash, + oracle::{nullifier_key::get_app_nullifier_secret_key}, protocol_types::constants::GENERATOR_INDEX__NOTE_NULLIFIER }; @@ -39,12 +39,12 @@ impl NoteInterface for PublicKeyNote { // Broadcasts the note as an encrypted log on L1. fn broadcast(self, context: &mut PrivateContext, slot: Field) { - let encryption_pub_key = get_public_key(self.owner); + let ivpk_m = get_ivpk_m(context, self.owner); context.emit_encrypted_log( (*context).this_address(), slot, Self::get_note_type_id(), - encryption_pub_key, + ivpk_m, self.serialize_content(), ); } diff --git a/noir-projects/noir-contracts/contracts/schnorr_single_key_account_contract/src/main.nr b/noir-projects/noir-contracts/contracts/schnorr_single_key_account_contract/src/main.nr index bb6aad4b787..5c75c095d2f 100644 --- a/noir-projects/noir-contracts/contracts/schnorr_single_key_account_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/schnorr_single_key_account_contract/src/main.nr @@ -6,7 +6,8 @@ contract SchnorrSingleKeyAccount { use dep::authwit::{entrypoint::{app::AppPayload, fee::FeePayload}, account::AccountActions}; - use crate::{util::recover_address, auth_oracle::get_auth_witness}; + // use crate::{util::recover_address, auth_oracle::get_auth_witness}; + use crate::auth_oracle::get_auth_witness; global ACCOUNT_ACTIONS_STORAGE_SLOT = 1; diff --git a/noir-projects/noir-contracts/contracts/schnorr_single_key_account_contract/src/util.nr b/noir-projects/noir-contracts/contracts/schnorr_single_key_account_contract/src/util.nr index f337d688bbd..89f7e2e9b4d 100644 --- a/noir-projects/noir-contracts/contracts/schnorr_single_key_account_contract/src/util.nr +++ b/noir-projects/noir-contracts/contracts/schnorr_single_key_account_contract/src/util.nr @@ -3,18 +3,19 @@ use dep::aztec::protocol_types::address::PublicKeysHash; use dep::std::{schnorr::verify_signature_slice}; use crate::auth_oracle::AuthWitness; -pub fn recover_address(message_hash: Field, witness: AuthWitness) -> AztecAddress { - let message_bytes = message_hash.to_be_bytes(32); - let verification = verify_signature_slice( - witness.owner.x, - witness.owner.y, - witness.signature, - message_bytes - ); - assert(verification == true); +// TODO(#5830): the following is currently broken because we are no longer able to compute public keys hash +// pub fn recover_address(message_hash: Field, witness: AuthWitness) -> AztecAddress { +// let message_bytes = message_hash.to_be_bytes(32); +// let verification = verify_signature_slice( +// witness.owner.x, +// witness.owner.y, +// witness.signature, +// message_bytes +// ); +// assert(verification == true); - AztecAddress::compute( - PublicKeysHash::compute(witness.owner), - witness.partial_address - ) -} +// AztecAddress::compute( +// PublicKeysHash::compute(witness.owner), +// witness.partial_address +// ) +// } diff --git a/noir-projects/noir-contracts/contracts/test_contract/src/main.nr b/noir-projects/noir-contracts/contracts/test_contract/src/main.nr index aba898225ff..2a248a7eaf2 100644 --- a/noir-projects/noir-contracts/contracts/test_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/test_contract/src/main.nr @@ -23,18 +23,15 @@ contract Test { use dep::aztec::state_vars::{shared_mutable::SharedMutablePrivateGetter, map::derive_storage_slot_in_map}; use dep::aztec::{ - keys::getters::get_npk_m, + keys::getters::{get_npk_m, get_ivpk_m}, context::{Context, inputs::private_context_inputs::PrivateContextInputs}, - hash::{pedersen_hash, poseidon2_hash, compute_secret_hash, ArgsHasher}, + hash::{pedersen_hash, compute_secret_hash, ArgsHasher}, note::{ lifecycle::{create_note, destroy_note}, note_getter::{get_notes, view_notes}, note_getter_options::NoteStatus }, deploy::deploy_contract as aztec_deploy_contract, - oracle::{ - encryption::aes128_encrypt, get_public_key::get_public_key as get_public_key_oracle, - unsafe_rand::unsafe_rand - } + oracle::{encryption::aes128_encrypt, unsafe_rand::unsafe_rand} }; use dep::token_portal_content_hash_lib::{get_mint_private_content_hash, get_mint_public_content_hash}; use dep::value_note::value_note::ValueNote; @@ -53,8 +50,8 @@ contract Test { } #[aztec(private)] - fn get_public_key(address: AztecAddress) -> [Field; 2] { - let pub_key = get_public_key_oracle(address); + fn get_master_incoming_viewing_public_key(address: AztecAddress) -> [Field; 2] { + let pub_key = get_ivpk_m(&mut context, address); [pub_key.x, pub_key.y] } diff --git a/noir-projects/noir-contracts/contracts/token_blacklist_contract/src/types/token_note.nr b/noir-projects/noir-contracts/contracts/token_blacklist_contract/src/types/token_note.nr index 798a9cfe174..8492b92a1fc 100644 --- a/noir-projects/noir-contracts/contracts/token_blacklist_contract/src/types/token_note.nr +++ b/noir-projects/noir-contracts/contracts/token_blacklist_contract/src/types/token_note.nr @@ -1,8 +1,8 @@ use dep::aztec::{ - prelude::{AztecAddress, NoteHeader, NoteInterface, PrivateContext}, + keys::getters::get_ivpk_m, prelude::{AztecAddress, NoteHeader, NoteInterface, PrivateContext}, protocol_types::constants::GENERATOR_INDEX__NOTE_NULLIFIER, note::utils::compute_note_hash_for_consumption, hash::poseidon2_hash, - oracle::{unsafe_rand::unsafe_rand, nullifier_key::get_app_nullifier_secret_key, get_public_key::get_public_key} + oracle::{unsafe_rand::unsafe_rand, nullifier_key::get_app_nullifier_secret_key} }; trait OwnedNote { @@ -52,12 +52,12 @@ impl NoteInterface for TokenNote { fn broadcast(self, context: &mut PrivateContext, slot: Field) { // We only bother inserting the note if non-empty to save funds on gas. if !(self.amount == U128::from_integer(0)) { - let encryption_pub_key = get_public_key(self.owner); + let ivpk_m = get_ivpk_m(context, self.owner); context.emit_encrypted_log( (*context).this_address(), slot, Self::get_note_type_id(), - encryption_pub_key, + ivpk_m, self.serialize_content(), ); } diff --git a/noir-projects/noir-contracts/contracts/token_contract/src/types/token_note.nr b/noir-projects/noir-contracts/contracts/token_contract/src/types/token_note.nr index 798a9cfe174..8492b92a1fc 100644 --- a/noir-projects/noir-contracts/contracts/token_contract/src/types/token_note.nr +++ b/noir-projects/noir-contracts/contracts/token_contract/src/types/token_note.nr @@ -1,8 +1,8 @@ use dep::aztec::{ - prelude::{AztecAddress, NoteHeader, NoteInterface, PrivateContext}, + keys::getters::get_ivpk_m, prelude::{AztecAddress, NoteHeader, NoteInterface, PrivateContext}, protocol_types::constants::GENERATOR_INDEX__NOTE_NULLIFIER, note::utils::compute_note_hash_for_consumption, hash::poseidon2_hash, - oracle::{unsafe_rand::unsafe_rand, nullifier_key::get_app_nullifier_secret_key, get_public_key::get_public_key} + oracle::{unsafe_rand::unsafe_rand, nullifier_key::get_app_nullifier_secret_key} }; trait OwnedNote { @@ -52,12 +52,12 @@ impl NoteInterface for TokenNote { fn broadcast(self, context: &mut PrivateContext, slot: Field) { // We only bother inserting the note if non-empty to save funds on gas. if !(self.amount == U128::from_integer(0)) { - let encryption_pub_key = get_public_key(self.owner); + let ivpk_m = get_ivpk_m(context, self.owner); context.emit_encrypted_log( (*context).this_address(), slot, Self::get_note_type_id(), - encryption_pub_key, + ivpk_m, self.serialize_content(), ); } diff --git a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/common.nr b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/common.nr deleted file mode 100644 index 8f828e9a6ca..00000000000 --- a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/common.nr +++ /dev/null @@ -1,190 +0,0 @@ -use dep::std; -use dep::types::{ - abis::{ - call_request::CallRequest, accumulated_data::PrivateAccumulatedData, - private_circuit_public_inputs::PrivateCircuitPublicInputs, - private_kernel::private_call_data::PrivateCallData -}, - address::{AztecAddress, PartialAddress}, contract_class_id::ContractClassId, - hash::{private_functions_root_from_siblings, stdlib_recursion_verification_key_compress_native_vk}, - utils::{arrays::{array_length, validate_array}}, traits::{is_empty, is_empty_array} -}; - -fn validate_arrays(app_public_inputs: PrivateCircuitPublicInputs) { - // Each of the following arrays is expected to be zero-padded. - // In addition, some of the following arrays (new_note_hashes, etc...) are passed - // to extend_from_array_to_array() routines which rely on the passed arrays to be well-formed. - - validate_array(app_public_inputs.note_hash_read_requests); - validate_array(app_public_inputs.nullifier_read_requests); - validate_array(app_public_inputs.nullifier_key_validation_requests); - validate_array(app_public_inputs.new_note_hashes); - validate_array(app_public_inputs.new_nullifiers); - validate_array(app_public_inputs.private_call_stack_hashes); - validate_array(app_public_inputs.public_call_stack_hashes); - validate_array(app_public_inputs.new_l2_to_l1_msgs); - validate_array(app_public_inputs.encrypted_logs_hashes); - validate_array(app_public_inputs.unencrypted_logs_hashes); -} - -fn perform_static_call_checks(private_call: PrivateCallData) { - let public_inputs = private_call.call_stack_item.public_inputs; - let is_static_call = public_inputs.call_context.is_static_call; - if is_static_call { - // No state changes are allowed for static calls: - assert( - is_empty_array(public_inputs.new_note_hashes), "new_note_hashes must be empty for static calls" - ); - assert( - is_empty_array(public_inputs.new_nullifiers), "new_nullifiers must be empty for static calls" - ); - - let new_l2_to_l1_msgs_length = array_length(public_inputs.new_l2_to_l1_msgs); - assert(new_l2_to_l1_msgs_length == 0, "new_l2_to_l1_msgs must be empty for static calls"); - - // TODO: reevaluate when implementing https://github.com/AztecProtocol/aztec-packages/issues/1165 - // This 4 magical number is the minimum size of the buffer, since it has to store the total length of all the serialized logs. - assert( - public_inputs.encrypted_log_preimages_length == 4, "No encrypted logs are allowed for static calls" - ); - - assert( - public_inputs.unencrypted_log_preimages_length == 4, "No unencrypted logs are allowed for static calls" - ); - } -} - -fn is_valid_caller(request_from_stack: CallRequest, fn_being_verified: PrivateCallData) -> bool { - let call_context = fn_being_verified.call_stack_item.public_inputs.call_context; - - let valid_caller_context = request_from_stack.caller_context.msg_sender.eq(call_context.msg_sender) - & request_from_stack.caller_context.storage_contract_address.eq(call_context.storage_contract_address); - - request_from_stack.caller_contract_address.eq(fn_being_verified.call_stack_item.contract_address) - & (request_from_stack.caller_context.is_empty() | valid_caller_context) -} - -fn validate_call_request(request: CallRequest, hash: Field, private_call: PrivateCallData) { - if hash != 0 { - assert_eq(request.hash, hash, "call stack hash does not match call request hash"); - assert(is_valid_caller(request, private_call), "invalid caller"); - } else { - assert(is_empty(request), "call requests length does not match the expected length"); - } -} - -fn validate_call_requests(call_requests: [CallRequest; N], hashes: [Field; N], private_call: PrivateCallData) { - for i in 0..N { - let hash = hashes[i]; - let request = call_requests[i]; - validate_call_request(request, hash, private_call); - } -} - -// TODO: Move to a seperate file. -pub fn validate_private_call_data(private_call: PrivateCallData) { - let private_call_public_inputs = private_call.call_stack_item.public_inputs; - - validate_arrays(private_call_public_inputs); - - contract_logic(private_call); - - perform_static_call_checks(private_call); - - // Private call stack. - validate_call_requests( - private_call.private_call_stack, - private_call_public_inputs.private_call_stack_hashes, - private_call - ); - - // Public call stack. - validate_call_requests( - private_call.public_call_stack, - private_call_public_inputs.public_call_stack_hashes, - private_call - ); - - // Teardown call - validate_call_request( - private_call.public_teardown_call_request, - private_call_public_inputs.public_teardown_function_hash, - private_call - ); -} - -fn contract_logic(private_call: PrivateCallData) { - let contract_address = private_call.call_stack_item.contract_address; - - // TODO(https://github.com/AztecProtocol/aztec-packages/issues/3062): Why is this using a hash function from the stdlib::recursion namespace - let private_call_vk_hash = stdlib_recursion_verification_key_compress_native_vk(private_call.vk); - - assert(!contract_address.is_zero(), "contract address cannot be zero"); - // std::println(f"contract_address={contract_address}"); - // std::println(f"private_call_vk_hash={private_call_vk_hash}"); - - // Recompute the contract class id - let computed_private_functions_root = private_functions_root_from_siblings( - private_call.call_stack_item.function_data.selector, - private_call_vk_hash, - private_call.function_leaf_membership_witness.leaf_index, - private_call.function_leaf_membership_witness.sibling_path - ); - // std::println(f"computed_private_functions_root={computed_private_functions_root}"); - - let computed_contract_class_id = ContractClassId::compute( - private_call.contract_class_artifact_hash, - computed_private_functions_root, - private_call.contract_class_public_bytecode_commitment - ); - // std::println(f"computed_contract_class_id={computed_contract_class_id}"); - - // Recompute contract address using the preimage which includes the class_id - let computed_partial_address = PartialAddress::compute_from_salted_initialization_hash( - computed_contract_class_id, - private_call.salted_initialization_hash - ); - // std::println(f"computed_partial_address={computed_partial_address}"); - - let computed_address = AztecAddress::compute(private_call.public_keys_hash, computed_partial_address); - // std::println(f"computed_address={computed_address}"); - - assert(computed_address.eq(contract_address), "computed contract address does not match expected one"); -} - -pub fn validate_previous_kernel_values(end: PrivateAccumulatedData) { - assert( - end.new_nullifiers[0].value() != 0, "The 0th nullifier in the accumulated nullifier array is zero" - ); -} - -pub fn validate_call_against_request(private_call: PrivateCallData, request: CallRequest) { - let call_stack_item = private_call.call_stack_item; - assert( - request.hash == call_stack_item.hash(), "calculated private_call_hash does not match provided private_call_hash at the top of the call stack" - ); - - let call_context = call_stack_item.public_inputs.call_context; - - if call_context.is_delegate_call { - let caller_context = request.caller_context; - assert(!caller_context.is_empty(), "caller context cannot be empty for delegate calls"); - assert( - call_context.msg_sender.eq(caller_context.msg_sender), "call stack msg_sender does not match expected msg_sender for delegate calls" - ); - assert( - call_context.storage_contract_address.eq(caller_context.storage_contract_address), "call stack storage address does not match expected contract address for delegate calls" - ); - assert( - !call_stack_item.contract_address.eq(call_context.storage_contract_address), "curent contract address must not match storage contract address for delegate calls" - ); - } else { - let caller_contract_address = request.caller_contract_address; - assert( - call_context.msg_sender.eq(caller_contract_address), "call stack msg_sender does not match caller contract address" - ); - assert( - call_context.storage_contract_address.eq(call_stack_item.contract_address), "call stack storage address does not match expected contract address" - ); - } -} diff --git a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/kernel_circuit_public_inputs_composer.nr b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/kernel_circuit_public_inputs_composer.nr index f1b37f57318..2dec1046f68 100644 --- a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/kernel_circuit_public_inputs_composer.nr +++ b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/kernel_circuit_public_inputs_composer.nr @@ -129,7 +129,9 @@ impl KernelCircuitPublicInputsComposer { fn silo_note_hashes(&mut self) { let first_nullifier = self.public_inputs.end.new_nullifiers.get_unchecked(0).value(); - assert(first_nullifier != 0, "The 0th nullifier in the accumulated nullifier array is zero"); + + // This check is unnecessary. The 0th nullifier will always be set a non-zero value in private_kernel_init. + // assert(first_nullifier != 0, "The 0th nullifier in the accumulated nullifier array is zero"); let note_hashes = self.public_inputs.end.new_note_hashes.storage; for i in 0..MAX_NEW_NOTE_HASHES_PER_TX { diff --git a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/lib.nr b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/lib.nr index 7ce19ad3c8e..88de299ebac 100644 --- a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/lib.nr +++ b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/lib.nr @@ -1,14 +1,13 @@ mod kernel_circuit_public_inputs_composer; +mod private_call_data_validator; mod private_kernel_circuit_public_inputs_composer; mod private_kernel_init; mod private_kernel_inner; mod private_kernel_tail; mod private_kernel_tail_to_public; +mod tests; use private_kernel_init::PrivateKernelInitCircuitPrivateInputs; use private_kernel_inner::PrivateKernelInnerCircuitPrivateInputs; use private_kernel_tail::PrivateKernelTailCircuitPrivateInputs; use private_kernel_tail_to_public::PrivateKernelTailToPublicCircuitPrivateInputs; - -// TODO: rename to be precise as to what its common to. -mod common; diff --git a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_call_data_validator.nr b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_call_data_validator.nr new file mode 100644 index 00000000000..774b66b306e --- /dev/null +++ b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_call_data_validator.nr @@ -0,0 +1,230 @@ +use dep::types::{ + abis::{ + call_context::CallContext, call_request::CallRequest, private_call_stack_item::PrivateCallStackItem, + private_kernel::private_call_data::PrivateCallData +}, + address::{AztecAddress, PartialAddress}, contract_class_id::ContractClassId, + hash::{private_functions_root_from_siblings, stdlib_recursion_verification_key_compress_native_vk}, + traits::{is_empty, is_empty_array}, transaction::tx_request::TxRequest, + utils::arrays::{array_length, validate_array} +}; + +fn validate_arrays(data: PrivateCallData) -> ArrayLengths { + let public_inputs = data.call_stack_item.public_inputs; + + // Each of the following arrays is expected to be zero-padded. + ArrayLengths { + note_hash_read_requests: validate_array(public_inputs.note_hash_read_requests), + nullifier_read_requests: validate_array(public_inputs.nullifier_read_requests), + nullifier_key_validation_requests: validate_array(public_inputs.nullifier_key_validation_requests), + new_note_hashes: validate_array(public_inputs.new_note_hashes), + new_nullifiers: validate_array(public_inputs.new_nullifiers), + new_l2_to_l1_msgs: validate_array(public_inputs.new_l2_to_l1_msgs), + private_call_stack_hashes: validate_array(public_inputs.private_call_stack_hashes), + public_call_stack_hashes: validate_array(public_inputs.public_call_stack_hashes), + encrypted_logs_hashes: validate_array(public_inputs.encrypted_logs_hashes), + unencrypted_logs_hashes: validate_array(public_inputs.unencrypted_logs_hashes) + } +} + +fn is_valid_caller(request: CallRequest, caller_address: AztecAddress, caller_context: CallContext) -> bool { + let valid_caller_context = request.caller_context.msg_sender.eq(caller_context.msg_sender) + & request.caller_context.storage_contract_address.eq(caller_context.storage_contract_address); + + request.caller_contract_address.eq(caller_address) + & (request.caller_context.is_empty() | valid_caller_context) +} + +fn validate_call_request(request: CallRequest, hash: Field, caller: PrivateCallStackItem) { + if hash != 0 { + assert_eq(request.hash, hash, "call stack hash does not match call request hash"); + assert( + is_valid_caller( + request, + caller.contract_address, + caller.public_inputs.call_context + ), "invalid caller" + ); + } else { + assert(is_empty(request), "call requests length does not match the expected length"); + } +} + +fn validate_call_requests(call_requests: [CallRequest; N], hashes: [Field; N], caller: PrivateCallStackItem) { + for i in 0..N { + validate_call_request(call_requests[i], hashes[i], caller); + } +} + +struct ArrayLengths { + note_hash_read_requests: u64, + nullifier_read_requests: u64, + nullifier_key_validation_requests: u64, + new_note_hashes: u64, + new_nullifiers: u64, + new_l2_to_l1_msgs: u64, + private_call_stack_hashes: u64, + public_call_stack_hashes: u64, + encrypted_logs_hashes: u64, + unencrypted_logs_hashes: u64, +} + +struct PrivateCallDataValidator { + data: PrivateCallData, + array_lengths: ArrayLengths, +} + +impl PrivateCallDataValidator { + pub fn new(data: PrivateCallData) -> Self { + let array_lengths = validate_arrays(data); + PrivateCallDataValidator { data, array_lengths } + } + + pub fn validate(self) { + self.validate_contract_address(); + self.validate_call(); + self.validate_private_call_requests(); + self.validate_public_call_requests(); + self.validate_teardown_call_request(); + } + + // Confirm that the TxRequest (user's intent) matches the private call being executed. + pub fn validate_against_tx_request(self, tx_request: TxRequest) { + let call_stack_item = self.data.call_stack_item; + assert_eq( + tx_request.origin, call_stack_item.contract_address, "origin address does not match call stack items contract address" + ); + assert_eq( + tx_request.function_data.hash(), call_stack_item.function_data.hash(), "tx_request function_data must match call_stack_item function_data" + ); + assert_eq( + tx_request.args_hash, call_stack_item.public_inputs.args_hash, "noir function args passed to tx_request must match args in the call_stack_item" + ); + assert_eq( + tx_request.tx_context, call_stack_item.public_inputs.tx_context, "tx_context in tx_request must match tx_context in call_stack_item" + ); + + // If checking against TxRequest, it must be the first call, which has the following restrictions. + let call_context = call_stack_item.public_inputs.call_context; + assert(call_context.is_delegate_call == false, "Users cannot make a delegatecall"); + assert(call_context.is_static_call == false, "Users cannot make a static call"); + } + + pub fn validate_against_call_request(self, request: CallRequest) { + let call_stack_item = self.data.call_stack_item; + + assert_eq( + request.hash, call_stack_item.hash(), "calculated private_call_hash does not match provided private_call_hash at the top of the call stack" + ); + + let call_context = call_stack_item.public_inputs.call_context; + + if call_context.is_delegate_call { + let caller_context = request.caller_context; + assert(!caller_context.is_empty(), "caller context cannot be empty for delegate calls"); + assert_eq( + call_context.msg_sender, caller_context.msg_sender, "call stack msg_sender does not match expected msg_sender for delegate calls" + ); + assert_eq( + call_context.storage_contract_address, caller_context.storage_contract_address, "call stack storage address does not match expected contract address for delegate calls" + ); + } else { + assert_eq( + call_context.msg_sender, request.caller_contract_address, "call stack msg_sender does not match caller contract address" + ); + } + } + + fn validate_contract_address(self) { + let contract_address = self.data.call_stack_item.contract_address; + + // TODO(https://github.com/AztecProtocol/aztec-packages/issues/3062): Why is this using a hash function from the stdlib::recursion namespace + let private_call_vk_hash = stdlib_recursion_verification_key_compress_native_vk(self.data.vk); + + assert(!contract_address.is_zero(), "contract address cannot be zero"); + // std::println(f"contract_address={contract_address}"); + // std::println(f"private_call_vk_hash={private_call_vk_hash}"); + + // Recompute the contract class id + let computed_private_functions_root = private_functions_root_from_siblings( + self.data.call_stack_item.function_data.selector, + private_call_vk_hash, + self.data.function_leaf_membership_witness.leaf_index, + self.data.function_leaf_membership_witness.sibling_path + ); + // std::println(f"computed_private_functions_root={computed_private_functions_root}"); + + let computed_contract_class_id = ContractClassId::compute( + self.data.contract_class_artifact_hash, + computed_private_functions_root, + self.data.contract_class_public_bytecode_commitment + ); + // std::println(f"computed_contract_class_id={computed_contract_class_id}"); + + // Recompute contract address using the preimage which includes the class_id + let computed_partial_address = PartialAddress::compute_from_salted_initialization_hash( + computed_contract_class_id, + self.data.salted_initialization_hash + ); + // std::println(f"computed_partial_address={computed_partial_address}"); + + let computed_address = AztecAddress::compute(self.data.public_keys_hash, computed_partial_address); + // std::println(f"computed_address={computed_address}"); + + assert( + computed_address.eq(contract_address), "computed contract address does not match expected one" + ); + } + + fn validate_call(self) { + let call_context = self.data.call_stack_item.public_inputs.call_context; + + let is_own_storage = call_context.storage_contract_address == self.data.call_stack_item.contract_address; + if call_context.is_delegate_call { + assert( + !is_own_storage, "current contract address must not match storage contract address for delegate calls" + ); + } else { + assert(is_own_storage, "call stack storage address does not match expected contract address"); + } + + if call_context.is_static_call { + // No state changes are allowed for static calls: + assert_eq(self.array_lengths.new_note_hashes, 0, "new_note_hashes must be empty for static calls"); + assert_eq(self.array_lengths.new_nullifiers, 0, "new_nullifiers must be empty for static calls"); + assert_eq( + self.array_lengths.new_l2_to_l1_msgs, 0, "new_l2_to_l1_msgs must be empty for static calls" + ); + assert_eq( + self.array_lengths.encrypted_logs_hashes, 0, "encrypted_logs_hashes must be empty for static calls" + ); + assert_eq( + self.array_lengths.unencrypted_logs_hashes, 0, "unencrypted_logs_hashes must be empty for static calls" + ); + } + } + + fn validate_private_call_requests(self) { + validate_call_requests( + self.data.private_call_stack, + self.data.call_stack_item.public_inputs.private_call_stack_hashes, + self.data.call_stack_item + ); + } + + fn validate_public_call_requests(self) { + validate_call_requests( + self.data.public_call_stack, + self.data.call_stack_item.public_inputs.public_call_stack_hashes, + self.data.call_stack_item + ); + } + + fn validate_teardown_call_request(self) { + validate_call_request( + self.data.public_teardown_call_request, + self.data.call_stack_item.public_inputs.public_teardown_function_hash, + self.data.call_stack_item + ); + } +} diff --git a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_circuit_public_inputs_composer.nr b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_circuit_public_inputs_composer.nr index ef4e6008eb9..b7f5dc96d85 100644 --- a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_circuit_public_inputs_composer.nr +++ b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_circuit_public_inputs_composer.nr @@ -215,7 +215,7 @@ impl PrivateKernelCircuitPublicInputsComposer { let call_request = source.public_teardown_call_request; if !is_empty(call_request) { assert( - self.public_inputs.public_teardown_call_request.is_empty(), "Public teardown call request already set" + is_empty(self.public_inputs.public_teardown_call_request), "Public teardown call request already set" ); self.public_inputs.public_teardown_call_request = call_request; } diff --git a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_init.nr b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_init.nr index 4c3872bac03..da6ef0cc4ba 100644 --- a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_init.nr +++ b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_init.nr @@ -1,4 +1,7 @@ -use crate::{common, private_kernel_circuit_public_inputs_composer::PrivateKernelCircuitPublicInputsComposer}; +use crate::{ + private_call_data_validator::PrivateCallDataValidator, + private_kernel_circuit_public_inputs_composer::PrivateKernelCircuitPublicInputsComposer +}; use dep::types::{ abis::{ private_kernel::private_call_data::{PrivateCallData, verify_private_call}, @@ -20,63 +23,15 @@ struct PrivateKernelInitCircuitPrivateInputs { } impl PrivateKernelInitCircuitPrivateInputs { - // Confirm that the TxRequest (user's intent) - // matches the private call being executed - fn validate_this_private_call_against_tx_request(self) { - let tx_request = self.tx_request; - // Call stack item for the initial call - let call_stack_item = self.private_call.call_stack_item; - - // Checks to ensure that the user's intent matches the initial private call - // - // We use the word correct to denote whether it matches the user intent. - // - // Ensure we are calling the correct initial contract - let origin_address_matches = tx_request.origin.eq(call_stack_item.contract_address); - assert(origin_address_matches, "origin address does not match call stack items contract address"); - // - // Ensure we are calling the correct initial function in the contract - let entry_point_function_matches = tx_request.function_data.hash() == call_stack_item.function_data.hash(); - assert( - entry_point_function_matches, "tx_request function_data must match call_stack_item function_data" - ); - // - // Ensure we are passing the correct arguments to the function. - let args_match = tx_request.args_hash == call_stack_item.public_inputs.args_hash; - assert(args_match, "noir function args passed to tx_request must match args in the call_stack_item"); - // - // Ensure we are passing the correct tx context - let tx_context_matches = tx_request.tx_context == call_stack_item.public_inputs.tx_context; - assert(tx_context_matches, "tx_context in tx_request must match tx_context in call_stack_item"); - } - - fn validate_inputs(self) { - let call_stack_item = self.private_call.call_stack_item; - - let function_data = call_stack_item.function_data; - assert(function_data.is_private, "Private kernel circuit can only execute a private function"); - - let call_context = call_stack_item.public_inputs.call_context; - assert(call_context.is_delegate_call == false, "Users cannot make a delegatecall"); - assert(call_context.is_static_call == false, "Users cannot make a static call"); - // The below also prevents delegatecall/staticcall in the base case - assert( - call_context.storage_contract_address.eq(call_stack_item.contract_address), "Storage contract address must be that of the called contract" - ); - } - pub fn native_private_kernel_circuit_initial(self) -> PrivateKernelCircuitPublicInputs { - let private_call_public_inputs = self.private_call.call_stack_item.public_inputs; - // verify/aggregate the private call proof verify_private_call(self.private_call); - self.validate_inputs(); - - common::validate_private_call_data(self.private_call); - - self.validate_this_private_call_against_tx_request(); + let privateCallDataValidator = PrivateCallDataValidator::new(self.private_call); + privateCallDataValidator.validate(); + privateCallDataValidator.validate_against_tx_request(self.tx_request); + let private_call_public_inputs = self.private_call.call_stack_item.public_inputs; PrivateKernelCircuitPublicInputsComposer::new_from_tx_request(self.tx_request, private_call_public_inputs).compose( private_call_public_inputs, self.hints.note_hash_nullifier_counters, @@ -167,177 +122,6 @@ mod tests { assert_eq(public_inputs.end.unencrypted_logs_hashes[1].value, unencrypted_logs_hashes[1]); } - #[test(should_fail_with = "invalid array")] - fn input_validation_malformed_arrays_note_hash_read_requests() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - - builder.private_call.public_inputs.note_hash_read_requests.extend_from_array( - [ - ReadRequest::empty(), - ReadRequest { value: 9123, counter: 1 } - ] - ); - - builder.failed(); - } - - #[test(should_fail_with = "invalid array")] - fn input_validation_malformed_arrays_note_hashes() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - - builder.private_call.public_inputs.new_note_hashes.extend_from_array( - [ - NoteHash { value: 0, counter: 0 }, - NoteHash { value: 9123, counter: 1 } - ] - ); - - builder.failed(); - } - - #[test(should_fail_with = "invalid array")] - fn input_validation_malformed_arrays_nullifiers() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - - builder.private_call.public_inputs.new_nullifiers.extend_from_array( - [ - Nullifier { value: 0, note_hash: 0, counter: 0 }, - Nullifier { value: 9123, note_hash: 0, counter: 1 } - ] - ); - - builder.failed(); - } - - #[test(should_fail_with = "invalid array")] - fn input_validation_malformed_arrays_private_call_stack() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - - builder.private_call.public_inputs.private_call_stack_hashes.extend_from_array([0, 9123]); - - builder.failed(); - } - - #[test(should_fail_with = "invalid array")] - fn input_validation_malformed_arrays_public_call_stack() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - - builder.private_call.public_inputs.public_call_stack_hashes.extend_from_array([0, 9123]); - - builder.failed(); - } - - #[test(should_fail_with = "invalid array")] - fn input_validation_malformed_arrays_new_l2_to_l1_msgs() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - - builder.private_call.public_inputs.new_l2_to_l1_msgs.extend_from_array( - [ - L2ToL1Message::empty(), - L2ToL1Message { recipient: EthAddress::from_field(6), content: 9123, counter: 0 } - ] - ); - - builder.failed(); - } - - #[test(should_fail_with = "invalid array")] - fn input_validation_malformed_arrays_logs() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - - builder.private_call.public_inputs.encrypted_logs_hashes.extend_from_array( - [ - SideEffect { value: 0, counter: 0 }, - SideEffect { value: 9123, counter: 1 } - ] - ); - - builder.failed(); - } - - #[test(should_fail_with="Private kernel circuit can only execute a private function")] - fn private_function_is_private_false_fails() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - - // Set is_private in function data to false. - builder.private_call.function_data.is_private = false; - - builder.failed(); - } - - #[test(should_fail_with="Users cannot make a static call")] - fn private_function_static_call_fails() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - - // Set is_static_call to true. - builder.private_call.public_inputs.call_context.is_static_call = true; - - builder.failed(); - } - - #[test(should_fail_with="Users cannot make a delegatecall")] - fn private_function_delegate_call_fails() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - - // Set is_delegate_call to true. - builder.private_call.public_inputs.call_context.is_delegate_call = true; - - builder.failed(); - } - - #[test(should_fail_with="Storage contract address must be that of the called contract")] - fn private_function_incorrect_storage_contract_address_fails() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - - // Set the storage_contract_address to a random scalar. - builder.private_call.public_inputs.call_context.storage_contract_address = AztecAddress::from_field(356); - - builder.failed(); - } - - #[test(should_fail_with="computed contract address does not match expected one")] - fn private_function_incorrect_function_leaf_index_fails() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - - // Set the leaf index of the function leaf to a wrong value (the correct value + 1). - let leaf_index = builder.private_call.function_leaf_membership_witness.leaf_index; - builder.private_call.function_leaf_membership_witness.leaf_index = leaf_index + 1; - - builder.failed(); - } - - #[test(should_fail_with="computed contract address does not match expected one")] - fn private_function_incorrect_function_leaf_sibling_path_fails() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - - // Set the first value of the sibling path to a wrong value (the correct value + 1). - let sibling_path_0 = builder.private_call.function_leaf_membership_witness.sibling_path[0]; - builder.private_call.function_leaf_membership_witness.sibling_path[0] = sibling_path_0 + 1; - - builder.failed(); - } - - #[test(should_fail_with="computed contract address does not match expected one")] - fn private_function_incorrect_contract_class_preimage_fails() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - builder.private_call.contract_class_artifact_hash = builder.private_call.contract_class_artifact_hash + 1; - builder.failed(); - } - - #[test(should_fail_with="computed contract address does not match expected one")] - fn private_function_incorrect_partial_address_preimage_fails() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - builder.private_call.salted_initialization_hash.inner = builder.private_call.salted_initialization_hash.inner + 1; - builder.failed(); - } - - #[test(should_fail_with="computed contract address does not match expected one")] - fn private_function_incorrect_address_preimage_fails() { - let mut builder = PrivateKernelInitInputsBuilder::new(); - builder.private_call.public_keys_hash.inner = builder.private_call.public_keys_hash.inner + 1; - builder.failed(); - } - #[test] fn default_max_block_number() { let mut builder = PrivateKernelInitInputsBuilder::new(); diff --git a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_inner.nr b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_inner.nr index 07eabb0f6e0..20376e9b83a 100644 --- a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_inner.nr +++ b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_inner.nr @@ -1,10 +1,12 @@ -use crate::{common, private_kernel_circuit_public_inputs_composer::PrivateKernelCircuitPublicInputsComposer}; +use crate::{ + private_call_data_validator::PrivateCallDataValidator, + private_kernel_circuit_public_inputs_composer::PrivateKernelCircuitPublicInputsComposer +}; use dep::types::{ abis::{ private_kernel_data::{PrivateKernelData, verify_previous_kernel_proof}, private_kernel::private_call_data::{PrivateCallData, verify_private_call}, - kernel_circuit_public_inputs::{PrivateKernelCircuitPublicInputs, PrivateKernelCircuitPublicInputsBuilder}, - side_effect::SideEffect + kernel_circuit_public_inputs::PrivateKernelCircuitPublicInputs }, constants::MAX_NEW_NOTE_HASHES_PER_CALL, utils::arrays::array_length }; @@ -20,36 +22,24 @@ struct PrivateKernelInnerCircuitPrivateInputs { } impl PrivateKernelInnerCircuitPrivateInputs { - fn validate_inputs(self) { - let this_call_stack_item = self.private_call.call_stack_item; - let function_data = this_call_stack_item.function_data; - assert(function_data.is_private, "Private kernel circuit can only execute a private function"); - } - pub fn native_private_kernel_circuit_inner(self) -> PrivateKernelCircuitPublicInputs { - let private_call_public_inputs = self.private_call.call_stack_item.public_inputs; - let previous_kernel_public_inputs = self.previous_kernel.public_inputs; - // verify/aggregate the private call proof verify_private_call(self.private_call); // verify/aggregate the previous kernel verify_previous_kernel_proof(self.previous_kernel); - common::validate_previous_kernel_values(previous_kernel_public_inputs.end); + let privateCallDataValidator = PrivateCallDataValidator::new(self.private_call); + privateCallDataValidator.validate(); - self.validate_inputs(); - - common::validate_private_call_data(self.private_call); - - let mut private_call_stack = previous_kernel_public_inputs.end.private_call_stack; + let private_call_stack = self.previous_kernel.public_inputs.end.private_call_stack; // TODO: Should be a hint from private inputs. let private_call_stack_size = array_length(private_call_stack); let call_request = private_call_stack[private_call_stack_size - 1]; - common::validate_call_against_request(self.private_call, call_request); + privateCallDataValidator.validate_against_call_request(call_request); PrivateKernelCircuitPublicInputsComposer::new_from_previous_kernel(self.previous_kernel.public_inputs).compose( - private_call_public_inputs, + self.private_call.call_stack_item.public_inputs, self.hints.note_hash_nullifier_counters, self.private_call.private_call_stack, self.private_call.public_call_stack, @@ -122,357 +112,6 @@ mod tests { } } - #[test(should_fail_with = "contract address cannot be zero")] - fn private_function_zero_storage_contract_address_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - // Set (storage) contract_address to 0 - builder.private_call.contract_address = AztecAddress::zero(); - builder.private_call.public_inputs.call_context.storage_contract_address = AztecAddress::zero(); - - builder.failed(); - } - - #[test(should_fail_with="computed contract address does not match expected one")] - fn private_function_incorrect_function_leaf_index_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - // Set the leaf index of the function leaf to a wrong value (the correct value + 1). - let leaf_index = builder.private_call.function_leaf_membership_witness.leaf_index; - builder.private_call.function_leaf_membership_witness.leaf_index = leaf_index + 1; - - builder.failed(); - } - - #[test(should_fail_with="computed contract address does not match expected one")] - fn private_function_incorrect_function_leaf_sibling_path_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - // Set the first value of the sibling path to a wrong value (the correct value + 1). - let sibling_path_0 = builder.private_call.function_leaf_membership_witness.sibling_path[0]; - builder.private_call.function_leaf_membership_witness.sibling_path[0] = sibling_path_0 + 1; - - builder.failed(); - } - - #[test(should_fail_with="computed contract address does not match expected one")] - fn private_function_incorrect_contract_class_preimage_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - builder.private_call.contract_class_artifact_hash = builder.private_call.contract_class_artifact_hash + 1; - builder.failed(); - } - - #[test(should_fail_with="computed contract address does not match expected one")] - fn private_function_incorrect_partial_address_preimage_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - builder.private_call.salted_initialization_hash.inner = builder.private_call.salted_initialization_hash.inner + 1; - builder.failed(); - } - - #[test(should_fail_with="computed contract address does not match expected one")] - fn private_function_incorrect_address_preimage_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - builder.private_call.public_keys_hash.inner = builder.private_call.public_keys_hash.inner + 1; - builder.failed(); - } - - #[test(should_fail_with = "calculated private_call_hash does not match provided private_call_hash at the top of the call stack")] - fn private_function_incorrect_call_stack_item_hash_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - let private_call = builder.private_call.finish(); - let hash = private_call.call_stack_item.hash(); - // Set the first call stack hash to a wrong value (the correct hash + 1). - builder.previous_kernel.push_private_call_request(hash + 1, false); - let previous_kernel = builder.previous_kernel.to_private_kernel_data(); - - let kernel = PrivateKernelInnerCircuitPrivateInputs { previous_kernel, private_call, hints: builder.hints }; - - let _ = kernel.native_private_kernel_circuit_inner(); - } - - #[test(should_fail_with="call stack msg_sender does not match caller contract address")] - fn incorrect_msg_sender_for_regular_calls_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - // Set the msg_sender to a wrong value. - builder.private_call.public_inputs.call_context.msg_sender.inner += 1; - - builder.failed(); - } - - #[test(should_fail_with="call stack storage address does not match expected contract address")] - fn incorrect_storage_contract_for_regular_calls_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - // Set the storage contract address to a wrong value. - builder.private_call.public_inputs.call_context.storage_contract_address.inner += 1; - - builder.failed(); - } - - #[test] - fn delegate_call_succeeds() { - let mut builder = PrivateKernelInnerInputsBuilder::new().is_delegate_call(); - builder.succeeded(); - } - - #[test(should_fail_with="caller context cannot be empty for delegate calls")] - fn empty_caller_context_for_delegate_calls_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new().is_delegate_call(); - - let private_call = builder.private_call.finish(); - let hash = private_call.call_stack_item.hash(); - // Caller context is empty for regular calls. - let is_delegate_call = false; - builder.previous_kernel.push_private_call_request(hash, is_delegate_call); - let previous_kernel = builder.previous_kernel.to_private_kernel_data(); - - let kernel = PrivateKernelInnerCircuitPrivateInputs { previous_kernel, private_call, hints: builder.hints }; - - let _ = kernel.native_private_kernel_circuit_inner(); - } - - #[test(should_fail_with="call stack msg_sender does not match expected msg_sender for delegate calls")] - fn incorrect_msg_sender_for_delegate_calls_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new().is_delegate_call(); - - // Set the msg_sender to be the caller contract. - builder.private_call.public_inputs.call_context.msg_sender = builder.previous_kernel.contract_address; - - builder.failed(); - } - - #[test(should_fail_with="call stack storage address does not match expected contract address for delegate calls")] - fn incorrect_storage_address_for_delegate_calls_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new().is_delegate_call(); - - // Set the storage contract address to be the contract address. - builder.private_call.public_inputs.call_context.storage_contract_address = builder.private_call.contract_address; - - builder.failed(); - } - - #[test(should_fail_with="curent contract address must not match storage contract address for delegate calls")] - fn incorrect_storage_contract_for_delegate_calls_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new().is_delegate_call(); - - // Change the storage contract address to be the same as the contract address. - builder.private_call.public_inputs.call_context.storage_contract_address = builder.private_call.contract_address; - - let private_call = builder.private_call.finish(); - let hash = private_call.call_stack_item.hash(); - builder.previous_kernel.push_private_call_request(hash, true); - let mut call_request = builder.previous_kernel.private_call_stack.pop(); - // Change the caller's storage contract address to be the same as the contract address. - call_request.caller_context.storage_contract_address = builder.private_call.contract_address; - builder.previous_kernel.private_call_stack.push(call_request); - - let previous_kernel = builder.previous_kernel.to_private_kernel_data(); - let kernel = PrivateKernelInnerCircuitPrivateInputs { previous_kernel, private_call, hints: builder.hints }; - let _ = kernel.native_private_kernel_circuit_inner(); - } - - #[test] - fn call_requests_succeeds() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.append_private_call_requests(2, false); - builder.private_call.append_private_call_requests(1, true); - builder.private_call.append_public_call_requests(1, false); - builder.private_call.append_public_call_requests(2, true); - - builder.succeeded(); - } - - #[test(should_fail_with = "call requests length does not match the expected length")] - fn incorrect_private_call_requests_length_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.append_private_call_requests(2, false); - // Remove one call stack item hash. - let _ = builder.private_call.public_inputs.private_call_stack_hashes.pop(); - - builder.failed(); - } - - #[test(should_fail_with = "call requests length does not match the expected length")] - fn incorrect_public_call_requests_length_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.append_public_call_requests(2, false); - // Remove one call stack item hash. - let _ = builder.private_call.public_inputs.public_call_stack_hashes.pop(); - - builder.failed(); - } - - #[test(should_fail_with = "call stack hash does not match call request hash")] - fn incorrect_private_call_request_hash_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.append_private_call_requests(2, false); - let mut call_request = builder.private_call.private_call_stack.pop(); - // Change the hash to be a different value. - call_request.hash += 1; - builder.private_call.private_call_stack.push(call_request); - - builder.failed(); - } - - #[test(should_fail_with = "call stack hash does not match call request hash")] - fn incorrect_public_call_request_hash_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.append_public_call_requests(2, false); - let mut call_request = builder.private_call.public_call_stack.pop(); - // Change the hash to be a different value. - call_request.hash += 1; - builder.private_call.public_call_stack.push(call_request); - - builder.failed(); - } - - #[test(should_fail_with = "invalid caller")] - fn incorrect_caller_address_for_private_call_request_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.append_private_call_requests(1, false); - let mut call_request = builder.private_call.private_call_stack.pop(); - // Change the caller contract address to be a different value. - call_request.caller_contract_address.inner += 1; - builder.private_call.private_call_stack.push(call_request); - - builder.failed(); - } - - #[test(should_fail_with = "invalid caller")] - fn incorrect_caller_address_for_public_call_request_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.append_public_call_requests(1, false); - let mut call_request = builder.private_call.public_call_stack.pop(); - // Change the caller contract address to be a different value. - call_request.caller_contract_address.inner += 1; - builder.private_call.public_call_stack.push(call_request); - - builder.failed(); - } - - #[test(should_fail_with = "invalid caller")] - fn incorrect_caller_context_for_private_delegate_call_request_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.append_private_call_requests(1, true); - let mut call_request = builder.private_call.private_call_stack.pop(); - // Change the storage contract to be a different value. - call_request.caller_context.storage_contract_address.inner += 1; - builder.private_call.private_call_stack.push(call_request); - - builder.failed(); - } - - #[test(should_fail_with = "invalid caller")] - fn incorrect_caller_context_for_public_delegate_call_request_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.append_public_call_requests(1, true); - let mut call_request = builder.private_call.public_call_stack.pop(); - // Change the storage contract to be a different value. - call_request.caller_context.storage_contract_address.inner += 1; - builder.private_call.public_call_stack.push(call_request); - - builder.failed(); - } - - #[test(should_fail_with = "invalid array")] - fn input_validation_malformed_arrays_read_requests() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.public_inputs.note_hash_read_requests.extend_from_array( - [ - ReadRequest::empty(), - ReadRequest { value: 9123, counter: 1 } - ] - ); - - builder.failed(); - } - - #[test(should_fail_with = "invalid array")] - fn input_validation_malformed_arrays_note_hashes() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.public_inputs.new_note_hashes.extend_from_array( - [ - NoteHash { value: 0, counter: 0 }, - NoteHash { value: 9123, counter: 1 } - ] - ); - - builder.failed(); - } - - #[test(should_fail_with = "invalid array")] - fn input_validation_malformed_arrays_nullifiers() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.public_inputs.new_nullifiers.extend_from_array( - [ - Nullifier { value: 0, note_hash: 0, counter: 0 }, - Nullifier { value: 12, note_hash: 0, counter: 1 } - ] - ); - - builder.failed(); - } - - #[test(should_fail_with = "invalid array")] - fn input_validation_malformed_arrays_private_call_stack() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.public_inputs.private_call_stack_hashes.extend_from_array([0, 888]); - - builder.failed(); - } - - #[test(should_fail_with = "invalid array")] - fn input_validation_malformed_arrays_public_call_stack() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.public_inputs.public_call_stack_hashes.extend_from_array([0, 888]); - - builder.failed(); - } - - #[test(should_fail_with = "invalid array")] - fn input_validation_malformed_arrays_new_l2_to_l1_msgs() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.public_inputs.new_l2_to_l1_msgs.extend_from_array( - [ - L2ToL1Message::empty(), - L2ToL1Message { recipient: EthAddress::from_field(6), content: 888, counter: 0 } - ] - ); - - builder.failed(); - } - - #[test(should_fail_with = "invalid array")] - fn input_validation_malformed_arrays_logs() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.public_inputs.encrypted_logs_hashes.extend_from_array( - [ - SideEffect { value: 0, counter: 0 }, - SideEffect { value: 9123, counter: 1 } - ] - ); - - builder.failed(); - } - #[test(should_fail_with = "push out of bounds")] fn private_kernel_should_fail_if_aggregating_too_many_note_hashes() { let mut builder = PrivateKernelInnerInputsBuilder::new(); @@ -514,15 +153,6 @@ mod tests { builder.failed(); } - #[test(should_fail_with="Private kernel circuit can only execute a private function")] - fn private_function_is_private_false_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.private_call.function_data.is_private = false; - - builder.failed(); - } - #[test] fn propagate_previous_kernel_max_block_number() { let mut builder = PrivateKernelInnerInputsBuilder::new(); @@ -618,31 +248,4 @@ mod tests { assert_eq(public_inputs.end.encrypted_logs_hashes[1].value, encrypted_logs_hash); assert_eq(public_inputs.end.unencrypted_logs_hashes[1].value, unencrypted_logs_hash); } - - #[test(should_fail_with="new_note_hashes must be empty for static calls")] - fn creating_new_note_hashes_on_static_call_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new().is_static_call(); - - builder.private_call.public_inputs.new_note_hashes.push(NoteHash { value: 1, counter: 0 }); - - builder.failed(); - } - - #[test(should_fail_with="new_nullifiers must be empty for static calls")] - fn creating_new_nullifiers_on_static_call_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new().is_static_call(); - - builder.private_call.public_inputs.new_nullifiers.push(Nullifier { value: 1, note_hash: 0, counter: 0 }); - - builder.failed(); - } - - #[test(should_fail_with="The 0th nullifier in the accumulated nullifier array is zero")] - fn zero_0th_nullifier_fails() { - let mut builder = PrivateKernelInnerInputsBuilder::new(); - - builder.previous_kernel.new_nullifiers = BoundedVec::new(); - - builder.failed(); - } } diff --git a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_tail.nr b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_tail.nr index 98b6b9c08fa..fabb95e9911 100644 --- a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_tail.nr +++ b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_tail.nr @@ -581,13 +581,6 @@ mod tests { builder.failed(); } - #[test(should_fail_with="The 0th nullifier in the accumulated nullifier array is zero")] - unconstrained fn zero_0th_nullifier_fails() { - let mut builder = PrivateKernelTailInputsBuilder::new(); - builder.previous_kernel.new_nullifiers = BoundedVec::new(); - builder.failed(); - } - #[test] unconstrained fn empty_tx_consumes_teardown_limits_plus_fixed_gas() { let mut builder = PrivateKernelTailInputsBuilder::new(); diff --git a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_tail_to_public.nr b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_tail_to_public.nr index 42309b81435..fa73aaf18ae 100644 --- a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_tail_to_public.nr +++ b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/private_kernel_tail_to_public.nr @@ -483,13 +483,6 @@ mod tests { builder.succeeded(); } - #[test(should_fail_with="The 0th nullifier in the accumulated nullifier array is zero")] - unconstrained fn zero_0th_nullifier_fails() { - let mut builder = PrivateKernelTailToPublicInputsBuilder::new(); - builder.previous_kernel.new_nullifiers = BoundedVec::new(); - builder.failed(); - } - #[test] unconstrained fn split_nullifiers_into_non_revertible() { let mut builder = PrivateKernelTailToPublicInputsBuilder::new(); diff --git a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/tests.nr b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/tests.nr new file mode 100644 index 00000000000..e92b7e1016a --- /dev/null +++ b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/tests.nr @@ -0,0 +1 @@ +mod private_call_data_validator_builder; diff --git a/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/tests/private_call_data_validator_builder.nr b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/tests/private_call_data_validator_builder.nr new file mode 100644 index 00000000000..8802b64d81e --- /dev/null +++ b/noir-projects/noir-protocol-circuits/crates/private-kernel-lib/src/tests/private_call_data_validator_builder.nr @@ -0,0 +1,790 @@ +use crate::private_call_data_validator::PrivateCallDataValidator; +use dep::types::{ + abis::{ + call_request::CallRequest, caller_context::CallerContext, note_hash::NoteHash, nullifier::Nullifier, + nullifier_key_validation_request::NullifierKeyValidationRequest, read_request::ReadRequest, + side_effect::SideEffect +}, + address::{AztecAddress, EthAddress}, grumpkin_point::GrumpkinPoint, + messaging::l2_to_l1_message::L2ToL1Message, + tests::private_call_data_builder::PrivateCallDataBuilder, transaction::tx_request::TxRequest +}; + +struct PrivateCallDataValidatorBuilder { + private_call: PrivateCallDataBuilder, +} + +impl PrivateCallDataValidatorBuilder { + pub fn new() -> Self { + let private_call = PrivateCallDataBuilder::new(); + PrivateCallDataValidatorBuilder { private_call } + } + + pub fn is_delegate_call(&mut self) -> Self { + let _ = self.private_call.is_delegate_call(); + *self + } + + pub fn is_static_call(&mut self) -> Self { + let _ = self.private_call.is_static_call(); + *self + } + + pub fn validate(self) { + let private_call = self.private_call.finish(); + PrivateCallDataValidator::new(private_call).validate(); + } + + pub fn validate_against_tx_request(self, request: TxRequest) { + let private_call = self.private_call.finish(); + PrivateCallDataValidator::new(private_call).validate_against_tx_request(request); + } + + pub fn validate_against_call_request(self, request: CallRequest) { + let private_call = self.private_call.finish(); + PrivateCallDataValidator::new(private_call).validate_against_call_request(request); + } +} + +/** + * validate_arrays + */ + +#[test(should_fail_with="invalid array")] +fn validate_arrays_malformed_note_hash_read_requests_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.public_inputs.note_hash_read_requests.extend_from_array( + [ + ReadRequest::empty(), + ReadRequest { value: 9123, counter: 1 } + ] + ); + + builder.validate(); +} + +#[test(should_fail_with="invalid array")] +fn validate_arrays_malformed_nullifier_read_requests_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.public_inputs.nullifier_read_requests.extend_from_array( + [ + ReadRequest::empty(), + ReadRequest { value: 9123, counter: 1 } + ] + ); + + builder.validate(); +} + +// Enable this test if MAX_NULLIFIER_KEY_VALIDATION_REQUESTS_PER_CALL is greater than 1. +// #[test(should_fail_with="invalid array")] +// fn validate_arrays_malformed_nullifier_key_validation_requests_fails() { +// let mut builder = PrivateCallDataValidatorBuilder::new(); + +// builder.private_call.public_inputs.nullifier_key_validation_requests.extend_from_array( +// [ +// NullifierKeyValidationRequest::empty(), +// NullifierKeyValidationRequest { master_nullifier_public_key: GrumpkinPoint { x: 12, y: 34 }, app_nullifier_secret_key: 5 } +// ] +// ); + +// builder.validate(); +// } + +#[test(should_fail_with="invalid array")] +fn validate_arrays_malformed_note_hashes_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.public_inputs.new_note_hashes.extend_from_array( + [ + NoteHash::empty(), + NoteHash { value: 9123, counter: 1 } + ] + ); + + builder.validate(); +} + +#[test(should_fail_with="invalid array")] +fn validate_arrays_malformed_nullifiers_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.public_inputs.new_nullifiers.extend_from_array( + [ + Nullifier::empty(), + Nullifier { value: 9123, note_hash: 0, counter: 1 } + ] + ); + + builder.validate(); +} + +#[test(should_fail_with="invalid array")] +fn validate_arrays_malformed_l2_to_l1_msgs_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.public_inputs.new_l2_to_l1_msgs.extend_from_array( + [ + L2ToL1Message::empty(), + L2ToL1Message { recipient: EthAddress::from_field(6), content: 9123, counter: 0 } + ] + ); + + builder.validate(); +} + +#[test(should_fail_with="invalid array")] +fn validate_arrays_malformed_private_call_stack_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.public_inputs.private_call_stack_hashes.extend_from_array([0, 9123]); + + builder.validate(); +} + +#[test(should_fail_with="invalid array")] +fn validate_arrays_malformed_public_call_stack_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.public_inputs.public_call_stack_hashes.extend_from_array([0, 9123]); + + builder.validate(); +} + +#[test(should_fail_with="invalid array")] +fn validate_arrays_malformed_encrypted_logs_hashes_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.public_inputs.encrypted_logs_hashes.extend_from_array( + [ + SideEffect { value: 0, counter: 0 }, + SideEffect { value: 9123, counter: 1 } + ] + ); + + builder.validate(); +} + +#[test(should_fail_with="invalid array")] +fn validate_arrays_malformed_unencrypted_logs_hashes_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.public_inputs.unencrypted_logs_hashes.extend_from_array( + [ + SideEffect { value: 0, counter: 0 }, + SideEffect { value: 9123, counter: 1 } + ] + ); + + builder.validate(); +} + +/** + * validate_contract_address + */ + +#[test(should_fail_with="contract address cannot be zero")] +fn validate_contract_address_zero_storage_contract_address_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + // Set (storage) contract_address to 0 + builder.private_call.contract_address = AztecAddress::zero(); + builder.private_call.public_inputs.call_context.storage_contract_address = AztecAddress::zero(); + + builder.validate(); +} + +#[test(should_fail_with="computed contract address does not match expected one")] +fn validate_contract_address_incorrect_function_leaf_index_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + // Set the leaf index of the function leaf to a wrong value (the correct value + 1). + let leaf_index = builder.private_call.function_leaf_membership_witness.leaf_index; + builder.private_call.function_leaf_membership_witness.leaf_index = leaf_index + 1; + + builder.validate(); +} + +#[test(should_fail_with="computed contract address does not match expected one")] +fn validate_contract_address_incorrect_function_leaf_sibling_path_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + // Set the first value of the sibling path to a wrong value (the correct value + 1). + let sibling_path_0 = builder.private_call.function_leaf_membership_witness.sibling_path[0]; + builder.private_call.function_leaf_membership_witness.sibling_path[0] = sibling_path_0 + 1; + + builder.validate(); +} + +#[test(should_fail_with="computed contract address does not match expected one")] +fn validate_contract_address_incorrect_contract_class_preimage_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.contract_class_artifact_hash = builder.private_call.contract_class_artifact_hash + 1; + + builder.validate(); +} + +#[test(should_fail_with="computed contract address does not match expected one")] +fn validate_contract_address_incorrect_partial_address_preimage_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.salted_initialization_hash.inner = builder.private_call.salted_initialization_hash.inner + 1; + + builder.validate(); +} + +#[test(should_fail_with="computed contract address does not match expected one")] +fn validate_contract_address_incorrect_address_preimage_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.public_keys_hash.inner = builder.private_call.public_keys_hash.inner + 1; + + builder.validate(); +} + +/** + * validate_call + */ + +#[test] +fn validate_call_is_regular_succeeds() { + let builder = PrivateCallDataValidatorBuilder::new(); + builder.validate(); +} + +#[test(should_fail_with="call stack storage address does not match expected contract address")] +fn validate_call_is_regular_mismatch_storage_contract_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + // Change the storage contract address to be a different value. + builder.private_call.public_inputs.call_context.storage_contract_address.inner += 1; + + builder.validate(); +} + +#[test] +fn validate_call_is_delegate_succeeds() { + let builder = PrivateCallDataValidatorBuilder::new().is_delegate_call(); + builder.validate(); +} + +#[test(should_fail_with="current contract address must not match storage contract address for delegate calls")] +fn validate_call_is_delegate_call_from_same_contract_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new().is_delegate_call(); + + // Change the caller's storage contract address to be the same as the contract address. + builder.private_call.public_inputs.call_context.storage_contract_address = builder.private_call.contract_address; + + builder.validate(); +} + +#[test] +fn validate_call_is_static_succeeds() { + let mut builder = PrivateCallDataValidatorBuilder::new().is_static_call(); + builder.validate(); +} + +#[test(should_fail_with="call stack storage address does not match expected contract address")] +fn validate_call_is_static_mismatch_storage_contract_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new().is_static_call(); + + // Change the storage contract address to be a different value. + builder.private_call.public_inputs.call_context.storage_contract_address.inner += 1; + + builder.validate(); +} + +#[test(should_fail_with="new_note_hashes must be empty for static calls")] +fn validate_call_is_static_creating_note_hashes_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new().is_static_call(); + + builder.private_call.public_inputs.new_note_hashes.push(NoteHash { value: 1, counter: 0 }); + + builder.validate(); +} + +#[test(should_fail_with="new_nullifiers must be empty for static calls")] +fn validate_call_is_static_creating_nullifiers_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new().is_static_call(); + + builder.private_call.public_inputs.new_nullifiers.push(Nullifier { value: 1, counter: 0, note_hash: 0 }); + + builder.validate(); +} + +#[test(should_fail_with="new_l2_to_l1_msgs must be empty for static calls")] +fn validate_call_is_static_creating_l2_to_l1_msgs_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new().is_static_call(); + + builder.private_call.public_inputs.new_l2_to_l1_msgs.push(L2ToL1Message { recipient: EthAddress::from_field(6), content: 9123, counter: 0 }); + + builder.validate(); +} + +#[test(should_fail_with="encrypted_logs_hashes must be empty for static calls")] +fn validate_call_is_static_creating_encrypted_logs_hashes_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new().is_static_call(); + + builder.private_call.public_inputs.encrypted_logs_hashes.push(SideEffect { value: 9123, counter: 1 }); + + builder.validate(); +} + +#[test(should_fail_with="unencrypted_logs_hashes must be empty for static calls")] +fn validate_call_is_static_creating_unencrypted_logs_hashes_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new().is_static_call(); + + builder.private_call.public_inputs.unencrypted_logs_hashes.push(SideEffect { value: 9123, counter: 1 }); + + builder.validate(); +} + +/** + * validate_private_call_requests + */ + +#[test] +fn validate_private_call_requests_succeeds() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_private_call_requests(2, false); + + builder.validate(); +} + +#[test] +fn validate_private_call_requests_delegate_calls_succeeds() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_private_call_requests(2, true); + + builder.validate(); +} + +#[test(should_fail_with="call stack hash does not match call request hash")] +fn validate_private_call_requests_incorrect_hash_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_private_call_requests(2, false); + let mut call_request = builder.private_call.private_call_stack.pop(); + // Change the hash to be a different value. + call_request.hash += 1; + builder.private_call.private_call_stack.push(call_request); + + builder.validate(); +} + +#[test(should_fail_with="invalid caller")] +fn validate_private_call_requests_incorrect_caller_address_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_private_call_requests(1, false); + let mut call_request = builder.private_call.private_call_stack.pop(); + // Change the caller contract address to be a different value. + call_request.caller_contract_address.inner += 1; + builder.private_call.private_call_stack.push(call_request); + + builder.validate(); +} + +#[test(should_fail_with="invalid caller")] +fn validate_private_call_requests_incorrect_caller_storage_contract_address_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_private_call_requests(1, true); + let mut call_request = builder.private_call.private_call_stack.pop(); + // Change the storage contract to be a different value. + call_request.caller_context.storage_contract_address.inner += 1; + builder.private_call.private_call_stack.push(call_request); + + builder.validate(); +} + +#[test(should_fail_with="invalid caller")] +fn validate_private_call_requests_incorrect_caller_msg_sender_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_private_call_requests(1, true); + let mut call_request = builder.private_call.private_call_stack.pop(); + // Change the msg_sender to be a different value. + call_request.caller_context.msg_sender.inner += 1; + builder.private_call.private_call_stack.push(call_request); + + builder.validate(); +} + +#[test(should_fail_with="call requests length does not match the expected length")] +fn validate_private_call_requests_fewer_hashes_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_private_call_requests(2, false); + // Remove one call stack item hash. + let _ = builder.private_call.public_inputs.private_call_stack_hashes.pop(); + + builder.validate(); +} + +#[test(should_fail_with="call stack hash does not match call request hash")] +fn validate_private_call_requests_more_hashes_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_private_call_requests(2, false); + // Add one random call stack item hash. + builder.private_call.public_inputs.private_call_stack_hashes.push(9123); + + builder.validate(); +} + +/** + * validate_public_call_requests + */ + +#[test] +fn validate_public_call_requests_succeeds() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_public_call_requests(2, false); + + builder.validate(); +} + +#[test] +fn validate_public_call_requests_delegate_calls_succeeds() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_public_call_requests(2, true); + + builder.validate(); +} + +#[test(should_fail_with="call stack hash does not match call request hash")] +fn validate_public_call_requests_incorrect_hash_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_public_call_requests(2, false); + let mut call_request = builder.private_call.public_call_stack.pop(); + // Change the hash to be a different value. + call_request.hash += 1; + builder.private_call.public_call_stack.push(call_request); + + builder.validate(); +} + +#[test(should_fail_with="invalid caller")] +fn validate_public_call_requests_incorrect_caller_address_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_public_call_requests(1, false); + let mut call_request = builder.private_call.public_call_stack.pop(); + // Change the caller contract address to be a different value. + call_request.caller_contract_address.inner += 1; + builder.private_call.public_call_stack.push(call_request); + + builder.validate(); +} + +#[test(should_fail_with="invalid caller")] +fn validate_public_call_requests_incorrect_caller_storage_contract_address_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_public_call_requests(1, true); + let mut call_request = builder.private_call.public_call_stack.pop(); + // Change the storage contract to be a different value. + call_request.caller_context.storage_contract_address.inner += 1; + builder.private_call.public_call_stack.push(call_request); + + builder.validate(); +} + +#[test(should_fail_with="invalid caller")] +fn validate_public_call_requests_incorrect_caller_msg_sender_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_public_call_requests(1, true); + let mut call_request = builder.private_call.public_call_stack.pop(); + // Change the msg_sender to be a different value. + call_request.caller_context.msg_sender.inner += 1; + builder.private_call.public_call_stack.push(call_request); + + builder.validate(); +} + +#[test(should_fail_with="call requests length does not match the expected length")] +fn validate_public_call_requests_fewer_hashes_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_public_call_requests(2, false); + // Remove one call stack item hash. + let _ = builder.private_call.public_inputs.public_call_stack_hashes.pop(); + + builder.validate(); +} + +#[test(should_fail_with="call stack hash does not match call request hash")] +fn validate_public_call_requests_more_hashes_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.append_public_call_requests(2, false); + // Add one random call stack item hash. + builder.private_call.public_inputs.public_call_stack_hashes.push(9123); + + builder.validate(); +} + +/** + * validate_teardown_call_request + */ + +#[test] +fn validate_teardown_call_request_succeeds() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.add_teaddown_call_request(false); + + builder.validate(); +} + +#[test] +fn validate_teardown_call_request_delegate_calls_succeeds() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.add_teaddown_call_request(true); + + builder.validate(); +} + +#[test(should_fail_with="call stack hash does not match call request hash")] +fn validate_teardown_call_request_incorrect_hash_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.add_teaddown_call_request(true); + // Change the hash to be a different value. + builder.private_call.public_teardown_call_request.hash += 1; + + builder.validate(); +} + +#[test(should_fail_with="invalid caller")] +fn validate_teardown_call_request_incorrect_caller_address_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.add_teaddown_call_request(true); + // Change the caller contract address to be a different value. + builder.private_call.public_teardown_call_request.caller_contract_address.inner += 1; + + builder.validate(); +} + +#[test(should_fail_with="invalid caller")] +fn validate_teardown_call_request_incorrect_caller_storage_contract_address_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.add_teaddown_call_request(true); + // Change the storage contract to be a different value. + builder.private_call.public_teardown_call_request.caller_context.storage_contract_address.inner += 1; + + builder.validate(); +} + +#[test(should_fail_with="invalid caller")] +fn validate_teardown_call_request_incorrect_caller_msg_sender_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.add_teaddown_call_request(true); + // Change the msg_sender to be a different value. + builder.private_call.public_teardown_call_request.caller_context.msg_sender.inner += 1; + + builder.validate(); +} + +#[test(should_fail_with="call requests length does not match the expected length")] +fn validate_teardown_call_request_fewer_hashes_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.add_teaddown_call_request(true); + // Remove the call stack item hash. + builder.private_call.public_inputs.public_teardown_function_hash = 0; + + builder.validate(); +} + +#[test(should_fail_with="call stack hash does not match call request hash")] +fn validate_teardown_call_request_more_hashes_fails() { + let mut builder = PrivateCallDataValidatorBuilder::new(); + + builder.private_call.add_teaddown_call_request(true); + // Remove the call request. + builder.private_call.public_teardown_call_request = CallRequest::empty(); + + builder.validate(); +} + +/** + * validate_against_tx_request + */ + +#[test] +fn validate_against_tx_request_succeeds() { + let builder = PrivateCallDataValidatorBuilder::new(); + + let request = builder.private_call.build_tx_request(); + + builder.validate_against_tx_request(request); +} + +#[test(should_fail_with="origin address does not match call stack items contract address")] +fn validate_against_tx_request_mismatch_contract_address_fails() { + let builder = PrivateCallDataValidatorBuilder::new(); + + let mut request = builder.private_call.build_tx_request(); + // Tweak the origin to be a different value. + request.origin.inner += 1; + + builder.validate_against_tx_request(request); +} + +#[test(should_fail_with="tx_request function_data must match call_stack_item function_data")] +fn validate_against_tx_request_mismatch_function_data_fails() { + let builder = PrivateCallDataValidatorBuilder::new(); + + let mut request = builder.private_call.build_tx_request(); + // Tweak the function selector to be a different value. + request.function_data.selector.inner += 1; + + builder.validate_against_tx_request(request); +} + +#[test(should_fail_with="noir function args passed to tx_request must match args in the call_stack_item")] +fn validate_against_tx_request_mismatch_args_hash_fails() { + let builder = PrivateCallDataValidatorBuilder::new(); + + let mut request = builder.private_call.build_tx_request(); + // Tweak the args hash to be a different value. + request.args_hash += 1; + + builder.validate_against_tx_request(request); +} + +#[test(should_fail_with="tx_context in tx_request must match tx_context in call_stack_item")] +fn validate_against_tx_request_mismatch_chain_id_fails() { + let builder = PrivateCallDataValidatorBuilder::new(); + + let mut request = builder.private_call.build_tx_request(); + // Tweak the chain id to be a different value. + request.tx_context.chain_id += 1; + + builder.validate_against_tx_request(request); +} + +#[test(should_fail_with="tx_context in tx_request must match tx_context in call_stack_item")] +fn validate_against_tx_request_mismatch_version_fails() { + let builder = PrivateCallDataValidatorBuilder::new(); + + let mut request = builder.private_call.build_tx_request(); + // Tweak the version to be a different value. + request.tx_context.version += 1; + + builder.validate_against_tx_request(request); +} + +#[test(should_fail_with="Users cannot make a static call")] +fn validate_against_tx_request_static_call_fails() { + let builder = PrivateCallDataValidatorBuilder::new().is_static_call(); + let request = builder.private_call.build_tx_request(); + builder.validate_against_tx_request(request); +} + +#[test(should_fail_with="Users cannot make a delegatecall")] +fn validate_against_tx_request_delegate_call_fails() { + let builder = PrivateCallDataValidatorBuilder::new().is_delegate_call(); + let request = builder.private_call.build_tx_request(); + builder.validate_against_tx_request(request); +} + +/** + * validate_against_call_request + */ + +#[test] +fn validate_against_call_request_succeeds() { + let builder = PrivateCallDataValidatorBuilder::new(); + + let request = builder.private_call.build_call_request(); + + builder.validate_against_call_request(request); +} + +#[test] +fn validate_against_call_request_delegate_call_succeeds() { + let builder = PrivateCallDataValidatorBuilder::new().is_delegate_call(); + + let request = builder.private_call.build_call_request(); + + builder.validate_against_call_request(request); +} + +#[test] +fn validate_against_call_request_static_call_succeeds() { + let builder = PrivateCallDataValidatorBuilder::new().is_static_call(); + + let request = builder.private_call.build_call_request(); + + builder.validate_against_call_request(request); +} + +#[test(should_fail_with="calculated private_call_hash does not match provided private_call_hash at the top of the call stack")] +fn validate_against_call_request_mismatch_hash_fails() { + let builder = PrivateCallDataValidatorBuilder::new(); + + let mut request = builder.private_call.build_call_request(); + // Tweak the hash to be a different value. + request.hash += 1; + + builder.validate_against_call_request(request); +} + +#[test(should_fail_with="caller context cannot be empty for delegate calls")] +fn validate_against_call_request_empty_caller_context_for_delegate_calls_fails() { + let builder = PrivateCallDataValidatorBuilder::new().is_delegate_call(); + + let mut request = builder.private_call.build_call_request(); + request.caller_context = CallerContext::empty(); + + builder.validate_against_call_request(request); +} + +#[test(should_fail_with="call stack msg_sender does not match expected msg_sender for delegate calls")] +fn validate_against_call_request_incorrect_msg_sender_for_delegate_call_fails() { + let builder = PrivateCallDataValidatorBuilder::new().is_delegate_call(); + + let mut request = builder.private_call.build_call_request(); + // Tweak the msg_sender to be a different value. + request.caller_context.msg_sender.inner += 1; + + builder.validate_against_call_request(request); +} + +#[test(should_fail_with="call stack storage address does not match expected contract address for delegate calls")] +fn validate_against_call_request_incorrect_storage_contract_address_for_delegate_call_fails() { + let builder = PrivateCallDataValidatorBuilder::new().is_delegate_call(); + + let mut request = builder.private_call.build_call_request(); + // Tweak the storage contract address to be a different value. + request.caller_context.storage_contract_address.inner += 1; + + builder.validate_against_call_request(request); +} + +#[test(should_fail_with="call stack msg_sender does not match caller contract address")] +fn validate_against_call_request_incorrect_msg_sender_for_regular_call_fails() { + let builder = PrivateCallDataValidatorBuilder::new(); + + let mut request = builder.private_call.build_call_request(); + // Tweak the caller's contract address to be a different value. + request.caller_contract_address.inner += 1; + + builder.validate_against_call_request(request); +} diff --git a/noir-projects/noir-protocol-circuits/crates/types/src/abis/call_request.nr b/noir-projects/noir-protocol-circuits/crates/types/src/abis/call_request.nr index 140b1967ca7..2c667c25f08 100644 --- a/noir-projects/noir-protocol-circuits/crates/types/src/abis/call_request.nr +++ b/noir-projects/noir-protocol-circuits/crates/types/src/abis/call_request.nr @@ -34,12 +34,6 @@ impl Empty for CallRequest { } } -impl CallRequest { - pub fn is_empty(self) -> bool { - self.hash == 0 - } -} - impl Serialize for CallRequest { fn serialize(self) -> [Field; CALL_REQUEST_LENGTH] { let mut fields: BoundedVec = BoundedVec::new(); diff --git a/noir-projects/noir-protocol-circuits/crates/types/src/abis/kernel_circuit_public_inputs/public_kernel_circuit_public_inputs.nr b/noir-projects/noir-protocol-circuits/crates/types/src/abis/kernel_circuit_public_inputs/public_kernel_circuit_public_inputs.nr index c385a96f25a..06722dfce8e 100644 --- a/noir-projects/noir-protocol-circuits/crates/types/src/abis/kernel_circuit_public_inputs/public_kernel_circuit_public_inputs.nr +++ b/noir-projects/noir-protocol-circuits/crates/types/src/abis/kernel_circuit_public_inputs/public_kernel_circuit_public_inputs.nr @@ -17,17 +17,17 @@ impl PublicKernelCircuitPublicInputs { pub fn needs_setup(self) -> bool { // public calls for setup are deposited in the non-revertible public call stack. // if an element is present, we need to run setup - !self.end_non_revertible.public_call_stack[0].is_empty() + self.end_non_revertible.public_call_stack[0].hash != 0 } pub fn needs_app_logic(self) -> bool { // public calls for app logic are deposited in the revertible public call stack. // if an element is present, we need to run app logic - !self.end.public_call_stack[0].is_empty() + self.end.public_call_stack[0].hash != 0 } pub fn needs_teardown(self) -> bool { // the public call specified for teardown, if any, is placed in the teardown call stack - !self.public_teardown_call_stack[0].is_empty() + self.public_teardown_call_stack[0].hash != 0 } } diff --git a/noir-projects/noir-protocol-circuits/crates/types/src/address/aztec_address.nr b/noir-projects/noir-protocol-circuits/crates/types/src/address/aztec_address.nr index 6413bedf15e..6c91a609990 100644 --- a/noir-projects/noir-protocol-circuits/crates/types/src/address/aztec_address.nr +++ b/noir-projects/noir-protocol-circuits/crates/types/src/address/aztec_address.nr @@ -59,25 +59,6 @@ impl AztecAddress { ) } - pub fn compute_from_public_keys_and_partial_address( - nullifier_public_key: GrumpkinPoint, - incoming_public_key: GrumpkinPoint, - outgoing_public_key: GrumpkinPoint, - tagging_public_key: GrumpkinPoint, - partial_address: PartialAddress - ) -> AztecAddress { - let public_keys_hash = PublicKeysHash::compute_new( - nullifier_public_key, - incoming_public_key, - outgoing_public_key, - tagging_public_key - ); - - let computed_address = AztecAddress::compute(public_keys_hash, partial_address); - - computed_address - } - pub fn is_zero(self) -> bool { self.inner == 0 } @@ -93,7 +74,7 @@ impl AztecAddress { } #[test] -fn compute_address_from_partial_and_pubkey() { +fn compute_address_from_partial_and_pub_keys_hash() { let pub_keys_hash = PublicKeysHash::from_field(1); let partial_address = PartialAddress::from_field(2); diff --git a/noir-projects/noir-protocol-circuits/crates/types/src/address/public_keys_hash.nr b/noir-projects/noir-protocol-circuits/crates/types/src/address/public_keys_hash.nr index f91d1383a19..09ad9ba1a15 100644 --- a/noir-projects/noir-protocol-circuits/crates/types/src/address/public_keys_hash.nr +++ b/noir-projects/noir-protocol-circuits/crates/types/src/address/public_keys_hash.nr @@ -38,37 +38,18 @@ impl PublicKeysHash { Self { inner: field } } - // TODO(#5830): When we do this refactor, rename compute_new -> compute - pub fn compute(public_key: GrumpkinPoint) -> Self { - PublicKeysHash::from_field( - pedersen_hash( - [ - public_key.x, - public_key.y - ], - GENERATOR_INDEX__PARTIAL_ADDRESS - ) - ) - } - - // TODO(#5830): When we do this refactor, rename compute_new -> compute - pub fn compute_new( - nullifier_public_key: GrumpkinPoint, - incoming_public_key: GrumpkinPoint, - outgoing_public_key: GrumpkinPoint, - tagging_public_key: GrumpkinPoint - ) -> Self { + pub fn compute(npk_m: GrumpkinPoint, ivpk_m: GrumpkinPoint, ovpk_m: GrumpkinPoint, tpk_m: GrumpkinPoint) -> Self { PublicKeysHash::from_field( poseidon2_hash( [ - nullifier_public_key.x, - nullifier_public_key.y, - incoming_public_key.x, - incoming_public_key.y, - outgoing_public_key.x, - outgoing_public_key.y, - tagging_public_key.x, - tagging_public_key.y, + npk_m.x, + npk_m.y, + ivpk_m.x, + ivpk_m.y, + ovpk_m.x, + ovpk_m.y, + tpk_m.x, + tpk_m.y, GENERATOR_INDEX__PUBLIC_KEYS_HASH ] ) @@ -84,11 +65,14 @@ impl PublicKeysHash { } } -// TODO(#5830): re-enable this test once the compute function is updated -// #[test] -// fn compute_public_keys_hash() { -// let point = GrumpkinPoint { x: 1, y: 2 }; -// let actual = PublicKeysHash::compute(point); -// let expected_public_keys_hash = 0x22d83a089d7650514c2de24cd30185a414d943eaa19817c67bffe2c3183006a3; -// assert(actual.to_field() == expected_public_keys_hash); -// } +#[test] +fn compute_public_keys_hash() { + let npk_m = GrumpkinPoint { x: 1, y: 2 }; + let ivpk_m = GrumpkinPoint { x: 3, y: 4 }; + let ovpk_m = GrumpkinPoint { x: 5, y: 6 }; + let tpk_m = GrumpkinPoint { x: 7, y: 8 }; + + let actual = PublicKeysHash::compute(npk_m, ivpk_m, ovpk_m, tpk_m); + let expected_public_keys_hash = 0x1936abe4f6a920d16a9f6917f10a679507687e2cd935dd1f1cdcb1e908c027f3; + assert(actual.to_field() == expected_public_keys_hash); +} diff --git a/noir-projects/noir-protocol-circuits/crates/types/src/tests/private_call_data_builder.nr b/noir-projects/noir-protocol-circuits/crates/types/src/tests/private_call_data_builder.nr index 999da287c85..68154e67806 100644 --- a/noir-projects/noir-protocol-circuits/crates/types/src/tests/private_call_data_builder.nr +++ b/noir-projects/noir-protocol-circuits/crates/types/src/tests/private_call_data_builder.nr @@ -75,7 +75,7 @@ impl PrivateCallDataBuilder { } pub fn build_tx_request(self) -> TxRequest { - let tx_context = self.build_tx_context(); + let tx_context = self.public_inputs.build_tx_context(); TxRequest { origin: self.contract_address, args_hash: self.public_inputs.args_hash, @@ -84,8 +84,24 @@ impl PrivateCallDataBuilder { } } - pub fn build_tx_context(self) -> TxContext { - self.public_inputs.build_tx_context() + pub fn build_call_request(self) -> CallRequest { + let hash = self.build_call_stack_item().hash(); + let is_delegate_call = self.public_inputs.call_context.is_delegate_call; + let caller_context = if is_delegate_call { + CallerContext { + msg_sender: fixtures::MSG_SENDER, + storage_contract_address: self.public_inputs.call_context.storage_contract_address + } + } else { + CallerContext::empty() + }; + CallRequest { + hash, + caller_contract_address: self.public_inputs.call_context.msg_sender, + caller_context, + start_side_effect_counter: 0, + end_side_effect_counter: 0 + } } pub fn append_private_call_requests(&mut self, num_requests: u64, is_delegate_call: bool) { @@ -100,6 +116,12 @@ impl PrivateCallDataBuilder { self.public_call_stack.extend_from_bounded_vec(call_requests); } + pub fn add_teaddown_call_request(&mut self, is_delegate_call: bool) { + let hash = 99887654; + self.public_inputs.public_teardown_function_hash = hash; + self.public_teardown_call_request = self.generate_call_request(hash, is_delegate_call); + } + fn generate_call_requests( self, requests: BoundedVec, @@ -107,12 +129,6 @@ impl PrivateCallDataBuilder { is_delegate_call: bool ) -> (BoundedVec, BoundedVec) { let value_offset = requests.len(); - let mut caller_context = CallerContext::empty(); - if is_delegate_call { - let call_context = self.public_inputs.call_context; - caller_context.msg_sender = call_context.msg_sender; - caller_context.storage_contract_address = call_context.storage_contract_address; - } let mut call_requests: BoundedVec = BoundedVec::new(); let mut hashes: BoundedVec = BoundedVec::new(); let mut exceeded_len = false; @@ -121,14 +137,7 @@ impl PrivateCallDataBuilder { if !exceeded_len { // The default hash is its index + 7788. let hash = (value_offset + 7788) as Field; - let request = CallRequest { - hash, - caller_contract_address: self.contract_address, - caller_context, - // TODO: populate these - start_side_effect_counter: 0, - end_side_effect_counter: 0 - }; + let request = self.generate_call_request(hash, is_delegate_call); hashes.push(hash); call_requests.push(request); } @@ -136,6 +145,23 @@ impl PrivateCallDataBuilder { (hashes, call_requests) } + fn generate_call_request(self, hash: Field, is_delegate_call: bool) -> CallRequest { + let mut caller_context = CallerContext::empty(); + if is_delegate_call { + let call_context = self.public_inputs.call_context; + caller_context.msg_sender = call_context.msg_sender; + caller_context.storage_contract_address = call_context.storage_contract_address; + } + CallRequest { + hash, + caller_contract_address: self.contract_address, + caller_context, + // TODO: populate these + start_side_effect_counter: 0, + end_side_effect_counter: 0 + } + } + pub fn set_tx_max_block_number(&mut self, max_block_number: u32) { self.public_inputs.max_block_number = MaxBlockNumber::new(max_block_number); } diff --git a/noir-projects/noir-protocol-circuits/crates/types/src/utils/arrays.nr b/noir-projects/noir-protocol-circuits/crates/types/src/utils/arrays.nr index 52c277355c8..a84baf83a2a 100644 --- a/noir-projects/noir-protocol-circuits/crates/types/src/utils/arrays.nr +++ b/noir-projects/noir-protocol-circuits/crates/types/src/utils/arrays.nr @@ -16,21 +16,18 @@ pub fn array_to_bounded_vec(array: [T; N]) -> BoundedVec where T: Em // Routine which validates that all zero values of an array form a contiguous region at the end, i.e., // of the form: [*,*,*...,0,0,0,0] where any * is non-zero. Note that a full array of non-zero values is // valid. -pub fn validate_array(array: [T; N]) where T: Empty + Eq { - let array_length = array.len(); - - let mut first_zero_pos = array_length; - let mut last_non_zero_pos = 0; - - for i in 0..array_length { - let is_empty = is_empty(array[i]); - if !is_empty { - last_non_zero_pos = i; - } else if is_empty & (first_zero_pos == array_length) { - first_zero_pos = i; +pub fn validate_array(array: [T; N]) -> u64 where T: Empty + Eq { + let mut seen_empty = false; + let mut length = 0; + for i in 0..N { + if is_empty(array[i]) { + seen_empty = true; + } else { + assert(seen_empty == false, "invalid array"); + length += 1; } } - assert(last_non_zero_pos <= first_zero_pos, "invalid array"); + length } // Helper method to determine the number of non-zero/empty elements in a validated array (ie, validate_array(array) @@ -147,31 +144,40 @@ pub fn assert_sorted_array( #[test] fn smoke_validate_array() { let valid_array: [Field; 0] = []; - validate_array(valid_array); + assert(validate_array(valid_array) == 0); let valid_array = [0]; - validate_array(valid_array); + assert(validate_array(valid_array) == 0); + + let valid_array = [3]; + assert(validate_array(valid_array) == 1); let valid_array = [1, 2, 3]; - validate_array(valid_array); + assert(validate_array(valid_array) == 3); let valid_array = [1, 2, 3, 0]; - validate_array(valid_array); + assert(validate_array(valid_array) == 3); let valid_array = [1, 2, 3, 0, 0]; - validate_array(valid_array); + assert(validate_array(valid_array) == 3); } #[test(should_fail_with = "invalid array")] -fn smoke_validate_array_invalid() { +fn smoke_validate_array_invalid_case0() { let invalid_array = [0, 1]; - validate_array(invalid_array); + let _ = validate_array(invalid_array); +} + +#[test(should_fail_with = "invalid array")] +fn smoke_validate_array_invalid_case1() { + let invalid_array = [1, 0, 0, 1, 0]; + let _ = validate_array(invalid_array); } #[test(should_fail_with = "invalid array")] fn smoke_validate_array_invalid_case2() { let invalid_array = [0, 0, 0, 0, 1]; - validate_array(invalid_array); + let _ = validate_array(invalid_array); } #[test] diff --git a/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp b/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp index 5afcd68e987..222a7da6399 100644 --- a/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp +++ b/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp @@ -870,7 +870,17 @@ namespace Program { static Sha256Compression bincodeDeserialize(std::vector); }; - std::variant value; + struct ToRadix { + Program::MemoryAddress input; + uint32_t radix; + Program::HeapArray output; + + friend bool operator==(const ToRadix&, const ToRadix&); + std::vector bincodeSerialize() const; + static ToRadix bincodeDeserialize(std::vector); + }; + + std::variant value; friend bool operator==(const BlackBoxOp&, const BlackBoxOp&); std::vector bincodeSerialize() const; @@ -4293,6 +4303,50 @@ Program::BlackBoxOp::Sha256Compression serde::Deserializable BlackBoxOp::ToRadix::bincodeSerialize() const { + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); + } + + inline BlackBoxOp::ToRadix BlackBoxOp::ToRadix::bincodeDeserialize(std::vector input) { + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw serde::deserialization_error("Some input bytes were not read"); + } + return value; + } + +} // end of namespace Program + +template <> +template +void serde::Serializable::serialize(const Program::BlackBoxOp::ToRadix &obj, Serializer &serializer) { + serde::Serializable::serialize(obj.input, serializer); + serde::Serializable::serialize(obj.radix, serializer); + serde::Serializable::serialize(obj.output, serializer); +} + +template <> +template +Program::BlackBoxOp::ToRadix serde::Deserializable::deserialize(Deserializer &deserializer) { + Program::BlackBoxOp::ToRadix obj; + obj.input = serde::Deserializable::deserialize(deserializer); + obj.radix = serde::Deserializable::deserialize(deserializer); + obj.output = serde::Deserializable::deserialize(deserializer); + return obj; +} + namespace Program { inline bool operator==(const BlockId &lhs, const BlockId &rhs) { diff --git a/noir/noir-repo/acvm-repo/acvm_js/build.sh b/noir/noir-repo/acvm-repo/acvm_js/build.sh index c07d2d8a4c1..ee93413ab85 100755 --- a/noir/noir-repo/acvm-repo/acvm_js/build.sh +++ b/noir/noir-repo/acvm-repo/acvm_js/build.sh @@ -25,7 +25,7 @@ function run_if_available { require_command jq require_command cargo require_command wasm-bindgen -#require_command wasm-opt +# require_command wasm-opt self_path=$(dirname "$(readlink -f "$0")") pname=$(cargo read-manifest | jq -r '.name') diff --git a/noir/noir-repo/acvm-repo/brillig/src/black_box.rs b/noir/noir-repo/acvm-repo/brillig/src/black_box.rs index 15abc19ed90..9a66b428dc3 100644 --- a/noir/noir-repo/acvm-repo/brillig/src/black_box.rs +++ b/noir/noir-repo/acvm-repo/brillig/src/black_box.rs @@ -126,4 +126,9 @@ pub enum BlackBoxOp { hash_values: HeapVector, output: HeapArray, }, + ToRadix { + input: MemoryAddress, + radix: u32, + output: HeapArray, + }, } diff --git a/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs b/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs index c999b5bf330..d6ecd25f454 100644 --- a/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs +++ b/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs @@ -5,6 +5,7 @@ use acvm_blackbox_solver::{ aes128_encrypt, blake2s, blake3, ecdsa_secp256k1_verify, ecdsa_secp256r1_verify, keccak256, keccakf1600, sha256, sha256compression, BlackBoxFunctionSolver, BlackBoxResolutionError, }; +use num_bigint::BigUint; use crate::memory::MemoryValue; use crate::Memory; @@ -295,6 +296,25 @@ pub(crate) fn evaluate_black_box( memory.write_slice(memory.read_ref(output.pointer), &state); Ok(()) } + BlackBoxOp::ToRadix { input, radix, output } => { + let input: FieldElement = + memory.read(*input).try_into().expect("ToRadix input not a field"); + + let mut input = BigUint::from_bytes_be(&input.to_be_bytes()); + let radix = BigUint::from(*radix); + + let mut limbs: Vec = Vec::with_capacity(output.size); + + for _ in 0..output.size { + let limb = &input % &radix; + limbs.push(FieldElement::from_be_bytes_reduce(&limb.to_bytes_be()).into()); + input /= &radix; + } + + memory.write_slice(memory.read_ref(output.pointer), &limbs); + + Ok(()) + } } } @@ -321,6 +341,7 @@ fn black_box_function_from_op(op: &BlackBoxOp) -> BlackBoxFunc { BlackBoxOp::BigIntToLeBytes { .. } => BlackBoxFunc::BigIntToLeBytes, BlackBoxOp::Poseidon2Permutation { .. } => BlackBoxFunc::Poseidon2Permutation, BlackBoxOp::Sha256Compression { .. } => BlackBoxFunc::Sha256Compression, + BlackBoxOp::ToRadix { .. } => unreachable!("ToRadix is not an ACIR BlackBoxFunc"), } } diff --git a/noir/noir-repo/compiler/noirc_driver/src/lib.rs b/noir/noir-repo/compiler/noirc_driver/src/lib.rs index 5f1985b0553..d7368f299b8 100644 --- a/noir/noir-repo/compiler/noirc_driver/src/lib.rs +++ b/noir/noir-repo/compiler/noirc_driver/src/lib.rs @@ -84,10 +84,6 @@ pub struct CompileOptions { #[arg(long, conflicts_with = "deny_warnings")] pub silence_warnings: bool, - /// Output ACIR gzipped bytecode instead of the JSON artefact - #[arg(long, hide = true)] - pub only_acir: bool, - /// Disables the builtin Aztec macros being used in the compiler #[arg(long, hide = true)] pub disable_macros: bool, @@ -103,6 +99,10 @@ pub struct CompileOptions { /// Force Brillig output (for step debugging) #[arg(long, hide = true)] pub force_brillig: bool, + + /// Enable the experimental elaborator pass + #[arg(long, hide = true)] + pub use_elaborator: bool, } fn parse_expression_width(input: &str) -> Result { @@ -245,12 +245,13 @@ pub fn check_crate( crate_id: CrateId, deny_warnings: bool, disable_macros: bool, + use_elaborator: bool, ) -> CompilationResult<()> { let macros: &[&dyn MacroProcessor] = if disable_macros { &[] } else { &[&aztec_macros::AztecMacro as &dyn MacroProcessor] }; let mut errors = vec![]; - let diagnostics = CrateDefMap::collect_defs(crate_id, context, macros); + let diagnostics = CrateDefMap::collect_defs(crate_id, context, use_elaborator, macros); errors.extend(diagnostics.into_iter().map(|(error, file_id)| { let diagnostic = CustomDiagnostic::from(&error); diagnostic.in_file(file_id) @@ -282,8 +283,13 @@ pub fn compile_main( options: &CompileOptions, cached_program: Option, ) -> CompilationResult { - let (_, mut warnings) = - check_crate(context, crate_id, options.deny_warnings, options.disable_macros)?; + let (_, mut warnings) = check_crate( + context, + crate_id, + options.deny_warnings, + options.disable_macros, + options.use_elaborator, + )?; let main = context.get_main_function(&crate_id).ok_or_else(|| { // TODO(#2155): This error might be a better to exist in Nargo @@ -318,8 +324,13 @@ pub fn compile_contract( crate_id: CrateId, options: &CompileOptions, ) -> CompilationResult { - let (_, warnings) = - check_crate(context, crate_id, options.deny_warnings, options.disable_macros)?; + let (_, warnings) = check_crate( + context, + crate_id, + options.deny_warnings, + options.disable_macros, + options.use_elaborator, + )?; // TODO: We probably want to error if contracts is empty let contracts = context.get_all_contracts(&crate_id); diff --git a/noir/noir-repo/compiler/noirc_driver/tests/stdlib_warnings.rs b/noir/noir-repo/compiler/noirc_driver/tests/stdlib_warnings.rs index 6f437621123..327c8daad06 100644 --- a/noir/noir-repo/compiler/noirc_driver/tests/stdlib_warnings.rs +++ b/noir/noir-repo/compiler/noirc_driver/tests/stdlib_warnings.rs @@ -24,7 +24,8 @@ fn stdlib_does_not_produce_constant_warnings() -> Result<(), ErrorsAndWarnings> let mut context = Context::new(file_manager, parsed_files); let root_crate_id = prepare_crate(&mut context, file_name); - let ((), warnings) = noirc_driver::check_crate(&mut context, root_crate_id, false, false)?; + let ((), warnings) = + noirc_driver::check_crate(&mut context, root_crate_id, false, false, false)?; assert_eq!(warnings, Vec::new(), "stdlib is producing warnings"); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index f660c8e0b7a..6a4f9f5cc0e 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -488,8 +488,22 @@ impl<'block> BrilligBlock<'block> { } Value::Intrinsic(Intrinsic::ToRadix(endianness)) => { let source = self.convert_ssa_single_addr_value(arguments[0], dfg); - let radix = self.convert_ssa_single_addr_value(arguments[1], dfg); - let limb_count = self.convert_ssa_single_addr_value(arguments[2], dfg); + + let radix: u32 = dfg + .get_numeric_constant(arguments[1]) + .expect("Radix should be known") + .try_to_u64() + .expect("Radix should fit in u64") + .try_into() + .expect("Radix should be u32"); + + let limb_count: usize = dfg + .get_numeric_constant(arguments[2]) + .expect("Limb count should be known") + .try_to_u64() + .expect("Limb count should fit in u64") + .try_into() + .expect("Limb count should fit in usize"); let results = dfg.instruction_results(instruction_id); @@ -511,7 +525,8 @@ impl<'block> BrilligBlock<'block> { .extract_vector(); // Update the user-facing slice length - self.brillig_context.cast_instruction(target_len, limb_count); + self.brillig_context + .usize_const_instruction(target_len.address, limb_count.into()); self.brillig_context.codegen_to_radix( source, @@ -524,7 +539,13 @@ impl<'block> BrilligBlock<'block> { } Value::Intrinsic(Intrinsic::ToBits(endianness)) => { let source = self.convert_ssa_single_addr_value(arguments[0], dfg); - let limb_count = self.convert_ssa_single_addr_value(arguments[1], dfg); + let limb_count: usize = dfg + .get_numeric_constant(arguments[1]) + .expect("Limb count should be known") + .try_to_u64() + .expect("Limb count should fit in u64") + .try_into() + .expect("Limb count should fit in usize"); let results = dfg.instruction_results(instruction_id); @@ -549,21 +570,18 @@ impl<'block> BrilligBlock<'block> { BrilligVariable::SingleAddr(..) => unreachable!("ICE: ToBits on non-array"), }; - let radix = self.brillig_context.make_constant_instruction(2_usize.into(), 32); - // Update the user-facing slice length - self.brillig_context.cast_instruction(target_len, limb_count); + self.brillig_context + .usize_const_instruction(target_len.address, limb_count.into()); self.brillig_context.codegen_to_radix( source, target_vector, - radix, + 2, limb_count, matches!(endianness, Endian::Big), 1, ); - - self.brillig_context.deallocate_single_addr(radix); } _ => { unreachable!("unsupported function call type {:?}", dfg[*func]) diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_intrinsic.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_intrinsic.rs index ab756217bcd..58166554e1d 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_intrinsic.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_intrinsic.rs @@ -1,6 +1,7 @@ -use acvm::FieldElement; - -use crate::brillig::brillig_ir::BrilligBinaryOp; +use acvm::{ + acir::brillig::{BlackBoxOp, HeapArray}, + FieldElement, +}; use super::{ brillig_variable::{BrilligVector, SingleAddrVariable}, @@ -36,57 +37,46 @@ impl BrilligContext { &mut self, source_field: SingleAddrVariable, target_vector: BrilligVector, - radix: SingleAddrVariable, - limb_count: SingleAddrVariable, + radix: u32, + limb_count: usize, big_endian: bool, limb_bit_size: u32, ) { assert!(source_field.bit_size == FieldElement::max_num_bits()); - assert!(radix.bit_size == 32); - assert!(limb_count.bit_size == 32); - let radix_as_field = - SingleAddrVariable::new(self.allocate_register(), FieldElement::max_num_bits()); - self.cast_instruction(radix_as_field, radix); - self.cast_instruction(SingleAddrVariable::new_usize(target_vector.size), limb_count); + self.usize_const_instruction(target_vector.size, limb_count.into()); self.usize_const_instruction(target_vector.rc, 1_usize.into()); self.codegen_allocate_array(target_vector.pointer, target_vector.size); - let shifted_field = - SingleAddrVariable::new(self.allocate_register(), FieldElement::max_num_bits()); - self.mov_instruction(shifted_field.address, source_field.address); + self.black_box_op_instruction(BlackBoxOp::ToRadix { + input: source_field.address, + radix, + output: HeapArray { pointer: target_vector.pointer, size: limb_count }, + }); let limb_field = SingleAddrVariable::new(self.allocate_register(), FieldElement::max_num_bits()); let limb_casted = SingleAddrVariable::new(self.allocate_register(), limb_bit_size); - self.codegen_loop(target_vector.size, |ctx, iterator_register| { - // Compute the modulus - ctx.binary_instruction( - shifted_field, - radix_as_field, - limb_field, - BrilligBinaryOp::Modulo, - ); - // Cast it - ctx.cast_instruction(limb_casted, limb_field); - // Write it - ctx.codegen_array_set(target_vector.pointer, iterator_register, limb_casted.address); - // Integer div the field - ctx.binary_instruction( - shifted_field, - radix_as_field, - shifted_field, - BrilligBinaryOp::UnsignedDiv, - ); - }); + if limb_bit_size != FieldElement::max_num_bits() { + self.codegen_loop(target_vector.size, |ctx, iterator_register| { + // Read the limb + ctx.codegen_array_get(target_vector.pointer, iterator_register, limb_field.address); + // Cast it + ctx.cast_instruction(limb_casted, limb_field); + // Write it + ctx.codegen_array_set( + target_vector.pointer, + iterator_register, + limb_casted.address, + ); + }); + } // Deallocate our temporary registers - self.deallocate_single_addr(shifted_field); self.deallocate_single_addr(limb_field); self.deallocate_single_addr(limb_casted); - self.deallocate_single_addr(radix_as_field); if big_endian { self.codegen_reverse_vector_in_place(target_vector); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs index 667ccf6ddbe..f02f6059e7c 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs @@ -451,6 +451,15 @@ impl DebugShow { output ); } + BlackBoxOp::ToRadix { input, radix, output } => { + debug_println!( + self.enable_debug_trace, + " TO_RADIX {} {} -> {}", + input, + radix, + output + ); + } } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index bddfb25f26c..77b9e545e03 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -11,7 +11,7 @@ use crate::ssa::{ ir::{ basic_block::BasicBlockId, dfg::{CallStack, InsertInstructionResult}, - function::{Function, FunctionId}, + function::{Function, FunctionId, RuntimeType}, instruction::{Instruction, InstructionId, TerminatorInstruction}, value::{Value, ValueId}, }, @@ -392,10 +392,11 @@ impl<'function> PerFunctionContext<'function> { Some(func_id) => { let function = &ssa.functions[&func_id]; // If we have not already finished the flattening pass, functions marked - // to not have predicates should be marked as entry points. + // to not have predicates should be marked as entry points unless we are inlining into brillig. let no_predicates_is_entry_point = self.context.no_predicates_is_entry_point - && function.is_no_predicates(); + && function.is_no_predicates() + && !matches!(self.source_function.runtime(), RuntimeType::Brillig); if function.runtime().is_entry_point() || no_predicates_is_entry_point { self.push_instruction(*id); } else { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/function.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/function.rs index dc426a4642a..8acc068d86a 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/function.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/function.rs @@ -32,6 +32,15 @@ pub enum FunctionKind { Recursive, } +impl FunctionKind { + pub fn can_ignore_return_type(self) -> bool { + match self { + FunctionKind::LowLevel | FunctionKind::Builtin | FunctionKind::Oracle => true, + FunctionKind::Normal | FunctionKind::Recursive => false, + } + } +} + impl NoirFunction { pub fn normal(def: FunctionDefinition) -> NoirFunction { NoirFunction { kind: FunctionKind::Normal, def } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs index 0da39edfd85..94b5841e52c 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs @@ -565,7 +565,7 @@ impl ForRange { identifier: Ident, block: Expression, for_loop_span: Span, - ) -> StatementKind { + ) -> Statement { /// Counter used to generate unique names when desugaring /// code in the parser requires the creation of fresh variables. /// The parser is stateless so this is a static global instead. @@ -662,7 +662,8 @@ impl ForRange { let block = ExpressionKind::Block(BlockExpression { statements: vec![let_array, for_loop], }); - StatementKind::Expression(Expression::new(block, for_loop_span)) + let kind = StatementKind::Expression(Expression::new(block, for_loop_span)); + Statement { kind, span: for_loop_span } } } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs new file mode 100644 index 00000000000..ed8ed5305d1 --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -0,0 +1,604 @@ +use iter_extended::vecmap; +use noirc_errors::{Location, Span}; +use regex::Regex; +use rustc_hash::FxHashSet as HashSet; + +use crate::{ + ast::{ + ArrayLiteral, ConstructorExpression, IfExpression, InfixExpression, Lambda, + UnresolvedTypeExpression, + }, + hir::{ + resolution::{errors::ResolverError, resolver::LambdaContext}, + type_check::TypeCheckError, + }, + hir_def::{ + expr::{ + HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCastExpression, + HirConstructorExpression, HirIdent, HirIfExpression, HirIndexExpression, + HirInfixExpression, HirLambda, HirMemberAccess, HirMethodCallExpression, + HirMethodReference, HirPrefixExpression, + }, + traits::TraitConstraint, + }, + macros_api::{ + BlockExpression, CallExpression, CastExpression, Expression, ExpressionKind, HirExpression, + HirLiteral, HirStatement, Ident, IndexExpression, Literal, MemberAccessExpression, + MethodCallExpression, PrefixExpression, + }, + node_interner::{DefinitionKind, ExprId, FuncId}, + Shared, StructType, Type, +}; + +use super::Elaborator; + +impl<'context> Elaborator<'context> { + pub(super) fn elaborate_expression(&mut self, expr: Expression) -> (ExprId, Type) { + let (hir_expr, typ) = match expr.kind { + ExpressionKind::Literal(literal) => self.elaborate_literal(literal, expr.span), + ExpressionKind::Block(block) => self.elaborate_block(block), + ExpressionKind::Prefix(prefix) => self.elaborate_prefix(*prefix), + ExpressionKind::Index(index) => self.elaborate_index(*index), + ExpressionKind::Call(call) => self.elaborate_call(*call, expr.span), + ExpressionKind::MethodCall(call) => self.elaborate_method_call(*call, expr.span), + ExpressionKind::Constructor(constructor) => self.elaborate_constructor(*constructor), + ExpressionKind::MemberAccess(access) => { + return self.elaborate_member_access(*access, expr.span) + } + ExpressionKind::Cast(cast) => self.elaborate_cast(*cast, expr.span), + ExpressionKind::Infix(infix) => return self.elaborate_infix(*infix, expr.span), + ExpressionKind::If(if_) => self.elaborate_if(*if_), + ExpressionKind::Variable(variable) => return self.elaborate_variable(variable), + ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple), + ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda), + ExpressionKind::Parenthesized(expr) => return self.elaborate_expression(*expr), + ExpressionKind::Quote(quote) => self.elaborate_quote(quote), + ExpressionKind::Comptime(comptime) => self.elaborate_comptime_block(comptime), + ExpressionKind::Error => (HirExpression::Error, Type::Error), + }; + let id = self.interner.push_expr(hir_expr); + self.interner.push_expr_location(id, expr.span, self.file); + self.interner.push_expr_type(id, typ.clone()); + (id, typ) + } + + pub(super) fn elaborate_block(&mut self, block: BlockExpression) -> (HirExpression, Type) { + self.push_scope(); + let mut block_type = Type::Unit; + let mut statements = Vec::with_capacity(block.statements.len()); + + for (i, statement) in block.statements.into_iter().enumerate() { + let (id, stmt_type) = self.elaborate_statement(statement); + statements.push(id); + + if let HirStatement::Semi(expr) = self.interner.statement(&id) { + let inner_expr_type = self.interner.id_type(expr); + let span = self.interner.expr_span(&expr); + + self.unify(&inner_expr_type, &Type::Unit, || TypeCheckError::UnusedResultError { + expr_type: inner_expr_type.clone(), + expr_span: span, + }); + + if i + 1 == statements.len() { + block_type = stmt_type; + } + } + } + + self.pop_scope(); + (HirExpression::Block(HirBlockExpression { statements }), block_type) + } + + fn elaborate_literal(&mut self, literal: Literal, span: Span) -> (HirExpression, Type) { + use HirExpression::Literal as Lit; + match literal { + Literal::Unit => (Lit(HirLiteral::Unit), Type::Unit), + Literal::Bool(b) => (Lit(HirLiteral::Bool(b)), Type::Bool), + Literal::Integer(integer, sign) => { + let int = HirLiteral::Integer(integer, sign); + (Lit(int), self.polymorphic_integer_or_field()) + } + Literal::Str(str) | Literal::RawStr(str, _) => { + let len = Type::Constant(str.len() as u64); + (Lit(HirLiteral::Str(str)), Type::String(Box::new(len))) + } + Literal::FmtStr(str) => self.elaborate_fmt_string(str, span), + Literal::Array(array_literal) => { + self.elaborate_array_literal(array_literal, span, true) + } + Literal::Slice(array_literal) => { + self.elaborate_array_literal(array_literal, span, false) + } + } + } + + fn elaborate_array_literal( + &mut self, + array_literal: ArrayLiteral, + span: Span, + is_array: bool, + ) -> (HirExpression, Type) { + let (expr, elem_type, length) = match array_literal { + ArrayLiteral::Standard(elements) => { + let first_elem_type = self.interner.next_type_variable(); + let first_span = elements.first().map(|elem| elem.span).unwrap_or(span); + + let elements = vecmap(elements.into_iter().enumerate(), |(i, elem)| { + let span = elem.span; + let (elem_id, elem_type) = self.elaborate_expression(elem); + + self.unify(&elem_type, &first_elem_type, || { + TypeCheckError::NonHomogeneousArray { + first_span, + first_type: first_elem_type.to_string(), + first_index: 0, + second_span: span, + second_type: elem_type.to_string(), + second_index: i, + } + .add_context("elements in an array must have the same type") + }); + elem_id + }); + + let length = Type::Constant(elements.len() as u64); + (HirArrayLiteral::Standard(elements), first_elem_type, length) + } + ArrayLiteral::Repeated { repeated_element, length } => { + let span = length.span; + let length = + UnresolvedTypeExpression::from_expr(*length, span).unwrap_or_else(|error| { + self.push_err(ResolverError::ParserError(Box::new(error))); + UnresolvedTypeExpression::Constant(0, span) + }); + + let length = self.convert_expression_type(length); + let (repeated_element, elem_type) = self.elaborate_expression(*repeated_element); + + let length_clone = length.clone(); + (HirArrayLiteral::Repeated { repeated_element, length }, elem_type, length_clone) + } + }; + let constructor = if is_array { HirLiteral::Array } else { HirLiteral::Slice }; + let elem_type = Box::new(elem_type); + let typ = if is_array { + Type::Array(Box::new(length), elem_type) + } else { + Type::Slice(elem_type) + }; + (HirExpression::Literal(constructor(expr)), typ) + } + + fn elaborate_fmt_string(&mut self, str: String, call_expr_span: Span) -> (HirExpression, Type) { + let re = Regex::new(r"\{([a-zA-Z0-9_]+)\}") + .expect("ICE: an invalid regex pattern was used for checking format strings"); + + let mut fmt_str_idents = Vec::new(); + let mut capture_types = Vec::new(); + + for field in re.find_iter(&str) { + let matched_str = field.as_str(); + let ident_name = &matched_str[1..(matched_str.len() - 1)]; + + let scope_tree = self.scopes.current_scope_tree(); + let variable = scope_tree.find(ident_name); + if let Some((old_value, _)) = variable { + old_value.num_times_used += 1; + let ident = HirExpression::Ident(old_value.ident.clone()); + let expr_id = self.interner.push_expr(ident); + self.interner.push_expr_location(expr_id, call_expr_span, self.file); + let ident = old_value.ident.clone(); + let typ = self.type_check_variable(ident, expr_id); + self.interner.push_expr_type(expr_id, typ.clone()); + capture_types.push(typ); + fmt_str_idents.push(expr_id); + } else if ident_name.parse::().is_ok() { + self.push_err(ResolverError::NumericConstantInFormatString { + name: ident_name.to_owned(), + span: call_expr_span, + }); + } else { + self.push_err(ResolverError::VariableNotDeclared { + name: ident_name.to_owned(), + span: call_expr_span, + }); + } + } + + let len = Type::Constant(str.len() as u64); + let typ = Type::FmtString(Box::new(len), Box::new(Type::Tuple(capture_types))); + (HirExpression::Literal(HirLiteral::FmtStr(str, fmt_str_idents)), typ) + } + + fn elaborate_prefix(&mut self, prefix: PrefixExpression) -> (HirExpression, Type) { + let span = prefix.rhs.span; + let (rhs, rhs_type) = self.elaborate_expression(prefix.rhs); + let ret_type = self.type_check_prefix_operand(&prefix.operator, &rhs_type, span); + (HirExpression::Prefix(HirPrefixExpression { operator: prefix.operator, rhs }), ret_type) + } + + fn elaborate_index(&mut self, index_expr: IndexExpression) -> (HirExpression, Type) { + let span = index_expr.index.span; + let (index, index_type) = self.elaborate_expression(index_expr.index); + + let expected = self.polymorphic_integer_or_field(); + self.unify(&index_type, &expected, || TypeCheckError::TypeMismatch { + expected_typ: "an integer".to_owned(), + expr_typ: index_type.to_string(), + expr_span: span, + }); + + // When writing `a[i]`, if `a : &mut ...` then automatically dereference `a` as many + // times as needed to get the underlying array. + let lhs_span = index_expr.collection.span; + let (lhs, lhs_type) = self.elaborate_expression(index_expr.collection); + let (collection, lhs_type) = self.insert_auto_dereferences(lhs, lhs_type); + + let typ = match lhs_type.follow_bindings() { + // XXX: We can check the array bounds here also, but it may be better to constant fold first + // and have ConstId instead of ExprId for constants + Type::Array(_, base_type) => *base_type, + Type::Slice(base_type) => *base_type, + Type::Error => Type::Error, + typ => { + self.push_err(TypeCheckError::TypeMismatch { + expected_typ: "Array".to_owned(), + expr_typ: typ.to_string(), + expr_span: lhs_span, + }); + Type::Error + } + }; + + let expr = HirExpression::Index(HirIndexExpression { collection, index }); + (expr, typ) + } + + fn elaborate_call(&mut self, call: CallExpression, span: Span) -> (HirExpression, Type) { + let (func, func_type) = self.elaborate_expression(*call.func); + + let mut arguments = Vec::with_capacity(call.arguments.len()); + let args = vecmap(call.arguments, |arg| { + let span = arg.span; + let (arg, typ) = self.elaborate_expression(arg); + arguments.push(arg); + (typ, arg, span) + }); + + let location = Location::new(span, self.file); + let call = HirCallExpression { func, arguments, location }; + let typ = self.type_check_call(&call, func_type, args, span); + (HirExpression::Call(call), typ) + } + + fn elaborate_method_call( + &mut self, + method_call: MethodCallExpression, + span: Span, + ) -> (HirExpression, Type) { + let object_span = method_call.object.span; + let (mut object, mut object_type) = self.elaborate_expression(method_call.object); + object_type = object_type.follow_bindings(); + + let method_name = method_call.method_name.0.contents.as_str(); + match self.lookup_method(&object_type, method_name, span) { + Some(method_ref) => { + // Automatically add `&mut` if the method expects a mutable reference and + // the object is not already one. + if let HirMethodReference::FuncId(func_id) = &method_ref { + if *func_id != FuncId::dummy_id() { + let function_type = self.interner.function_meta(func_id).typ.clone(); + + self.try_add_mutable_reference_to_object( + &function_type, + &mut object_type, + &mut object, + ); + } + } + + // These arguments will be given to the desugared function call. + // Compared to the method arguments, they also contain the object. + let mut function_args = Vec::with_capacity(method_call.arguments.len() + 1); + let mut arguments = Vec::with_capacity(method_call.arguments.len()); + + function_args.push((object_type.clone(), object, object_span)); + + for arg in method_call.arguments { + let span = arg.span; + let (arg, typ) = self.elaborate_expression(arg); + arguments.push(arg); + function_args.push((typ, arg, span)); + } + + let location = Location::new(span, self.file); + let method = method_call.method_name; + let method_call = HirMethodCallExpression { method, object, arguments, location }; + + // Desugar the method call into a normal, resolved function call + // so that the backend doesn't need to worry about methods + // TODO: update object_type here? + let ((function_id, function_name), function_call) = method_call.into_function_call( + &method_ref, + object_type, + location, + self.interner, + ); + + let func_type = self.type_check_variable(function_name, function_id); + + // Type check the new call now that it has been changed from a method call + // to a function call. This way we avoid duplicating code. + let typ = self.type_check_call(&function_call, func_type, function_args, span); + (HirExpression::Call(function_call), typ) + } + None => (HirExpression::Error, Type::Error), + } + } + + fn elaborate_constructor( + &mut self, + constructor: ConstructorExpression, + ) -> (HirExpression, Type) { + let span = constructor.type_name.span(); + + match self.lookup_type_or_error(constructor.type_name) { + Some(Type::Struct(r#type, struct_generics)) => { + let struct_type = r#type.clone(); + let generics = struct_generics.clone(); + + let fields = constructor.fields; + let field_types = r#type.borrow().get_fields(&struct_generics); + let fields = self.resolve_constructor_expr_fields( + struct_type.clone(), + field_types, + fields, + span, + ); + let expr = HirExpression::Constructor(HirConstructorExpression { + fields, + r#type, + struct_generics, + }); + (expr, Type::Struct(struct_type, generics)) + } + Some(typ) => { + self.push_err(ResolverError::NonStructUsedInConstructor { typ, span }); + (HirExpression::Error, Type::Error) + } + None => (HirExpression::Error, Type::Error), + } + } + + /// Resolve all the fields of a struct constructor expression. + /// Ensures all fields are present, none are repeated, and all + /// are part of the struct. + fn resolve_constructor_expr_fields( + &mut self, + struct_type: Shared, + field_types: Vec<(String, Type)>, + fields: Vec<(Ident, Expression)>, + span: Span, + ) -> Vec<(Ident, ExprId)> { + let mut ret = Vec::with_capacity(fields.len()); + let mut seen_fields = HashSet::default(); + let mut unseen_fields = struct_type.borrow().field_names(); + + for (field_name, field) in fields { + let expected_type = field_types.iter().find(|(name, _)| name == &field_name.0.contents); + let expected_type = expected_type.map(|(_, typ)| typ).unwrap_or(&Type::Error); + + let field_span = field.span; + let (resolved, field_type) = self.elaborate_expression(field); + + if unseen_fields.contains(&field_name) { + unseen_fields.remove(&field_name); + seen_fields.insert(field_name.clone()); + + self.unify_with_coercions(&field_type, expected_type, resolved, || { + TypeCheckError::TypeMismatch { + expected_typ: expected_type.to_string(), + expr_typ: field_type.to_string(), + expr_span: field_span, + } + }); + } else if seen_fields.contains(&field_name) { + // duplicate field + self.push_err(ResolverError::DuplicateField { field: field_name.clone() }); + } else { + // field not required by struct + self.push_err(ResolverError::NoSuchField { + field: field_name.clone(), + struct_definition: struct_type.borrow().name.clone(), + }); + } + + ret.push((field_name, resolved)); + } + + if !unseen_fields.is_empty() { + self.push_err(ResolverError::MissingFields { + span, + missing_fields: unseen_fields.into_iter().map(|field| field.to_string()).collect(), + struct_definition: struct_type.borrow().name.clone(), + }); + } + + ret + } + + fn elaborate_member_access( + &mut self, + access: MemberAccessExpression, + span: Span, + ) -> (ExprId, Type) { + let (lhs, lhs_type) = self.elaborate_expression(access.lhs); + let rhs = access.rhs; + // `is_offset` is only used when lhs is a reference and we want to return a reference to rhs + let access = HirMemberAccess { lhs, rhs, is_offset: false }; + let expr_id = self.intern_expr(HirExpression::MemberAccess(access.clone()), span); + let typ = self.type_check_member_access(access, expr_id, lhs_type, span); + self.interner.push_expr_type(expr_id, typ.clone()); + (expr_id, typ) + } + + pub fn intern_expr(&mut self, expr: HirExpression, span: Span) -> ExprId { + let id = self.interner.push_expr(expr); + self.interner.push_expr_location(id, span, self.file); + id + } + + fn elaborate_cast(&mut self, cast: CastExpression, span: Span) -> (HirExpression, Type) { + let (lhs, lhs_type) = self.elaborate_expression(cast.lhs); + let r#type = self.resolve_type(cast.r#type); + let result = self.check_cast(lhs_type, &r#type, span); + let expr = HirExpression::Cast(HirCastExpression { lhs, r#type }); + (expr, result) + } + + fn elaborate_infix(&mut self, infix: InfixExpression, span: Span) -> (ExprId, Type) { + let (lhs, lhs_type) = self.elaborate_expression(infix.lhs); + let (rhs, rhs_type) = self.elaborate_expression(infix.rhs); + let trait_id = self.interner.get_operator_trait_method(infix.operator.contents); + + let operator = HirBinaryOp::new(infix.operator, self.file); + let expr = HirExpression::Infix(HirInfixExpression { + lhs, + operator, + trait_method_id: trait_id, + rhs, + }); + + let expr_id = self.interner.push_expr(expr); + self.interner.push_expr_location(expr_id, span, self.file); + + let typ = match self.infix_operand_type_rules(&lhs_type, &operator, &rhs_type, span) { + Ok((typ, use_impl)) => { + if use_impl { + // Delay checking the trait constraint until the end of the function. + // Checking it now could bind an unbound type variable to any type + // that implements the trait. + let constraint = TraitConstraint { + typ: lhs_type.clone(), + trait_id: trait_id.trait_id, + trait_generics: Vec::new(), + }; + self.trait_constraints.push((constraint, expr_id)); + self.type_check_operator_method(expr_id, trait_id, &lhs_type, span); + } + typ + } + Err(error) => { + self.push_err(error); + Type::Error + } + }; + + self.interner.push_expr_type(expr_id, typ.clone()); + (expr_id, typ) + } + + fn elaborate_if(&mut self, if_expr: IfExpression) -> (HirExpression, Type) { + let expr_span = if_expr.condition.span; + let (condition, cond_type) = self.elaborate_expression(if_expr.condition); + let (consequence, mut ret_type) = self.elaborate_expression(if_expr.consequence); + + self.unify(&cond_type, &Type::Bool, || TypeCheckError::TypeMismatch { + expected_typ: Type::Bool.to_string(), + expr_typ: cond_type.to_string(), + expr_span, + }); + + let alternative = if_expr.alternative.map(|alternative| { + let expr_span = alternative.span; + let (else_, else_type) = self.elaborate_expression(alternative); + + self.unify(&ret_type, &else_type, || { + let err = TypeCheckError::TypeMismatch { + expected_typ: ret_type.to_string(), + expr_typ: else_type.to_string(), + expr_span, + }; + + let context = if ret_type == Type::Unit { + "Are you missing a semicolon at the end of your 'else' branch?" + } else if else_type == Type::Unit { + "Are you missing a semicolon at the end of the first block of this 'if'?" + } else { + "Expected the types of both if branches to be equal" + }; + + err.add_context(context) + }); + else_ + }); + + if alternative.is_none() { + ret_type = Type::Unit; + } + + let if_expr = HirIfExpression { condition, consequence, alternative }; + (HirExpression::If(if_expr), ret_type) + } + + fn elaborate_tuple(&mut self, tuple: Vec) -> (HirExpression, Type) { + let mut element_ids = Vec::with_capacity(tuple.len()); + let mut element_types = Vec::with_capacity(tuple.len()); + + for element in tuple { + let (id, typ) = self.elaborate_expression(element); + element_ids.push(id); + element_types.push(typ); + } + + (HirExpression::Tuple(element_ids), Type::Tuple(element_types)) + } + + fn elaborate_lambda(&mut self, lambda: Lambda) -> (HirExpression, Type) { + self.push_scope(); + let scope_index = self.scopes.current_scope_index(); + + self.lambda_stack.push(LambdaContext { captures: Vec::new(), scope_index }); + + let mut arg_types = Vec::with_capacity(lambda.parameters.len()); + let parameters = vecmap(lambda.parameters, |(pattern, typ)| { + let parameter = DefinitionKind::Local(None); + let typ = self.resolve_inferred_type(typ); + arg_types.push(typ.clone()); + (self.elaborate_pattern(pattern, typ.clone(), parameter), typ) + }); + + let return_type = self.resolve_inferred_type(lambda.return_type); + let body_span = lambda.body.span; + let (body, body_type) = self.elaborate_expression(lambda.body); + + let lambda_context = self.lambda_stack.pop().unwrap(); + self.pop_scope(); + + self.unify(&body_type, &return_type, || TypeCheckError::TypeMismatch { + expected_typ: return_type.to_string(), + expr_typ: body_type.to_string(), + expr_span: body_span, + }); + + let captured_vars = vecmap(&lambda_context.captures, |capture| { + self.interner.definition_type(capture.ident.id) + }); + + let env_type = + if captured_vars.is_empty() { Type::Unit } else { Type::Tuple(captured_vars) }; + + let captures = lambda_context.captures; + let expr = HirExpression::Lambda(HirLambda { parameters, return_type, body, captures }); + (expr, Type::Function(arg_types, Box::new(body_type), Box::new(env_type))) + } + + fn elaborate_quote(&mut self, block: BlockExpression) -> (HirExpression, Type) { + (HirExpression::Quote(block), Type::Code) + } + + fn elaborate_comptime_block(&mut self, _comptime: BlockExpression) -> (HirExpression, Type) { + todo!("Elaborate comptime block") + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs new file mode 100644 index 00000000000..446e5b62ead --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs @@ -0,0 +1,782 @@ +#![allow(unused)] +use std::{ + collections::{BTreeMap, BTreeSet}, + rc::Rc, +}; + +use crate::hir::def_map::CrateDefMap; +use crate::{ + ast::{ + ArrayLiteral, ConstructorExpression, FunctionKind, IfExpression, InfixExpression, Lambda, + UnresolvedTraitConstraint, UnresolvedTypeExpression, + }, + hir::{ + def_collector::dc_crate::CompilationError, + resolution::{errors::ResolverError, path_resolver::PathResolver, resolver::LambdaContext}, + scope::ScopeForest as GenericScopeForest, + type_check::TypeCheckError, + }, + hir_def::{ + expr::{ + HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCastExpression, + HirConstructorExpression, HirIdent, HirIfExpression, HirIndexExpression, + HirInfixExpression, HirLambda, HirMemberAccess, HirMethodCallExpression, + HirMethodReference, HirPrefixExpression, + }, + traits::TraitConstraint, + }, + macros_api::{ + BlockExpression, CallExpression, CastExpression, Expression, ExpressionKind, HirExpression, + HirLiteral, HirStatement, Ident, IndexExpression, Literal, MemberAccessExpression, + MethodCallExpression, NodeInterner, NoirFunction, PrefixExpression, Statement, + StatementKind, StructId, + }, + node_interner::{DefinitionKind, DependencyId, ExprId, FuncId, StmtId, TraitId}, + Shared, StructType, Type, TypeVariable, +}; +use crate::{ + ast::{TraitBound, UnresolvedGenerics}, + graph::CrateId, + hir::{ + def_collector::{ + dc_crate::{CollectedItems, DefCollector}, + errors::DefCollectorErrorKind, + }, + def_map::{LocalModuleId, ModuleDefId, ModuleId, MAIN_FUNCTION}, + resolution::{ + errors::PubPosition, + import::{PathResolution, PathResolutionError}, + path_resolver::StandardPathResolver, + }, + Context, + }, + hir_def::function::{FuncMeta, HirFunction}, + macros_api::{Param, Path, UnresolvedType, UnresolvedTypeData, Visibility}, + node_interner::TraitImplId, + token::FunctionAttribute, + Generics, +}; + +mod expressions; +mod patterns; +mod scope; +mod statements; +mod types; + +use fm::FileId; +use iter_extended::vecmap; +use noirc_errors::{Location, Span}; +use regex::Regex; +use rustc_hash::FxHashSet as HashSet; + +/// ResolverMetas are tagged onto each definition to track how many times they are used +#[derive(Debug, PartialEq, Eq)] +pub struct ResolverMeta { + num_times_used: usize, + ident: HirIdent, + warn_if_unused: bool, +} + +type ScopeForest = GenericScopeForest; + +pub struct Elaborator<'context> { + scopes: ScopeForest, + + errors: Vec<(CompilationError, FileId)>, + + interner: &'context mut NodeInterner, + + def_maps: &'context BTreeMap, + + file: FileId, + + in_unconstrained_fn: bool, + nested_loops: usize, + + /// True if the current module is a contract. + /// This is usually determined by self.path_resolver.module_id(), but it can + /// be overridden for impls. Impls are an odd case since the methods within resolve + /// as if they're in the parent module, but should be placed in a child module. + /// Since they should be within a child module, in_contract is manually set to false + /// for these so we can still resolve them in the parent module without them being in a contract. + in_contract: bool, + + /// Contains a mapping of the current struct or functions's generics to + /// unique type variables if we're resolving a struct. Empty otherwise. + /// This is a Vec rather than a map to preserve the order a functions generics + /// were declared in. + generics: Vec<(Rc, TypeVariable, Span)>, + + /// When resolving lambda expressions, we need to keep track of the variables + /// that are captured. We do this in order to create the hidden environment + /// parameter for the lambda function. + lambda_stack: Vec, + + /// Set to the current type if we're resolving an impl + self_type: Option, + + /// The current dependency item we're resolving. + /// Used to link items to their dependencies in the dependency graph + current_item: Option, + + /// If we're currently resolving methods within a trait impl, this will be set + /// to the corresponding trait impl ID. + current_trait_impl: Option, + + trait_id: Option, + + /// In-resolution names + /// + /// This needs to be a set because we can have multiple in-resolution + /// names when resolving structs that are declared in reverse order of their + /// dependencies, such as in the following case: + /// + /// ``` + /// struct Wrapper { + /// value: Wrapped + /// } + /// struct Wrapped { + /// } + /// ``` + resolving_ids: BTreeSet, + + trait_bounds: Vec, + + current_function: Option, + + /// All type variables created in the current function. + /// This map is used to default any integer type variables at the end of + /// a function (before checking trait constraints) if a type wasn't already chosen. + type_variables: Vec, + + /// Trait constraints are collected during type checking until they are + /// verified at the end of a function. This is because constraints arise + /// on each variable, but it is only until function calls when the types + /// needed for the trait constraint may become known. + trait_constraints: Vec<(TraitConstraint, ExprId)>, + + /// The current module this elaborator is in. + /// Initially empty, it is set whenever a new top-level item is resolved. + local_module: LocalModuleId, + + crate_id: CrateId, +} + +impl<'context> Elaborator<'context> { + pub fn new(context: &'context mut Context, crate_id: CrateId) -> Self { + Self { + scopes: ScopeForest::default(), + errors: Vec::new(), + interner: &mut context.def_interner, + def_maps: &context.def_maps, + file: FileId::dummy(), + in_unconstrained_fn: false, + nested_loops: 0, + in_contract: false, + generics: Vec::new(), + lambda_stack: Vec::new(), + self_type: None, + current_item: None, + trait_id: None, + local_module: LocalModuleId::dummy_id(), + crate_id, + resolving_ids: BTreeSet::new(), + trait_bounds: Vec::new(), + current_function: None, + type_variables: Vec::new(), + trait_constraints: Vec::new(), + current_trait_impl: None, + } + } + + pub fn elaborate( + context: &'context mut Context, + crate_id: CrateId, + items: CollectedItems, + ) -> Vec<(CompilationError, FileId)> { + let mut this = Self::new(context, crate_id); + + // the resolver filters literal globals first + for global in items.globals {} + + for alias in items.type_aliases {} + + for trait_ in items.traits {} + + for struct_ in items.types {} + + for trait_impl in &items.trait_impls { + // only collect now + } + + for impl_ in &items.impls { + // only collect now + } + + // resolver resolves non-literal globals here + + for functions in items.functions { + this.file = functions.file_id; + this.trait_id = functions.trait_id; // TODO: Resolve? + for (local_module, id, func) in functions.functions { + this.local_module = local_module; + this.elaborate_function(func, id); + } + } + + for impl_ in items.impls {} + + for trait_impl in items.trait_impls {} + + let cycle_errors = this.interner.check_for_dependency_cycles(); + this.errors.extend(cycle_errors); + + this.errors + } + + fn elaborate_function(&mut self, mut function: NoirFunction, id: FuncId) { + self.current_function = Some(id); + self.resolve_where_clause(&mut function.def.where_clause); + + // Without this, impl methods can accidentally be placed in contracts. See #3254 + if self.self_type.is_some() { + self.in_contract = false; + } + + self.scopes.start_function(); + self.current_item = Some(DependencyId::Function(id)); + + // Check whether the function has globals in the local module and add them to the scope + self.resolve_local_globals(); + self.add_generics(&function.def.generics); + + self.desugar_impl_trait_args(&mut function, id); + self.trait_bounds = function.def.where_clause.clone(); + + let is_low_level_or_oracle = function + .attributes() + .function + .as_ref() + .map_or(false, |func| func.is_low_level() || func.is_oracle()); + + if function.def.is_unconstrained { + self.in_unconstrained_fn = true; + } + + let func_meta = self.extract_meta(&function, id); + + self.add_trait_constraints_to_scope(&func_meta); + + let (hir_func, body_type) = match function.kind { + FunctionKind::Builtin | FunctionKind::LowLevel | FunctionKind::Oracle => { + (HirFunction::empty(), Type::Error) + } + FunctionKind::Normal | FunctionKind::Recursive => { + let block_span = function.def.span; + let (block, body_type) = self.elaborate_block(function.def.body); + let expr_id = self.intern_expr(block, block_span); + self.interner.push_expr_type(expr_id, body_type.clone()); + (HirFunction::unchecked_from_expr(expr_id), body_type) + } + }; + + if !func_meta.can_ignore_return_type() { + self.type_check_function_body(body_type, &func_meta, hir_func.as_expr()); + } + + // Default any type variables that still need defaulting. + // This is done before trait impl search since leaving them bindable can lead to errors + // when multiple impls are available. Instead we default first to choose the Field or u64 impl. + for typ in &self.type_variables { + if let Type::TypeVariable(variable, kind) = typ.follow_bindings() { + let msg = "TypeChecker should only track defaultable type vars"; + variable.bind(kind.default_type().expect(msg)); + } + } + + // Verify any remaining trait constraints arising from the function body + for (constraint, expr_id) in std::mem::take(&mut self.trait_constraints) { + let span = self.interner.expr_span(&expr_id); + self.verify_trait_constraint( + &constraint.typ, + constraint.trait_id, + &constraint.trait_generics, + expr_id, + span, + ); + } + + // Now remove all the `where` clause constraints we added + for constraint in &func_meta.trait_constraints { + self.interner.remove_assumed_trait_implementations_for_trait(constraint.trait_id); + } + + let func_scope_tree = self.scopes.end_function(); + + // The arguments to low-level and oracle functions are always unused so we do not produce warnings for them. + if !is_low_level_or_oracle { + self.check_for_unused_variables_in_scope_tree(func_scope_tree); + } + + self.trait_bounds.clear(); + + self.interner.push_fn_meta(func_meta, id); + self.interner.update_fn(id, hir_func); + self.current_function = None; + } + + /// This turns function parameters of the form: + /// fn foo(x: impl Bar) + /// + /// into + /// fn foo(x: T0_impl_Bar) where T0_impl_Bar: Bar + fn desugar_impl_trait_args(&mut self, func: &mut NoirFunction, func_id: FuncId) { + let mut impl_trait_generics = HashSet::default(); + let mut counter: usize = 0; + for parameter in func.def.parameters.iter_mut() { + if let UnresolvedTypeData::TraitAsType(path, args) = ¶meter.typ.typ { + let mut new_generic_ident: Ident = + format!("T{}_impl_{}", func_id, path.as_string()).into(); + let mut new_generic_path = Path::from_ident(new_generic_ident.clone()); + while impl_trait_generics.contains(&new_generic_ident) + || self.lookup_generic_or_global_type(&new_generic_path).is_some() + { + new_generic_ident = + format!("T{}_impl_{}_{}", func_id, path.as_string(), counter).into(); + new_generic_path = Path::from_ident(new_generic_ident.clone()); + counter += 1; + } + impl_trait_generics.insert(new_generic_ident.clone()); + + let is_synthesized = true; + let new_generic_type_data = + UnresolvedTypeData::Named(new_generic_path, vec![], is_synthesized); + let new_generic_type = + UnresolvedType { typ: new_generic_type_data.clone(), span: None }; + let new_trait_bound = TraitBound { + trait_path: path.clone(), + trait_id: None, + trait_generics: args.to_vec(), + }; + let new_trait_constraint = UnresolvedTraitConstraint { + typ: new_generic_type, + trait_bound: new_trait_bound, + }; + + parameter.typ.typ = new_generic_type_data; + func.def.generics.push(new_generic_ident); + func.def.where_clause.push(new_trait_constraint); + } + } + self.add_generics(&impl_trait_generics.into_iter().collect()); + } + + /// Add the given generics to scope. + /// Each generic will have a fresh Shared associated with it. + pub fn add_generics(&mut self, generics: &UnresolvedGenerics) -> Generics { + vecmap(generics, |generic| { + // Map the generic to a fresh type variable + let id = self.interner.next_type_variable_id(); + let typevar = TypeVariable::unbound(id); + let span = generic.0.span(); + + // Check for name collisions of this generic + let name = Rc::new(generic.0.contents.clone()); + + if let Some((_, _, first_span)) = self.find_generic(&name) { + self.push_err(ResolverError::DuplicateDefinition { + name: generic.0.contents.clone(), + first_span: *first_span, + second_span: span, + }); + } else { + self.generics.push((name, typevar.clone(), span)); + } + + typevar + }) + } + + fn push_err(&mut self, error: impl Into) { + self.errors.push((error.into(), self.file)); + } + + fn resolve_where_clause(&mut self, clause: &mut [UnresolvedTraitConstraint]) { + for bound in clause { + if let Some(trait_id) = self.resolve_trait_by_path(bound.trait_bound.trait_path.clone()) + { + bound.trait_bound.trait_id = Some(trait_id); + } + } + } + + fn resolve_trait_by_path(&mut self, path: Path) -> Option { + let path_resolver = StandardPathResolver::new(self.module_id()); + + let error = match path_resolver.resolve(self.def_maps, path.clone()) { + Ok(PathResolution { module_def_id: ModuleDefId::TraitId(trait_id), error }) => { + if let Some(error) = error { + self.push_err(error); + } + return Some(trait_id); + } + Ok(_) => DefCollectorErrorKind::NotATrait { not_a_trait_name: path }, + Err(_) => DefCollectorErrorKind::TraitNotFound { trait_path: path }, + }; + self.push_err(error); + None + } + + fn resolve_local_globals(&mut self) { + let globals = vecmap(self.interner.get_all_globals(), |global| { + (global.id, global.local_id, global.ident.clone()) + }); + for (id, local_module_id, name) in globals { + if local_module_id == self.local_module { + let definition = DefinitionKind::Global(id); + self.add_global_variable_decl(name, definition); + } + } + } + + /// TODO: This is currently only respected for generic free functions + /// there's a bunch of other places where trait constraints can pop up + fn resolve_trait_constraints( + &mut self, + where_clause: &[UnresolvedTraitConstraint], + ) -> Vec { + where_clause + .iter() + .cloned() + .filter_map(|constraint| self.resolve_trait_constraint(constraint)) + .collect() + } + + pub fn resolve_trait_constraint( + &mut self, + constraint: UnresolvedTraitConstraint, + ) -> Option { + let typ = self.resolve_type(constraint.typ); + let trait_generics = + vecmap(constraint.trait_bound.trait_generics, |typ| self.resolve_type(typ)); + + let span = constraint.trait_bound.trait_path.span(); + let the_trait = self.lookup_trait_or_error(constraint.trait_bound.trait_path)?; + let trait_id = the_trait.id; + + let expected_generics = the_trait.generics.len(); + let actual_generics = trait_generics.len(); + + if actual_generics != expected_generics { + let item_name = the_trait.name.to_string(); + self.push_err(ResolverError::IncorrectGenericCount { + span, + item_name, + actual: actual_generics, + expected: expected_generics, + }); + } + + Some(TraitConstraint { typ, trait_id, trait_generics }) + } + + /// Extract metadata from a NoirFunction + /// to be used in analysis and intern the function parameters + /// Prerequisite: self.add_generics() has already been called with the given + /// function's generics, including any generics from the impl, if any. + fn extract_meta(&mut self, func: &NoirFunction, func_id: FuncId) -> FuncMeta { + let location = Location::new(func.name_ident().span(), self.file); + let id = self.interner.function_definition_id(func_id); + let name_ident = HirIdent::non_trait_method(id, location); + + let attributes = func.attributes().clone(); + let has_no_predicates_attribute = attributes.is_no_predicates(); + let should_fold = attributes.is_foldable(); + if !self.inline_attribute_allowed(func) { + if has_no_predicates_attribute { + self.push_err(ResolverError::NoPredicatesAttributeOnUnconstrained { + ident: func.name_ident().clone(), + }); + } else if should_fold { + self.push_err(ResolverError::FoldAttributeOnUnconstrained { + ident: func.name_ident().clone(), + }); + } + } + // Both the #[fold] and #[no_predicates] alter a function's inline type and code generation in similar ways. + // In certain cases such as type checking (for which the following flag will be used) both attributes + // indicate we should code generate in the same way. Thus, we unify the attributes into one flag here. + let has_inline_attribute = has_no_predicates_attribute || should_fold; + let is_entry_point = self.is_entry_point_function(func); + + let mut generics = vecmap(&self.generics, |(_, typevar, _)| typevar.clone()); + let mut parameters = vec![]; + let mut parameter_types = vec![]; + + for Param { visibility, pattern, typ, span: _ } in func.parameters().iter().cloned() { + if visibility == Visibility::Public && !self.pub_allowed(func) { + self.push_err(ResolverError::UnnecessaryPub { + ident: func.name_ident().clone(), + position: PubPosition::Parameter, + }); + } + + let type_span = typ.span.unwrap_or_else(|| pattern.span()); + let typ = self.resolve_type_inner(typ, &mut generics); + self.check_if_type_is_valid_for_program_input( + &typ, + is_entry_point, + has_inline_attribute, + type_span, + ); + let pattern = self.elaborate_pattern(pattern, typ.clone(), DefinitionKind::Local(None)); + + parameters.push((pattern, typ.clone(), visibility)); + parameter_types.push(typ); + } + + let return_type = Box::new(self.resolve_type(func.return_type())); + + self.declare_numeric_generics(¶meter_types, &return_type); + + if !self.pub_allowed(func) && func.def.return_visibility == Visibility::Public { + self.push_err(ResolverError::UnnecessaryPub { + ident: func.name_ident().clone(), + position: PubPosition::ReturnType, + }); + } + + let is_low_level_function = + attributes.function.as_ref().map_or(false, |func| func.is_low_level()); + + if !self.crate_id.is_stdlib() && is_low_level_function { + let error = + ResolverError::LowLevelFunctionOutsideOfStdlib { ident: func.name_ident().clone() }; + self.push_err(error); + } + + // 'pub' is required on return types for entry point functions + if is_entry_point + && return_type.as_ref() != &Type::Unit + && func.def.return_visibility == Visibility::Private + { + self.push_err(ResolverError::NecessaryPub { ident: func.name_ident().clone() }); + } + // '#[recursive]' attribute is only allowed for entry point functions + if !is_entry_point && func.kind == FunctionKind::Recursive { + self.push_err(ResolverError::MisplacedRecursiveAttribute { + ident: func.name_ident().clone(), + }); + } + + if matches!(attributes.function, Some(FunctionAttribute::Test { .. })) + && !parameters.is_empty() + { + self.push_err(ResolverError::TestFunctionHasParameters { + span: func.name_ident().span(), + }); + } + + let mut typ = Type::Function(parameter_types, return_type, Box::new(Type::Unit)); + + if !generics.is_empty() { + typ = Type::Forall(generics, Box::new(typ)); + } + + self.interner.push_definition_type(name_ident.id, typ.clone()); + + let direct_generics = func.def.generics.iter(); + let direct_generics = direct_generics + .filter_map(|generic| self.find_generic(&generic.0.contents)) + .map(|(name, typevar, _span)| (name.clone(), typevar.clone())) + .collect(); + + FuncMeta { + name: name_ident, + kind: func.kind, + location, + typ, + direct_generics, + trait_impl: self.current_trait_impl, + parameters: parameters.into(), + return_type: func.def.return_type.clone(), + return_visibility: func.def.return_visibility, + has_body: !func.def.body.is_empty(), + trait_constraints: self.resolve_trait_constraints(&func.def.where_clause), + is_entry_point, + has_inline_attribute, + } + } + + /// Only sized types are valid to be used as main's parameters or the parameters to a contract + /// function. If the given type is not sized (e.g. contains a slice or NamedGeneric type), an + /// error is issued. + fn check_if_type_is_valid_for_program_input( + &mut self, + typ: &Type, + is_entry_point: bool, + has_inline_attribute: bool, + span: Span, + ) { + if (is_entry_point && !typ.is_valid_for_program_input()) + || (has_inline_attribute && !typ.is_valid_non_inlined_function_input()) + { + self.push_err(TypeCheckError::InvalidTypeForEntryPoint { span }); + } + } + + fn inline_attribute_allowed(&self, func: &NoirFunction) -> bool { + // Inline attributes are only relevant for constrained functions + // as all unconstrained functions are not inlined + !func.def.is_unconstrained + } + + /// True if the 'pub' keyword is allowed on parameters in this function + /// 'pub' on function parameters is only allowed for entry point functions + fn pub_allowed(&self, func: &NoirFunction) -> bool { + self.is_entry_point_function(func) || func.attributes().is_foldable() + } + + fn is_entry_point_function(&self, func: &NoirFunction) -> bool { + if self.in_contract { + func.attributes().is_contract_entry_point() + } else { + func.name() == MAIN_FUNCTION + } + } + + fn declare_numeric_generics(&mut self, params: &[Type], return_type: &Type) { + if self.generics.is_empty() { + return; + } + + for (name_to_find, type_variable) in Self::find_numeric_generics(params, return_type) { + // Declare any generics to let users use numeric generics in scope. + // Don't issue a warning if these are unused + // + // We can fail to find the generic in self.generics if it is an implicit one created + // by the compiler. This can happen when, e.g. eliding array lengths using the slice + // syntax [T]. + if let Some((name, _, span)) = + self.generics.iter().find(|(name, _, _)| name.as_ref() == &name_to_find) + { + let ident = Ident::new(name.to_string(), *span); + let definition = DefinitionKind::GenericType(type_variable); + self.add_variable_decl_inner(ident, false, false, false, definition); + } + } + } + + fn find_numeric_generics( + parameters: &[Type], + return_type: &Type, + ) -> Vec<(String, TypeVariable)> { + let mut found = BTreeMap::new(); + for parameter in parameters { + Self::find_numeric_generics_in_type(parameter, &mut found); + } + Self::find_numeric_generics_in_type(return_type, &mut found); + found.into_iter().collect() + } + + fn find_numeric_generics_in_type(typ: &Type, found: &mut BTreeMap) { + match typ { + Type::FieldElement + | Type::Integer(_, _) + | Type::Bool + | Type::Unit + | Type::Error + | Type::TypeVariable(_, _) + | Type::Constant(_) + | Type::NamedGeneric(_, _) + | Type::Code + | Type::Forall(_, _) => (), + + Type::TraitAsType(_, _, args) => { + for arg in args { + Self::find_numeric_generics_in_type(arg, found); + } + } + + Type::Array(length, element_type) => { + if let Type::NamedGeneric(type_variable, name) = length.as_ref() { + found.insert(name.to_string(), type_variable.clone()); + } + Self::find_numeric_generics_in_type(element_type, found); + } + + Type::Slice(element_type) => { + Self::find_numeric_generics_in_type(element_type, found); + } + + Type::Tuple(fields) => { + for field in fields { + Self::find_numeric_generics_in_type(field, found); + } + } + + Type::Function(parameters, return_type, _env) => { + for parameter in parameters { + Self::find_numeric_generics_in_type(parameter, found); + } + Self::find_numeric_generics_in_type(return_type, found); + } + + Type::Struct(struct_type, generics) => { + for (i, generic) in generics.iter().enumerate() { + if let Type::NamedGeneric(type_variable, name) = generic { + if struct_type.borrow().generic_is_numeric(i) { + found.insert(name.to_string(), type_variable.clone()); + } + } else { + Self::find_numeric_generics_in_type(generic, found); + } + } + } + Type::Alias(alias, generics) => { + for (i, generic) in generics.iter().enumerate() { + if let Type::NamedGeneric(type_variable, name) = generic { + if alias.borrow().generic_is_numeric(i) { + found.insert(name.to_string(), type_variable.clone()); + } + } else { + Self::find_numeric_generics_in_type(generic, found); + } + } + } + Type::MutableReference(element) => Self::find_numeric_generics_in_type(element, found), + Type::String(length) => { + if let Type::NamedGeneric(type_variable, name) = length.as_ref() { + found.insert(name.to_string(), type_variable.clone()); + } + } + Type::FmtString(length, fields) => { + if let Type::NamedGeneric(type_variable, name) = length.as_ref() { + found.insert(name.to_string(), type_variable.clone()); + } + Self::find_numeric_generics_in_type(fields, found); + } + } + } + + fn add_trait_constraints_to_scope(&mut self, func_meta: &FuncMeta) { + for constraint in &func_meta.trait_constraints { + let object = constraint.typ.clone(); + let trait_id = constraint.trait_id; + let generics = constraint.trait_generics.clone(); + + if !self.interner.add_assumed_trait_implementation(object, trait_id, generics) { + if let Some(the_trait) = self.interner.try_get_trait(trait_id) { + let trait_name = the_trait.name.to_string(); + let typ = constraint.typ.clone(); + let span = func_meta.location.span; + self.push_err(TypeCheckError::UnneededTraitConstraint { + trait_name, + typ, + span, + }); + } + } + } + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/patterns.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/patterns.rs new file mode 100644 index 00000000000..195d37878f1 --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/patterns.rs @@ -0,0 +1,465 @@ +use iter_extended::vecmap; +use noirc_errors::{Location, Span}; +use rustc_hash::FxHashSet as HashSet; + +use crate::{ + ast::ERROR_IDENT, + hir::{ + resolution::errors::ResolverError, + type_check::{Source, TypeCheckError}, + }, + hir_def::{ + expr::{HirIdent, ImplKind}, + stmt::HirPattern, + }, + macros_api::{HirExpression, Ident, Path, Pattern}, + node_interner::{DefinitionId, DefinitionKind, ExprId, TraitImplKind}, + Shared, StructType, Type, TypeBindings, +}; + +use super::{Elaborator, ResolverMeta}; + +impl<'context> Elaborator<'context> { + pub(super) fn elaborate_pattern( + &mut self, + pattern: Pattern, + expected_type: Type, + definition_kind: DefinitionKind, + ) -> HirPattern { + self.elaborate_pattern_mut(pattern, expected_type, definition_kind, None) + } + + fn elaborate_pattern_mut( + &mut self, + pattern: Pattern, + expected_type: Type, + definition: DefinitionKind, + mutable: Option, + ) -> HirPattern { + match pattern { + Pattern::Identifier(name) => { + // If this definition is mutable, do not store the rhs because it will + // not always refer to the correct value of the variable + let definition = match (mutable, definition) { + (Some(_), DefinitionKind::Local(_)) => DefinitionKind::Local(None), + (_, other) => other, + }; + let ident = self.add_variable_decl(name, mutable.is_some(), true, definition); + self.interner.push_definition_type(ident.id, expected_type); + HirPattern::Identifier(ident) + } + Pattern::Mutable(pattern, span, _) => { + if let Some(first_mut) = mutable { + self.push_err(ResolverError::UnnecessaryMut { first_mut, second_mut: span }); + } + + let pattern = + self.elaborate_pattern_mut(*pattern, expected_type, definition, Some(span)); + let location = Location::new(span, self.file); + HirPattern::Mutable(Box::new(pattern), location) + } + Pattern::Tuple(fields, span) => { + let field_types = match expected_type { + Type::Tuple(fields) => fields, + Type::Error => Vec::new(), + expected_type => { + let tuple = + Type::Tuple(vecmap(&fields, |_| self.interner.next_type_variable())); + + self.push_err(TypeCheckError::TypeMismatchWithSource { + expected: expected_type, + actual: tuple, + span, + source: Source::Assignment, + }); + Vec::new() + } + }; + + let fields = vecmap(fields.into_iter().enumerate(), |(i, field)| { + let field_type = field_types.get(i).cloned().unwrap_or(Type::Error); + self.elaborate_pattern_mut(field, field_type, definition.clone(), mutable) + }); + let location = Location::new(span, self.file); + HirPattern::Tuple(fields, location) + } + Pattern::Struct(name, fields, span) => self.elaborate_struct_pattern( + name, + fields, + span, + expected_type, + definition, + mutable, + ), + } + } + + fn elaborate_struct_pattern( + &mut self, + name: Path, + fields: Vec<(Ident, Pattern)>, + span: Span, + expected_type: Type, + definition: DefinitionKind, + mutable: Option, + ) -> HirPattern { + let error_identifier = |this: &mut Self| { + // Must create a name here to return a HirPattern::Identifier. Allowing + // shadowing here lets us avoid further errors if we define ERROR_IDENT + // multiple times. + let name = ERROR_IDENT.into(); + let identifier = this.add_variable_decl(name, false, true, definition.clone()); + HirPattern::Identifier(identifier) + }; + + let (struct_type, generics) = match self.lookup_type_or_error(name) { + Some(Type::Struct(struct_type, generics)) => (struct_type, generics), + None => return error_identifier(self), + Some(typ) => { + self.push_err(ResolverError::NonStructUsedInConstructor { typ, span }); + return error_identifier(self); + } + }; + + let actual_type = Type::Struct(struct_type.clone(), generics); + let location = Location::new(span, self.file); + + self.unify(&actual_type, &expected_type, || TypeCheckError::TypeMismatchWithSource { + expected: expected_type.clone(), + actual: actual_type.clone(), + span: location.span, + source: Source::Assignment, + }); + + let typ = struct_type.clone(); + let fields = self.resolve_constructor_pattern_fields( + typ, + fields, + span, + expected_type.clone(), + definition, + mutable, + ); + + HirPattern::Struct(expected_type, fields, location) + } + + /// Resolve all the fields of a struct constructor expression. + /// Ensures all fields are present, none are repeated, and all + /// are part of the struct. + fn resolve_constructor_pattern_fields( + &mut self, + struct_type: Shared, + fields: Vec<(Ident, Pattern)>, + span: Span, + expected_type: Type, + definition: DefinitionKind, + mutable: Option, + ) -> Vec<(Ident, HirPattern)> { + let mut ret = Vec::with_capacity(fields.len()); + let mut seen_fields = HashSet::default(); + let mut unseen_fields = struct_type.borrow().field_names(); + + for (field, pattern) in fields { + let field_type = expected_type.get_field_type(&field.0.contents).unwrap_or(Type::Error); + let resolved = + self.elaborate_pattern_mut(pattern, field_type, definition.clone(), mutable); + + if unseen_fields.contains(&field) { + unseen_fields.remove(&field); + seen_fields.insert(field.clone()); + } else if seen_fields.contains(&field) { + // duplicate field + self.push_err(ResolverError::DuplicateField { field: field.clone() }); + } else { + // field not required by struct + self.push_err(ResolverError::NoSuchField { + field: field.clone(), + struct_definition: struct_type.borrow().name.clone(), + }); + } + + ret.push((field, resolved)); + } + + if !unseen_fields.is_empty() { + self.push_err(ResolverError::MissingFields { + span, + missing_fields: unseen_fields.into_iter().map(|field| field.to_string()).collect(), + struct_definition: struct_type.borrow().name.clone(), + }); + } + + ret + } + + pub(super) fn add_variable_decl( + &mut self, + name: Ident, + mutable: bool, + allow_shadowing: bool, + definition: DefinitionKind, + ) -> HirIdent { + self.add_variable_decl_inner(name, mutable, allow_shadowing, true, definition) + } + + pub fn add_variable_decl_inner( + &mut self, + name: Ident, + mutable: bool, + allow_shadowing: bool, + warn_if_unused: bool, + definition: DefinitionKind, + ) -> HirIdent { + if definition.is_global() { + return self.add_global_variable_decl(name, definition); + } + + let location = Location::new(name.span(), self.file); + let id = + self.interner.push_definition(name.0.contents.clone(), mutable, definition, location); + let ident = HirIdent::non_trait_method(id, location); + let resolver_meta = + ResolverMeta { num_times_used: 0, ident: ident.clone(), warn_if_unused }; + + let scope = self.scopes.get_mut_scope(); + let old_value = scope.add_key_value(name.0.contents.clone(), resolver_meta); + + if !allow_shadowing { + if let Some(old_value) = old_value { + self.push_err(ResolverError::DuplicateDefinition { + name: name.0.contents, + first_span: old_value.ident.location.span, + second_span: location.span, + }); + } + } + + ident + } + + pub fn add_global_variable_decl( + &mut self, + name: Ident, + definition: DefinitionKind, + ) -> HirIdent { + let scope = self.scopes.get_mut_scope(); + + // This check is necessary to maintain the same definition ids in the interner. Currently, each function uses a new resolver that has its own ScopeForest and thus global scope. + // We must first check whether an existing definition ID has been inserted as otherwise there will be multiple definitions for the same global statement. + // This leads to an error in evaluation where the wrong definition ID is selected when evaluating a statement using the global. The check below prevents this error. + let mut global_id = None; + let global = self.interner.get_all_globals(); + for global_info in global { + if global_info.ident == name && global_info.local_id == self.local_module { + global_id = Some(global_info.id); + } + } + + let (ident, resolver_meta) = if let Some(id) = global_id { + let global = self.interner.get_global(id); + let hir_ident = HirIdent::non_trait_method(global.definition_id, global.location); + let ident = hir_ident.clone(); + let resolver_meta = ResolverMeta { num_times_used: 0, ident, warn_if_unused: true }; + (hir_ident, resolver_meta) + } else { + let location = Location::new(name.span(), self.file); + let id = + self.interner.push_definition(name.0.contents.clone(), false, definition, location); + let ident = HirIdent::non_trait_method(id, location); + let resolver_meta = + ResolverMeta { num_times_used: 0, ident: ident.clone(), warn_if_unused: true }; + (ident, resolver_meta) + }; + + let old_global_value = scope.add_key_value(name.0.contents.clone(), resolver_meta); + if let Some(old_global_value) = old_global_value { + self.push_err(ResolverError::DuplicateDefinition { + name: name.0.contents.clone(), + first_span: old_global_value.ident.location.span, + second_span: name.span(), + }); + } + ident + } + + // Checks for a variable having been declared before. + // (Variable declaration and definition cannot be separate in Noir.) + // Once the variable has been found, intern and link `name` to this definition, + // returning (the ident, the IdentId of `name`) + // + // If a variable is not found, then an error is logged and a dummy id + // is returned, for better error reporting UX + pub(super) fn find_variable_or_default(&mut self, name: &Ident) -> (HirIdent, usize) { + self.use_variable(name).unwrap_or_else(|error| { + self.push_err(error); + let id = DefinitionId::dummy_id(); + let location = Location::new(name.span(), self.file); + (HirIdent::non_trait_method(id, location), 0) + }) + } + + /// Lookup and use the specified variable. + /// This will increment its use counter by one and return the variable if found. + /// If the variable is not found, an error is returned. + pub(super) fn use_variable( + &mut self, + name: &Ident, + ) -> Result<(HirIdent, usize), ResolverError> { + // Find the definition for this Ident + let scope_tree = self.scopes.current_scope_tree(); + let variable = scope_tree.find(&name.0.contents); + + let location = Location::new(name.span(), self.file); + if let Some((variable_found, scope)) = variable { + variable_found.num_times_used += 1; + let id = variable_found.ident.id; + Ok((HirIdent::non_trait_method(id, location), scope)) + } else { + Err(ResolverError::VariableNotDeclared { + name: name.0.contents.clone(), + span: name.0.span(), + }) + } + } + + pub(super) fn elaborate_variable(&mut self, variable: Path) -> (ExprId, Type) { + let span = variable.span; + let expr = self.resolve_variable(variable); + let id = self.interner.push_expr(HirExpression::Ident(expr.clone())); + self.interner.push_expr_location(id, span, self.file); + let typ = self.type_check_variable(expr, id); + self.interner.push_expr_type(id, typ.clone()); + (id, typ) + } + + fn resolve_variable(&mut self, path: Path) -> HirIdent { + if let Some((method, constraint, assumed)) = self.resolve_trait_generic_path(&path) { + HirIdent { + location: Location::new(path.span, self.file), + id: self.interner.trait_method_id(method), + impl_kind: ImplKind::TraitMethod(method, constraint, assumed), + } + } else { + // If the Path is being used as an Expression, then it is referring to a global from a separate module + // Otherwise, then it is referring to an Identifier + // This lookup allows support of such statements: let x = foo::bar::SOME_GLOBAL + 10; + // If the expression is a singular indent, we search the resolver's current scope as normal. + let (hir_ident, var_scope_index) = self.get_ident_from_path(path); + + if hir_ident.id != DefinitionId::dummy_id() { + match self.interner.definition(hir_ident.id).kind { + DefinitionKind::Function(id) => { + if let Some(current_item) = self.current_item { + self.interner.add_function_dependency(current_item, id); + } + } + DefinitionKind::Global(global_id) => { + if let Some(current_item) = self.current_item { + self.interner.add_global_dependency(current_item, global_id); + } + } + DefinitionKind::GenericType(_) => { + // Initialize numeric generics to a polymorphic integer type in case + // they're used in expressions. We must do this here since type_check_variable + // does not check definition kinds and otherwise expects parameters to + // already be typed. + if self.interner.definition_type(hir_ident.id) == Type::Error { + let typ = Type::polymorphic_integer_or_field(self.interner); + self.interner.push_definition_type(hir_ident.id, typ); + } + } + DefinitionKind::Local(_) => { + // only local variables can be captured by closures. + self.resolve_local_variable(hir_ident.clone(), var_scope_index); + } + } + } + + hir_ident + } + } + + pub(super) fn type_check_variable(&mut self, ident: HirIdent, expr_id: ExprId) -> Type { + let mut bindings = TypeBindings::new(); + + // Add type bindings from any constraints that were used. + // We need to do this first since otherwise instantiating the type below + // will replace each trait generic with a fresh type variable, rather than + // the type used in the trait constraint (if it exists). See #4088. + if let ImplKind::TraitMethod(_, constraint, _) = &ident.impl_kind { + let the_trait = self.interner.get_trait(constraint.trait_id); + assert_eq!(the_trait.generics.len(), constraint.trait_generics.len()); + + for (param, arg) in the_trait.generics.iter().zip(&constraint.trait_generics) { + // Avoid binding t = t + if !arg.occurs(param.id()) { + bindings.insert(param.id(), (param.clone(), arg.clone())); + } + } + } + + // An identifiers type may be forall-quantified in the case of generic functions. + // E.g. `fn foo(t: T, field: Field) -> T` has type `forall T. fn(T, Field) -> T`. + // We must instantiate identifiers at every call site to replace this T with a new type + // variable to handle generic functions. + let t = self.interner.id_type_substitute_trait_as_type(ident.id); + + // This instantiates a trait's generics as well which need to be set + // when the constraint below is later solved for when the function is + // finished. How to link the two? + let (typ, bindings) = t.instantiate_with_bindings(bindings, self.interner); + + // Push any trait constraints required by this definition to the context + // to be checked later when the type of this variable is further constrained. + if let Some(definition) = self.interner.try_definition(ident.id) { + if let DefinitionKind::Function(function) = definition.kind { + let function = self.interner.function_meta(&function); + + for mut constraint in function.trait_constraints.clone() { + constraint.apply_bindings(&bindings); + self.trait_constraints.push((constraint, expr_id)); + } + } + } + + if let ImplKind::TraitMethod(_, mut constraint, assumed) = ident.impl_kind { + constraint.apply_bindings(&bindings); + if assumed { + let trait_impl = TraitImplKind::Assumed { + object_type: constraint.typ, + trait_generics: constraint.trait_generics, + }; + self.interner.select_impl_for_expression(expr_id, trait_impl); + } else { + // Currently only one impl can be selected per expr_id, so this + // constraint needs to be pushed after any other constraints so + // that monomorphization can resolve this trait method to the correct impl. + self.trait_constraints.push((constraint, expr_id)); + } + } + + self.interner.store_instantiation_bindings(expr_id, bindings); + typ + } + + fn get_ident_from_path(&mut self, path: Path) -> (HirIdent, usize) { + let location = Location::new(path.span(), self.file); + + let error = match path.as_ident().map(|ident| self.use_variable(ident)) { + Some(Ok(found)) => return found, + // Try to look it up as a global, but still issue the first error if we fail + Some(Err(error)) => match self.lookup_global(path) { + Ok(id) => return (HirIdent::non_trait_method(id, location), 0), + Err(_) => error, + }, + None => match self.lookup_global(path) { + Ok(id) => return (HirIdent::non_trait_method(id, location), 0), + Err(error) => error, + }, + }; + self.push_err(error); + let id = DefinitionId::dummy_id(); + (HirIdent::non_trait_method(id, location), 0) + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/scope.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/scope.rs new file mode 100644 index 00000000000..cf10dbbc2b2 --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/scope.rs @@ -0,0 +1,200 @@ +use noirc_errors::Spanned; +use rustc_hash::FxHashMap as HashMap; + +use crate::ast::ERROR_IDENT; +use crate::hir::comptime::Value; +use crate::hir::def_map::{LocalModuleId, ModuleId}; +use crate::hir::resolution::path_resolver::{PathResolver, StandardPathResolver}; +use crate::hir::resolution::resolver::SELF_TYPE_NAME; +use crate::hir::scope::{Scope as GenericScope, ScopeTree as GenericScopeTree}; +use crate::macros_api::Ident; +use crate::{ + hir::{ + def_map::{ModuleDefId, TryFromModuleDefId}, + resolution::errors::ResolverError, + }, + hir_def::{ + expr::{HirCapturedVar, HirIdent}, + traits::Trait, + }, + macros_api::{Path, StructId}, + node_interner::{DefinitionId, TraitId, TypeAliasId}, + Shared, StructType, +}; +use crate::{Type, TypeAlias}; + +use super::{Elaborator, ResolverMeta}; + +type Scope = GenericScope; +type ScopeTree = GenericScopeTree; + +impl<'context> Elaborator<'context> { + pub(super) fn lookup(&mut self, path: Path) -> Result { + let span = path.span(); + let id = self.resolve_path(path)?; + T::try_from(id).ok_or_else(|| ResolverError::Expected { + expected: T::description(), + got: id.as_str().to_owned(), + span, + }) + } + + pub(super) fn module_id(&self) -> ModuleId { + assert_ne!(self.local_module, LocalModuleId::dummy_id(), "local_module is unset"); + ModuleId { krate: self.crate_id, local_id: self.local_module } + } + + pub(super) fn resolve_path(&mut self, path: Path) -> Result { + let resolver = StandardPathResolver::new(self.module_id()); + let path_resolution = resolver.resolve(self.def_maps, path)?; + + if let Some(error) = path_resolution.error { + self.push_err(error); + } + + Ok(path_resolution.module_def_id) + } + + pub(super) fn get_struct(&self, type_id: StructId) -> Shared { + self.interner.get_struct(type_id) + } + + pub(super) fn get_trait_mut(&mut self, trait_id: TraitId) -> &mut Trait { + self.interner.get_trait_mut(trait_id) + } + + pub(super) fn resolve_local_variable(&mut self, hir_ident: HirIdent, var_scope_index: usize) { + let mut transitive_capture_index: Option = None; + + for lambda_index in 0..self.lambda_stack.len() { + if self.lambda_stack[lambda_index].scope_index > var_scope_index { + // Beware: the same variable may be captured multiple times, so we check + // for its presence before adding the capture below. + let position = self.lambda_stack[lambda_index] + .captures + .iter() + .position(|capture| capture.ident.id == hir_ident.id); + + if position.is_none() { + self.lambda_stack[lambda_index].captures.push(HirCapturedVar { + ident: hir_ident.clone(), + transitive_capture_index, + }); + } + + if lambda_index + 1 < self.lambda_stack.len() { + // There is more than one closure between the current scope and + // the scope of the variable, so this is a propagated capture. + // We need to track the transitive capture index as we go up in + // the closure stack. + transitive_capture_index = Some(position.unwrap_or( + // If this was a fresh capture, we added it to the end of + // the captures vector: + self.lambda_stack[lambda_index].captures.len() - 1, + )); + } + } + } + } + + pub(super) fn lookup_global(&mut self, path: Path) -> Result { + let span = path.span(); + let id = self.resolve_path(path)?; + + if let Some(function) = TryFromModuleDefId::try_from(id) { + return Ok(self.interner.function_definition_id(function)); + } + + if let Some(global) = TryFromModuleDefId::try_from(id) { + let global = self.interner.get_global(global); + return Ok(global.definition_id); + } + + let expected = "global variable".into(); + let got = "local variable".into(); + Err(ResolverError::Expected { span, expected, got }) + } + + pub fn push_scope(&mut self) { + self.scopes.start_scope(); + } + + pub fn pop_scope(&mut self) { + let scope = self.scopes.end_scope(); + self.check_for_unused_variables_in_scope_tree(scope.into()); + } + + pub fn check_for_unused_variables_in_scope_tree(&mut self, scope_decls: ScopeTree) { + let mut unused_vars = Vec::new(); + for scope in scope_decls.0.into_iter() { + Self::check_for_unused_variables_in_local_scope(scope, &mut unused_vars); + } + + for unused_var in unused_vars.iter() { + if let Some(definition_info) = self.interner.try_definition(unused_var.id) { + let name = &definition_info.name; + if name != ERROR_IDENT && !definition_info.is_global() { + let ident = Ident(Spanned::from(unused_var.location.span, name.to_owned())); + self.push_err(ResolverError::UnusedVariable { ident }); + } + } + } + } + + fn check_for_unused_variables_in_local_scope(decl_map: Scope, unused_vars: &mut Vec) { + let unused_variables = decl_map.filter(|(variable_name, metadata)| { + let has_underscore_prefix = variable_name.starts_with('_'); // XXX: This is used for development mode, and will be removed + metadata.warn_if_unused && metadata.num_times_used == 0 && !has_underscore_prefix + }); + unused_vars.extend(unused_variables.map(|(_, meta)| meta.ident.clone())); + } + + /// Lookup a given trait by name/path. + pub fn lookup_trait_or_error(&mut self, path: Path) -> Option<&mut Trait> { + match self.lookup(path) { + Ok(trait_id) => Some(self.get_trait_mut(trait_id)), + Err(error) => { + self.push_err(error); + None + } + } + } + + /// Lookup a given struct type by name. + pub fn lookup_struct_or_error(&mut self, path: Path) -> Option> { + match self.lookup(path) { + Ok(struct_id) => Some(self.get_struct(struct_id)), + Err(error) => { + self.push_err(error); + None + } + } + } + + /// Looks up a given type by name. + /// This will also instantiate any struct types found. + pub(super) fn lookup_type_or_error(&mut self, path: Path) -> Option { + let ident = path.as_ident(); + if ident.map_or(false, |i| i == SELF_TYPE_NAME) { + if let Some(typ) = &self.self_type { + return Some(typ.clone()); + } + } + + match self.lookup(path) { + Ok(struct_id) => { + let struct_type = self.get_struct(struct_id); + let generics = struct_type.borrow().instantiate(self.interner); + Some(Type::Struct(struct_type, generics)) + } + Err(error) => { + self.push_err(error); + None + } + } + } + + pub fn lookup_type_alias(&mut self, path: Path) -> Option> { + self.lookup(path).ok().map(|id| self.interner.get_type_alias(id)) + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs new file mode 100644 index 00000000000..a7a2df4041e --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs @@ -0,0 +1,409 @@ +use noirc_errors::{Location, Span}; + +use crate::{ + ast::{AssignStatement, ConstrainStatement, LValue}, + hir::{ + resolution::errors::ResolverError, + type_check::{Source, TypeCheckError}, + }, + hir_def::{ + expr::HirIdent, + stmt::{ + HirAssignStatement, HirConstrainStatement, HirForStatement, HirLValue, HirLetStatement, + }, + }, + macros_api::{ + ForLoopStatement, ForRange, HirStatement, LetStatement, Statement, StatementKind, + }, + node_interner::{DefinitionId, DefinitionKind, StmtId}, + Type, +}; + +use super::Elaborator; + +impl<'context> Elaborator<'context> { + fn elaborate_statement_value(&mut self, statement: Statement) -> (HirStatement, Type) { + match statement.kind { + StatementKind::Let(let_stmt) => self.elaborate_let(let_stmt), + StatementKind::Constrain(constrain) => self.elaborate_constrain(constrain), + StatementKind::Assign(assign) => self.elaborate_assign(assign), + StatementKind::For(for_stmt) => self.elaborate_for(for_stmt), + StatementKind::Break => self.elaborate_jump(true, statement.span), + StatementKind::Continue => self.elaborate_jump(false, statement.span), + StatementKind::Comptime(statement) => self.elaborate_comptime(*statement), + StatementKind::Expression(expr) => { + let (expr, typ) = self.elaborate_expression(expr); + (HirStatement::Expression(expr), typ) + } + StatementKind::Semi(expr) => { + let (expr, _typ) = self.elaborate_expression(expr); + (HirStatement::Semi(expr), Type::Unit) + } + StatementKind::Error => (HirStatement::Error, Type::Error), + } + } + + pub(super) fn elaborate_statement(&mut self, statement: Statement) -> (StmtId, Type) { + let span = statement.span; + let (hir_statement, typ) = self.elaborate_statement_value(statement); + let id = self.interner.push_stmt(hir_statement); + self.interner.push_stmt_location(id, span, self.file); + (id, typ) + } + + pub(super) fn elaborate_let(&mut self, let_stmt: LetStatement) -> (HirStatement, Type) { + let expr_span = let_stmt.expression.span; + let (expression, expr_type) = self.elaborate_expression(let_stmt.expression); + let definition = DefinitionKind::Local(Some(expression)); + let annotated_type = self.resolve_type(let_stmt.r#type); + + // First check if the LHS is unspecified + // If so, then we give it the same type as the expression + let r#type = if annotated_type != Type::Error { + // Now check if LHS is the same type as the RHS + // Importantly, we do not coerce any types implicitly + self.unify_with_coercions(&expr_type, &annotated_type, expression, || { + TypeCheckError::TypeMismatch { + expected_typ: annotated_type.to_string(), + expr_typ: expr_type.to_string(), + expr_span, + } + }); + if annotated_type.is_unsigned() { + self.lint_overflowing_uint(&expression, &annotated_type); + } + annotated_type + } else { + expr_type + }; + + let let_ = HirLetStatement { + pattern: self.elaborate_pattern(let_stmt.pattern, r#type.clone(), definition), + r#type, + expression, + attributes: let_stmt.attributes, + comptime: let_stmt.comptime, + }; + (HirStatement::Let(let_), Type::Unit) + } + + pub(super) fn elaborate_constrain(&mut self, stmt: ConstrainStatement) -> (HirStatement, Type) { + let expr_span = stmt.0.span; + let (expr_id, expr_type) = self.elaborate_expression(stmt.0); + + // Must type check the assertion message expression so that we instantiate bindings + let msg = stmt.1.map(|assert_msg_expr| self.elaborate_expression(assert_msg_expr).0); + + self.unify(&expr_type, &Type::Bool, || TypeCheckError::TypeMismatch { + expr_typ: expr_type.to_string(), + expected_typ: Type::Bool.to_string(), + expr_span, + }); + + (HirStatement::Constrain(HirConstrainStatement(expr_id, self.file, msg)), Type::Unit) + } + + pub(super) fn elaborate_assign(&mut self, assign: AssignStatement) -> (HirStatement, Type) { + let span = assign.expression.span; + let (expression, expr_type) = self.elaborate_expression(assign.expression); + let (lvalue, lvalue_type, mutable) = self.elaborate_lvalue(assign.lvalue, span); + + if !mutable { + let (name, span) = self.get_lvalue_name_and_span(&lvalue); + self.push_err(TypeCheckError::VariableMustBeMutable { name, span }); + } + + self.unify_with_coercions(&expr_type, &lvalue_type, expression, || { + TypeCheckError::TypeMismatchWithSource { + actual: expr_type.clone(), + expected: lvalue_type.clone(), + span, + source: Source::Assignment, + } + }); + + let stmt = HirAssignStatement { lvalue, expression }; + (HirStatement::Assign(stmt), Type::Unit) + } + + pub(super) fn elaborate_for(&mut self, for_loop: ForLoopStatement) -> (HirStatement, Type) { + let (start, end) = match for_loop.range { + ForRange::Range(start, end) => (start, end), + ForRange::Array(_) => { + let for_stmt = + for_loop.range.into_for(for_loop.identifier, for_loop.block, for_loop.span); + + return self.elaborate_statement_value(for_stmt); + } + }; + + let start_span = start.span; + let end_span = end.span; + + let (start_range, start_range_type) = self.elaborate_expression(start); + let (end_range, end_range_type) = self.elaborate_expression(end); + let (identifier, block) = (for_loop.identifier, for_loop.block); + + self.nested_loops += 1; + self.push_scope(); + + // TODO: For loop variables are currently mutable by default since we haven't + // yet implemented syntax for them to be optionally mutable. + let kind = DefinitionKind::Local(None); + let identifier = self.add_variable_decl(identifier, false, true, kind); + + // Check that start range and end range have the same types + let range_span = start_span.merge(end_span); + self.unify(&start_range_type, &end_range_type, || TypeCheckError::TypeMismatch { + expected_typ: start_range_type.to_string(), + expr_typ: end_range_type.to_string(), + expr_span: range_span, + }); + + let expected_type = self.polymorphic_integer(); + + self.unify(&start_range_type, &expected_type, || TypeCheckError::TypeCannotBeUsed { + typ: start_range_type.clone(), + place: "for loop", + span: range_span, + }); + + self.interner.push_definition_type(identifier.id, start_range_type); + + let (block, _block_type) = self.elaborate_expression(block); + + self.pop_scope(); + self.nested_loops -= 1; + + let statement = + HirStatement::For(HirForStatement { start_range, end_range, block, identifier }); + + (statement, Type::Unit) + } + + fn elaborate_jump(&mut self, is_break: bool, span: noirc_errors::Span) -> (HirStatement, Type) { + if !self.in_unconstrained_fn { + self.push_err(ResolverError::JumpInConstrainedFn { is_break, span }); + } + if self.nested_loops == 0 { + self.push_err(ResolverError::JumpOutsideLoop { is_break, span }); + } + + let expr = if is_break { HirStatement::Break } else { HirStatement::Continue }; + (expr, self.interner.next_type_variable()) + } + + fn get_lvalue_name_and_span(&self, lvalue: &HirLValue) -> (String, Span) { + match lvalue { + HirLValue::Ident(name, _) => { + let span = name.location.span; + + if let Some(definition) = self.interner.try_definition(name.id) { + (definition.name.clone(), span) + } else { + ("(undeclared variable)".into(), span) + } + } + HirLValue::MemberAccess { object, .. } => self.get_lvalue_name_and_span(object), + HirLValue::Index { array, .. } => self.get_lvalue_name_and_span(array), + HirLValue::Dereference { lvalue, .. } => self.get_lvalue_name_and_span(lvalue), + } + } + + fn elaborate_lvalue(&mut self, lvalue: LValue, assign_span: Span) -> (HirLValue, Type, bool) { + match lvalue { + LValue::Ident(ident) => { + let mut mutable = true; + let (ident, scope_index) = self.find_variable_or_default(&ident); + self.resolve_local_variable(ident.clone(), scope_index); + + let typ = if ident.id == DefinitionId::dummy_id() { + Type::Error + } else { + if let Some(definition) = self.interner.try_definition(ident.id) { + mutable = definition.mutable; + } + + let typ = self.interner.definition_type(ident.id).instantiate(self.interner).0; + typ.follow_bindings() + }; + + (HirLValue::Ident(ident.clone(), typ.clone()), typ, mutable) + } + LValue::MemberAccess { object, field_name, span } => { + let (object, lhs_type, mut mutable) = self.elaborate_lvalue(*object, assign_span); + let mut object = Box::new(object); + let field_name = field_name.clone(); + + let object_ref = &mut object; + let mutable_ref = &mut mutable; + let location = Location::new(span, self.file); + + let dereference_lhs = move |_: &mut Self, _, element_type| { + // We must create a temporary value first to move out of object_ref before + // we eventually reassign to it. + let id = DefinitionId::dummy_id(); + let ident = HirIdent::non_trait_method(id, location); + let tmp_value = HirLValue::Ident(ident, Type::Error); + + let lvalue = std::mem::replace(object_ref, Box::new(tmp_value)); + *object_ref = + Box::new(HirLValue::Dereference { lvalue, element_type, location }); + *mutable_ref = true; + }; + + let name = &field_name.0.contents; + let (object_type, field_index) = self + .check_field_access(&lhs_type, name, field_name.span(), Some(dereference_lhs)) + .unwrap_or((Type::Error, 0)); + + let field_index = Some(field_index); + let typ = object_type.clone(); + let lvalue = + HirLValue::MemberAccess { object, field_name, field_index, typ, location }; + (lvalue, object_type, mutable) + } + LValue::Index { array, index, span } => { + let expr_span = index.span; + let (index, index_type) = self.elaborate_expression(index); + let location = Location::new(span, self.file); + + let expected = self.polymorphic_integer_or_field(); + self.unify(&index_type, &expected, || TypeCheckError::TypeMismatch { + expected_typ: "an integer".to_owned(), + expr_typ: index_type.to_string(), + expr_span, + }); + + let (mut lvalue, mut lvalue_type, mut mutable) = + self.elaborate_lvalue(*array, assign_span); + + // Before we check that the lvalue is an array, try to dereference it as many times + // as needed to unwrap any &mut wrappers. + while let Type::MutableReference(element) = lvalue_type.follow_bindings() { + let element_type = element.as_ref().clone(); + lvalue = + HirLValue::Dereference { lvalue: Box::new(lvalue), element_type, location }; + lvalue_type = *element; + // We know this value to be mutable now since we found an `&mut` + mutable = true; + } + + let typ = match lvalue_type.follow_bindings() { + Type::Array(_, elem_type) => *elem_type, + Type::Slice(elem_type) => *elem_type, + Type::Error => Type::Error, + Type::String(_) => { + let (_lvalue_name, lvalue_span) = self.get_lvalue_name_and_span(&lvalue); + self.push_err(TypeCheckError::StringIndexAssign { span: lvalue_span }); + Type::Error + } + other => { + // TODO: Need a better span here + self.push_err(TypeCheckError::TypeMismatch { + expected_typ: "array".to_string(), + expr_typ: other.to_string(), + expr_span: assign_span, + }); + Type::Error + } + }; + + let array = Box::new(lvalue); + let array_type = typ.clone(); + (HirLValue::Index { array, index, typ, location }, array_type, mutable) + } + LValue::Dereference(lvalue, span) => { + let (lvalue, reference_type, _) = self.elaborate_lvalue(*lvalue, assign_span); + let lvalue = Box::new(lvalue); + let location = Location::new(span, self.file); + + let element_type = Type::type_variable(self.interner.next_type_variable_id()); + let expected_type = Type::MutableReference(Box::new(element_type.clone())); + + self.unify(&reference_type, &expected_type, || TypeCheckError::TypeMismatch { + expected_typ: expected_type.to_string(), + expr_typ: reference_type.to_string(), + expr_span: assign_span, + }); + + // Dereferences are always mutable since we already type checked against a &mut T + let typ = element_type.clone(); + let lvalue = HirLValue::Dereference { lvalue, element_type, location }; + (lvalue, typ, true) + } + } + } + + /// Type checks a field access, adding dereference operators as necessary + pub(super) fn check_field_access( + &mut self, + lhs_type: &Type, + field_name: &str, + span: Span, + dereference_lhs: Option, + ) -> Option<(Type, usize)> { + let lhs_type = lhs_type.follow_bindings(); + + match &lhs_type { + Type::Struct(s, args) => { + let s = s.borrow(); + if let Some((field, index)) = s.get_field(field_name, args) { + return Some((field, index)); + } + } + Type::Tuple(elements) => { + if let Ok(index) = field_name.parse::() { + let length = elements.len(); + if index < length { + return Some((elements[index].clone(), index)); + } else { + self.push_err(TypeCheckError::TupleIndexOutOfBounds { + index, + lhs_type, + length, + span, + }); + return None; + } + } + } + // If the lhs is a mutable reference we automatically transform + // lhs.field into (*lhs).field + Type::MutableReference(element) => { + if let Some(mut dereference_lhs) = dereference_lhs { + dereference_lhs(self, lhs_type.clone(), element.as_ref().clone()); + return self.check_field_access( + element, + field_name, + span, + Some(dereference_lhs), + ); + } else { + let (element, index) = + self.check_field_access(element, field_name, span, dereference_lhs)?; + return Some((Type::MutableReference(Box::new(element)), index)); + } + } + _ => (), + } + + // If we get here the type has no field named 'access.rhs'. + // Now we specialize the error message based on whether we know the object type in question yet. + if let Type::TypeVariable(..) = &lhs_type { + self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); + } else if lhs_type != Type::Error { + self.push_err(TypeCheckError::AccessUnknownMember { + lhs_type, + field_name: field_name.to_string(), + span, + }); + } + + None + } + + pub(super) fn elaborate_comptime(&self, _statement: Statement) -> (HirStatement, Type) { + todo!("Comptime scanning") + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs new file mode 100644 index 00000000000..4c8364b6dda --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs @@ -0,0 +1,1438 @@ +use std::rc::Rc; + +use iter_extended::vecmap; +use noirc_errors::{Location, Span}; + +use crate::{ + ast::{BinaryOpKind, IntegerBitSize, UnresolvedTraitConstraint, UnresolvedTypeExpression}, + hir::{ + def_map::ModuleDefId, + resolution::{ + errors::ResolverError, + import::PathResolution, + resolver::{verify_mutable_reference, SELF_TYPE_NAME}, + }, + type_check::{Source, TypeCheckError}, + }, + hir_def::{ + expr::{ + HirBinaryOp, HirCallExpression, HirIdent, HirMemberAccess, HirMethodReference, + HirPrefixExpression, + }, + function::FuncMeta, + traits::{Trait, TraitConstraint}, + }, + macros_api::{ + HirExpression, HirLiteral, HirStatement, Path, PathKind, SecondaryAttribute, Signedness, + UnaryOp, UnresolvedType, UnresolvedTypeData, + }, + node_interner::{DefinitionKind, ExprId, GlobalId, TraitId, TraitImplKind, TraitMethodId}, + Generics, Shared, StructType, Type, TypeAlias, TypeBinding, TypeVariable, TypeVariableKind, +}; + +use super::Elaborator; + +impl<'context> Elaborator<'context> { + /// Translates an UnresolvedType to a Type + pub(super) fn resolve_type(&mut self, typ: UnresolvedType) -> Type { + let span = typ.span; + let resolved_type = self.resolve_type_inner(typ, &mut vec![]); + if resolved_type.is_nested_slice() { + self.push_err(ResolverError::NestedSlices { span: span.unwrap() }); + } + + resolved_type + } + + /// Translates an UnresolvedType into a Type and appends any + /// freshly created TypeVariables created to new_variables. + pub fn resolve_type_inner( + &mut self, + typ: UnresolvedType, + new_variables: &mut Generics, + ) -> Type { + use crate::ast::UnresolvedTypeData::*; + + let resolved_type = match typ.typ { + FieldElement => Type::FieldElement, + Array(size, elem) => { + let elem = Box::new(self.resolve_type_inner(*elem, new_variables)); + let size = self.resolve_array_size(Some(size), new_variables); + Type::Array(Box::new(size), elem) + } + Slice(elem) => { + let elem = Box::new(self.resolve_type_inner(*elem, new_variables)); + Type::Slice(elem) + } + Expression(expr) => self.convert_expression_type(expr), + Integer(sign, bits) => Type::Integer(sign, bits), + Bool => Type::Bool, + String(size) => { + let resolved_size = self.resolve_array_size(size, new_variables); + Type::String(Box::new(resolved_size)) + } + FormatString(size, fields) => { + let resolved_size = self.convert_expression_type(size); + let fields = self.resolve_type_inner(*fields, new_variables); + Type::FmtString(Box::new(resolved_size), Box::new(fields)) + } + Code => Type::Code, + Unit => Type::Unit, + Unspecified => Type::Error, + Error => Type::Error, + Named(path, args, _) => self.resolve_named_type(path, args, new_variables), + TraitAsType(path, args) => self.resolve_trait_as_type(path, args, new_variables), + + Tuple(fields) => { + Type::Tuple(vecmap(fields, |field| self.resolve_type_inner(field, new_variables))) + } + Function(args, ret, env) => { + let args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables)); + let ret = Box::new(self.resolve_type_inner(*ret, new_variables)); + + // expect() here is valid, because the only places we don't have a span are omitted types + // e.g. a function without return type implicitly has a spanless UnresolvedType::Unit return type + // To get an invalid env type, the user must explicitly specify the type, which will have a span + let env_span = + env.span.expect("Unexpected missing span for closure environment type"); + + let env = Box::new(self.resolve_type_inner(*env, new_variables)); + + match *env { + Type::Unit | Type::Tuple(_) | Type::NamedGeneric(_, _) => { + Type::Function(args, ret, env) + } + _ => { + self.push_err(ResolverError::InvalidClosureEnvironment { + typ: *env, + span: env_span, + }); + Type::Error + } + } + } + MutableReference(element) => { + Type::MutableReference(Box::new(self.resolve_type_inner(*element, new_variables))) + } + Parenthesized(typ) => self.resolve_type_inner(*typ, new_variables), + }; + + if let Type::Struct(_, _) = resolved_type { + if let Some(unresolved_span) = typ.span { + // Record the location of the type reference + self.interner.push_type_ref_location( + resolved_type.clone(), + Location::new(unresolved_span, self.file), + ); + } + } + resolved_type + } + + pub fn find_generic(&self, target_name: &str) -> Option<&(Rc, TypeVariable, Span)> { + self.generics.iter().find(|(name, _, _)| name.as_ref() == target_name) + } + + fn resolve_named_type( + &mut self, + path: Path, + args: Vec, + new_variables: &mut Generics, + ) -> Type { + if args.is_empty() { + if let Some(typ) = self.lookup_generic_or_global_type(&path) { + return typ; + } + } + + // Check if the path is a type variable first. We currently disallow generics on type + // variables since we do not support higher-kinded types. + if path.segments.len() == 1 { + let name = &path.last_segment().0.contents; + + if name == SELF_TYPE_NAME { + if let Some(self_type) = self.self_type.clone() { + if !args.is_empty() { + self.push_err(ResolverError::GenericsOnSelfType { span: path.span() }); + } + return self_type; + } + } + } + + let span = path.span(); + let mut args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables)); + + if let Some(type_alias) = self.lookup_type_alias(path.clone()) { + let type_alias = type_alias.borrow(); + let expected_generic_count = type_alias.generics.len(); + let type_alias_string = type_alias.to_string(); + let id = type_alias.id; + + self.verify_generics_count(expected_generic_count, &mut args, span, || { + type_alias_string + }); + + if let Some(item) = self.current_item { + self.interner.add_type_alias_dependency(item, id); + } + + // Collecting Type Alias references [Location]s to be used by LSP in order + // to resolve the definition of the type alias + self.interner.add_type_alias_ref(id, Location::new(span, self.file)); + + // Because there is no ordering to when type aliases (and other globals) are resolved, + // it is possible for one to refer to an Error type and issue no error if it is set + // equal to another type alias. Fixing this fully requires an analysis to create a DFG + // of definition ordering, but for now we have an explicit check here so that we at + // least issue an error that the type was not found instead of silently passing. + let alias = self.interner.get_type_alias(id); + return Type::Alias(alias, args); + } + + match self.lookup_struct_or_error(path) { + Some(struct_type) => { + if self.resolving_ids.contains(&struct_type.borrow().id) { + self.push_err(ResolverError::SelfReferentialStruct { + span: struct_type.borrow().name.span(), + }); + + return Type::Error; + } + + let expected_generic_count = struct_type.borrow().generics.len(); + if !self.in_contract + && self + .interner + .struct_attributes(&struct_type.borrow().id) + .iter() + .any(|attr| matches!(attr, SecondaryAttribute::Abi(_))) + { + self.push_err(ResolverError::AbiAttributeOutsideContract { + span: struct_type.borrow().name.span(), + }); + } + self.verify_generics_count(expected_generic_count, &mut args, span, || { + struct_type.borrow().to_string() + }); + + if let Some(current_item) = self.current_item { + let dependency_id = struct_type.borrow().id; + self.interner.add_type_dependency(current_item, dependency_id); + } + + Type::Struct(struct_type, args) + } + None => Type::Error, + } + } + + fn resolve_trait_as_type( + &mut self, + path: Path, + args: Vec, + new_variables: &mut Generics, + ) -> Type { + let args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables)); + + if let Some(t) = self.lookup_trait_or_error(path) { + Type::TraitAsType(t.id, Rc::new(t.name.to_string()), args) + } else { + Type::Error + } + } + + fn verify_generics_count( + &mut self, + expected_count: usize, + args: &mut Vec, + span: Span, + type_name: impl FnOnce() -> String, + ) { + if args.len() != expected_count { + self.push_err(ResolverError::IncorrectGenericCount { + span, + item_name: type_name(), + actual: args.len(), + expected: expected_count, + }); + + // Fix the generic count so we can continue typechecking + args.resize_with(expected_count, || Type::Error); + } + } + + pub fn lookup_generic_or_global_type(&mut self, path: &Path) -> Option { + if path.segments.len() == 1 { + let name = &path.last_segment().0.contents; + if let Some((name, var, _)) = self.find_generic(name) { + return Some(Type::NamedGeneric(var.clone(), name.clone())); + } + } + + // If we cannot find a local generic of the same name, try to look up a global + match self.resolve_path(path.clone()) { + Ok(ModuleDefId::GlobalId(id)) => { + if let Some(current_item) = self.current_item { + self.interner.add_global_dependency(current_item, id); + } + + Some(Type::Constant(self.eval_global_as_array_length(id, path))) + } + _ => None, + } + } + + fn resolve_array_size( + &mut self, + length: Option, + new_variables: &mut Generics, + ) -> Type { + match length { + None => { + let id = self.interner.next_type_variable_id(); + let typevar = TypeVariable::unbound(id); + new_variables.push(typevar.clone()); + + // 'Named'Generic is a bit of a misnomer here, we want a type variable that + // wont be bound over but this one has no name since we do not currently + // require users to explicitly be generic over array lengths. + Type::NamedGeneric(typevar, Rc::new("".into())) + } + Some(length) => self.convert_expression_type(length), + } + } + + pub(super) fn convert_expression_type(&mut self, length: UnresolvedTypeExpression) -> Type { + match length { + UnresolvedTypeExpression::Variable(path) => { + self.lookup_generic_or_global_type(&path).unwrap_or_else(|| { + self.push_err(ResolverError::NoSuchNumericTypeVariable { path }); + Type::Constant(0) + }) + } + UnresolvedTypeExpression::Constant(int, _) => Type::Constant(int), + UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, _) => { + let (lhs_span, rhs_span) = (lhs.span(), rhs.span()); + let lhs = self.convert_expression_type(*lhs); + let rhs = self.convert_expression_type(*rhs); + + match (lhs, rhs) { + (Type::Constant(lhs), Type::Constant(rhs)) => { + Type::Constant(op.function()(lhs, rhs)) + } + (lhs, _) => { + let span = + if !matches!(lhs, Type::Constant(_)) { lhs_span } else { rhs_span }; + self.push_err(ResolverError::InvalidArrayLengthExpr { span }); + Type::Constant(0) + } + } + } + } + } + + // this resolves Self::some_static_method, inside an impl block (where we don't have a concrete self_type) + // + // Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not + // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` + fn resolve_trait_static_method_by_self( + &mut self, + path: &Path, + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { + let trait_id = self.trait_id?; + + if path.kind == PathKind::Plain && path.segments.len() == 2 { + let name = &path.segments[0].0.contents; + let method = &path.segments[1]; + + if name == SELF_TYPE_NAME { + let the_trait = self.interner.get_trait(trait_id); + let method = the_trait.find_method(method.0.contents.as_str())?; + + let constraint = TraitConstraint { + typ: self.self_type.clone()?, + trait_generics: Type::from_generics(&the_trait.generics), + trait_id, + }; + return Some((method, constraint, false)); + } + } + None + } + + // this resolves TraitName::some_static_method + // + // Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not + // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` + fn resolve_trait_static_method( + &mut self, + path: &Path, + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { + if path.kind == PathKind::Plain && path.segments.len() == 2 { + let method = &path.segments[1]; + + let mut trait_path = path.clone(); + trait_path.pop(); + let trait_id = self.lookup(trait_path).ok()?; + let the_trait = self.interner.get_trait(trait_id); + + let method = the_trait.find_method(method.0.contents.as_str())?; + let constraint = TraitConstraint { + typ: Type::TypeVariable( + the_trait.self_type_typevar.clone(), + TypeVariableKind::Normal, + ), + trait_generics: Type::from_generics(&the_trait.generics), + trait_id, + }; + return Some((method, constraint, false)); + } + None + } + + // This resolves a static trait method T::trait_method by iterating over the where clause + // + // Returns the trait method, trait constraint, and whether the impl is assumed from a where + // clause. This is always true since this helper searches where clauses for a generic constraint. + // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` + fn resolve_trait_method_by_named_generic( + &mut self, + path: &Path, + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { + if path.segments.len() != 2 { + return None; + } + + for UnresolvedTraitConstraint { typ, trait_bound } in self.trait_bounds.clone() { + if let UnresolvedTypeData::Named(constraint_path, _, _) = &typ.typ { + // if `path` is `T::method_name`, we're looking for constraint of the form `T: SomeTrait` + if constraint_path.segments.len() == 1 + && path.segments[0] != constraint_path.last_segment() + { + continue; + } + + if let Ok(ModuleDefId::TraitId(trait_id)) = + self.resolve_path(trait_bound.trait_path.clone()) + { + let the_trait = self.interner.get_trait(trait_id); + if let Some(method) = + the_trait.find_method(path.segments.last().unwrap().0.contents.as_str()) + { + let constraint = TraitConstraint { + trait_id, + typ: self.resolve_type(typ.clone()), + trait_generics: vecmap(trait_bound.trait_generics, |typ| { + self.resolve_type(typ) + }), + }; + return Some((method, constraint, true)); + } + } + } + } + None + } + + // Try to resolve the given trait method path. + // + // Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not + // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` + pub(super) fn resolve_trait_generic_path( + &mut self, + path: &Path, + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { + self.resolve_trait_static_method_by_self(path) + .or_else(|| self.resolve_trait_static_method(path)) + .or_else(|| self.resolve_trait_method_by_named_generic(path)) + } + + fn eval_global_as_array_length(&mut self, global: GlobalId, path: &Path) -> u64 { + let Some(stmt) = self.interner.get_global_let_statement(global) else { + let path = path.clone(); + self.push_err(ResolverError::NoSuchNumericTypeVariable { path }); + return 0; + }; + + let length = stmt.expression; + let span = self.interner.expr_span(&length); + let result = self.try_eval_array_length_id(length, span); + + match result.map(|length| length.try_into()) { + Ok(Ok(length_value)) => return length_value, + Ok(Err(_cast_err)) => self.push_err(ResolverError::IntegerTooLarge { span }), + Err(Some(error)) => self.push_err(error), + Err(None) => (), + } + 0 + } + + fn try_eval_array_length_id( + &self, + rhs: ExprId, + span: Span, + ) -> Result> { + // Arbitrary amount of recursive calls to try before giving up + let fuel = 100; + self.try_eval_array_length_id_with_fuel(rhs, span, fuel) + } + + fn try_eval_array_length_id_with_fuel( + &self, + rhs: ExprId, + span: Span, + fuel: u32, + ) -> Result> { + if fuel == 0 { + // If we reach here, it is likely from evaluating cyclic globals. We expect an error to + // be issued for them after name resolution so issue no error now. + return Err(None); + } + + match self.interner.expression(&rhs) { + HirExpression::Literal(HirLiteral::Integer(int, false)) => { + int.try_into_u128().ok_or(Some(ResolverError::IntegerTooLarge { span })) + } + HirExpression::Ident(ident) => { + let definition = self.interner.definition(ident.id); + match definition.kind { + DefinitionKind::Global(global_id) => { + let let_statement = self.interner.get_global_let_statement(global_id); + if let Some(let_statement) = let_statement { + let expression = let_statement.expression; + self.try_eval_array_length_id_with_fuel(expression, span, fuel - 1) + } else { + Err(Some(ResolverError::InvalidArrayLengthExpr { span })) + } + } + _ => Err(Some(ResolverError::InvalidArrayLengthExpr { span })), + } + } + HirExpression::Infix(infix) => { + let lhs = self.try_eval_array_length_id_with_fuel(infix.lhs, span, fuel - 1)?; + let rhs = self.try_eval_array_length_id_with_fuel(infix.rhs, span, fuel - 1)?; + + match infix.operator.kind { + BinaryOpKind::Add => Ok(lhs + rhs), + BinaryOpKind::Subtract => Ok(lhs - rhs), + BinaryOpKind::Multiply => Ok(lhs * rhs), + BinaryOpKind::Divide => Ok(lhs / rhs), + BinaryOpKind::Equal => Ok((lhs == rhs) as u128), + BinaryOpKind::NotEqual => Ok((lhs != rhs) as u128), + BinaryOpKind::Less => Ok((lhs < rhs) as u128), + BinaryOpKind::LessEqual => Ok((lhs <= rhs) as u128), + BinaryOpKind::Greater => Ok((lhs > rhs) as u128), + BinaryOpKind::GreaterEqual => Ok((lhs >= rhs) as u128), + BinaryOpKind::And => Ok(lhs & rhs), + BinaryOpKind::Or => Ok(lhs | rhs), + BinaryOpKind::Xor => Ok(lhs ^ rhs), + BinaryOpKind::ShiftRight => Ok(lhs >> rhs), + BinaryOpKind::ShiftLeft => Ok(lhs << rhs), + BinaryOpKind::Modulo => Ok(lhs % rhs), + } + } + _other => Err(Some(ResolverError::InvalidArrayLengthExpr { span })), + } + } + + /// Check if an assignment is overflowing with respect to `annotated_type` + /// in a declaration statement where `annotated_type` is an unsigned integer + pub(super) fn lint_overflowing_uint(&mut self, rhs_expr: &ExprId, annotated_type: &Type) { + let expr = self.interner.expression(rhs_expr); + let span = self.interner.expr_span(rhs_expr); + match expr { + HirExpression::Literal(HirLiteral::Integer(value, false)) => { + let v = value.to_u128(); + if let Type::Integer(_, bit_count) = annotated_type { + let bit_count: u32 = (*bit_count).into(); + let max = 1 << bit_count; + if v >= max { + self.push_err(TypeCheckError::OverflowingAssignment { + expr: value, + ty: annotated_type.clone(), + range: format!("0..={}", max - 1), + span, + }); + }; + }; + } + HirExpression::Prefix(expr) => { + self.lint_overflowing_uint(&expr.rhs, annotated_type); + if matches!(expr.operator, UnaryOp::Minus) { + self.push_err(TypeCheckError::InvalidUnaryOp { + kind: "annotated_type".to_string(), + span, + }); + } + } + HirExpression::Infix(expr) => { + self.lint_overflowing_uint(&expr.lhs, annotated_type); + self.lint_overflowing_uint(&expr.rhs, annotated_type); + } + _ => {} + } + } + + pub(super) fn unify( + &mut self, + actual: &Type, + expected: &Type, + make_error: impl FnOnce() -> TypeCheckError, + ) { + let mut errors = Vec::new(); + actual.unify(expected, &mut errors, make_error); + self.errors.extend(errors.into_iter().map(|error| (error.into(), self.file))); + } + + /// Wrapper of Type::unify_with_coercions using self.errors + pub(super) fn unify_with_coercions( + &mut self, + actual: &Type, + expected: &Type, + expression: ExprId, + make_error: impl FnOnce() -> TypeCheckError, + ) { + let mut errors = Vec::new(); + actual.unify_with_coercions(expected, expression, self.interner, &mut errors, make_error); + self.errors.extend(errors.into_iter().map(|error| (error.into(), self.file))); + } + + /// Return a fresh integer or field type variable and log it + /// in self.type_variables to default it later. + pub(super) fn polymorphic_integer_or_field(&mut self) -> Type { + let typ = Type::polymorphic_integer_or_field(self.interner); + self.type_variables.push(typ.clone()); + typ + } + + /// Return a fresh integer type variable and log it + /// in self.type_variables to default it later. + pub(super) fn polymorphic_integer(&mut self) -> Type { + let typ = Type::polymorphic_integer(self.interner); + self.type_variables.push(typ.clone()); + typ + } + + /// Translates a (possibly Unspecified) UnresolvedType to a Type. + /// Any UnresolvedType::Unspecified encountered are replaced with fresh type variables. + pub(super) fn resolve_inferred_type(&mut self, typ: UnresolvedType) -> Type { + match &typ.typ { + UnresolvedTypeData::Unspecified => self.interner.next_type_variable(), + _ => self.resolve_type_inner(typ, &mut vec![]), + } + } + + pub(super) fn type_check_prefix_operand( + &mut self, + op: &crate::ast::UnaryOp, + rhs_type: &Type, + span: Span, + ) -> Type { + let mut unify = |this: &mut Self, expected| { + this.unify(rhs_type, &expected, || TypeCheckError::TypeMismatch { + expr_typ: rhs_type.to_string(), + expected_typ: expected.to_string(), + expr_span: span, + }); + expected + }; + + match op { + crate::ast::UnaryOp::Minus => { + if rhs_type.is_unsigned() { + self.push_err(TypeCheckError::InvalidUnaryOp { + kind: rhs_type.to_string(), + span, + }); + } + let expected = self.polymorphic_integer_or_field(); + self.unify(rhs_type, &expected, || TypeCheckError::InvalidUnaryOp { + kind: rhs_type.to_string(), + span, + }); + expected + } + crate::ast::UnaryOp::Not => { + let rhs_type = rhs_type.follow_bindings(); + + // `!` can work on booleans or integers + if matches!(rhs_type, Type::Integer(..)) { + return rhs_type; + } + + unify(self, Type::Bool) + } + crate::ast::UnaryOp::MutableReference => { + Type::MutableReference(Box::new(rhs_type.follow_bindings())) + } + crate::ast::UnaryOp::Dereference { implicitly_added: _ } => { + let element_type = self.interner.next_type_variable(); + unify(self, Type::MutableReference(Box::new(element_type.clone()))); + element_type + } + } + } + + /// Insert as many dereference operations as necessary to automatically dereference a method + /// call object to its base value type T. + pub(super) fn insert_auto_dereferences(&mut self, object: ExprId, typ: Type) -> (ExprId, Type) { + if let Type::MutableReference(element) = typ { + let location = self.interner.id_location(object); + + let object = self.interner.push_expr(HirExpression::Prefix(HirPrefixExpression { + operator: UnaryOp::Dereference { implicitly_added: true }, + rhs: object, + })); + self.interner.push_expr_type(object, element.as_ref().clone()); + self.interner.push_expr_location(object, location.span, location.file); + + // Recursively dereference to allow for converting &mut &mut T to T + self.insert_auto_dereferences(object, *element) + } else { + (object, typ) + } + } + + /// Given a method object: `(*foo).bar` of a method call `(*foo).bar.baz()`, remove the + /// implicitly added dereference operator if one is found. + /// + /// Returns Some(new_expr_id) if a dereference was removed and None otherwise. + fn try_remove_implicit_dereference(&mut self, object: ExprId) -> Option { + match self.interner.expression(&object) { + HirExpression::MemberAccess(mut access) => { + let new_lhs = self.try_remove_implicit_dereference(access.lhs)?; + access.lhs = new_lhs; + access.is_offset = true; + + // `object` will have a different type now, which will be filled in + // later when type checking the method call as a function call. + self.interner.replace_expr(&object, HirExpression::MemberAccess(access)); + Some(object) + } + HirExpression::Prefix(prefix) => match prefix.operator { + // Found a dereference we can remove. Now just replace it with its rhs to remove it. + UnaryOp::Dereference { implicitly_added: true } => Some(prefix.rhs), + _ => None, + }, + _ => None, + } + } + + fn bind_function_type_impl( + &mut self, + fn_params: &[Type], + fn_ret: &Type, + callsite_args: &[(Type, ExprId, Span)], + span: Span, + ) -> Type { + if fn_params.len() != callsite_args.len() { + self.push_err(TypeCheckError::ParameterCountMismatch { + expected: fn_params.len(), + found: callsite_args.len(), + span, + }); + return Type::Error; + } + + for (param, (arg, _, arg_span)) in fn_params.iter().zip(callsite_args) { + self.unify(arg, param, || TypeCheckError::TypeMismatch { + expected_typ: param.to_string(), + expr_typ: arg.to_string(), + expr_span: *arg_span, + }); + } + + fn_ret.clone() + } + + pub(super) fn bind_function_type( + &mut self, + function: Type, + args: Vec<(Type, ExprId, Span)>, + span: Span, + ) -> Type { + // Could do a single unification for the entire function type, but matching beforehand + // lets us issue a more precise error on the individual argument that fails to type check. + match function { + Type::TypeVariable(binding, TypeVariableKind::Normal) => { + if let TypeBinding::Bound(typ) = &*binding.borrow() { + return self.bind_function_type(typ.clone(), args, span); + } + + let ret = self.interner.next_type_variable(); + let args = vecmap(args, |(arg, _, _)| arg); + let env_type = self.interner.next_type_variable(); + let expected = Type::Function(args, Box::new(ret.clone()), Box::new(env_type)); + + if let Err(error) = binding.try_bind(expected, span) { + self.push_err(error); + } + ret + } + // The closure env is ignored on purpose: call arguments never place + // constraints on closure environments. + Type::Function(parameters, ret, _env) => { + self.bind_function_type_impl(¶meters, &ret, &args, span) + } + Type::Error => Type::Error, + found => { + self.push_err(TypeCheckError::ExpectedFunction { found, span }); + Type::Error + } + } + } + + pub(super) fn check_cast(&mut self, from: Type, to: &Type, span: Span) -> Type { + match from.follow_bindings() { + Type::Integer(..) + | Type::FieldElement + | Type::TypeVariable(_, TypeVariableKind::IntegerOrField) + | Type::TypeVariable(_, TypeVariableKind::Integer) + | Type::Bool => (), + + Type::TypeVariable(_, _) => { + self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); + return Type::Error; + } + Type::Error => return Type::Error, + from => { + self.push_err(TypeCheckError::InvalidCast { from, span }); + return Type::Error; + } + } + + match to { + Type::Integer(sign, bits) => Type::Integer(*sign, *bits), + Type::FieldElement => Type::FieldElement, + Type::Bool => Type::Bool, + Type::Error => Type::Error, + _ => { + self.push_err(TypeCheckError::UnsupportedCast { span }); + Type::Error + } + } + } + + // Given a binary comparison operator and another type. This method will produce the output type + // and a boolean indicating whether to use the trait impl corresponding to the operator + // or not. A value of false indicates the caller to use a primitive operation for this + // operator, while a true value indicates a user-provided trait impl is required. + fn comparator_operand_type_rules( + &mut self, + lhs_type: &Type, + rhs_type: &Type, + op: &HirBinaryOp, + span: Span, + ) -> Result<(Type, bool), TypeCheckError> { + use Type::*; + + match (lhs_type, rhs_type) { + // Avoid reporting errors multiple times + (Error, _) | (_, Error) => Ok((Bool, false)), + (Alias(alias, args), other) | (other, Alias(alias, args)) => { + let alias = alias.borrow().get_type(args); + self.comparator_operand_type_rules(&alias, other, op, span) + } + + // Matches on TypeVariable must be first to follow any type + // bindings. + (TypeVariable(var, _), other) | (other, TypeVariable(var, _)) => { + if let TypeBinding::Bound(binding) = &*var.borrow() { + return self.comparator_operand_type_rules(other, binding, op, span); + } + + let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span); + Ok((Bool, use_impl)) + } + (Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => { + if sign_x != sign_y { + return Err(TypeCheckError::IntegerSignedness { + sign_x: *sign_x, + sign_y: *sign_y, + span, + }); + } + if bit_width_x != bit_width_y { + return Err(TypeCheckError::IntegerBitWidth { + bit_width_x: *bit_width_x, + bit_width_y: *bit_width_y, + span, + }); + } + Ok((Bool, false)) + } + (FieldElement, FieldElement) => { + if op.kind.is_valid_for_field_type() { + Ok((Bool, false)) + } else { + Err(TypeCheckError::FieldComparison { span }) + } + } + + // <= and friends are technically valid for booleans, just not very useful + (Bool, Bool) => Ok((Bool, false)), + + (lhs, rhs) => { + self.unify(lhs, rhs, || TypeCheckError::TypeMismatchWithSource { + expected: lhs.clone(), + actual: rhs.clone(), + span: op.location.span, + source: Source::Binary, + }); + Ok((Bool, true)) + } + } + } + + /// Handles the TypeVariable case for checking binary operators. + /// Returns true if we should use the impl for the operator instead of the primitive + /// version of it. + fn bind_type_variables_for_infix( + &mut self, + lhs_type: &Type, + op: &HirBinaryOp, + rhs_type: &Type, + span: Span, + ) -> bool { + self.unify(lhs_type, rhs_type, || TypeCheckError::TypeMismatchWithSource { + expected: lhs_type.clone(), + actual: rhs_type.clone(), + source: Source::Binary, + span, + }); + + let use_impl = !lhs_type.is_numeric(); + + // If this operator isn't valid for fields we have to possibly narrow + // TypeVariableKind::IntegerOrField to TypeVariableKind::Integer. + // Doing so also ensures a type error if Field is used. + // The is_numeric check is to allow impls for custom types to bypass this. + if !op.kind.is_valid_for_field_type() && lhs_type.is_numeric() { + let target = Type::polymorphic_integer(self.interner); + + use crate::ast::BinaryOpKind::*; + use TypeCheckError::*; + self.unify(lhs_type, &target, || match op.kind { + Less | LessEqual | Greater | GreaterEqual => FieldComparison { span }, + And | Or | Xor | ShiftRight | ShiftLeft => FieldBitwiseOp { span }, + Modulo => FieldModulo { span }, + other => unreachable!("Operator {other:?} should be valid for Field"), + }); + } + + use_impl + } + + // Given a binary operator and another type. This method will produce the output type + // and a boolean indicating whether to use the trait impl corresponding to the operator + // or not. A value of false indicates the caller to use a primitive operation for this + // operator, while a true value indicates a user-provided trait impl is required. + pub(super) fn infix_operand_type_rules( + &mut self, + lhs_type: &Type, + op: &HirBinaryOp, + rhs_type: &Type, + span: Span, + ) -> Result<(Type, bool), TypeCheckError> { + if op.kind.is_comparator() { + return self.comparator_operand_type_rules(lhs_type, rhs_type, op, span); + } + + use Type::*; + match (lhs_type, rhs_type) { + // An error type on either side will always return an error + (Error, _) | (_, Error) => Ok((Error, false)), + (Alias(alias, args), other) | (other, Alias(alias, args)) => { + let alias = alias.borrow().get_type(args); + self.infix_operand_type_rules(&alias, op, other, span) + } + + // Matches on TypeVariable must be first so that we follow any type + // bindings. + (TypeVariable(int, _), other) | (other, TypeVariable(int, _)) => { + if let TypeBinding::Bound(binding) = &*int.borrow() { + return self.infix_operand_type_rules(binding, op, other, span); + } + if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight { + self.unify( + rhs_type, + &Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight), + || TypeCheckError::InvalidShiftSize { span }, + ); + let use_impl = if lhs_type.is_numeric() { + let integer_type = Type::polymorphic_integer(self.interner); + self.bind_type_variables_for_infix(lhs_type, op, &integer_type, span) + } else { + true + }; + return Ok((lhs_type.clone(), use_impl)); + } + let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span); + Ok((other.clone(), use_impl)) + } + (Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => { + if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight { + if *sign_y != Signedness::Unsigned || *bit_width_y != IntegerBitSize::Eight { + return Err(TypeCheckError::InvalidShiftSize { span }); + } + return Ok((Integer(*sign_x, *bit_width_x), false)); + } + if sign_x != sign_y { + return Err(TypeCheckError::IntegerSignedness { + sign_x: *sign_x, + sign_y: *sign_y, + span, + }); + } + if bit_width_x != bit_width_y { + return Err(TypeCheckError::IntegerBitWidth { + bit_width_x: *bit_width_x, + bit_width_y: *bit_width_y, + span, + }); + } + Ok((Integer(*sign_x, *bit_width_x), false)) + } + // The result of two Fields is always a witness + (FieldElement, FieldElement) => { + if !op.kind.is_valid_for_field_type() { + if op.kind == BinaryOpKind::Modulo { + return Err(TypeCheckError::FieldModulo { span }); + } else { + return Err(TypeCheckError::FieldBitwiseOp { span }); + } + } + Ok((FieldElement, false)) + } + + (Bool, Bool) => Ok((Bool, false)), + + (lhs, rhs) => { + if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight { + if rhs == &Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight) { + return Ok((lhs.clone(), true)); + } + return Err(TypeCheckError::InvalidShiftSize { span }); + } + self.unify(lhs, rhs, || TypeCheckError::TypeMismatchWithSource { + expected: lhs.clone(), + actual: rhs.clone(), + span: op.location.span, + source: Source::Binary, + }); + Ok((lhs.clone(), true)) + } + } + } + + /// Prerequisite: verify_trait_constraint of the operator's trait constraint. + /// + /// Although by this point the operator is expected to already have a trait impl, + /// we still need to match the operator's type against the method's instantiated type + /// to ensure the instantiation bindings are correct and the monomorphizer can + /// re-apply the needed bindings. + pub(super) fn type_check_operator_method( + &mut self, + expr_id: ExprId, + trait_method_id: TraitMethodId, + object_type: &Type, + span: Span, + ) { + let the_trait = self.interner.get_trait(trait_method_id.trait_id); + + let method = &the_trait.methods[trait_method_id.method_index]; + let (method_type, mut bindings) = method.typ.clone().instantiate(self.interner); + + match method_type { + Type::Function(args, _, _) => { + // We can cheat a bit and match against only the object type here since no operator + // overload uses other generic parameters or return types aside from the object type. + let expected_object_type = &args[0]; + self.unify(object_type, expected_object_type, || TypeCheckError::TypeMismatch { + expected_typ: expected_object_type.to_string(), + expr_typ: object_type.to_string(), + expr_span: span, + }); + } + other => { + unreachable!("Expected operator method to have a function type, but found {other}") + } + } + + // We must also remember to apply these substitutions to the object_type + // referenced by the selected trait impl, if one has yet to be selected. + let impl_kind = self.interner.get_selected_impl_for_expression(expr_id); + if let Some(TraitImplKind::Assumed { object_type, trait_generics }) = impl_kind { + let the_trait = self.interner.get_trait(trait_method_id.trait_id); + let object_type = object_type.substitute(&bindings); + bindings.insert( + the_trait.self_type_typevar_id, + (the_trait.self_type_typevar.clone(), object_type.clone()), + ); + self.interner.select_impl_for_expression( + expr_id, + TraitImplKind::Assumed { object_type, trait_generics }, + ); + } + + self.interner.store_instantiation_bindings(expr_id, bindings); + } + + pub(super) fn type_check_member_access( + &mut self, + mut access: HirMemberAccess, + expr_id: ExprId, + lhs_type: Type, + span: Span, + ) -> Type { + let access_lhs = &mut access.lhs; + + let dereference_lhs = |this: &mut Self, lhs_type, element| { + let old_lhs = *access_lhs; + *access_lhs = this.interner.push_expr(HirExpression::Prefix(HirPrefixExpression { + operator: crate::ast::UnaryOp::Dereference { implicitly_added: true }, + rhs: old_lhs, + })); + this.interner.push_expr_type(old_lhs, lhs_type); + this.interner.push_expr_type(*access_lhs, element); + + let old_location = this.interner.id_location(old_lhs); + this.interner.push_expr_location(*access_lhs, span, old_location.file); + }; + + // If this access is just a field offset, we want to avoid dereferencing + let dereference_lhs = (!access.is_offset).then_some(dereference_lhs); + + match self.check_field_access(&lhs_type, &access.rhs.0.contents, span, dereference_lhs) { + Some((element_type, index)) => { + self.interner.set_field_index(expr_id, index); + // We must update `access` in case we added any dereferences to it + self.interner.replace_expr(&expr_id, HirExpression::MemberAccess(access)); + element_type + } + None => Type::Error, + } + } + + pub(super) fn lookup_method( + &mut self, + object_type: &Type, + method_name: &str, + span: Span, + ) -> Option { + match object_type.follow_bindings() { + Type::Struct(typ, _args) => { + let id = typ.borrow().id; + match self.interner.lookup_method(object_type, id, method_name, false) { + Some(method_id) => Some(HirMethodReference::FuncId(method_id)), + None => { + self.push_err(TypeCheckError::UnresolvedMethodCall { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span, + }); + None + } + } + } + // TODO: We should allow method calls on `impl Trait`s eventually. + // For now it is fine since they are only allowed on return types. + Type::TraitAsType(..) => { + self.push_err(TypeCheckError::UnresolvedMethodCall { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span, + }); + None + } + Type::NamedGeneric(_, _) => { + let func_meta = self.interner.function_meta( + &self.current_function.expect("unexpected method outside a function"), + ); + + for constraint in &func_meta.trait_constraints { + if *object_type == constraint.typ { + if let Some(the_trait) = self.interner.try_get_trait(constraint.trait_id) { + for (method_index, method) in the_trait.methods.iter().enumerate() { + if method.name.0.contents == method_name { + let trait_method = TraitMethodId { + trait_id: constraint.trait_id, + method_index, + }; + return Some(HirMethodReference::TraitMethodId( + trait_method, + constraint.trait_generics.clone(), + )); + } + } + } + } + } + + self.push_err(TypeCheckError::UnresolvedMethodCall { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span, + }); + None + } + // Mutable references to another type should resolve to methods of their element type. + // This may be a struct or a primitive type. + Type::MutableReference(element) => self + .interner + .lookup_primitive_trait_method_mut(element.as_ref(), method_name) + .map(HirMethodReference::FuncId) + .or_else(|| self.lookup_method(&element, method_name, span)), + + // If we fail to resolve the object to a struct type, we have no way of type + // checking its arguments as we can't even resolve the name of the function + Type::Error => None, + + // The type variable must be unbound at this point since follow_bindings was called + Type::TypeVariable(_, TypeVariableKind::Normal) => { + self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); + None + } + + other => match self.interner.lookup_primitive_method(&other, method_name) { + Some(method_id) => Some(HirMethodReference::FuncId(method_id)), + None => { + self.push_err(TypeCheckError::UnresolvedMethodCall { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span, + }); + None + } + }, + } + } + + pub(super) fn type_check_call( + &mut self, + call: &HirCallExpression, + func_type: Type, + args: Vec<(Type, ExprId, Span)>, + span: Span, + ) -> Type { + // Need to setup these flags here as `self` is borrowed mutably to type check the rest of the call expression + // These flags are later used to type check calls to unconstrained functions from constrained functions + let func_mod = self.current_function.map(|func| self.interner.function_modifiers(&func)); + let is_current_func_constrained = + func_mod.map_or(true, |func_mod| !func_mod.is_unconstrained); + + let is_unconstrained_call = self.is_unconstrained_call(call.func); + self.check_if_deprecated(call.func); + + // Check that we are not passing a mutable reference from a constrained runtime to an unconstrained runtime + if is_current_func_constrained && is_unconstrained_call { + for (typ, _, _) in args.iter() { + if matches!(&typ.follow_bindings(), Type::MutableReference(_)) { + self.push_err(TypeCheckError::ConstrainedReferenceToUnconstrained { span }); + } + } + } + + let return_type = self.bind_function_type(func_type, args, span); + + // Check that we are not passing a slice from an unconstrained runtime to a constrained runtime + if is_current_func_constrained && is_unconstrained_call { + if return_type.contains_slice() { + self.push_err(TypeCheckError::UnconstrainedSliceReturnToConstrained { span }); + } else if matches!(&return_type.follow_bindings(), Type::MutableReference(_)) { + self.push_err(TypeCheckError::UnconstrainedReferenceToConstrained { span }); + } + }; + + return_type + } + + fn check_if_deprecated(&mut self, expr: ExprId) { + if let HirExpression::Ident(HirIdent { location, id, impl_kind: _ }) = + self.interner.expression(&expr) + { + if let Some(DefinitionKind::Function(func_id)) = + self.interner.try_definition(id).map(|def| &def.kind) + { + let attributes = self.interner.function_attributes(func_id); + if let Some(note) = attributes.get_deprecated_note() { + self.push_err(TypeCheckError::CallDeprecated { + name: self.interner.definition_name(id).to_string(), + note, + span: location.span, + }); + } + } + } + } + + fn is_unconstrained_call(&self, expr: ExprId) -> bool { + if let HirExpression::Ident(HirIdent { id, .. }) = self.interner.expression(&expr) { + if let Some(DefinitionKind::Function(func_id)) = + self.interner.try_definition(id).map(|def| &def.kind) + { + let modifiers = self.interner.function_modifiers(func_id); + return modifiers.is_unconstrained; + } + } + false + } + + /// Check if the given method type requires a mutable reference to the object type, and check + /// if the given object type is already a mutable reference. If not, add one. + /// This is used to automatically transform a method call: `foo.bar()` into a function + /// call: `bar(&mut foo)`. + /// + /// A notable corner case of this function is where it interacts with auto-deref of `.`. + /// If a field is being mutated e.g. `foo.bar.mutate_bar()` where `foo: &mut Foo`, the compiler + /// will insert a dereference before bar `(*foo).bar.mutate_bar()` which would cause us to + /// mutate a copy of bar rather than a reference to it. We must check for this corner case here + /// and remove the implicitly added dereference operator if we find one. + pub(super) fn try_add_mutable_reference_to_object( + &mut self, + function_type: &Type, + object_type: &mut Type, + object: &mut ExprId, + ) { + let expected_object_type = match function_type { + Type::Function(args, _, _) => args.first(), + Type::Forall(_, typ) => match typ.as_ref() { + Type::Function(args, _, _) => args.first(), + typ => unreachable!("Unexpected type for function: {typ}"), + }, + typ => unreachable!("Unexpected type for function: {typ}"), + }; + + if let Some(expected_object_type) = expected_object_type { + let actual_type = object_type.follow_bindings(); + + if matches!(expected_object_type.follow_bindings(), Type::MutableReference(_)) { + if !matches!(actual_type, Type::MutableReference(_)) { + if let Err(error) = verify_mutable_reference(self.interner, *object) { + self.push_err(TypeCheckError::ResolverError(error)); + } + + let new_type = Type::MutableReference(Box::new(actual_type)); + *object_type = new_type.clone(); + + // First try to remove a dereference operator that may have been implicitly + // inserted by a field access expression `foo.bar` on a mutable reference `foo`. + let new_object = self.try_remove_implicit_dereference(*object); + + // If that didn't work, then wrap the whole expression in an `&mut` + *object = new_object.unwrap_or_else(|| { + let location = self.interner.id_location(*object); + + let new_object = + self.interner.push_expr(HirExpression::Prefix(HirPrefixExpression { + operator: UnaryOp::MutableReference, + rhs: *object, + })); + self.interner.push_expr_type(new_object, new_type); + self.interner.push_expr_location(new_object, location.span, location.file); + new_object + }); + } + // Otherwise if the object type is a mutable reference and the method is not, insert as + // many dereferences as needed. + } else if matches!(actual_type, Type::MutableReference(_)) { + let (new_object, new_type) = self.insert_auto_dereferences(*object, actual_type); + *object_type = new_type; + *object = new_object; + } + } + } + + pub fn type_check_function_body(&mut self, body_type: Type, meta: &FuncMeta, body_id: ExprId) { + let (expr_span, empty_function) = self.function_info(body_id); + let declared_return_type = meta.return_type(); + + let func_span = self.interner.expr_span(&body_id); // XXX: We could be more specific and return the span of the last stmt, however stmts do not have spans yet + if let Type::TraitAsType(trait_id, _, generics) = declared_return_type { + if self.interner.lookup_trait_implementation(&body_type, *trait_id, generics).is_err() { + self.push_err(TypeCheckError::TypeMismatchWithSource { + expected: declared_return_type.clone(), + actual: body_type, + span: func_span, + source: Source::Return(meta.return_type.clone(), expr_span), + }); + } + } else { + self.unify_with_coercions(&body_type, declared_return_type, body_id, || { + let mut error = TypeCheckError::TypeMismatchWithSource { + expected: declared_return_type.clone(), + actual: body_type.clone(), + span: func_span, + source: Source::Return(meta.return_type.clone(), expr_span), + }; + + if empty_function { + error = error.add_context( + "implicitly returns `()` as its body has no tail or `return` expression", + ); + } + error + }); + } + } + + fn function_info(&self, function_body_id: ExprId) -> (noirc_errors::Span, bool) { + let (expr_span, empty_function) = + if let HirExpression::Block(block) = self.interner.expression(&function_body_id) { + let last_stmt = block.statements().last(); + let mut span = self.interner.expr_span(&function_body_id); + + if let Some(last_stmt) = last_stmt { + if let HirStatement::Expression(expr) = self.interner.statement(last_stmt) { + span = self.interner.expr_span(&expr); + } + } + + (span, last_stmt.is_none()) + } else { + (self.interner.expr_span(&function_body_id), false) + }; + (expr_span, empty_function) + } + + pub fn verify_trait_constraint( + &mut self, + object_type: &Type, + trait_id: TraitId, + trait_generics: &[Type], + function_ident_id: ExprId, + span: Span, + ) { + match self.interner.lookup_trait_implementation(object_type, trait_id, trait_generics) { + Ok(impl_kind) => { + self.interner.select_impl_for_expression(function_ident_id, impl_kind); + } + Err(erroring_constraints) => { + if erroring_constraints.is_empty() { + self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); + } else { + // Don't show any errors where try_get_trait returns None. + // This can happen if a trait is used that was never declared. + let constraints = erroring_constraints + .into_iter() + .map(|constraint| { + let r#trait = self.interner.try_get_trait(constraint.trait_id)?; + let mut name = r#trait.name.to_string(); + if !constraint.trait_generics.is_empty() { + let generics = + vecmap(&constraint.trait_generics, ToString::to_string); + name += &format!("<{}>", generics.join(", ")); + } + Some((constraint.typ, name)) + }) + .collect::>>(); + + if let Some(constraints) = constraints { + self.push_err(TypeCheckError::NoMatchingImplFound { constraints, span }); + } + } + } + } + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 2f6b101e62f..4aac0fec9c3 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -1,5 +1,6 @@ use super::dc_mod::collect_defs; use super::errors::{DefCollectorErrorKind, DuplicateType}; +use crate::elaborator::Elaborator; use crate::graph::CrateId; use crate::hir::comptime::{Interpreter, InterpreterError}; use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleId}; @@ -129,14 +130,18 @@ pub struct UnresolvedGlobal { /// Given a Crate root, collect all definitions in that crate pub struct DefCollector { pub(crate) def_map: CrateDefMap, - pub(crate) collected_imports: Vec, - pub(crate) collected_functions: Vec, - pub(crate) collected_types: BTreeMap, - pub(crate) collected_type_aliases: BTreeMap, - pub(crate) collected_traits: BTreeMap, - pub(crate) collected_globals: Vec, - pub(crate) collected_impls: ImplMap, - pub(crate) collected_traits_impls: Vec, + pub(crate) imports: Vec, + pub(crate) items: CollectedItems, +} + +pub struct CollectedItems { + pub(crate) functions: Vec, + pub(crate) types: BTreeMap, + pub(crate) type_aliases: BTreeMap, + pub(crate) traits: BTreeMap, + pub(crate) globals: Vec, + pub(crate) impls: ImplMap, + pub(crate) trait_impls: Vec, } /// Maps the type and the module id in which the impl is defined to the functions contained in that @@ -210,14 +215,16 @@ impl DefCollector { fn new(def_map: CrateDefMap) -> DefCollector { DefCollector { def_map, - collected_imports: vec![], - collected_functions: vec![], - collected_types: BTreeMap::new(), - collected_type_aliases: BTreeMap::new(), - collected_traits: BTreeMap::new(), - collected_impls: HashMap::new(), - collected_globals: vec![], - collected_traits_impls: vec![], + imports: vec![], + items: CollectedItems { + functions: vec![], + types: BTreeMap::new(), + type_aliases: BTreeMap::new(), + traits: BTreeMap::new(), + impls: HashMap::new(), + globals: vec![], + trait_impls: vec![], + }, } } @@ -229,6 +236,7 @@ impl DefCollector { context: &mut Context, ast: SortedModule, root_file_id: FileId, + use_elaborator: bool, macro_processors: &[&dyn MacroProcessor], ) -> Vec<(CompilationError, FileId)> { let mut errors: Vec<(CompilationError, FileId)> = vec![]; @@ -242,7 +250,12 @@ impl DefCollector { let crate_graph = &context.crate_graph[crate_id]; for dep in crate_graph.dependencies.clone() { - errors.extend(CrateDefMap::collect_defs(dep.crate_id, context, macro_processors)); + errors.extend(CrateDefMap::collect_defs( + dep.crate_id, + context, + use_elaborator, + macro_processors, + )); let dep_def_root = context.def_map(&dep.crate_id).expect("ice: def map was just created").root; @@ -275,18 +288,13 @@ impl DefCollector { // Add the current crate to the collection of DefMaps context.def_maps.insert(crate_id, def_collector.def_map); - inject_prelude(crate_id, context, crate_root, &mut def_collector.collected_imports); + inject_prelude(crate_id, context, crate_root, &mut def_collector.imports); for submodule in submodules { - inject_prelude( - crate_id, - context, - LocalModuleId(submodule), - &mut def_collector.collected_imports, - ); + inject_prelude(crate_id, context, LocalModuleId(submodule), &mut def_collector.imports); } // Resolve unresolved imports collected from the crate, one by one. - for collected_import in def_collector.collected_imports { + for collected_import in std::mem::take(&mut def_collector.imports) { match resolve_import(crate_id, &collected_import, &context.def_maps) { Ok(resolved_import) => { if let Some(error) = resolved_import.error { @@ -323,6 +331,12 @@ impl DefCollector { } } + if use_elaborator { + let mut more_errors = Elaborator::elaborate(context, crate_id, def_collector.items); + more_errors.append(&mut errors); + return errors; + } + let mut resolved_module = ResolvedModule { errors, ..Default::default() }; // We must first resolve and intern the globals before we can resolve any stmts inside each function. @@ -330,26 +344,25 @@ impl DefCollector { // // Additionally, we must resolve integer globals before structs since structs may refer to // the values of integer globals as numeric generics. - let (literal_globals, other_globals) = - filter_literal_globals(def_collector.collected_globals); + let (literal_globals, other_globals) = filter_literal_globals(def_collector.items.globals); resolved_module.resolve_globals(context, literal_globals, crate_id); resolved_module.errors.extend(resolve_type_aliases( context, - def_collector.collected_type_aliases, + def_collector.items.type_aliases, crate_id, )); resolved_module.errors.extend(resolve_traits( context, - def_collector.collected_traits, + def_collector.items.traits, crate_id, )); // Must resolve structs before we resolve globals. resolved_module.errors.extend(resolve_structs( context, - def_collector.collected_types, + def_collector.items.types, crate_id, )); @@ -358,7 +371,7 @@ impl DefCollector { resolved_module.errors.extend(collect_trait_impls( context, crate_id, - &mut def_collector.collected_traits_impls, + &mut def_collector.items.trait_impls, )); // Before we resolve any function symbols we must go through our impls and @@ -368,11 +381,7 @@ impl DefCollector { // // These are resolved after trait impls so that struct methods are chosen // over trait methods if there are name conflicts. - resolved_module.errors.extend(collect_impls( - context, - crate_id, - &def_collector.collected_impls, - )); + resolved_module.errors.extend(collect_impls(context, crate_id, &def_collector.items.impls)); // We must wait to resolve non-integer globals until after we resolve structs since struct // globals will need to reference the struct type they're initialized to to ensure they are valid. @@ -383,7 +392,7 @@ impl DefCollector { &mut context.def_interner, crate_id, &context.def_maps, - def_collector.collected_functions, + def_collector.items.functions, None, &mut resolved_module.errors, ); @@ -392,13 +401,13 @@ impl DefCollector { &mut context.def_interner, crate_id, &context.def_maps, - def_collector.collected_impls, + def_collector.items.impls, &mut resolved_module.errors, )); resolved_module.trait_impl_functions = resolve_trait_impls( context, - def_collector.collected_traits_impls, + def_collector.items.trait_impls, crate_id, &mut resolved_module.errors, ); @@ -431,15 +440,18 @@ fn inject_prelude( crate_root: LocalModuleId, collected_imports: &mut Vec, ) { - let segments: Vec<_> = "std::prelude" - .split("::") - .map(|segment| crate::ast::Ident::new(segment.into(), Span::default())) - .collect(); + if !crate_id.is_stdlib() { + let segments: Vec<_> = "std::prelude" + .split("::") + .map(|segment| crate::ast::Ident::new(segment.into(), Span::default())) + .collect(); - let path = - Path { segments: segments.clone(), kind: crate::ast::PathKind::Dep, span: Span::default() }; + let path = Path { + segments: segments.clone(), + kind: crate::ast::PathKind::Dep, + span: Span::default(), + }; - if !crate_id.is_stdlib() { if let Ok(PathResolution { module_def_id, error }) = path_resolver::resolve_path( &context.def_maps, ModuleId { krate: crate_id, local_id: crate_root }, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index b2ec7dbc813..e688f192d3d 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -70,7 +70,7 @@ pub fn collect_defs( // Then add the imports to defCollector to resolve once all modules in the hierarchy have been resolved for import in ast.imports { - collector.def_collector.collected_imports.push(ImportDirective { + collector.def_collector.imports.push(ImportDirective { module_id: collector.module_id, path: import.path, alias: import.alias, @@ -126,7 +126,7 @@ impl<'a> ModCollector<'a> { errors.push((err.into(), self.file_id)); } - self.def_collector.collected_globals.push(UnresolvedGlobal { + self.def_collector.items.globals.push(UnresolvedGlobal { file_id: self.file_id, module_id: self.module_id, global_id, @@ -154,7 +154,7 @@ impl<'a> ModCollector<'a> { } let key = (r#impl.object_type, self.module_id); - let methods = self.def_collector.collected_impls.entry(key).or_default(); + let methods = self.def_collector.items.impls.entry(key).or_default(); methods.push((r#impl.generics, r#impl.type_span, unresolved_functions)); } } @@ -191,7 +191,7 @@ impl<'a> ModCollector<'a> { trait_generics: trait_impl.trait_generics, }; - self.def_collector.collected_traits_impls.push(unresolved_trait_impl); + self.def_collector.items.trait_impls.push(unresolved_trait_impl); } } @@ -269,7 +269,7 @@ impl<'a> ModCollector<'a> { } } - self.def_collector.collected_functions.push(unresolved_functions); + self.def_collector.items.functions.push(unresolved_functions); errors } @@ -316,7 +316,7 @@ impl<'a> ModCollector<'a> { } // And store the TypeId -> StructType mapping somewhere it is reachable - self.def_collector.collected_types.insert(id, unresolved); + self.def_collector.items.types.insert(id, unresolved); } definition_errors } @@ -354,7 +354,7 @@ impl<'a> ModCollector<'a> { errors.push((err.into(), self.file_id)); } - self.def_collector.collected_type_aliases.insert(type_alias_id, unresolved); + self.def_collector.items.type_aliases.insert(type_alias_id, unresolved); } errors } @@ -506,7 +506,7 @@ impl<'a> ModCollector<'a> { method_ids, fns_with_default_impl: unresolved_functions, }; - self.def_collector.collected_traits.insert(trait_id, unresolved); + self.def_collector.items.traits.insert(trait_id, unresolved); } errors } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/mod.rs index 590c2e3d6b6..19e06387d43 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/mod.rs @@ -73,6 +73,7 @@ impl CrateDefMap { pub fn collect_defs( crate_id: CrateId, context: &mut Context, + use_elaborator: bool, macro_processors: &[&dyn MacroProcessor], ) -> Vec<(CompilationError, FileId)> { // Check if this Crate has already been compiled @@ -116,7 +117,14 @@ impl CrateDefMap { }; // Now we want to populate the CrateDefMap using the DefCollector - errors.extend(DefCollector::collect(def_map, context, ast, root_file_id, macro_processors)); + errors.extend(DefCollector::collect( + def_map, + context, + ast, + root_file_id, + use_elaborator, + macro_processors, + )); errors.extend( parsing_errors.iter().map(|e| (e.clone().into(), root_file_id)).collect::>(), diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/import.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/import.rs index 8850331f683..343113836ed 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/import.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/import.rs @@ -2,11 +2,14 @@ use noirc_errors::{CustomDiagnostic, Span}; use thiserror::Error; use crate::graph::CrateId; +use crate::hir::def_collector::dc_crate::CompilationError; use std::collections::BTreeMap; use crate::ast::{Ident, ItemVisibility, Path, PathKind}; use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleDefId, ModuleId, PerNs}; +use super::errors::ResolverError; + #[derive(Debug, Clone)] pub struct ImportDirective { pub module_id: LocalModuleId, @@ -53,6 +56,12 @@ pub struct ResolvedImport { pub error: Option, } +impl From for CompilationError { + fn from(error: PathResolutionError) -> Self { + Self::ResolverError(ResolverError::PathResolutionError(error)) + } +} + impl<'a> From<&'a PathResolutionError> for CustomDiagnostic { fn from(error: &'a PathResolutionError) -> Self { match &error { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 60baaecab59..7dc307fe716 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -56,17 +56,17 @@ use crate::hir_def::{ use super::errors::{PubPosition, ResolverError}; use super::import::PathResolution; -const SELF_TYPE_NAME: &str = "Self"; +pub const SELF_TYPE_NAME: &str = "Self"; type Scope = GenericScope; type ScopeTree = GenericScopeTree; type ScopeForest = GenericScopeForest; pub struct LambdaContext { - captures: Vec, + pub captures: Vec, /// the index in the scope tree /// (sometimes being filled by ScopeTree's find method) - scope_index: usize, + pub scope_index: usize, } /// The primary jobs of the Resolver are to validate that every variable found refers to exactly 1 @@ -1345,7 +1345,7 @@ impl<'a> Resolver<'a> { range @ ForRange::Array(_) => { let for_stmt = range.into_for(for_loop.identifier, for_loop.block, for_loop.span); - self.resolve_stmt(for_stmt, for_loop.span) + self.resolve_stmt(for_stmt.kind, for_loop.span) } } } @@ -1361,7 +1361,7 @@ impl<'a> Resolver<'a> { StatementKind::Comptime(statement) => { let hir_statement = self.resolve_stmt(statement.kind, statement.span); let statement_id = self.interner.push_stmt(hir_statement); - self.interner.push_statement_location(statement_id, statement.span, self.file); + self.interner.push_stmt_location(statement_id, statement.span, self.file); HirStatement::Comptime(statement_id) } } @@ -1370,7 +1370,7 @@ impl<'a> Resolver<'a> { pub fn intern_stmt(&mut self, stmt: Statement) -> StmtId { let hir_stmt = self.resolve_stmt(stmt.kind, stmt.span); let id = self.interner.push_stmt(hir_stmt); - self.interner.push_statement_location(id, stmt.span, self.file); + self.interner.push_stmt_location(id, stmt.span, self.file); id } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/expr.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/expr.rs index 9b40c959981..48598109829 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -250,14 +250,14 @@ impl<'interner> TypeChecker<'interner> { } // TODO: update object_type here? - let function_call = method_call.into_function_call( + let (_, function_call) = method_call.into_function_call( &method_ref, object_type, location, self.interner, ); - self.interner.replace_expr(expr_id, function_call); + self.interner.replace_expr(expr_id, HirExpression::Call(function_call)); // Type check the new call now that it has been changed from a method call // to a function call. This way we avoid duplicating code. diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/mod.rs index 0f8131d6ebb..2e448858d9e 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -25,7 +25,7 @@ use crate::{ Type, TypeBindings, }; -use self::errors::Source; +pub use self::errors::Source; pub struct TypeChecker<'interner> { interner: &'interner mut NodeInterner, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs index bf7d9b7b4ba..8df6785e0eb 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs @@ -200,13 +200,15 @@ pub enum HirMethodReference { impl HirMethodCallExpression { /// Converts a method call into a function call + /// + /// Returns ((func_var_id, func_var), call_expr) pub fn into_function_call( mut self, method: &HirMethodReference, object_type: Type, location: Location, interner: &mut NodeInterner, - ) -> HirExpression { + ) -> ((ExprId, HirIdent), HirCallExpression) { let mut arguments = vec![self.object]; arguments.append(&mut self.arguments); @@ -224,10 +226,11 @@ impl HirMethodCallExpression { (id, ImplKind::TraitMethod(*method_id, constraint, false)) } }; - let func = HirExpression::Ident(HirIdent { location, id, impl_kind }); - let func = interner.push_expr(func); + let func_var = HirIdent { location, id, impl_kind }; + let func = interner.push_expr(HirExpression::Ident(func_var.clone())); interner.push_expr_location(func, location.span, location.file); - HirExpression::Call(HirCallExpression { func, arguments, location }) + let expr = HirCallExpression { func, arguments, location }; + ((func, func_var), expr) } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/function.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/function.rs index c38dd41fd3d..ceec9ad8580 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/function.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/function.rs @@ -135,10 +135,7 @@ impl FuncMeta { /// So this method tells the type checker to ignore the return /// of the empty function, which is unit pub fn can_ignore_return_type(&self) -> bool { - match self.kind { - FunctionKind::LowLevel | FunctionKind::Builtin | FunctionKind::Oracle => true, - FunctionKind::Normal | FunctionKind::Recursive => false, - } + self.kind.can_ignore_return_type() } pub fn function_signature(&self) -> FunctionSignature { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs index f3b2a24c1f0..f31aeea0552 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs @@ -1423,14 +1423,14 @@ impl Type { /// Retrieves the type of the given field name /// Panics if the type is not a struct or tuple. - pub fn get_field_type(&self, field_name: &str) -> Type { + pub fn get_field_type(&self, field_name: &str) -> Option { match self { - Type::Struct(def, args) => def.borrow().get_field(field_name, args).unwrap().0, + Type::Struct(def, args) => def.borrow().get_field(field_name, args).map(|(typ, _)| typ), Type::Tuple(fields) => { let mut fields = fields.iter().enumerate(); - fields.find(|(i, _)| i.to_string() == *field_name).unwrap().1.clone() + fields.find(|(i, _)| i.to_string() == *field_name).map(|(_, typ)| typ).cloned() } - other => panic!("Tried to iterate over the fields of '{other}', which has none"), + _ => None, } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/lib.rs b/noir/noir-repo/compiler/noirc_frontend/src/lib.rs index 958a18ac2fb..b05c635f436 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/lib.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/lib.rs @@ -12,6 +12,7 @@ pub mod ast; pub mod debug; +pub mod elaborator; pub mod graph; pub mod lexer; pub mod monomorphization; diff --git a/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs b/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs index 88adc7a9414..faf89016f96 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs @@ -532,7 +532,7 @@ impl NodeInterner { self.id_to_type.insert(expr_id.into(), typ); } - /// Store the type for an interned expression + /// Store the type for a definition pub fn push_definition_type(&mut self, definition_id: DefinitionId, typ: Type) { self.definition_to_type.insert(definition_id, typ); } @@ -696,7 +696,7 @@ impl NodeInterner { let statement = self.push_stmt(HirStatement::Error); let span = name.span(); let id = self.push_global(name, local_id, statement, file, attributes, mutable); - self.push_statement_location(statement, span, file); + self.push_stmt_location(statement, span, file); id } @@ -942,7 +942,7 @@ impl NodeInterner { self.id_location(stmt_id) } - pub fn push_statement_location(&mut self, id: StmtId, span: Span, file: FileId) { + pub fn push_stmt_location(&mut self, id: StmtId, span: Span, file: FileId) { self.id_to_location.insert(id.into(), Location::new(span, file)); } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/tests.rs b/noir/noir-repo/compiler/noirc_frontend/src/tests.rs index 6f7470807be..fb80a7d8018 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/tests.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/tests.rs @@ -81,6 +81,7 @@ pub(crate) fn get_program(src: &str) -> (ParsedModule, Context, Vec<(Compilation &mut context, program.clone().into_sorted(), root_file_id, + false, &[], // No macro processors )); } diff --git a/noir/noir-repo/noir_stdlib/src/field/bn254.nr b/noir/noir-repo/noir_stdlib/src/field/bn254.nr index d70310be391..2e82d9e7c23 100644 --- a/noir/noir-repo/noir_stdlib/src/field/bn254.nr +++ b/noir/noir-repo/noir_stdlib/src/field/bn254.nr @@ -25,7 +25,7 @@ unconstrained fn decompose_unsafe(x: Field) -> (Field, Field) { fn assert_gt_limbs(a: (Field, Field), b: (Field, Field)) { let (alo, ahi) = a; let (blo, bhi) = b; - let borrow = lte_unsafe(alo, blo, 16); + let borrow = lte_unsafe_16(alo, blo); let rlo = alo - blo - 1 + (borrow as Field) * TWO_POW_128; let rhi = ahi - bhi - (borrow as Field); @@ -51,9 +51,9 @@ pub fn decompose(x: Field) -> (Field, Field) { (xlo, xhi) } -unconstrained fn lt_unsafe(x: Field, y: Field, num_bytes: u32) -> bool { - let x_bytes = x.__to_le_radix(256, num_bytes); - let y_bytes = y.__to_le_radix(256, num_bytes); +fn lt_unsafe_internal(x: Field, y: Field, num_bytes: u32) -> bool { + let x_bytes = x.to_le_radix(256, num_bytes); + let y_bytes = y.to_le_radix(256, num_bytes); let mut x_is_lt = false; let mut done = false; for i in 0..num_bytes { @@ -70,8 +70,20 @@ unconstrained fn lt_unsafe(x: Field, y: Field, num_bytes: u32) -> bool { x_is_lt } -unconstrained fn lte_unsafe(x: Field, y: Field, num_bytes: u32) -> bool { - lt_unsafe(x, y, num_bytes) | (x == y) +fn lte_unsafe_internal(x: Field, y: Field, num_bytes: u32) -> bool { + if x == y { + true + } else { + lt_unsafe_internal(x, y, num_bytes) + } +} + +unconstrained fn lt_unsafe_32(x: Field, y: Field) -> bool { + lt_unsafe_internal(x, y, 32) +} + +unconstrained fn lte_unsafe_16(x: Field, y: Field) -> bool { + lte_unsafe_internal(x, y, 16) } pub fn assert_gt(a: Field, b: Field) { @@ -90,7 +102,7 @@ pub fn assert_lt(a: Field, b: Field) { pub fn gt(a: Field, b: Field) -> bool { if a == b { false - } else if lt_unsafe(a, b, 32) { + } else if lt_unsafe_32(a, b) { assert_gt(b, a); false } else { @@ -105,7 +117,10 @@ pub fn lt(a: Field, b: Field) -> bool { mod tests { // TODO: Allow imports from "super" - use crate::field::bn254::{decompose_unsafe, decompose, lt_unsafe, assert_gt, gt, lt, TWO_POW_128, lte_unsafe, PLO, PHI}; + use crate::field::bn254::{ + decompose_unsafe, decompose, lt_unsafe_internal, assert_gt, gt, lt, TWO_POW_128, + lte_unsafe_internal, PLO, PHI + }; #[test] fn check_decompose_unsafe() { @@ -123,23 +138,23 @@ mod tests { #[test] fn check_lt_unsafe() { - assert(lt_unsafe(0, 1, 16)); - assert(lt_unsafe(0, 0x100, 16)); - assert(lt_unsafe(0x100, TWO_POW_128 - 1, 16)); - assert(!lt_unsafe(0, TWO_POW_128, 16)); + assert(lt_unsafe_internal(0, 1, 16)); + assert(lt_unsafe_internal(0, 0x100, 16)); + assert(lt_unsafe_internal(0x100, TWO_POW_128 - 1, 16)); + assert(!lt_unsafe_internal(0, TWO_POW_128, 16)); } #[test] fn check_lte_unsafe() { - assert(lte_unsafe(0, 1, 16)); - assert(lte_unsafe(0, 0x100, 16)); - assert(lte_unsafe(0x100, TWO_POW_128 - 1, 16)); - assert(!lte_unsafe(0, TWO_POW_128, 16)); - - assert(lte_unsafe(0, 0, 16)); - assert(lte_unsafe(0x100, 0x100, 16)); - assert(lte_unsafe(TWO_POW_128 - 1, TWO_POW_128 - 1, 16)); - assert(lte_unsafe(TWO_POW_128, TWO_POW_128, 16)); + assert(lte_unsafe_internal(0, 1, 16)); + assert(lte_unsafe_internal(0, 0x100, 16)); + assert(lte_unsafe_internal(0x100, TWO_POW_128 - 1, 16)); + assert(!lte_unsafe_internal(0, TWO_POW_128, 16)); + + assert(lte_unsafe_internal(0, 0, 16)); + assert(lte_unsafe_internal(0x100, 0x100, 16)); + assert(lte_unsafe_internal(TWO_POW_128 - 1, TWO_POW_128 - 1, 16)); + assert(lte_unsafe_internal(TWO_POW_128, TWO_POW_128, 16)); } #[test] diff --git a/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/Nargo.toml b/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/Nargo.toml new file mode 100644 index 00000000000..328d78c8f99 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "no_predicates_brillig" +type = "bin" +authors = [""] +compiler_version = ">=0.27.0" + +[dependencies] diff --git a/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/Prover.toml b/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/Prover.toml new file mode 100644 index 00000000000..93a825f609f --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/Prover.toml @@ -0,0 +1,2 @@ +x = "10" +y = "20" diff --git a/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/src/main.nr b/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/src/main.nr new file mode 100644 index 00000000000..1d088473aa7 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/src/main.nr @@ -0,0 +1,12 @@ +unconstrained fn main(x: u32, y: pub u32) { + basic_checks(x, y); +} + +#[no_predicates] +fn basic_checks(x: u32, y: u32) { + if x > y { + assert(x == 10); + } else { + assert(y == 20); + } +} diff --git a/noir/noir-repo/test_programs/execution_success/unit_value/Nargo.toml b/noir/noir-repo/test_programs/execution_success/unit_value/Nargo.toml index f7e3697a7c1..1f9c4524ec5 100644 --- a/noir/noir-repo/test_programs/execution_success/unit_value/Nargo.toml +++ b/noir/noir-repo/test_programs/execution_success/unit_value/Nargo.toml @@ -1,7 +1,7 @@ [package] -name = "short" +name = "unit_value" type = "bin" authors = [""] compiler_version = ">=0.23.0" -[dependencies] \ No newline at end of file +[dependencies] diff --git a/noir/noir-repo/test_programs/rebuild.sh b/noir/noir-repo/test_programs/rebuild.sh index 51e97278281..4733bad10c3 100755 --- a/noir/noir-repo/test_programs/rebuild.sh +++ b/noir/noir-repo/test_programs/rebuild.sh @@ -16,13 +16,14 @@ process_dir() { if [ -d ./target/ ]; then rm -r ./target/ fi - nargo compile --only-acir && nargo execute witness + nargo execute witness if [ -d "$current_dir/acir_artifacts/$dir_name/target" ]; then rm -r "$current_dir/acir_artifacts/$dir_name/target" fi mkdir $current_dir/acir_artifacts/$dir_name/target + mv ./target/$dir_name.json $current_dir/acir_artifacts/$dir_name/target/program.json mv ./target/*.gz $current_dir/acir_artifacts/$dir_name/target/ cd $current_dir @@ -70,4 +71,4 @@ if [ ! -z "$exit_status" ]; then echo "Rebuild failed!" exit $exit_status fi -echo "Rebuild Succeeded!" \ No newline at end of file +echo "Rebuild Succeeded!" diff --git a/noir/noir-repo/tooling/backend_interface/src/cli/contract.rs b/noir/noir-repo/tooling/backend_interface/src/cli/contract.rs index e83fc1909b6..935b96b3ac4 100644 --- a/noir/noir-repo/tooling/backend_interface/src/cli/contract.rs +++ b/noir/noir-repo/tooling/backend_interface/src/cli/contract.rs @@ -48,15 +48,15 @@ fn contract_command() -> Result<(), BackendError> { let temp_directory = tempdir().expect("could not create a temporary directory"); let temp_directory_path = temp_directory.path(); - let bytecode_path = temp_directory_path.join("acir.gz"); + let artifact_path = temp_directory_path.join("program.json"); let vk_path = temp_directory_path.join("vk"); let crs_path = backend.backend_directory(); - std::fs::File::create(&bytecode_path).expect("file should be created"); + std::fs::File::create(&artifact_path).expect("file should be created"); let write_vk_command = super::WriteVkCommand { - bytecode_path, + artifact_path, vk_path_output: vk_path.clone(), crs_path: crs_path.clone(), }; diff --git a/noir/noir-repo/tooling/backend_interface/src/cli/gates.rs b/noir/noir-repo/tooling/backend_interface/src/cli/gates.rs index 9e12596bfd7..ce6c6cebfd3 100644 --- a/noir/noir-repo/tooling/backend_interface/src/cli/gates.rs +++ b/noir/noir-repo/tooling/backend_interface/src/cli/gates.rs @@ -10,7 +10,7 @@ use super::string_from_stderr; /// for the given bytecode. pub(crate) struct GatesCommand { pub(crate) crs_path: PathBuf, - pub(crate) bytecode_path: PathBuf, + pub(crate) artifact_path: PathBuf, } #[derive(Deserialize)] @@ -31,7 +31,7 @@ impl GatesCommand { .arg("-c") .arg(self.crs_path) .arg("-b") - .arg(self.bytecode_path) + .arg(self.artifact_path) .output()?; if !output.status.success() { @@ -53,12 +53,12 @@ fn gate_command() -> Result<(), BackendError> { let temp_directory = tempdir().expect("could not create a temporary directory"); let temp_directory_path = temp_directory.path(); - let bytecode_path = temp_directory_path.join("acir.gz"); + let artifact_path = temp_directory_path.join("program.json"); let crs_path = backend.backend_directory(); - std::fs::File::create(&bytecode_path).expect("file should be created"); + std::fs::File::create(&artifact_path).expect("file should be created"); - let gate_command = GatesCommand { crs_path, bytecode_path }; + let gate_command = GatesCommand { crs_path, artifact_path }; let output = gate_command.run(backend.binary_path())?; // Mock backend always returns zero gates. diff --git a/noir/noir-repo/tooling/backend_interface/src/cli/prove.rs b/noir/noir-repo/tooling/backend_interface/src/cli/prove.rs index c63d8afab54..30a27048b48 100644 --- a/noir/noir-repo/tooling/backend_interface/src/cli/prove.rs +++ b/noir/noir-repo/tooling/backend_interface/src/cli/prove.rs @@ -13,7 +13,7 @@ use super::string_from_stderr; /// The proof will be written to the specified output file. pub(crate) struct ProveCommand { pub(crate) crs_path: PathBuf, - pub(crate) bytecode_path: PathBuf, + pub(crate) artifact_path: PathBuf, pub(crate) witness_path: PathBuf, } @@ -26,7 +26,7 @@ impl ProveCommand { .arg("-c") .arg(self.crs_path) .arg("-b") - .arg(self.bytecode_path) + .arg(self.artifact_path) .arg("-w") .arg(self.witness_path) .arg("-o") @@ -49,14 +49,14 @@ fn prove_command() -> Result<(), BackendError> { let temp_directory = tempdir().expect("could not create a temporary directory"); let temp_directory_path = temp_directory.path(); - let bytecode_path = temp_directory_path.join("acir.gz"); + let artifact_path = temp_directory_path.join("acir.gz"); let witness_path = temp_directory_path.join("witness.tr"); - std::fs::File::create(&bytecode_path).expect("file should be created"); + std::fs::File::create(&artifact_path).expect("file should be created"); std::fs::File::create(&witness_path).expect("file should be created"); let crs_path = backend.backend_directory(); - let prove_command = ProveCommand { crs_path, bytecode_path, witness_path }; + let prove_command = ProveCommand { crs_path, artifact_path, witness_path }; let proof = prove_command.run(backend.binary_path())?; assert_eq!(proof, "proof".as_bytes()); diff --git a/noir/noir-repo/tooling/backend_interface/src/cli/verify.rs b/noir/noir-repo/tooling/backend_interface/src/cli/verify.rs index 1a4ba50b7de..beea4bbec7d 100644 --- a/noir/noir-repo/tooling/backend_interface/src/cli/verify.rs +++ b/noir/noir-repo/tooling/backend_interface/src/cli/verify.rs @@ -41,25 +41,25 @@ fn verify_command() -> Result<(), BackendError> { let temp_directory = tempdir().expect("could not create a temporary directory"); let temp_directory_path = temp_directory.path(); - let bytecode_path = temp_directory_path.join("acir.gz"); + let artifact_path = temp_directory_path.join("acir.json"); let witness_path = temp_directory_path.join("witness.tr"); let proof_path = temp_directory_path.join("1_mul.proof"); let vk_path_output = temp_directory_path.join("vk"); let crs_path = backend.backend_directory(); - std::fs::File::create(&bytecode_path).expect("file should be created"); + std::fs::File::create(&artifact_path).expect("file should be created"); std::fs::File::create(&witness_path).expect("file should be created"); let write_vk_command = WriteVkCommand { - bytecode_path: bytecode_path.clone(), + artifact_path: artifact_path.clone(), crs_path: crs_path.clone(), vk_path_output: vk_path_output.clone(), }; write_vk_command.run(backend.binary_path())?; - let prove_command = ProveCommand { crs_path: crs_path.clone(), bytecode_path, witness_path }; + let prove_command = ProveCommand { crs_path: crs_path.clone(), artifact_path, witness_path }; let proof = prove_command.run(backend.binary_path())?; write_to_file(&proof, &proof_path); diff --git a/noir/noir-repo/tooling/backend_interface/src/cli/write_vk.rs b/noir/noir-repo/tooling/backend_interface/src/cli/write_vk.rs index da9fc04cbef..3d51b5a4a8c 100644 --- a/noir/noir-repo/tooling/backend_interface/src/cli/write_vk.rs +++ b/noir/noir-repo/tooling/backend_interface/src/cli/write_vk.rs @@ -7,7 +7,7 @@ use crate::BackendError; /// to write a verification key to a file pub(crate) struct WriteVkCommand { pub(crate) crs_path: PathBuf, - pub(crate) bytecode_path: PathBuf, + pub(crate) artifact_path: PathBuf, pub(crate) vk_path_output: PathBuf, } @@ -21,7 +21,7 @@ impl WriteVkCommand { .arg("-c") .arg(self.crs_path) .arg("-b") - .arg(self.bytecode_path) + .arg(self.artifact_path) .arg("-o") .arg(self.vk_path_output); @@ -42,14 +42,14 @@ fn write_vk_command() -> Result<(), BackendError> { let temp_directory = tempdir().expect("could not create a temporary directory"); let temp_directory_path = temp_directory.path(); - let bytecode_path = temp_directory_path.join("acir.gz"); + let artifact_path = temp_directory_path.join("program.json"); let vk_path_output = temp_directory.path().join("vk"); let crs_path = backend.backend_directory(); - std::fs::File::create(&bytecode_path).expect("file should be created"); + std::fs::File::create(&artifact_path).expect("file should be created"); - let write_vk_command = WriteVkCommand { bytecode_path, crs_path, vk_path_output }; + let write_vk_command = WriteVkCommand { artifact_path, crs_path, vk_path_output }; write_vk_command.run(backend.binary_path())?; drop(temp_directory); diff --git a/noir/noir-repo/tooling/backend_interface/src/proof_system.rs b/noir/noir-repo/tooling/backend_interface/src/proof_system.rs index ffd46acef0e..49fd57c968f 100644 --- a/noir/noir-repo/tooling/backend_interface/src/proof_system.rs +++ b/noir/noir-repo/tooling/backend_interface/src/proof_system.rs @@ -1,11 +1,8 @@ -use std::fs::File; use std::io::Write; use std::path::Path; +use std::{fs::File, path::PathBuf}; -use acvm::acir::{ - circuit::Program, - native_types::{WitnessMap, WitnessStack}, -}; +use acvm::acir::native_types::{WitnessMap, WitnessStack}; use acvm::FieldElement; use tempfile::tempdir; use tracing::warn; @@ -19,28 +16,20 @@ use crate::{Backend, BackendError}; impl Backend { pub fn get_exact_circuit_sizes( &self, - program: &Program, + artifact_path: PathBuf, ) -> Result, BackendError> { let binary_path = self.assert_binary_exists()?; self.assert_correct_version()?; - let temp_directory = tempdir().expect("could not create a temporary directory"); - let temp_directory = temp_directory.path().to_path_buf(); - - // Create a temporary file for the circuit - let circuit_path = temp_directory.join("circuit").with_extension("bytecode"); - let serialized_program = Program::serialize_program(program); - write_to_file(&serialized_program, &circuit_path); - - GatesCommand { crs_path: self.crs_directory(), bytecode_path: circuit_path } - .run(binary_path) + GatesCommand { crs_path: self.crs_directory(), artifact_path }.run(binary_path) } #[tracing::instrument(level = "trace", skip_all)] pub fn prove( &self, - program: &Program, + artifact_path: PathBuf, witness_stack: WitnessStack, + num_public_inputs: u32, ) -> Result, BackendError> { let binary_path = self.assert_binary_exists()?; self.assert_correct_version()?; @@ -54,20 +43,14 @@ impl Backend { let witness_path = temp_directory.join("witness").with_extension("tr"); write_to_file(&serialized_witnesses, &witness_path); - // Create a temporary file for the circuit - // - let bytecode_path = temp_directory.join("program").with_extension("bytecode"); - let serialized_program = Program::serialize_program(program); - write_to_file(&serialized_program, &bytecode_path); - // Create proof and store it in the specified path let proof_with_public_inputs = - ProveCommand { crs_path: self.crs_directory(), bytecode_path, witness_path } + ProveCommand { crs_path: self.crs_directory(), artifact_path, witness_path } .run(binary_path)?; let proof = bb_abstraction_leaks::remove_public_inputs( // TODO(https://github.com/noir-lang/noir/issues/4428) - program.functions[0].public_inputs().0.len(), + num_public_inputs as usize, &proof_with_public_inputs, ); Ok(proof) @@ -78,7 +61,7 @@ impl Backend { &self, proof: &[u8], public_inputs: WitnessMap, - program: &Program, + artifact_path: PathBuf, ) -> Result { let binary_path = self.assert_binary_exists()?; self.assert_correct_version()?; @@ -92,17 +75,12 @@ impl Backend { let proof_path = temp_directory.join("proof").with_extension("proof"); write_to_file(&proof_with_public_inputs, &proof_path); - // Create a temporary file for the circuit - let bytecode_path = temp_directory.join("program").with_extension("bytecode"); - let serialized_program = Program::serialize_program(program); - write_to_file(&serialized_program, &bytecode_path); - // Create the verification key and write it to the specified path let vk_path = temp_directory.join("vk"); WriteVkCommand { crs_path: self.crs_directory(), - bytecode_path, + artifact_path, vk_path_output: vk_path.clone(), } .run(binary_path)?; @@ -113,7 +91,7 @@ impl Backend { pub fn get_intermediate_proof_artifacts( &self, - program: &Program, + artifact_path: PathBuf, proof: &[u8], public_inputs: WitnessMap, ) -> Result<(Vec, FieldElement, Vec), BackendError> { @@ -123,18 +101,12 @@ impl Backend { let temp_directory = tempdir().expect("could not create a temporary directory"); let temp_directory = temp_directory.path().to_path_buf(); - // Create a temporary file for the circuit - // - let bytecode_path = temp_directory.join("program").with_extension("bytecode"); - let serialized_program = Program::serialize_program(program); - write_to_file(&serialized_program, &bytecode_path); - // Create the verification key and write it to the specified path let vk_path = temp_directory.join("vk"); WriteVkCommand { crs_path: self.crs_directory(), - bytecode_path, + artifact_path, vk_path_output: vk_path.clone(), } .run(binary_path)?; diff --git a/noir/noir-repo/tooling/backend_interface/src/smart_contract.rs b/noir/noir-repo/tooling/backend_interface/src/smart_contract.rs index 153ab52c83f..8b26ea07a2f 100644 --- a/noir/noir-repo/tooling/backend_interface/src/smart_contract.rs +++ b/noir/noir-repo/tooling/backend_interface/src/smart_contract.rs @@ -1,30 +1,25 @@ -use super::proof_system::write_to_file; +use std::path::PathBuf; + use crate::{ cli::{ContractCommand, WriteVkCommand}, Backend, BackendError, }; -use acvm::acir::circuit::Program; use tempfile::tempdir; impl Backend { - pub fn eth_contract(&self, program: &Program) -> Result { + pub fn eth_contract(&self, artifact_path: PathBuf) -> Result { let binary_path = self.assert_binary_exists()?; self.assert_correct_version()?; let temp_directory = tempdir().expect("could not create a temporary directory"); let temp_directory_path = temp_directory.path().to_path_buf(); - // Create a temporary file for the circuit - let bytecode_path = temp_directory_path.join("program").with_extension("bytecode"); - let serialized_program = Program::serialize_program(program); - write_to_file(&serialized_program, &bytecode_path); - // Create the verification key and write it to the specified path let vk_path = temp_directory_path.join("vk"); WriteVkCommand { crs_path: self.crs_directory(), - bytecode_path, + artifact_path, vk_path_output: vk_path.clone(), } .run(binary_path)?; @@ -35,33 +30,23 @@ impl Backend { #[cfg(test)] mod tests { - use std::collections::BTreeSet; - use acvm::acir::{ - circuit::{Circuit, ExpressionWidth, Opcode, Program, PublicInputs}, - native_types::{Expression, Witness}, - }; + use serde_json::json; + use tempfile::tempdir; - use crate::{get_mock_backend, BackendError}; + use crate::{get_mock_backend, proof_system::write_to_file, BackendError}; #[test] fn test_smart_contract() -> Result<(), BackendError> { - let expression = &(Witness(1) + Witness(2)) - &Expression::from(Witness(3)); - let constraint = Opcode::AssertZero(expression); + let dummy_artifact = json!({"bytecode": ""}); + let artifact_bytes = serde_json::to_vec(&dummy_artifact).unwrap(); - let circuit = Circuit { - current_witness_index: 4, - expression_width: ExpressionWidth::Bounded { width: 4 }, - opcodes: vec![constraint], - private_parameters: BTreeSet::from([Witness(1), Witness(2)]), - public_parameters: PublicInputs::default(), - return_values: PublicInputs::default(), - assert_messages: Default::default(), - recursive: false, - }; - let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() }; + let temp_directory = tempdir().expect("could not create a temporary directory"); + let temp_directory_path = temp_directory.path(); + let artifact_path = temp_directory_path.join("program.json"); + write_to_file(&artifact_bytes, &artifact_path); - let contract = get_mock_backend()?.eth_contract(&program)?; + let contract = get_mock_backend()?.eth_contract(artifact_path)?; assert!(contract.contains("contract VerifierContract")); diff --git a/noir/noir-repo/tooling/lsp/src/lib.rs b/noir/noir-repo/tooling/lsp/src/lib.rs index be9b83e02f6..05345b96c80 100644 --- a/noir/noir-repo/tooling/lsp/src/lib.rs +++ b/noir/noir-repo/tooling/lsp/src/lib.rs @@ -345,7 +345,7 @@ fn prepare_package_from_source_string() { let mut state = LspState::new(&client, acvm::blackbox_solver::StubbedBlackBoxSolver); let (mut context, crate_id) = crate::prepare_source(source.to_string(), &mut state); - let _check_result = noirc_driver::check_crate(&mut context, crate_id, false, false); + let _check_result = noirc_driver::check_crate(&mut context, crate_id, false, false, false); let main_func_id = context.get_main_function(&crate_id); assert!(main_func_id.is_some()); } diff --git a/noir/noir-repo/tooling/lsp/src/notifications/mod.rs b/noir/noir-repo/tooling/lsp/src/notifications/mod.rs index 355bb7832c4..3856bdc79e9 100644 --- a/noir/noir-repo/tooling/lsp/src/notifications/mod.rs +++ b/noir/noir-repo/tooling/lsp/src/notifications/mod.rs @@ -56,7 +56,7 @@ pub(super) fn on_did_change_text_document( state.input_files.insert(params.text_document.uri.to_string(), text.clone()); let (mut context, crate_id) = prepare_source(text, state); - let _ = check_crate(&mut context, crate_id, false, false); + let _ = check_crate(&mut context, crate_id, false, false, false); let workspace = match resolve_workspace_for_source_path( params.text_document.uri.to_file_path().unwrap().as_path(), @@ -139,7 +139,7 @@ fn process_noir_document( let (mut context, crate_id) = prepare_package(&workspace_file_manager, &parsed_files, package); - let file_diagnostics = match check_crate(&mut context, crate_id, false, false) { + let file_diagnostics = match check_crate(&mut context, crate_id, false, false, false) { Ok(((), warnings)) => warnings, Err(errors_and_warnings) => errors_and_warnings, }; diff --git a/noir/noir-repo/tooling/lsp/src/requests/code_lens_request.rs b/noir/noir-repo/tooling/lsp/src/requests/code_lens_request.rs index 893ba33d845..744bddedd9d 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/code_lens_request.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/code_lens_request.rs @@ -67,7 +67,7 @@ fn on_code_lens_request_inner( let (mut context, crate_id) = prepare_source(source_string, state); // We ignore the warnings and errors produced by compilation for producing code lenses // because we can still get the test functions even if compilation fails - let _ = check_crate(&mut context, crate_id, false, false); + let _ = check_crate(&mut context, crate_id, false, false, false); let collected_lenses = collect_lenses_for_package(&context, crate_id, &workspace, package, None); diff --git a/noir/noir-repo/tooling/lsp/src/requests/goto_declaration.rs b/noir/noir-repo/tooling/lsp/src/requests/goto_declaration.rs index 8e6d519b895..5cff16b2348 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/goto_declaration.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/goto_declaration.rs @@ -46,7 +46,7 @@ fn on_goto_definition_inner( interner = def_interner; } else { // We ignore the warnings and errors produced by compilation while resolving the definition - let _ = noirc_driver::check_crate(&mut context, crate_id, false, false); + let _ = noirc_driver::check_crate(&mut context, crate_id, false, false, false); interner = &context.def_interner; } diff --git a/noir/noir-repo/tooling/lsp/src/requests/goto_definition.rs b/noir/noir-repo/tooling/lsp/src/requests/goto_definition.rs index 88bb667f2e8..32e13ce00f6 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/goto_definition.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/goto_definition.rs @@ -54,7 +54,7 @@ fn on_goto_definition_inner( interner = def_interner; } else { // We ignore the warnings and errors produced by compilation while resolving the definition - let _ = noirc_driver::check_crate(&mut context, crate_id, false, false); + let _ = noirc_driver::check_crate(&mut context, crate_id, false, false, false); interner = &context.def_interner; } diff --git a/noir/noir-repo/tooling/lsp/src/requests/test_run.rs b/noir/noir-repo/tooling/lsp/src/requests/test_run.rs index 1844a3d9bf0..83b05ba06a2 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/test_run.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/test_run.rs @@ -60,7 +60,7 @@ fn on_test_run_request_inner( Some(package) => { let (mut context, crate_id) = prepare_package(&workspace_file_manager, &parsed_files, package); - if check_crate(&mut context, crate_id, false, false).is_err() { + if check_crate(&mut context, crate_id, false, false, false).is_err() { let result = NargoTestRunResult { id: params.id.clone(), result: "error".to_string(), diff --git a/noir/noir-repo/tooling/lsp/src/requests/tests.rs b/noir/noir-repo/tooling/lsp/src/requests/tests.rs index 5b78fcc65c3..cdf4ad338c4 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/tests.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/tests.rs @@ -61,7 +61,7 @@ fn on_tests_request_inner( prepare_package(&workspace_file_manager, &parsed_files, package); // We ignore the warnings and errors produced by compilation for producing tests // because we can still get the test functions even if compilation fails - let _ = check_crate(&mut context, crate_id, false, false); + let _ = check_crate(&mut context, crate_id, false, false, false); // We don't add test headings for a package if it contains no `#[test]` functions get_package_tests_in_crate(&context, &crate_id, &package.name) diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/check_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/check_cmd.rs index 208379b098d..d5313d96076 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/check_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/check_cmd.rs @@ -87,6 +87,7 @@ fn check_package( compile_options.deny_warnings, compile_options.disable_macros, compile_options.silence_warnings, + compile_options.use_elaborator, )?; if package.is_library() || package.is_contract() { @@ -173,8 +174,9 @@ pub(crate) fn check_crate_and_report_errors( deny_warnings: bool, disable_macros: bool, silence_warnings: bool, + use_elaborator: bool, ) -> Result<(), CompileError> { - let result = check_crate(context, crate_id, deny_warnings, disable_macros); + let result = check_crate(context, crate_id, deny_warnings, disable_macros, use_elaborator); report_errors(result, &context.file_manager, deny_warnings, silence_warnings) } diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/codegen_verifier_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/codegen_verifier_cmd.rs index 04ed5c2b6b8..6247560f621 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/codegen_verifier_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/codegen_verifier_cmd.rs @@ -7,7 +7,7 @@ use crate::errors::CliError; use clap::Args; use nargo_toml::{get_package_manifest, resolve_workspace_from_toml, PackageSelection}; -use noirc_driver::{CompileOptions, CompiledProgram, NOIR_ARTIFACT_VERSION_STRING}; +use noirc_driver::{CompileOptions, NOIR_ARTIFACT_VERSION_STRING}; use noirc_frontend::graph::CrateName; /// Generates a Solidity verifier smart contract for the program @@ -46,15 +46,15 @@ pub(crate) fn run( let binary_packages = workspace.into_iter().filter(|package| package.is_binary()); for package in binary_packages { let program_artifact_path = workspace.package_build_path(package); - let program: CompiledProgram = read_program_from_file(program_artifact_path)?.into(); + let program = read_program_from_file(&program_artifact_path)?; // TODO(https://github.com/noir-lang/noir/issues/4428): // We do not expect to have a smart contract verifier for a foldable program with multiple circuits. // However, in the future we can expect to possibly have non-inlined ACIR functions during compilation // that will be inlined at a later step such as by the ACVM compiler or by the backend. // Add appropriate handling here once the compiler enables multiple ACIR functions. - assert_eq!(program.program.functions.len(), 1); - let smart_contract_string = backend.eth_contract(&program.program)?; + assert_eq!(program.bytecode.functions.len(), 1); + let smart_contract_string = backend.eth_contract(program_artifact_path)?; let contract_dir = workspace.contracts_directory_path(package); create_named_dir(&contract_dir, "contract"); diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/compile_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/compile_cmd.rs index 8f28e5d9388..ecf2e2e9f53 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/compile_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/compile_cmd.rs @@ -22,7 +22,6 @@ use notify_debouncer_full::new_debouncer; use crate::errors::CliError; -use super::fs::program::only_acir; use super::fs::program::{read_program_from_file, save_contract_to_file, save_program_to_file}; use super::NargoConfig; use rayon::prelude::*; @@ -136,10 +135,9 @@ pub(super) fn compile_workspace_full( .partition(|package| package.is_binary()); // Save build artifacts to disk. - let only_acir = compile_options.only_acir; for (package, program) in binary_packages.into_iter().zip(compiled_programs) { let program = nargo::ops::transform_program(program, compile_options.expression_width); - save_program(program.clone(), &package, &workspace.target_directory_path(), only_acir); + save_program(program.clone(), &package, &workspace.target_directory_path()); } let circuit_dir = workspace.target_directory_path(); for (package, contract) in contract_packages.into_iter().zip(compiled_contracts) { @@ -197,18 +195,9 @@ pub(super) fn compile_workspace( } } -pub(super) fn save_program( - program: CompiledProgram, - package: &Package, - circuit_dir: &Path, - only_acir_opt: bool, -) { - if only_acir_opt { - only_acir(program.program, circuit_dir); - } else { - let program_artifact = ProgramArtifact::from(program.clone()); - save_program_to_file(&program_artifact, &package.name, circuit_dir); - } +pub(super) fn save_program(program: CompiledProgram, package: &Package, circuit_dir: &Path) { + let program_artifact = ProgramArtifact::from(program.clone()); + save_program_to_file(&program_artifact, &package.name, circuit_dir); } fn save_contract(contract: CompiledContract, package: &Package, circuit_dir: &Path) { diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/export_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/export_cmd.rs index a61f3ccfc02..324eed340ad 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/export_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/export_cmd.rs @@ -89,6 +89,7 @@ fn compile_exported_functions( compile_options.deny_warnings, compile_options.disable_macros, compile_options.silence_warnings, + compile_options.use_elaborator, )?; let exported_functions = context.get_all_exported_functions_in_crate(&crate_id); diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/fs/program.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/fs/program.rs index 72d686b0b36..ba017651667 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/fs/program.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/fs/program.rs @@ -1,6 +1,5 @@ use std::path::{Path, PathBuf}; -use acvm::acir::circuit::Program; use nargo::artifacts::{contract::ContractArtifact, program::ProgramArtifact}; use noirc_frontend::graph::CrateName; @@ -17,16 +16,6 @@ pub(crate) fn save_program_to_file>( save_build_artifact_to_file(program_artifact, &circuit_name, circuit_dir) } -/// Writes the bytecode as acir.gz -pub(crate) fn only_acir>(program: Program, circuit_dir: P) -> PathBuf { - create_named_dir(circuit_dir.as_ref(), "target"); - let circuit_path = circuit_dir.as_ref().join("acir").with_extension("gz"); - let bytes = Program::serialize_program(&program); - write_to_file(&bytes, &circuit_path); - - circuit_path -} - pub(crate) fn save_contract_to_file>( compiled_contract: &ContractArtifact, circuit_name: &str, @@ -60,16 +49,3 @@ pub(crate) fn read_program_from_file>( Ok(program) } - -pub(crate) fn read_contract_from_file>( - circuit_path: P, -) -> Result { - let file_path = circuit_path.as_ref().with_extension("json"); - - let input_string = - std::fs::read(&file_path).map_err(|_| FilesystemError::PathNotValid(file_path))?; - let contract = serde_json::from_slice(&input_string) - .map_err(|err| FilesystemError::ProgramSerializationError(err.to_string()))?; - - Ok(contract) -} diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/info_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/info_cmd.rs index f8f645d3c3a..d68aef497f6 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/info_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/info_cmd.rs @@ -1,11 +1,11 @@ -use std::collections::HashMap; +use std::{collections::HashMap, path::PathBuf}; use acvm::acir::circuit::ExpressionWidth; use backend_interface::BackendError; use clap::Args; use iter_extended::vecmap; use nargo::{ - artifacts::{contract::ContractArtifact, debug::DebugArtifact, program::ProgramArtifact}, + artifacts::{debug::DebugArtifact, program::ProgramArtifact}, package::Package, }; use nargo_toml::{get_package_manifest, resolve_workspace_from_toml, PackageSelection}; @@ -20,9 +20,7 @@ use crate::backends::Backend; use crate::errors::CliError; use super::{ - compile_cmd::compile_workspace_full, - fs::program::{read_contract_from_file, read_program_from_file}, - NargoConfig, + compile_cmd::compile_workspace_full, fs::program::read_program_from_file, NargoConfig, }; /// Provides detailed information on each of a program's function (represented by a single circuit) @@ -80,15 +78,6 @@ pub(crate) fn run( }) .collect::>()?; - let compiled_contracts: Vec = workspace - .into_iter() - .filter(|package| package.is_contract()) - .map(|package| { - let contract_artifact_path = workspace.package_build_path(package); - read_contract_from_file(contract_artifact_path) - }) - .collect::>()?; - if args.profile_info { for (_, compiled_program) in &binary_packages { let debug_artifact = DebugArtifact::from(compiled_program.clone()); @@ -97,17 +86,6 @@ pub(crate) fn run( print_span_opcodes(span_opcodes, &debug_artifact); } } - - for compiled_contract in &compiled_contracts { - let debug_artifact = DebugArtifact::from(compiled_contract.clone()); - let functions = &compiled_contract.functions; - for contract_function in functions { - for function_debug in contract_function.debug_symbols.debug_infos.iter() { - let span_opcodes = function_debug.count_span_opcodes(); - print_span_opcodes(span_opcodes, &debug_artifact); - } - } - } } let program_info = binary_packages @@ -116,6 +94,7 @@ pub(crate) fn run( .map(|(package, program)| { count_opcodes_and_gates_in_program( backend, + workspace.package_build_path(&package), program, &package, args.compile_options.expression_width, @@ -123,18 +102,7 @@ pub(crate) fn run( }) .collect::>()?; - let contract_info = compiled_contracts - .into_par_iter() - .map(|contract| { - count_opcodes_and_gates_in_contract( - backend, - contract, - args.compile_options.expression_width, - ) - }) - .collect::>()?; - - let info_report = InfoReport { programs: program_info, contracts: contract_info }; + let info_report = InfoReport { programs: program_info, contracts: Vec::new() }; if args.json { // Expose machine-readable JSON data. @@ -152,23 +120,6 @@ pub(crate) fn run( } program_table.printstd(); } - if !info_report.contracts.is_empty() { - let mut contract_table = table!([ - Fm->"Contract", - Fm->"Function", - Fm->"Expression Width", - Fm->"ACIR Opcodes", - Fm->"Backend Circuit Size" - ]); - for contract_info in info_report.contracts { - let contract_rows: Vec = contract_info.into(); - for row in contract_rows { - contract_table.add_row(row); - } - } - - contract_table.printstd(); - } } Ok(()) @@ -283,15 +234,12 @@ impl From for Vec { fn count_opcodes_and_gates_in_program( backend: &Backend, - mut compiled_program: ProgramArtifact, + program_artifact_path: PathBuf, + compiled_program: ProgramArtifact, package: &Package, expression_width: ExpressionWidth, ) -> Result { - // Unconstrained functions do not matter to a backend circuit count so we clear them - // before sending a serialized program to the backend - compiled_program.bytecode.unconstrained_functions.clear(); - - let program_circuit_sizes = backend.get_exact_circuit_sizes(&compiled_program.bytecode)?; + let program_circuit_sizes = backend.get_exact_circuit_sizes(program_artifact_path)?; let functions = compiled_program .bytecode .functions @@ -309,24 +257,3 @@ fn count_opcodes_and_gates_in_program( Ok(ProgramInfo { package_name: package.name.to_string(), expression_width, functions }) } - -fn count_opcodes_and_gates_in_contract( - backend: &Backend, - contract: ContractArtifact, - expression_width: ExpressionWidth, -) -> Result { - let functions = contract - .functions - .into_par_iter() - .map(|function| -> Result<_, BackendError> { - Ok(FunctionInfo { - name: function.name, - // TODO(https://github.com/noir-lang/noir/issues/4720) - acir_opcodes: function.bytecode.functions[0].opcodes.len(), - circuit_size: backend.get_exact_circuit_sizes(&function.bytecode)?[0].circuit_size, - }) - }) - .collect::>()?; - - Ok(ContractInfo { name: contract.name, expression_width, functions }) -} diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/prove_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/prove_cmd.rs index 6fb6e7269f7..127c5ac2ebb 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/prove_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/prove_cmd.rs @@ -1,3 +1,5 @@ +use std::path::PathBuf; + use clap::Args; use nargo::constants::{PROVER_INPUT_FILE, VERIFIER_INPUT_FILE}; use nargo::package::Package; @@ -68,12 +70,13 @@ pub(crate) fn run( let binary_packages = workspace.into_iter().filter(|package| package.is_binary()); for package in binary_packages { let program_artifact_path = workspace.package_build_path(package); - let program: CompiledProgram = read_program_from_file(program_artifact_path)?.into(); + let program: CompiledProgram = read_program_from_file(&program_artifact_path)?.into(); let proof = prove_package( backend, package, program, + program_artifact_path, &args.prover_name, &args.verifier_name, args.verify, @@ -86,10 +89,12 @@ pub(crate) fn run( Ok(()) } +#[allow(clippy::too_many_arguments)] fn prove_package( backend: &Backend, package: &Package, compiled_program: CompiledProgram, + program_artifact_path: PathBuf, prover_name: &str, verifier_name: &str, check_proof: bool, @@ -117,11 +122,15 @@ fn prove_package( Format::Toml, )?; - let proof = backend.prove(&compiled_program.program, witness_stack)?; + let proof = backend.prove( + program_artifact_path.clone(), + witness_stack, + compiled_program.program.functions[0].public_inputs().0.len() as u32, + )?; if check_proof { let public_inputs = public_abi.encode(&public_inputs, return_value)?; - let valid_proof = backend.verify(&proof, public_inputs, &compiled_program.program)?; + let valid_proof = backend.verify(&proof, public_inputs, program_artifact_path)?; if !valid_proof { return Err(CliError::InvalidProof("".into())); diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/test_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/test_cmd.rs index 967d4c87e6d..51e21248afd 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/test_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/test_cmd.rs @@ -175,6 +175,7 @@ fn run_test( crate_id, compile_options.deny_warnings, compile_options.disable_macros, + compile_options.use_elaborator, ) .expect("Any errors should have occurred when collecting test functions"); @@ -208,6 +209,7 @@ fn get_tests_in_package( compile_options.deny_warnings, compile_options.disable_macros, compile_options.silence_warnings, + compile_options.use_elaborator, )?; Ok(context diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/verify_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/verify_cmd.rs index a7f2772330a..ad1978cabe0 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/verify_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/verify_cmd.rs @@ -81,7 +81,8 @@ fn verify_package( let proof = load_hex_data(&proof_path)?; - let valid_proof = backend.verify(&proof, public_inputs, &compiled_program.program)?; + let valid_proof = + backend.verify(&proof, public_inputs, workspace.package_build_path(package))?; if valid_proof { Ok(()) diff --git a/noir/noir-repo/tooling/nargo_cli/tests/stdlib-tests.rs b/noir/noir-repo/tooling/nargo_cli/tests/stdlib-tests.rs index 9d377cfaee9..70a9354f50a 100644 --- a/noir/noir-repo/tooling/nargo_cli/tests/stdlib-tests.rs +++ b/noir/noir-repo/tooling/nargo_cli/tests/stdlib-tests.rs @@ -10,8 +10,7 @@ use nargo::{ parse_all, prepare_package, }; -#[test] -fn stdlib_noir_tests() { +fn run_stdlib_tests(use_elaborator: bool) { let mut file_manager = file_manager_with_stdlib(&PathBuf::from(".")); file_manager.add_file_with_source_canonical_path(&PathBuf::from("main.nr"), "".to_owned()); let parsed_files = parse_all(&file_manager); @@ -30,7 +29,7 @@ fn stdlib_noir_tests() { let (mut context, dummy_crate_id) = prepare_package(&file_manager, &parsed_files, &dummy_package); - let result = check_crate(&mut context, dummy_crate_id, true, false); + let result = check_crate(&mut context, dummy_crate_id, true, false, use_elaborator); report_errors(result, &context.file_manager, true, false) .expect("Error encountered while compiling standard library"); @@ -60,3 +59,15 @@ fn stdlib_noir_tests() { assert!(!test_report.is_empty(), "Could not find any tests within the stdlib"); assert!(test_report.iter().all(|(_, status)| !status.failed())); } + +#[test] +fn stdlib_noir_tests() { + run_stdlib_tests(false) +} + +// Once this no longer panics we can use the elaborator by default and remove the old passes +#[test] +#[should_panic] +fn stdlib_elaborator_tests() { + run_stdlib_tests(true) +} diff --git a/noir/noir-repo/tooling/noir_js/test/node/execute.test.ts b/noir/noir-repo/tooling/noir_js/test/node/execute.test.ts index dcf9f489003..d047e35035f 100644 --- a/noir/noir-repo/tooling/noir_js/test/node/execute.test.ts +++ b/noir/noir-repo/tooling/noir_js/test/node/execute.test.ts @@ -81,3 +81,51 @@ it('circuit with a raw assert payload should fail with the decoded payload', asy }); } }); + +it('successfully executes a program with multiple acir circuits', async () => { + const inputs = { + x: '10', + }; + try { + await new Noir(fold_fibonacci_program).execute(inputs); + } catch (error) { + const knownError = error as Error; + expect(knownError.message).to.equal('Circuit execution failed: Error: Cannot satisfy constraint'); + } +}); + +it('successfully executes a program with multiple acir circuits', async () => { + const inputs = { + x: '10', + }; + try { + await new Noir(fold_fibonacci_program).execute(inputs); + } catch (error) { + const knownError = error as Error; + expect(knownError.message).to.equal('Circuit execution failed: Error: Cannot satisfy constraint'); + } +}); + +it('successfully executes a program with multiple acir circuits', async () => { + const inputs = { + x: '10', + }; + try { + await new Noir(fold_fibonacci_program).execute(inputs); + } catch (error) { + const knownError = error as Error; + expect(knownError.message).to.equal('Circuit execution failed: Error: Cannot satisfy constraint'); + } +}); + +it('successfully executes a program with multiple acir circuits', async () => { + const inputs = { + x: '10', + }; + try { + await new Noir(fold_fibonacci_program).execute(inputs); + } catch (error) { + const knownError = error as Error; + expect(knownError.message).to.equal('Circuit execution failed: Error: Cannot satisfy constraint'); + } +}); diff --git a/yarn-project/accounts/src/defaults/account_contract.ts b/yarn-project/accounts/src/defaults/account_contract.ts index dc3b2330059..f2842c9ac0f 100644 --- a/yarn-project/accounts/src/defaults/account_contract.ts +++ b/yarn-project/accounts/src/defaults/account_contract.ts @@ -1,6 +1,5 @@ import { type AccountContract, type AccountInterface, type AuthWitnessProvider } from '@aztec/aztec.js/account'; import { type CompleteAddress } from '@aztec/circuit-types'; -import { type Fr } from '@aztec/circuits.js'; import { type ContractArtifact } from '@aztec/foundation/abi'; import { type NodeInfo } from '@aztec/types/interfaces'; @@ -20,7 +19,7 @@ export abstract class DefaultAccountContract implements AccountContract { return this.artifact; } - getInterface(address: CompleteAddress, publicKeysHash: Fr, nodeInfo: NodeInfo): AccountInterface { - return new DefaultAccountInterface(this.getAuthWitnessProvider(address), address, publicKeysHash, nodeInfo); + getInterface(address: CompleteAddress, nodeInfo: NodeInfo): AccountInterface { + return new DefaultAccountInterface(this.getAuthWitnessProvider(address), address, nodeInfo); } } diff --git a/yarn-project/accounts/src/defaults/account_interface.ts b/yarn-project/accounts/src/defaults/account_interface.ts index f32e96aa208..5d7fa311c6e 100644 --- a/yarn-project/accounts/src/defaults/account_interface.ts +++ b/yarn-project/accounts/src/defaults/account_interface.ts @@ -17,7 +17,6 @@ export class DefaultAccountInterface implements AccountInterface { constructor( private authWitnessProvider: AuthWitnessProvider, private address: CompleteAddress, - private publicKeysHash: Fr, nodeInfo: Pick, ) { this.entrypoint = new DefaultAccountEntrypoint( @@ -38,10 +37,6 @@ export class DefaultAccountInterface implements AccountInterface { return this.authWitnessProvider.createAuthWit(messageHash); } - getPublicKeysHash(): Fr { - return this.publicKeysHash; - } - getCompleteAddress(): CompleteAddress { return this.address; } diff --git a/yarn-project/accounts/src/testing/configuration.ts b/yarn-project/accounts/src/testing/configuration.ts index 7fc376ddd70..cc37380d93f 100644 --- a/yarn-project/accounts/src/testing/configuration.ts +++ b/yarn-project/accounts/src/testing/configuration.ts @@ -45,7 +45,9 @@ export async function getDeployedTestAccountsWallets(pxe: PXE): Promise { const initialEncryptionKey = sha512ToGrumpkinScalar([initialSecretKey, GeneratorIndex.IVSK_M]); const publicKey = generatePublicKey(initialEncryptionKey); - return registeredAccounts.find(registered => registered.publicKey.equals(publicKey)) != undefined; + return ( + registeredAccounts.find(registered => registered.masterIncomingViewingPublicKey.equals(publicKey)) != undefined + ); }).map(secretKey => { const signingKey = sha512ToGrumpkinScalar([secretKey, GeneratorIndex.IVSK_M]); // TODO(#5726): use actual salt here instead of hardcoding Fr.ZERO diff --git a/yarn-project/aztec.js/src/account/contract.ts b/yarn-project/aztec.js/src/account/contract.ts index 6ae607d386b..6c49a3b5cf0 100644 --- a/yarn-project/aztec.js/src/account/contract.ts +++ b/yarn-project/aztec.js/src/account/contract.ts @@ -1,5 +1,4 @@ import { type CompleteAddress } from '@aztec/circuit-types'; -import { type Fr } from '@aztec/circuits.js'; import { type ContractArtifact } from '@aztec/foundation/abi'; import { type NodeInfo } from '@aztec/types/interfaces'; @@ -26,11 +25,10 @@ export interface AccountContract { * The account interface is responsible for assembling tx requests given requested function calls, and * for creating signed auth witnesses given action identifiers (message hashes). * @param address - Address where this account contract is deployed. - * @param publicKeysHash - Hash of the public keys used to authorize actions. * @param nodeInfo - Info on the chain where it is deployed. * @returns An account interface instance for creating tx requests and authorizing actions. */ - getInterface(address: CompleteAddress, publicKeysHash: Fr, nodeInfo: NodeInfo): AccountInterface; + getInterface(address: CompleteAddress, nodeInfo: NodeInfo): AccountInterface; /** * Returns the auth witness provider for the given address. diff --git a/yarn-project/aztec.js/src/account/interface.ts b/yarn-project/aztec.js/src/account/interface.ts index 555fce8cbbc..5a5ab2cf28e 100644 --- a/yarn-project/aztec.js/src/account/interface.ts +++ b/yarn-project/aztec.js/src/account/interface.ts @@ -42,9 +42,6 @@ export interface AccountInterface extends AuthWitnessProvider, EntrypointInterfa /** Returns the complete address for this account. */ getCompleteAddress(): CompleteAddress; - /** Returns the public keys hash for this account. */ - getPublicKeysHash(): Fr; - /** Returns the address for this account. */ getAddress(): AztecAddress; diff --git a/yarn-project/aztec.js/src/account_manager/index.ts b/yarn-project/aztec.js/src/account_manager/index.ts index 549855a4d97..842236286a1 100644 --- a/yarn-project/aztec.js/src/account_manager/index.ts +++ b/yarn-project/aztec.js/src/account_manager/index.ts @@ -51,7 +51,7 @@ export class AccountManager { public async getAccount(): Promise { const nodeInfo = await this.pxe.getNodeInfo(); const completeAddress = this.getCompleteAddress(); - return this.accountContract.getInterface(completeAddress, this.getPublicKeysHash(), nodeInfo); + return this.accountContract.getInterface(completeAddress, nodeInfo); } /** diff --git a/yarn-project/aztec.js/src/utils/account.ts b/yarn-project/aztec.js/src/utils/account.ts index c128d8e227e..b9cc606b9b6 100644 --- a/yarn-project/aztec.js/src/utils/account.ts +++ b/yarn-project/aztec.js/src/utils/account.ts @@ -14,7 +14,7 @@ export async function waitForAccountSynch( address: CompleteAddress, { interval, timeout }: WaitOpts = DefaultWaitOpts, ): Promise { - const publicKey = address.publicKey.toString(); + const publicKey = address.masterIncomingViewingPublicKey.toString(); await retryUntil( async () => { const status = await pxe.getSyncStatus(); diff --git a/yarn-project/aztec.js/src/wallet/account_wallet.ts b/yarn-project/aztec.js/src/wallet/account_wallet.ts index a1f7cea1848..803d07010eb 100644 --- a/yarn-project/aztec.js/src/wallet/account_wallet.ts +++ b/yarn-project/aztec.js/src/wallet/account_wallet.ts @@ -16,10 +16,6 @@ export class AccountWallet extends BaseWallet { super(pxe); } - getPublicKeysHash(): Fr { - return this.account.getPublicKeysHash(); - } - createTxExecutionRequest(exec: ExecutionRequestInit): Promise { return this.account.createTxExecutionRequest(exec); } diff --git a/yarn-project/aztec.js/src/wallet/base_wallet.ts b/yarn-project/aztec.js/src/wallet/base_wallet.ts index eeacdb4f23a..200ad930dee 100644 --- a/yarn-project/aztec.js/src/wallet/base_wallet.ts +++ b/yarn-project/aztec.js/src/wallet/base_wallet.ts @@ -32,8 +32,6 @@ export abstract class BaseWallet implements Wallet { abstract getCompleteAddress(): CompleteAddress; - abstract getPublicKeysHash(): Fr; - abstract getChainId(): Fr; abstract getVersion(): Fr; @@ -80,9 +78,6 @@ export abstract class BaseWallet implements Wallet { getRegisteredAccount(address: AztecAddress): Promise { return this.pxe.getRegisteredAccount(address); } - getRegisteredAccountPublicKeysHash(address: AztecAddress): Promise { - return this.pxe.getRegisteredAccountPublicKeysHash(address); - } getRecipients(): Promise { return this.pxe.getRecipients(); } diff --git a/yarn-project/aztec.js/src/wallet/index.ts b/yarn-project/aztec.js/src/wallet/index.ts index ad92b67fdd0..08e8cb27c41 100644 --- a/yarn-project/aztec.js/src/wallet/index.ts +++ b/yarn-project/aztec.js/src/wallet/index.ts @@ -25,11 +25,7 @@ export async function getWallet( if (!completeAddress) { throw new Error(`Account ${address} not found`); } - const publicKeysHash = await pxe.getRegisteredAccountPublicKeysHash(address); - if (!publicKeysHash) { - throw new Error(`Public keys hash for account ${address} not found`); - } const nodeInfo = await pxe.getNodeInfo(); - const entrypoint = accountContract.getInterface(completeAddress, publicKeysHash, nodeInfo); + const entrypoint = accountContract.getInterface(completeAddress, nodeInfo); return new AccountWallet(pxe, entrypoint); } diff --git a/yarn-project/aztec/src/cli/util.ts b/yarn-project/aztec/src/cli/util.ts index 769e3b1aba1..610aa727288 100644 --- a/yarn-project/aztec/src/cli/util.ts +++ b/yarn-project/aztec/src/cli/util.ts @@ -126,7 +126,14 @@ export async function createAccountLogs( accountLogStrings.push(` Address: ${completeAddress.address.toString()}\n`); accountLogStrings.push(` Partial Address: ${completeAddress.partialAddress.toString()}\n`); accountLogStrings.push(` Secret Key: ${account.secretKey.toString()}\n`); - accountLogStrings.push(` Public Key: ${completeAddress.publicKey.toString()}\n\n`); + accountLogStrings.push(` Master nullifier public key: ${completeAddress.masterNullifierPublicKey.toString()}\n`); + accountLogStrings.push( + ` Master incoming viewing public key: ${completeAddress.masterIncomingViewingPublicKey.toString()}\n\n`, + ); + accountLogStrings.push( + ` Master outgoing viewing public key: ${completeAddress.masterOutgoingViewingPublicKey.toString()}\n\n`, + ); + accountLogStrings.push(` Master tagging public key: ${completeAddress.masterTaggingPublicKey.toString()}\n\n`); } } return accountLogStrings; diff --git a/yarn-project/circuit-types/src/interfaces/pxe.ts b/yarn-project/circuit-types/src/interfaces/pxe.ts index 0e95d024727..9e01820e4f7 100644 --- a/yarn-project/circuit-types/src/interfaces/pxe.ts +++ b/yarn-project/circuit-types/src/interfaces/pxe.ts @@ -1,4 +1,4 @@ -import { type AztecAddress, type CompleteAddress, type Fr, type PartialAddress, type Point } from '@aztec/circuits.js'; +import { type AztecAddress, type CompleteAddress, type Fr, type PartialAddress } from '@aztec/circuits.js'; import { type ContractArtifact } from '@aztec/foundation/abi'; import { type ContractClassWithId, type ContractInstanceWithAddress } from '@aztec/types/contracts'; import { type NodeInfo } from '@aztec/types/interfaces'; @@ -73,8 +73,7 @@ export interface PXE { * the recipient's notes. We can send notes to this account because we can encrypt them with the recipient's * public key. */ - // TODO: #5834: Nuke publicKeys optional parameter after `CompleteAddress` refactor. - registerRecipient(recipient: CompleteAddress, publicKeys?: Point[]): Promise; + registerRecipient(recipient: CompleteAddress): Promise; /** * Retrieves the user accounts registered on this PXE Service. @@ -91,15 +90,6 @@ export interface PXE { */ getRegisteredAccount(address: AztecAddress): Promise; - /** - * Retrieves the public keys hash of the account corresponding to the provided aztec address. - * - * @param address - The address of account. - * @returns The public keys hash of the requested account if found. - * TODO(#5834): refactor complete address and merge with getRegisteredAccount? - */ - getRegisteredAccountPublicKeysHash(address: AztecAddress): Promise; - /** * Retrieves the recipients added to this PXE Service. * @returns An array of recipients registered on this PXE Service. diff --git a/yarn-project/circuit-types/src/keys/key_store.ts b/yarn-project/circuit-types/src/keys/key_store.ts index 168ec8d5f04..b4a0d7ce300 100644 --- a/yarn-project/circuit-types/src/keys/key_store.ts +++ b/yarn-project/circuit-types/src/keys/key_store.ts @@ -1,9 +1,9 @@ import { type AztecAddress, + type CompleteAddress, type Fr, type GrumpkinPrivateKey, type PartialAddress, - type Point, type PublicKey, } from '@aztec/circuits.js'; @@ -13,17 +13,17 @@ import { export interface KeyStore { /** * Creates a new account from a randomly generated secret key. - * @returns A promise that resolves to the newly created account's AztecAddress. + * @returns A promise that resolves to the newly created account's CompleteAddress. */ - createAccount(): Promise; + createAccount(): Promise; /** * Adds an account to the key store from the provided secret key. * @param sk - The secret key of the account. * @param partialAddress - The partial address of the account. - * @returns The account's address. + * @returns The account's complete address. */ - addAccount(sk: Fr, partialAddress: PartialAddress): Promise; + addAccount(sk: Fr, partialAddress: PartialAddress): Promise; /** * Retrieves addresses of accounts stored in the key store. @@ -117,21 +117,4 @@ export interface KeyStore { * @returns A Promise that resolves to the public keys hash. */ getPublicKeysHash(account: AztecAddress): Promise; - - /** - * This is used to register a recipient / for storing public keys of an address - * @param accountAddress - The account address to store keys for. - * @param masterNullifierPublicKey - The stored master nullifier public key - * @param masterIncomingViewingPublicKey - The stored incoming viewing public key - * @param masterOutgoingViewingPublicKey - The stored outgoing viewing public key - * @param masterTaggingPublicKey - The stored master tagging public key - */ - // TODO(#5834): Move this function out of here. Key store should only be used for accounts, not recipients - addPublicKeysForAccount( - accountAddress: AztecAddress, - masterNullifierPublicKey: Point, - masterIncomingViewingPublicKey: Point, - masterOutgoingViewingPublicKey: Point, - masterTaggingPublicKey: Point, - ): Promise; } diff --git a/yarn-project/circuits.js/src/contract/__snapshots__/contract_address.test.ts.snap b/yarn-project/circuits.js/src/contract/__snapshots__/contract_address.test.ts.snap index dc6ef757820..37d75fc64af 100644 --- a/yarn-project/circuits.js/src/contract/__snapshots__/contract_address.test.ts.snap +++ b/yarn-project/circuits.js/src/contract/__snapshots__/contract_address.test.ts.snap @@ -1,7 +1,5 @@ // Jest Snapshot v1, https://goo.gl/fbAQLP -exports[`ContractAddress Address from partial matches Noir 1`] = `"0x1b6ead051e7b42665064ca6cf1ec77da0a36d86e00d1ff6e44077966c0c3a9fa"`; - exports[`ContractAddress Public key hash matches Noir 1`] = `"0x22d83a089d7650514c2de24cd30185a414d943eaa19817c67bffe2c3183006a3"`; exports[`ContractAddress computeContractAddressFromInstance 1`] = `"0x0bed63221d281713007bfb0c063e1f61d0646404fb3701b99bb92f41b6390604"`; diff --git a/yarn-project/circuits.js/src/contract/contract_address.test.ts b/yarn-project/circuits.js/src/contract/contract_address.test.ts index 6199e69a25d..a2c84f657a1 100644 --- a/yarn-project/circuits.js/src/contract/contract_address.test.ts +++ b/yarn-project/circuits.js/src/contract/contract_address.test.ts @@ -5,7 +5,6 @@ import { setupCustomSnapshotSerializers, updateInlineTestData } from '@aztec/fou import { AztecAddress, deriveKeys } from '../index.js'; import { computeContractAddressFromInstance, - computeContractAddressFromPartial, computeInitializationHash, computePartialAddress, computeSaltedInitializationHash, @@ -69,14 +68,6 @@ describe('ContractAddress', () => { }).toString(); expect(address).toMatchSnapshot(); - - // TODO(#5834): the following was removed from aztec_address.nr, should it be re-introduced? - // // Run with AZTEC_GENERATE_TEST_DATA=1 to update noir test data - // updateInlineTestData( - // 'noir-projects/noir-protocol-circuits/crates/types/src/address/aztec_address.nr', - // 'expected_computed_address_from_preimage', - // address.toString(), - // ); }); it('Public key hash matches Noir', () => { @@ -91,18 +82,4 @@ describe('ContractAddress', () => { hash.toString(), ); }); - - it('Address from partial matches Noir', () => { - const publicKeysHash = new Fr(1n); - const partialAddress = new Fr(2n); - const address = computeContractAddressFromPartial({ publicKeysHash, partialAddress }).toString(); - expect(address).toMatchSnapshot(); - - // Run with AZTEC_GENERATE_TEST_DATA=1 to update noir test data - updateInlineTestData( - 'noir-projects/noir-protocol-circuits/crates/types/src/address/aztec_address.nr', - 'expected_computed_address_from_partial_and_pubkey', - address.toString(), - ); - }); }); diff --git a/yarn-project/circuits.js/src/contract/contract_address.ts b/yarn-project/circuits.js/src/contract/contract_address.ts index 11c4dade226..353e3737d90 100644 --- a/yarn-project/circuits.js/src/contract/contract_address.ts +++ b/yarn-project/circuits.js/src/contract/contract_address.ts @@ -1,12 +1,12 @@ import { type FunctionAbi, FunctionSelector, encodeArguments } from '@aztec/foundation/abi'; -import { AztecAddress } from '@aztec/foundation/aztec-address'; -import { pedersenHash, poseidon2Hash } from '@aztec/foundation/crypto'; +import { type AztecAddress } from '@aztec/foundation/aztec-address'; +import { pedersenHash } from '@aztec/foundation/crypto'; import { Fr } from '@aztec/foundation/fields'; import { type ContractInstance } from '@aztec/types/contracts'; import { GeneratorIndex } from '../constants.gen.js'; import { computeVarArgsHash } from '../hash/hash.js'; -import { deriveKeys } from '../keys/index.js'; +import { computeAddress } from '../keys/index.js'; // TODO(@spalladino): Review all generator indices in this file @@ -26,7 +26,7 @@ export function computeContractAddressFromInstance( ): AztecAddress { const partialAddress = computePartialAddress(instance); const publicKeysHash = instance.publicKeysHash; - return computeContractAddressFromPartial({ partialAddress, publicKeysHash }); + return computeAddress(publicKeysHash, partialAddress); } /** @@ -56,19 +56,6 @@ export function computeSaltedInitializationHash( return pedersenHash([instance.salt, instance.initializationHash, instance.deployer], GeneratorIndex.PARTIAL_ADDRESS); } -/** - * Computes a contract address from its partial address and public keys hash. - * @param args - The hash of the public keys or the plain public key to be hashed, along with the partial address. - * @returns The contract address. - */ -export function computeContractAddressFromPartial( - args: ({ publicKeysHash: Fr } | { secretKey: Fr }) & { partialAddress: Fr }, -): AztecAddress { - const publicKeysHash = 'secretKey' in args ? deriveKeys(args.secretKey).publicKeysHash : args.publicKeysHash; - const result = poseidon2Hash([publicKeysHash, args.partialAddress, GeneratorIndex.CONTRACT_ADDRESS_V1]); - return AztecAddress.fromField(result); -} - /** * Computes the initialization hash for an instance given its constructor function and arguments. * @param initFn - Constructor function or empty if no initialization is expected. diff --git a/yarn-project/circuits.js/src/hints/build_nullifier_non_existent_read_request_hints.test.ts b/yarn-project/circuits.js/src/hints/build_nullifier_non_existent_read_request_hints.test.ts index 89a2fb7ed6d..7c5834fe945 100644 --- a/yarn-project/circuits.js/src/hints/build_nullifier_non_existent_read_request_hints.test.ts +++ b/yarn-project/circuits.js/src/hints/build_nullifier_non_existent_read_request_hints.test.ts @@ -131,6 +131,8 @@ describe('buildNullifierNonExistentReadRequestHints', () => { nonExistentReadRequests[0] = makeReadRequest(innerNullifier(2)); - await expect(() => buildHints()).rejects.toThrow('Nullifier exists in the pending set.'); + await expect(() => buildHints()).rejects.toThrow( + 'Nullifier DOES exists in the pending set at the time of reading, but there is a NonExistentReadRequest for it.', + ); }); }); diff --git a/yarn-project/circuits.js/src/hints/build_nullifier_non_existent_read_request_hints.ts b/yarn-project/circuits.js/src/hints/build_nullifier_non_existent_read_request_hints.ts index 5bb6fa3eb76..b2393cf4bc4 100644 --- a/yarn-project/circuits.js/src/hints/build_nullifier_non_existent_read_request_hints.ts +++ b/yarn-project/circuits.js/src/hints/build_nullifier_non_existent_read_request_hints.ts @@ -71,8 +71,13 @@ export async function buildNullifierNonExistentReadRequestHints( let nextPendingValueIndex = sortedValues.findIndex(v => !v.value.lt(siloedValue)); if (nextPendingValueIndex == -1) { nextPendingValueIndex = numPendingNullifiers; - } else if (sortedValues[nextPendingValueIndex].value.equals(siloedValue)) { - throw new Error('Nullifier exists in the pending set.'); + } else if ( + sortedValues[nextPendingValueIndex].value.equals(siloedValue) && + sortedValues[nextPendingValueIndex].counter < readRequest.counter + ) { + throw new Error( + 'Nullifier DOES exists in the pending set at the time of reading, but there is a NonExistentReadRequest for it.', + ); } builder.addHint(membershipWitness, leafPreimage, nextPendingValueIndex); diff --git a/yarn-project/circuits.js/src/keys/__snapshots__/index.test.ts.snap b/yarn-project/circuits.js/src/keys/__snapshots__/index.test.ts.snap new file mode 100644 index 00000000000..1ec1734aed8 --- /dev/null +++ b/yarn-project/circuits.js/src/keys/__snapshots__/index.test.ts.snap @@ -0,0 +1,3 @@ +// Jest Snapshot v1, https://goo.gl/fbAQLP + +exports[`🔑 Address from partial matches Noir 1`] = `"0x1b6ead051e7b42665064ca6cf1ec77da0a36d86e00d1ff6e44077966c0c3a9fa"`; diff --git a/yarn-project/circuits.js/src/keys/index.test.ts b/yarn-project/circuits.js/src/keys/index.test.ts new file mode 100644 index 00000000000..13d54bbbab1 --- /dev/null +++ b/yarn-project/circuits.js/src/keys/index.test.ts @@ -0,0 +1,44 @@ +import { Fr, Point } from '@aztec/foundation/fields'; +import { updateInlineTestData } from '@aztec/foundation/testing'; + +import { computeAddress, computePublicKeysHash } from './index.js'; + +describe('🔑', () => { + it('computing public keys hash matches Noir', () => { + const masterNullifierPublicKey = new Point(new Fr(1), new Fr(2)); + const masterIncomingViewingPublicKey = new Point(new Fr(3), new Fr(4)); + const masterOutgoingViewingPublicKey = new Point(new Fr(5), new Fr(6)); + const masterTaggingPublicKey = new Point(new Fr(7), new Fr(8)); + + const expected = Fr.fromString('0x1936abe4f6a920d16a9f6917f10a679507687e2cd935dd1f1cdcb1e908c027f3'); + expect( + computePublicKeysHash( + masterNullifierPublicKey, + masterIncomingViewingPublicKey, + masterOutgoingViewingPublicKey, + masterTaggingPublicKey, + ), + ).toEqual(expected); + + // Run with AZTEC_GENERATE_TEST_DATA=1 to update noir test data + updateInlineTestData( + 'noir-projects/noir-protocol-circuits/crates/types/src/address/public_keys_hash.nr', + 'expected_public_keys_hash', + expected.toString(), + ); + }); + + it('Address from partial matches Noir', () => { + const publicKeysHash = new Fr(1n); + const partialAddress = new Fr(2n); + const address = computeAddress(publicKeysHash, partialAddress).toString(); + expect(address).toMatchSnapshot(); + + // Run with AZTEC_GENERATE_TEST_DATA=1 to update noir test data + updateInlineTestData( + 'noir-projects/noir-protocol-circuits/crates/types/src/address/aztec_address.nr', + 'expected_computed_address_from_partial_and_pubkey', + address.toString(), + ); + }); +}); diff --git a/yarn-project/circuits.js/src/keys/index.ts b/yarn-project/circuits.js/src/keys/index.ts index f8da77fcba5..11fd962e75e 100644 --- a/yarn-project/circuits.js/src/keys/index.ts +++ b/yarn-project/circuits.js/src/keys/index.ts @@ -1,4 +1,4 @@ -import { type AztecAddress } from '@aztec/foundation/aztec-address'; +import { AztecAddress } from '@aztec/foundation/aztec-address'; import { poseidon2Hash, sha512ToGrumpkinScalar } from '@aztec/foundation/crypto'; import { type Fr, type GrumpkinScalar } from '@aztec/foundation/fields'; @@ -39,6 +39,11 @@ export function computePublicKeysHash( ]); } +export function computeAddress(publicKeysHash: Fr, partialAddress: Fr) { + const addressFr = poseidon2Hash([publicKeysHash, partialAddress, GeneratorIndex.CONTRACT_ADDRESS_V1]); + return AztecAddress.fromField(addressFr); +} + /** * Computes secret and public keys and public keys hash from a secret key. * @param secretKey - The secret key to derive keys from. diff --git a/yarn-project/circuits.js/src/structs/complete_address.test.ts b/yarn-project/circuits.js/src/structs/complete_address.test.ts index e8ce620e5e4..70c006ed2b2 100644 --- a/yarn-project/circuits.js/src/structs/complete_address.test.ts +++ b/yarn-project/circuits.js/src/structs/complete_address.test.ts @@ -4,16 +4,29 @@ import { Fr, Point } from '@aztec/foundation/fields'; import { CompleteAddress } from './complete_address.js'; describe('CompleteAddress', () => { - // TODO(#5834): re-enable or remove this test - it.skip('refuses to add an account with incorrect address for given partial address and pubkey', () => { - expect(() => CompleteAddress.create(AztecAddress.random(), Point.random(), Fr.random())).toThrow( - /cannot be derived/, - ); + it('refuses to add an account with incorrect address for given partial address and pubkey', () => { + expect(() => + CompleteAddress.create( + AztecAddress.random(), + Point.random(), + Point.random(), + Point.random(), + Point.random(), + Fr.random(), + ), + ).toThrow(/cannot be derived/); }); it('equals returns true when 2 instances are equal', () => { const address1 = CompleteAddress.random(); - const address2 = CompleteAddress.create(address1.address, address1.publicKey, address1.partialAddress); + const address2 = CompleteAddress.create( + address1.address, + address1.masterNullifierPublicKey, + address1.masterIncomingViewingPublicKey, + address1.masterOutgoingViewingPublicKey, + address1.masterTaggingPublicKey, + address1.partialAddress, + ); expect(address1.equals(address2)).toBe(true); }); diff --git a/yarn-project/circuits.js/src/structs/complete_address.ts b/yarn-project/circuits.js/src/structs/complete_address.ts index f4465685ca0..2e57265516a 100644 --- a/yarn-project/circuits.js/src/structs/complete_address.ts +++ b/yarn-project/circuits.js/src/structs/complete_address.ts @@ -2,8 +2,8 @@ import { AztecAddress } from '@aztec/foundation/aztec-address'; import { Fr, Point } from '@aztec/foundation/fields'; import { BufferReader } from '@aztec/foundation/serialize'; -import { computeContractAddressFromPartial, computePartialAddress } from '../contract/contract_address.js'; -import { deriveKeys } from '../keys/index.js'; +import { computePartialAddress } from '../contract/contract_address.js'; +import { computeAddress, computePublicKeysHash, deriveKeys } from '../keys/index.js'; import { type PartialAddress } from '../types/partial_address.js'; import { type PublicKey } from '../types/public_key.js'; @@ -22,8 +22,14 @@ export class CompleteAddress { public constructor( /** Contract address (typically of an account contract) */ public address: AztecAddress, - /** Public key corresponding to the address (used during note encryption). */ - public publicKey: PublicKey, + /** Master nullifier public key */ + public masterNullifierPublicKey: PublicKey, + /** Master incoming viewing public key */ + public masterIncomingViewingPublicKey: PublicKey, + /** Master outgoing viewing public key */ + public masterOutgoingViewingPublicKey: PublicKey, + /** Master tagging viewing public key */ + public masterTaggingPublicKey: PublicKey, /** Partial key corresponding to the public key to the address. */ public partialAddress: PartialAddress, ) {} @@ -31,32 +37,47 @@ export class CompleteAddress { /** Size in bytes of an instance */ static readonly SIZE_IN_BYTES = 32 * 4; - static create(address: AztecAddress, publicKey: PublicKey, partialAddress: PartialAddress) { - const completeAddress = new CompleteAddress(address, publicKey, partialAddress); - // TODO(#5834): re-enable validation - // completeAddress.validate(); + static create( + address: AztecAddress, + masterNullifierPublicKey: PublicKey, + masterIncomingViewingPublicKey: PublicKey, + masterOutgoingViewingPublicKey: PublicKey, + masterTaggingPublicKey: PublicKey, + partialAddress: PartialAddress, + ): CompleteAddress { + const completeAddress = new CompleteAddress( + address, + masterNullifierPublicKey, + masterIncomingViewingPublicKey, + masterOutgoingViewingPublicKey, + masterTaggingPublicKey, + partialAddress, + ); + completeAddress.validate(); return completeAddress; } - static random() { - // TODO(#5834): the following should be cleaned up - const secretKey = Fr.random(); - const partialAddress = Fr.random(); - const address = computeContractAddressFromPartial({ secretKey, partialAddress }); - const publicKey = deriveKeys(secretKey).masterIncomingViewingPublicKey; - return new CompleteAddress(address, publicKey, partialAddress); - } - - static fromRandomSecretKey() { - const secretKey = Fr.random(); - const partialAddress = Fr.random(); - return { secretKey, completeAddress: CompleteAddress.fromSecretKeyAndPartialAddress(secretKey, partialAddress) }; + static random(): CompleteAddress { + return this.fromSecretKeyAndPartialAddress(Fr.random(), Fr.random()); } static fromSecretKeyAndPartialAddress(secretKey: Fr, partialAddress: Fr): CompleteAddress { - const address = computeContractAddressFromPartial({ secretKey, partialAddress }); - const publicKey = deriveKeys(secretKey).masterIncomingViewingPublicKey; - return new CompleteAddress(address, publicKey, partialAddress); + const { + masterNullifierPublicKey, + masterIncomingViewingPublicKey, + masterOutgoingViewingPublicKey, + masterTaggingPublicKey, + publicKeysHash, + } = deriveKeys(secretKey); + const address = computeAddress(publicKeysHash, partialAddress); + return new CompleteAddress( + address, + masterNullifierPublicKey, + masterIncomingViewingPublicKey, + masterOutgoingViewingPublicKey, + masterTaggingPublicKey, + partialAddress, + ); } static fromSecretKeyAndInstance( @@ -64,29 +85,31 @@ export class CompleteAddress { instance: Parameters[0], ): CompleteAddress { const partialAddress = computePartialAddress(instance); - const address = computeContractAddressFromPartial({ secretKey, partialAddress }); - const publicKey = deriveKeys(secretKey).masterIncomingViewingPublicKey; - return new CompleteAddress(address, publicKey, partialAddress); + return CompleteAddress.fromSecretKeyAndPartialAddress(secretKey, partialAddress); } - // TODO(#5834): re-enable validation - // /** Throws if the address is not correctly derived from the public key and partial address.*/ - // public validate() { - // const expectedAddress = computeContractAddressFromPartial(this); - // const address = this.address; - // if (!expectedAddress.equals(address)) { - // throw new Error( - // `Address cannot be derived from pubkey and partial address (received ${address.toString()}, derived ${expectedAddress.toString()})`, - // ); - // } - // } + /** Throws if the address is not correctly derived from the public key and partial address.*/ + public validate() { + const publicKeysHash = computePublicKeysHash( + this.masterNullifierPublicKey, + this.masterIncomingViewingPublicKey, + this.masterOutgoingViewingPublicKey, + this.masterTaggingPublicKey, + ); + const expectedAddress = computeAddress(publicKeysHash, this.partialAddress); + if (!expectedAddress.equals(this.address)) { + throw new Error( + `Address cannot be derived from public keys and partial address (received ${this.address.toString()}, derived ${expectedAddress.toString()})`, + ); + } + } /** - * Gets a readable string representation of a the complete address. + * Gets a readable string representation of the complete address. * @returns A readable string representation of the complete address. */ public toReadableString(): string { - return ` Address: ${this.address.toString()}\n Public Key: ${this.publicKey.toString()}\n Partial Address: ${this.partialAddress.toString()}\n`; + return `Address: ${this.address.toString()}\nMaster Nullifier Public Key: ${this.masterNullifierPublicKey.toString()}\nMaster Incoming Viewing Public Key: ${this.masterIncomingViewingPublicKey.toString()}\nMaster Outgoing Viewing Public Key: ${this.masterOutgoingViewingPublicKey.toString()}\nMaster Tagging Public Key: ${this.masterTaggingPublicKey.toString()}\nPartial Address: ${this.partialAddress.toString()}\n`; } /** @@ -96,10 +119,13 @@ export class CompleteAddress { * @param other - The CompleteAddress instance to compare against. * @returns True if the buffers of both instances are equal, false otherwise. */ - equals(other: CompleteAddress) { + equals(other: CompleteAddress): boolean { return ( this.address.equals(other.address) && - this.publicKey.equals(other.publicKey) && + this.masterNullifierPublicKey.equals(other.masterNullifierPublicKey) && + this.masterIncomingViewingPublicKey.equals(other.masterIncomingViewingPublicKey) && + this.masterOutgoingViewingPublicKey.equals(other.masterOutgoingViewingPublicKey) && + this.masterTaggingPublicKey.equals(other.masterTaggingPublicKey) && this.partialAddress.equals(other.partialAddress) ); } @@ -110,8 +136,15 @@ export class CompleteAddress { * * @returns A Buffer representation of the CompleteAddress instance. */ - toBuffer() { - return Buffer.concat([this.address.toBuffer(), this.publicKey.toBuffer(), this.partialAddress.toBuffer()]); + toBuffer(): Buffer { + return Buffer.concat([ + this.address.toBuffer(), + this.masterNullifierPublicKey.toBuffer(), + this.masterIncomingViewingPublicKey.toBuffer(), + this.masterOutgoingViewingPublicKey.toBuffer(), + this.masterTaggingPublicKey.toBuffer(), + this.partialAddress.toBuffer(), + ]); } /** @@ -122,12 +155,22 @@ export class CompleteAddress { * @param buffer - The input buffer or BufferReader containing the address data. * @returns - A new CompleteAddress instance with the extracted address data. */ - static fromBuffer(buffer: Buffer | BufferReader) { + static fromBuffer(buffer: Buffer | BufferReader): CompleteAddress { const reader = BufferReader.asReader(buffer); const address = reader.readObject(AztecAddress); - const publicKey = reader.readObject(Point); + const masterNullifierPublicKey = reader.readObject(Point); + const masterIncomingViewingPublicKey = reader.readObject(Point); + const masterOutgoingViewingPublicKey = reader.readObject(Point); + const masterTaggingPublicKey = reader.readObject(Point); const partialAddress = reader.readObject(Fr); - return new this(address, publicKey, partialAddress); + return new CompleteAddress( + address, + masterNullifierPublicKey, + masterIncomingViewingPublicKey, + masterOutgoingViewingPublicKey, + masterTaggingPublicKey, + partialAddress, + ); } /** @@ -151,4 +194,13 @@ export class CompleteAddress { toString(): string { return `0x${this.toBuffer().toString('hex')}`; } + + get publicKeysHash(): Fr { + return computePublicKeysHash( + this.masterNullifierPublicKey, + this.masterIncomingViewingPublicKey, + this.masterOutgoingViewingPublicKey, + this.masterTaggingPublicKey, + ); + } } diff --git a/yarn-project/circuits.js/src/structs/private_circuit_public_inputs.ts b/yarn-project/circuits.js/src/structs/private_circuit_public_inputs.ts index 8697c01cbdd..03c22846910 100644 --- a/yarn-project/circuits.js/src/structs/private_circuit_public_inputs.ts +++ b/yarn-project/circuits.js/src/structs/private_circuit_public_inputs.ts @@ -39,7 +39,6 @@ import { TxContext } from './tx_context.js'; /** * Public inputs to a private circuit. - * @see abis/private_circuit_public_inputs.hpp. */ export class PrivateCircuitPublicInputs { constructor( diff --git a/yarn-project/end-to-end/src/benchmarks/utils.ts b/yarn-project/end-to-end/src/benchmarks/utils.ts index 0dbbe2d6162..1072040b1ce 100644 --- a/yarn-project/end-to-end/src/benchmarks/utils.ts +++ b/yarn-project/end-to-end/src/benchmarks/utils.ts @@ -127,7 +127,8 @@ export async function waitNewPXESynced( */ export async function waitRegisteredAccountSynced(pxe: PXE, secretKey: Fr, partialAddress: PartialAddress) { const l2Block = await pxe.getBlockNumber(); - const { publicKey } = await pxe.registerAccount(secretKey, partialAddress); - const isAccountSynced = async () => (await pxe.getSyncStatus()).notes[publicKey.toString()] === l2Block; + const { masterIncomingViewingPublicKey } = await pxe.registerAccount(secretKey, partialAddress); + const isAccountSynced = async () => + (await pxe.getSyncStatus()).notes[masterIncomingViewingPublicKey.toString()] === l2Block; await retryUntil(isAccountSynced, 'pxe-notes-sync'); } diff --git a/yarn-project/end-to-end/src/client_prover_integration/client_prover_test.ts b/yarn-project/end-to-end/src/client_prover_integration/client_prover_test.ts index 0303dcecdab..54318e1702c 100644 --- a/yarn-project/end-to-end/src/client_prover_integration/client_prover_test.ts +++ b/yarn-project/end-to-end/src/client_prover_integration/client_prover_test.ts @@ -130,7 +130,6 @@ export class ClientProverTest { this.logger.debug(`Main setup completed, initializing full prover PXE...`); ({ pxe: this.fullProverPXE, teardown: this.provenPXETeardown } = await setupPXEService( - 0, this.aztecNode, { proverEnabled: false, diff --git a/yarn-project/end-to-end/src/e2e_2_pxes.test.ts b/yarn-project/end-to-end/src/e2e_2_pxes.test.ts index 7e185f169e0..ebb1ec71e14 100644 --- a/yarn-project/end-to-end/src/e2e_2_pxes.test.ts +++ b/yarn-project/end-to-end/src/e2e_2_pxes.test.ts @@ -1,4 +1,5 @@ import { getUnsafeSchnorrAccount } from '@aztec/accounts/single_key'; +import { createAccounts } from '@aztec/accounts/testing'; import { type AztecAddress, type AztecNode, @@ -40,11 +41,9 @@ describe('e2e_2_pxes', () => { teardown: teardownA, } = await setup(1)); - ({ - pxe: pxeB, - wallets: [walletB], - teardown: teardownB, - } = await setupPXEService(1, aztecNode!, {}, undefined, true)); + ({ pxe: pxeB, teardown: teardownB } = await setupPXEService(aztecNode!, {}, undefined, true)); + + [walletB] = await createAccounts(pxeB, 1); }); afterEach(async () => { diff --git a/yarn-project/end-to-end/src/e2e_account_contracts.test.ts b/yarn-project/end-to-end/src/e2e_account_contracts.test.ts index e5f8f0743a5..68ac9c89e28 100644 --- a/yarn-project/end-to-end/src/e2e_account_contracts.test.ts +++ b/yarn-project/end-to-end/src/e2e_account_contracts.test.ts @@ -74,11 +74,7 @@ describe('e2e_account_contracts', () => { const walletAt = async (pxe: PXE, accountContract: AccountContract, address: CompleteAddress) => { const nodeInfo = await pxe.getNodeInfo(); - const publicKeysHash = await pxe.getRegisteredAccountPublicKeysHash(address.address); - if (!publicKeysHash) { - throw new Error(`Public keys hash for account ${address.address} not found`); - } - const entrypoint = accountContract.getInterface(address, publicKeysHash, nodeInfo); + const entrypoint = accountContract.getInterface(address, nodeInfo); return new AccountWallet(pxe, entrypoint); }; diff --git a/yarn-project/end-to-end/src/e2e_card_game.test.ts b/yarn-project/end-to-end/src/e2e_card_game.test.ts index f0949b16663..743ff3a38b3 100644 --- a/yarn-project/end-to-end/src/e2e_card_game.test.ts +++ b/yarn-project/end-to-end/src/e2e_card_game.test.ts @@ -105,7 +105,7 @@ describe('e2e_card_game', () => { const publicKey = deriveKeys(key).masterIncomingViewingPublicKey; return ( preRegisteredAccounts.find(preRegisteredAccount => { - return preRegisteredAccount.publicKey.equals(publicKey); + return preRegisteredAccount.masterIncomingViewingPublicKey.equals(publicKey); }) == undefined ); }); diff --git a/yarn-project/end-to-end/src/e2e_deploy_contract/deploy_test.ts b/yarn-project/end-to-end/src/e2e_deploy_contract/deploy_test.ts index 05b31422828..ffafbb038e7 100644 --- a/yarn-project/end-to-end/src/e2e_deploy_contract/deploy_test.ts +++ b/yarn-project/end-to-end/src/e2e_deploy_contract/deploy_test.ts @@ -3,11 +3,10 @@ import { type AccountWallet, type AztecAddress, type AztecNode, - CompleteAddress, type ContractArtifact, type ContractBase, type DebugLogger, - type Fr, + Fr, type PXE, type Wallet, createDebugLogger, @@ -81,10 +80,8 @@ export class DeployTest { } async registerRandomAccount(): Promise { - const pxe = this.pxe; - const { completeAddress: owner, secretKey } = CompleteAddress.fromRandomSecretKey(); - await pxe.registerAccount(secretKey, owner.partialAddress); - return owner.address; + const completeAddress = await this.pxe.registerAccount(Fr.random(), Fr.random()); + return completeAddress.address; } } diff --git a/yarn-project/end-to-end/src/e2e_deploy_contract/legacy.test.ts b/yarn-project/end-to-end/src/e2e_deploy_contract/legacy.test.ts index 403da38154c..25c96999fff 100644 --- a/yarn-project/end-to-end/src/e2e_deploy_contract/legacy.test.ts +++ b/yarn-project/end-to-end/src/e2e_deploy_contract/legacy.test.ts @@ -33,7 +33,7 @@ describe('e2e_deploy_contract legacy', () => { */ it('should deploy a test contract', async () => { const salt = Fr.random(); - const publicKeysHash = wallet.getPublicKeysHash(); + const publicKeysHash = wallet.getCompleteAddress().publicKeysHash; const deploymentData = getContractInstanceFromDeployParams(TestContractArtifact, { salt, publicKeysHash, @@ -68,7 +68,7 @@ describe('e2e_deploy_contract legacy', () => { logger.info(`Deploying contract ${index + 1}...`); const receipt = await deployer.deploy().send({ contractAddressSalt: Fr.random() }).wait({ wallet }); logger.info(`Sending TX to contract ${index + 1}...`); - await receipt.contract.methods.get_public_key(wallet.getAddress()).send().wait(); + await receipt.contract.methods.get_master_incoming_viewing_public_key(wallet.getAddress()).send().wait(); } }); diff --git a/yarn-project/end-to-end/src/e2e_key_registry.test.ts b/yarn-project/end-to-end/src/e2e_key_registry.test.ts index c770ceaf9dd..88d65a5037a 100644 --- a/yarn-project/end-to-end/src/e2e_key_registry.test.ts +++ b/yarn-project/end-to-end/src/e2e_key_registry.test.ts @@ -1,6 +1,5 @@ import { type AccountWallet, AztecAddress, Fr, type PXE } from '@aztec/aztec.js'; -import { CompleteAddress, GeneratorIndex, type PartialAddress, Point, deriveKeys } from '@aztec/circuits.js'; -import { poseidon2Hash } from '@aztec/foundation/crypto'; +import { CompleteAddress, Point } from '@aztec/circuits.js'; import { KeyRegistryContract, TestContract } from '@aztec/noir-contracts.js'; import { getCanonicalKeyRegistryAddress } from '@aztec/protocol-contracts/key-registry'; @@ -23,16 +22,7 @@ describe('Key Registry', () => { let teardown: () => Promise; - // TODO(#5834): use AztecAddress.compute or smt - const { - masterNullifierPublicKey, - masterIncomingViewingPublicKey, - masterOutgoingViewingPublicKey, - masterTaggingPublicKey, - publicKeysHash, - } = deriveKeys(Fr.random()); - const partialAddress: PartialAddress = Fr.random(); - let account: AztecAddress; + const account = CompleteAddress.random(); beforeAll(async () => { ({ teardown, pxe, wallets } = await setup(3)); @@ -41,11 +31,6 @@ describe('Key Registry', () => { testContract = await TestContract.deploy(wallets[0]).send().deployed(); await publicDeployAccounts(wallets[0], wallets.slice(0, 2)); - - // TODO(#5834): use AztecAddress.compute or smt - account = AztecAddress.fromField( - poseidon2Hash([publicKeysHash, partialAddress, GeneratorIndex.CONTRACT_ADDRESS_V1]), - ); }); const crossDelay = async () => { @@ -60,10 +45,10 @@ describe('Key Registry', () => { describe('failure cases', () => { it('throws when address preimage check fails', async () => { const keys = [ - masterNullifierPublicKey, - masterIncomingViewingPublicKey, - masterOutgoingViewingPublicKey, - masterTaggingPublicKey, + account.masterNullifierPublicKey, + account.masterIncomingViewingPublicKey, + account.masterOutgoingViewingPublicKey, + account.masterTaggingPublicKey, ]; // We randomly invalidate some of the keys @@ -72,7 +57,7 @@ describe('Key Registry', () => { await expect( keyRegistry .withWallet(wallets[0]) - .methods.register(AztecAddress.fromField(account), partialAddress, keys[0], keys[1], keys[2], keys[3]) + .methods.register(account, account.partialAddress, keys[0], keys[1], keys[2], keys[3]) .send() .wait(), ).rejects.toThrow('Computed address does not match supplied address'); @@ -82,7 +67,7 @@ describe('Key Registry', () => { await expect( keyRegistry .withWallet(wallets[0]) - .methods.rotate_nullifier_public_key(wallets[1].getAddress(), Point.random(), Fr.ZERO) + .methods.rotate_npk_m(wallets[1].getAddress(), Point.random(), Fr.ZERO) .send() .wait(), ).rejects.toThrow('Assertion failed: Message not authorized by account'); @@ -96,33 +81,20 @@ describe('Key Registry', () => { await expect( testContract.methods.test_nullifier_key_freshness(randomAddress, randomMasterNullifierPublicKey).send().wait(), - ).rejects.toThrow(`Cannot satisfy constraint 'computed_address.eq(address)'`); + ).rejects.toThrow(/No public key registered for address/); }); }); it('fresh key lib succeeds for non-registered account available in PXE', async () => { - // TODO(#5834): Make this not disgusting - const newAccountKeys = deriveKeys(Fr.random()); - const newAccountPartialAddress = Fr.random(); - const newAccount = AztecAddress.fromField( - poseidon2Hash([newAccountKeys.publicKeysHash, newAccountPartialAddress, GeneratorIndex.CONTRACT_ADDRESS_V1]), - ); - const newAccountCompleteAddress = CompleteAddress.create( - newAccount, - newAccountKeys.masterIncomingViewingPublicKey, - newAccountPartialAddress, - ); - - await pxe.registerRecipient(newAccountCompleteAddress, [ - newAccountKeys.masterNullifierPublicKey, - newAccountKeys.masterIncomingViewingPublicKey, - newAccountKeys.masterOutgoingViewingPublicKey, - newAccountKeys.masterTaggingPublicKey, - ]); + const newAccountCompleteAddress = CompleteAddress.random(); + await pxe.registerRecipient(newAccountCompleteAddress); // Should succeed as the account is now registered as a recipient in PXE await testContract.methods - .test_nullifier_key_freshness(newAccount, newAccountKeys.masterNullifierPublicKey) + .test_nullifier_key_freshness( + newAccountCompleteAddress.address, + newAccountCompleteAddress.masterNullifierPublicKey, + ) .send() .wait(); }); @@ -133,11 +105,11 @@ describe('Key Registry', () => { .withWallet(wallets[0]) .methods.register( account, - partialAddress, - masterNullifierPublicKey, - masterIncomingViewingPublicKey, - masterOutgoingViewingPublicKey, - masterTaggingPublicKey, + account.partialAddress, + account.masterNullifierPublicKey, + account.masterIncomingViewingPublicKey, + account.masterOutgoingViewingPublicKey, + account.masterTaggingPublicKey, ) .send() .wait(); @@ -157,13 +129,13 @@ describe('Key Registry', () => { .test_shared_mutable_private_getter_for_registry_contract(1, account) .simulate(); - expect(new Fr(nullifierPublicKeyX)).toEqual(masterNullifierPublicKey.x); + expect(new Fr(nullifierPublicKeyX)).toEqual(account.masterNullifierPublicKey.x); }); // Note: This test case is dependent on state from the previous one it('key lib succeeds for registered account', async () => { // Should succeed as the account is registered in key registry from tests before - await testContract.methods.test_nullifier_key_freshness(account, masterNullifierPublicKey).send().wait(); + await testContract.methods.test_nullifier_key_freshness(account, account.masterNullifierPublicKey).send().wait(); }); }); @@ -174,7 +146,7 @@ describe('Key Registry', () => { it('rotates npk_m', async () => { await keyRegistry .withWallet(wallets[0]) - .methods.rotate_nullifier_public_key(wallets[0].getAddress(), firstNewMasterNullifierPublicKey, Fr.ZERO) + .methods.rotate_npk_m(wallets[0].getAddress(), firstNewMasterNullifierPublicKey, Fr.ZERO) .send() .wait(); @@ -199,7 +171,7 @@ describe('Key Registry', () => { it(`rotates npk_m with authwit`, async () => { const action = keyRegistry .withWallet(wallets[1]) - .methods.rotate_nullifier_public_key(wallets[0].getAddress(), secondNewMasterNullifierPublicKey, Fr.ZERO); + .methods.rotate_npk_m(wallets[0].getAddress(), secondNewMasterNullifierPublicKey, Fr.ZERO); await wallets[0] .setPublicAuthWit({ caller: wallets[1].getCompleteAddress().address, action }, true) diff --git a/yarn-project/end-to-end/src/e2e_multiple_accounts_1_enc_key.test.ts b/yarn-project/end-to-end/src/e2e_multiple_accounts_1_enc_key.test.ts index a8b0ed53439..6aaae7545f7 100644 --- a/yarn-project/end-to-end/src/e2e_multiple_accounts_1_enc_key.test.ts +++ b/yarn-project/end-to-end/src/e2e_multiple_accounts_1_enc_key.test.ts @@ -50,7 +50,7 @@ describe('e2e_multiple_accounts_1_enc_key', () => { const encryptionPublicKey = deriveKeys(encryptionPrivateKey).masterIncomingViewingPublicKey; for (const account of accounts) { - expect(account.publicKey).toEqual(encryptionPublicKey); + expect(account.masterIncomingViewingPublicKey).toEqual(encryptionPublicKey); } logger.info(`Deploying Token...`); diff --git a/yarn-project/end-to-end/src/fixtures/snapshot_manager.ts b/yarn-project/end-to-end/src/fixtures/snapshot_manager.ts index 72651f79e27..7129a268eb9 100644 --- a/yarn-project/end-to-end/src/fixtures/snapshot_manager.ts +++ b/yarn-project/end-to-end/src/fixtures/snapshot_manager.ts @@ -8,9 +8,11 @@ import { EthCheatCodes, Fr, GrumpkinPrivateKey, + SignerlessWallet, type Wallet, } from '@aztec/aztec.js'; import { deployInstance, registerContractClass } from '@aztec/aztec.js/deployment'; +import { DefaultMultiCallEntrypoint } from '@aztec/aztec.js/entrypoint'; import { asyncMap } from '@aztec/foundation/async-map'; import { type Logger, createDebugLogger } from '@aztec/foundation/log'; import { makeBackoff, retry } from '@aztec/foundation/retry'; @@ -27,6 +29,7 @@ import { mnemonicToAccount } from 'viem/accounts'; import { MNEMONIC } from './fixtures.js'; import { getACVMConfig } from './get_acvm_config.js'; import { setupL1Contracts } from './setup_l1_contracts.js'; +import { deployCanonicalKeyRegistry } from './utils.js'; export type SubsystemsContext = { anvil: Anvil; @@ -264,6 +267,11 @@ async function setupFromFresh(statePath: string | undefined, logger: Logger): Pr pxeConfig.dataDirectory = statePath; const pxe = await createPXEService(aztecNode, pxeConfig); + logger.verbose('Deploying key registry...'); + await deployCanonicalKeyRegistry( + new SignerlessWallet(pxe, new DefaultMultiCallEntrypoint(aztecNodeConfig.chainId, aztecNodeConfig.version)), + ); + if (statePath) { writeFileSync(`${statePath}/aztec_node_config.json`, JSON.stringify(aztecNodeConfig)); } diff --git a/yarn-project/end-to-end/src/fixtures/utils.ts b/yarn-project/end-to-end/src/fixtures/utils.ts index 93439822fec..725d35a6ab6 100644 --- a/yarn-project/end-to-end/src/fixtures/utils.ts +++ b/yarn-project/end-to-end/src/fixtures/utils.ts @@ -153,7 +153,6 @@ async function initGasBridge({ walletClient, l1ContractAddresses }: DeployL1Cont /** * Sets up Private eXecution Environment (PXE). - * @param numberOfAccounts - The number of new accounts to be created once the PXE is initiated. * @param aztecNode - An instance of Aztec Node. * @param opts - Partial configuration for the PXE service. * @param firstPrivKey - The private key of the first account to be created. @@ -163,7 +162,6 @@ async function initGasBridge({ walletClient, l1ContractAddresses }: DeployL1Cont * @returns Private eXecution Environment (PXE), accounts, wallets and logger. */ export async function setupPXEService( - numberOfAccounts: number, aztecNode: AztecNode, opts: Partial = {}, logger = getLogger(), @@ -174,10 +172,6 @@ export async function setupPXEService( * The PXE instance. */ pxe: PXEService; - /** - * The wallets to be used. - */ - wallets: AccountWalletWithSecretKey[]; /** * Logger instance named as the current test. */ @@ -190,15 +184,12 @@ export async function setupPXEService( const pxeServiceConfig = { ...getPXEServiceConfig(), ...opts }; const pxe = await createPXEService(aztecNode, pxeServiceConfig, useLogSuffix, proofCreator); - const wallets = await createAccounts(pxe, numberOfAccounts); - const teardown = async () => { await pxe.stop(); }; return { pxe, - wallets, logger, teardown, }; @@ -230,14 +221,6 @@ async function setupWithRemoteEnvironment( logger.verbose('JSON RPC client connected to PXE'); logger.verbose(`Retrieving contract addresses from ${PXE_URL}`); const l1Contracts = (await pxeClient.getNodeInfo()).l1ContractAddresses; - logger.verbose('PXE created, constructing available wallets from already registered accounts...'); - const wallets = await getDeployedTestAccountsWallets(pxeClient); - - if (wallets.length < numberOfAccounts) { - const numNewAccounts = numberOfAccounts - wallets.length; - logger.verbose(`Deploying ${numNewAccounts} accounts...`); - wallets.push(...(await createAccounts(pxeClient, numNewAccounts))); - } const walletClient = createWalletClient({ account, @@ -269,6 +252,15 @@ async function setupWithRemoteEnvironment( ); } + logger.verbose('Constructing available wallets from already registered accounts...'); + const wallets = await getDeployedTestAccountsWallets(pxeClient); + + if (wallets.length < numberOfAccounts) { + const numNewAccounts = numberOfAccounts - wallets.length; + logger.verbose(`Deploying ${numNewAccounts} accounts...`); + wallets.push(...(await createAccounts(pxeClient, numNewAccounts))); + } + return { aztecNode, sequencer: undefined, @@ -400,7 +392,8 @@ export async function setup( const prover = aztecNode.getProver(); logger.verbose('Creating a pxe...'); - const { pxe, wallets } = await setupPXEService(numberOfAccounts, aztecNode!, pxeOpts, logger); + + const { pxe } = await setupPXEService(aztecNode!, pxeOpts, logger); logger.verbose('Deploying key registry...'); await deployCanonicalKeyRegistry( @@ -414,6 +407,7 @@ export async function setup( ); } + const wallets = await createAccounts(pxe, numberOfAccounts); const cheatCodes = CheatCodes.create(config.rpcUrl, pxe!); const teardown = async () => { @@ -616,7 +610,7 @@ export async function deployCanonicalGasToken(deployer: Wallet) { await expect(deployer.isContractPubliclyDeployed(gasToken.address)).resolves.toBe(true); } -async function deployCanonicalKeyRegistry(deployer: Wallet) { +export async function deployCanonicalKeyRegistry(deployer: Wallet) { const canonicalKeyRegistry = getCanonicalKeyRegistry(); // We check to see if there exists a contract at the canonical Key Registry address with the same contract class id as we expect. This means that diff --git a/yarn-project/key-store/src/test_key_store.test.ts b/yarn-project/key-store/src/test_key_store.test.ts index 61647e0097c..2395dbf1472 100644 --- a/yarn-project/key-store/src/test_key_store.test.ts +++ b/yarn-project/key-store/src/test_key_store.test.ts @@ -11,7 +11,7 @@ describe('TestKeyStore', () => { const sk = new Fr(8923n); const partialAddress = new Fr(243523n); - const accountAddress = await keyStore.addAccount(sk, partialAddress); + const { address: accountAddress } = await keyStore.addAccount(sk, partialAddress); expect(accountAddress.toString()).toMatchInlineSnapshot( `"0x1a8a9a1d91cbb353d8df4f1bbfd0283f7fc63766f671edd9443a1270a7b2a954"`, ); diff --git a/yarn-project/key-store/src/test_key_store.ts b/yarn-project/key-store/src/test_key_store.ts index a3c0d4b239b..a21763ebf83 100644 --- a/yarn-project/key-store/src/test_key_store.ts +++ b/yarn-project/key-store/src/test_key_store.ts @@ -1,12 +1,14 @@ import { type KeyStore, type PublicKey } from '@aztec/circuit-types'; import { AztecAddress, + CompleteAddress, Fr, GeneratorIndex, type GrumpkinPrivateKey, GrumpkinScalar, type PartialAddress, Point, + computeAddress, computeAppNullifierSecretKey, deriveKeys, } from '@aztec/circuits.js'; @@ -26,9 +28,9 @@ export class TestKeyStore implements KeyStore { /** * Creates a new account from a randomly generated secret key. - * @returns A promise that resolves to the newly created account's AztecAddress. + * @returns A promise that resolves to the newly created account's CompleteAddress. */ - public createAccount(): Promise { + public createAccount(): Promise { const sk = Fr.random(); const partialAddress = Fr.random(); return this.addAccount(sk, partialAddress); @@ -38,9 +40,9 @@ export class TestKeyStore implements KeyStore { * Adds an account to the key store from the provided secret key. * @param sk - The secret key of the account. * @param partialAddress - The partial address of the account. - * @returns The account's address. + * @returns The account's complete address. */ - public async addAccount(sk: Fr, partialAddress: PartialAddress): Promise { + public async addAccount(sk: Fr, partialAddress: PartialAddress): Promise { const { publicKeysHash, masterNullifierSecretKey, @@ -53,10 +55,7 @@ export class TestKeyStore implements KeyStore { masterTaggingPublicKey, } = deriveKeys(sk); - // We hash the partial address and the public keys hash to get the account address - // TODO(#5726): Move the following line to AztecAddress class? - const accountAddressFr = poseidon2Hash([publicKeysHash, partialAddress, GeneratorIndex.CONTRACT_ADDRESS_V1]); - const accountAddress = AztecAddress.fromField(accountAddressFr); + const accountAddress = computeAddress(publicKeysHash, partialAddress); // We save the keys to db await this.#keys.set(`${accountAddress.toString()}-public_keys_hash`, publicKeysHash.toBuffer()); @@ -72,7 +71,16 @@ export class TestKeyStore implements KeyStore { await this.#keys.set(`${accountAddress.toString()}-tpk_m`, masterTaggingPublicKey.toBuffer()); // At last, we return the newly derived account address - return Promise.resolve(accountAddress); + return Promise.resolve( + CompleteAddress.create( + accountAddress, + masterNullifierPublicKey, + masterIncomingViewingPublicKey, + masterOutgoingViewingPublicKey, + masterTaggingPublicKey, + partialAddress, + ), + ); } /** @@ -292,18 +300,4 @@ export class TestKeyStore implements KeyStore { } return Promise.resolve(Fr.fromBuffer(publicKeysHashBuffer)); } - - // TODO(#5834): Re-add separation between recipients and accounts in keystore. - public async addPublicKeysForAccount( - accountAddress: AztecAddress, - masterNullifierPublicKey: Point, - masterIncomingViewingPublicKey: Point, - masterOutgoingViewingPublicKey: Point, - masterTaggingPublicKey: Point, - ): Promise { - await this.#keys.set(`${accountAddress.toString()}-npk_m`, masterNullifierPublicKey.toBuffer()); - await this.#keys.set(`${accountAddress.toString()}-ivpk_m`, masterIncomingViewingPublicKey.toBuffer()); - await this.#keys.set(`${accountAddress.toString()}-ovpk_m`, masterOutgoingViewingPublicKey.toBuffer()); - await this.#keys.set(`${accountAddress.toString()}-tpk_m`, masterTaggingPublicKey.toBuffer()); - } } diff --git a/yarn-project/pxe/src/database/kv_pxe_database.ts b/yarn-project/pxe/src/database/kv_pxe_database.ts index c07a29219de..7452f98f5b9 100644 --- a/yarn-project/pxe/src/database/kv_pxe_database.ts +++ b/yarn-project/pxe/src/database/kv_pxe_database.ts @@ -209,7 +209,7 @@ export class KVPxeDatabase implements PxeDatabase { #getNotes(filter: NoteFilter): NoteDao[] { const publicKey: PublicKey | undefined = filter.owner - ? this.#getCompleteAddress(filter.owner)?.publicKey + ? this.#getCompleteAddress(filter.owner)?.masterIncomingViewingPublicKey : undefined; filter.status = filter.status ?? NoteStatus.ACTIVE; diff --git a/yarn-project/pxe/src/database/pxe_database_test_suite.ts b/yarn-project/pxe/src/database/pxe_database_test_suite.ts index 9a1fbbe46f8..440df3db400 100644 --- a/yarn-project/pxe/src/database/pxe_database_test_suite.ts +++ b/yarn-project/pxe/src/database/pxe_database_test_suite.ts @@ -92,7 +92,10 @@ export function describePxeDatabase(getDatabase: () => PxeDatabase) { [() => ({ txHash: notes[0].txHash }), () => [notes[0]]], [() => ({ txHash: randomTxHash() }), () => []], - [() => ({ owner: owners[0].address }), () => notes.filter(note => note.publicKey.equals(owners[0].publicKey))], + [ + () => ({ owner: owners[0].address }), + () => notes.filter(note => note.publicKey.equals(owners[0].masterIncomingViewingPublicKey)), + ], [ () => ({ contractAddress: contractAddresses[0], storageSlot: storageSlots[0] }), @@ -113,7 +116,7 @@ export function describePxeDatabase(getDatabase: () => PxeDatabase) { randomNoteDao({ contractAddress: contractAddresses[i % contractAddresses.length], storageSlot: storageSlots[i % storageSlots.length], - publicKey: owners[i % owners.length].publicKey, + publicKey: owners[i % owners.length].masterIncomingViewingPublicKey, index: BigInt(i), }), ); @@ -142,9 +145,11 @@ export function describePxeDatabase(getDatabase: () => PxeDatabase) { // Nullify all notes and use the same filter as other test cases for (const owner of owners) { - const notesToNullify = notes.filter(note => note.publicKey.equals(owner.publicKey)); + const notesToNullify = notes.filter(note => note.publicKey.equals(owner.masterIncomingViewingPublicKey)); const nullifiers = notesToNullify.map(note => note.siloedNullifier); - await expect(database.removeNullifiedNotes(nullifiers, owner.publicKey)).resolves.toEqual(notesToNullify); + await expect( + database.removeNullifiedNotes(nullifiers, owner.masterIncomingViewingPublicKey), + ).resolves.toEqual(notesToNullify); } await expect(database.getNotes({ ...getFilter(), status: NoteStatus.ACTIVE_OR_NULLIFIED })).resolves.toEqual( @@ -155,7 +160,7 @@ export function describePxeDatabase(getDatabase: () => PxeDatabase) { it('skips nullified notes by default or when requesting active', async () => { await database.addNotes(notes); - const notesToNullify = notes.filter(note => note.publicKey.equals(owners[0].publicKey)); + const notesToNullify = notes.filter(note => note.publicKey.equals(owners[0].masterIncomingViewingPublicKey)); const nullifiers = notesToNullify.map(note => note.siloedNullifier); await expect(database.removeNullifiedNotes(nullifiers, notesToNullify[0].publicKey)).resolves.toEqual( notesToNullify, @@ -171,7 +176,7 @@ export function describePxeDatabase(getDatabase: () => PxeDatabase) { it('returns active and nullified notes when requesting either', async () => { await database.addNotes(notes); - const notesToNullify = notes.filter(note => note.publicKey.equals(owners[0].publicKey)); + const notesToNullify = notes.filter(note => note.publicKey.equals(owners[0].masterIncomingViewingPublicKey)); const nullifiers = notesToNullify.map(note => note.siloedNullifier); await expect(database.removeNullifiedNotes(nullifiers, notesToNullify[0].publicKey)).resolves.toEqual( notesToNullify, @@ -215,7 +220,14 @@ export function describePxeDatabase(getDatabase: () => PxeDatabase) { it.skip('refuses to overwrite an address with a different public key', async () => { const address = CompleteAddress.random(); - const otherAddress = new CompleteAddress(address.address, Point.random(), address.partialAddress); + const otherAddress = new CompleteAddress( + address.address, + Point.random(), + Point.random(), + Point.random(), + Point.random(), + address.partialAddress, + ); await database.addCompleteAddress(address); await expect(database.addCompleteAddress(otherAddress)).rejects.toThrow(); diff --git a/yarn-project/pxe/src/pxe_service/pxe_service.ts b/yarn-project/pxe/src/pxe_service/pxe_service.ts index 3ed4fa30cf4..c918ffaafa9 100644 --- a/yarn-project/pxe/src/pxe_service/pxe_service.ts +++ b/yarn-project/pxe/src/pxe_service/pxe_service.ts @@ -25,7 +25,7 @@ import { type TxPXEProcessingStats } from '@aztec/circuit-types/stats'; import { AztecAddress, CallRequest, - CompleteAddress, + type CompleteAddress, FunctionData, MAX_PUBLIC_CALL_STACK_LENGTH_PER_TX, type PartialAddress, @@ -37,7 +37,7 @@ import { import { computeNoteHashNonce, siloNullifier } from '@aztec/circuits.js/hash'; import { type ContractArtifact, type DecodedReturn, FunctionSelector, encodeArguments } from '@aztec/foundation/abi'; import { arrayNonEmptyLength, padArrayEnd } from '@aztec/foundation/collection'; -import { Fr, type Point } from '@aztec/foundation/fields'; +import { Fr } from '@aztec/foundation/fields'; import { SerialQueue } from '@aztec/foundation/fifo'; import { type DebugLogger, createDebugLogger } from '@aztec/foundation/log'; import { Timer } from '@aztec/foundation/timer'; @@ -115,12 +115,12 @@ export class PXEService implements PXE { let count = 0; for (const address of registeredAddresses) { - if (!publicKeysSet.has(address.publicKey.toString())) { + if (!publicKeysSet.has(address.masterIncomingViewingPublicKey.toString())) { continue; } count++; - this.synchronizer.addAccount(address.publicKey, this.keyStore, this.config.l2StartingBlock); + this.synchronizer.addAccount(address.masterIncomingViewingPublicKey, this.keyStore, this.config.l2StartingBlock); } if (count > 0) { @@ -170,24 +170,21 @@ export class PXEService implements PXE { public async registerAccount(secretKey: Fr, partialAddress: PartialAddress): Promise { const accounts = await this.keyStore.getAccounts(); - const account = await this.keyStore.addAccount(secretKey, partialAddress); - const completeAddress = new CompleteAddress( - account, - await this.keyStore.getMasterIncomingViewingPublicKey(account), - partialAddress, - ); - if (accounts.includes(account)) { - this.log.info(`Account:\n "${completeAddress.address.toString()}"\n already registered.`); - return completeAddress; + const accountCompleteAddress = await this.keyStore.addAccount(secretKey, partialAddress); + if (accounts.includes(accountCompleteAddress.address)) { + this.log.info(`Account:\n "${accountCompleteAddress.address.toString()}"\n already registered.`); + return accountCompleteAddress; } else { - const masterIncomingViewingPublicKey = await this.keyStore.getMasterIncomingViewingPublicKey(account); + const masterIncomingViewingPublicKey = await this.keyStore.getMasterIncomingViewingPublicKey( + accountCompleteAddress.address, + ); this.synchronizer.addAccount(masterIncomingViewingPublicKey, this.keyStore, this.config.l2StartingBlock); - this.log.info(`Registered account ${completeAddress.address.toString()}`); - this.log.debug(`Registered account\n ${completeAddress.toReadableString()}`); + this.log.info(`Registered account ${accountCompleteAddress.address.toString()}`); + this.log.debug(`Registered account\n ${accountCompleteAddress.toReadableString()}`); } - await this.db.addCompleteAddress(completeAddress); - return completeAddress; + await this.db.addCompleteAddress(accountCompleteAddress); + return accountCompleteAddress; } public async getRegisteredAccounts(): Promise { @@ -214,20 +211,9 @@ export class PXEService implements PXE { return this.keyStore.getPublicKeysHash(address); } - public async registerRecipient(recipient: CompleteAddress, publicKeys: Point[] = []): Promise { + public async registerRecipient(recipient: CompleteAddress): Promise { const wasAdded = await this.db.addCompleteAddress(recipient); - // TODO #5834: This should be refactored to be okay with only adding complete address - if (publicKeys.length !== 0) { - await this.keyStore.addPublicKeysForAccount( - recipient.address, - publicKeys[0], - publicKeys[1], - publicKeys[2], - publicKeys[3], - ); - } - if (wasAdded) { this.log.info(`Added recipient:\n ${recipient.toReadableString()}`); } else { @@ -306,7 +292,7 @@ export class PXEService implements PXE { let owner = filter.owner; if (owner === undefined) { const completeAddresses = (await this.db.getCompleteAddresses()).find(address => - address.publicKey.equals(dao.publicKey), + address.masterIncomingViewingPublicKey.equals(dao.publicKey), ); if (completeAddresses === undefined) { throw new Error(`Cannot find complete address for public key ${dao.publicKey.toString()}`); @@ -319,8 +305,8 @@ export class PXEService implements PXE { } public async addNote(note: ExtendedNote) { - const { publicKey } = (await this.db.getCompleteAddress(note.owner)) ?? {}; - if (!publicKey) { + const { masterIncomingViewingPublicKey } = (await this.db.getCompleteAddress(note.owner)) ?? {}; + if (!masterIncomingViewingPublicKey) { throw new Error('Unknown account.'); } @@ -360,7 +346,7 @@ export class PXEService implements PXE { innerNoteHash, siloedNullifier, index, - publicKey, + masterIncomingViewingPublicKey, ), ); } diff --git a/yarn-project/pxe/src/pxe_service/test/pxe_test_suite.ts b/yarn-project/pxe/src/pxe_service/test/pxe_test_suite.ts index 75d2f11b142..cf9e7d6c4c4 100644 --- a/yarn-project/pxe/src/pxe_service/test/pxe_test_suite.ts +++ b/yarn-project/pxe/src/pxe_service/test/pxe_test_suite.ts @@ -70,7 +70,14 @@ export const pxeTestSuite = (testName: string, pxeSetup: () => Promise) => it('cannot register a recipient with the same aztec address but different pub key or partial address', async () => { const recipient1 = CompleteAddress.random(); - const recipient2 = new CompleteAddress(recipient1.address, Point.random(), Fr.random()); + const recipient2 = new CompleteAddress( + recipient1.address, + Point.random(), + Point.random(), + Point.random(), + Point.random(), + Fr.random(), + ); await pxe.registerRecipient(recipient1); await expect(() => pxe.registerRecipient(recipient2)).rejects.toThrow( diff --git a/yarn-project/pxe/src/simulator_oracle/index.ts b/yarn-project/pxe/src/simulator_oracle/index.ts index 12e540148b7..ac9ee3566ef 100644 --- a/yarn-project/pxe/src/simulator_oracle/index.ts +++ b/yarn-project/pxe/src/simulator_oracle/index.ts @@ -15,7 +15,6 @@ import { type FunctionSelector, type Header, type L1_TO_L2_MSG_TREE_HEIGHT, - type Point, } from '@aztec/circuits.js'; import { computeL1ToL2MessageNullifier } from '@aztec/circuits.js/hash'; import { type FunctionArtifact, getFunctionArtifact } from '@aztec/foundation/abi'; @@ -44,7 +43,6 @@ export class SimulatorOracle implements DBOracle { return { masterNullifierPublicKey, appNullifierSecretKey }; } - // TODO: #5834 async getCompleteAddress(address: AztecAddress): Promise { const completeAddress = await this.db.getCompleteAddress(address); if (!completeAddress) { @@ -79,16 +77,6 @@ export class SimulatorOracle implements DBOracle { return capsule; } - // TODO: #5834 - async getPublicKeysForAddress(address: AztecAddress): Promise { - const nullifierPublicKey = await this.keyStore.getMasterNullifierPublicKey(address); - const incomingViewingPublicKey = await this.keyStore.getMasterIncomingViewingPublicKey(address); - const outgoingViewingPublicKey = await this.keyStore.getMasterOutgoingViewingPublicKey(address); - const taggingPublicKey = await this.keyStore.getMasterTaggingPublicKey(address); - - return [nullifierPublicKey, incomingViewingPublicKey, outgoingViewingPublicKey, taggingPublicKey]; - } - async getNotes(contractAddress: AztecAddress, storageSlot: Fr, status: NoteStatus) { const noteDaos = await this.db.getNotes({ contractAddress, diff --git a/yarn-project/pxe/src/synchronizer/synchronizer.test.ts b/yarn-project/pxe/src/synchronizer/synchronizer.test.ts index f8deb8b8ca3..1c145eb3302 100644 --- a/yarn-project/pxe/src/synchronizer/synchronizer.test.ts +++ b/yarn-project/pxe/src/synchronizer/synchronizer.test.ts @@ -1,5 +1,5 @@ import { type AztecNode, L2Block } from '@aztec/circuit-types'; -import { CompleteAddress, Fr, type Header, INITIAL_L2_BLOCK_NUM } from '@aztec/circuits.js'; +import { Fr, type Header, INITIAL_L2_BLOCK_NUM } from '@aztec/circuits.js'; import { makeHeader } from '@aztec/circuits.js/testing'; import { randomInt } from '@aztec/foundation/crypto'; import { SerialQueue } from '@aztec/foundation/fifo'; @@ -130,12 +130,9 @@ describe('Synchronizer', () => { const addAddress = async (startingBlockNum: number) => { const secretKey = Fr.random(); const partialAddress = Fr.random(); - const accountAddress = await keyStore.addAccount(secretKey, partialAddress); - const masterIncomingViewingPublicKey = await keyStore.getMasterIncomingViewingPublicKey(accountAddress); - - const completeAddress = new CompleteAddress(accountAddress, masterIncomingViewingPublicKey, partialAddress); + const completeAddress = await keyStore.addAccount(secretKey, partialAddress); await database.addCompleteAddress(completeAddress); - synchronizer.addAccount(completeAddress.publicKey, keyStore, startingBlockNum); + synchronizer.addAccount(completeAddress.masterIncomingViewingPublicKey, keyStore, startingBlockNum); return completeAddress; }; diff --git a/yarn-project/pxe/src/synchronizer/synchronizer.ts b/yarn-project/pxe/src/synchronizer/synchronizer.ts index dc7f1890877..d7da26c991e 100644 --- a/yarn-project/pxe/src/synchronizer/synchronizer.ts +++ b/yarn-project/pxe/src/synchronizer/synchronizer.ts @@ -285,7 +285,8 @@ export class Synchronizer { if (!completeAddress) { throw new Error(`Checking if account is synched is not possible for ${account} because it is not registered.`); } - const findByPublicKey = (x: NoteProcessor) => x.masterIncomingViewingPublicKey.equals(completeAddress.publicKey); + const findByPublicKey = (x: NoteProcessor) => + x.masterIncomingViewingPublicKey.equals(completeAddress.masterIncomingViewingPublicKey); const processor = this.noteProcessors.find(findByPublicKey) ?? this.noteProcessorsToCatchUp.find(findByPublicKey); if (!processor) { throw new Error( diff --git a/yarn-project/simulator/src/acvm/oracle/oracle.ts b/yarn-project/simulator/src/acvm/oracle/oracle.ts index 415f8c3e84e..7df1704f427 100644 --- a/yarn-project/simulator/src/acvm/oracle/oracle.ts +++ b/yarn-project/simulator/src/acvm/oracle/oracle.ts @@ -1,5 +1,5 @@ import { MerkleTreeId, UnencryptedL2Log } from '@aztec/circuit-types'; -import { type PartialAddress, acvmFieldMessageToString, oracleDebugCallToFormattedStr } from '@aztec/circuits.js'; +import { acvmFieldMessageToString, oracleDebugCallToFormattedStr } from '@aztec/circuits.js'; import { EventSelector, FunctionSelector } from '@aztec/foundation/abi'; import { AztecAddress } from '@aztec/foundation/aztec-address'; import { Fr, Point } from '@aztec/foundation/fields'; @@ -53,14 +53,6 @@ export class Oracle { ]; } - // TODO: #5834 Nuke this - async getPublicKeyAndPartialAddress([address]: ACVMField[]) { - const { publicKey, partialAddress } = await this.typedOracle.getCompleteAddress( - AztecAddress.fromField(fromACVMField(address)), - ); - return [publicKey.x, publicKey.y, partialAddress].map(toACVMField); - } - async getContractInstance([address]: ACVMField[]) { const instance = await this.typedOracle.getContractInstance(AztecAddress.fromField(fromACVMField(address))); @@ -173,25 +165,22 @@ export class Oracle { } async getPublicKeysAndPartialAddress([address]: ACVMField[]): Promise { - let publicKeys: Point[] | undefined; - let partialAddress: PartialAddress; - - // TODO #5834: This should be reworked to return the public keys as well - try { - ({ partialAddress } = await this.typedOracle.getCompleteAddress(AztecAddress.fromField(fromACVMField(address)))); - } catch (err) { - partialAddress = Fr.ZERO; - } + const parsedAddress = AztecAddress.fromField(fromACVMField(address)); + const { + masterNullifierPublicKey, + masterIncomingViewingPublicKey, + masterOutgoingViewingPublicKey, + masterTaggingPublicKey, + partialAddress, + } = await this.typedOracle.getCompleteAddress(parsedAddress); - try { - publicKeys = await this.typedOracle.getPublicKeysForAddress(AztecAddress.fromField(fromACVMField(address))); - } catch (err) { - publicKeys = Array(4).fill(Point.ZERO); - } - - const acvmPublicKeys = publicKeys.flatMap(key => key.toFields()); - - return [...acvmPublicKeys, partialAddress].map(toACVMField); + return [ + ...masterNullifierPublicKey.toFields(), + ...masterIncomingViewingPublicKey.toFields(), + ...masterOutgoingViewingPublicKey.toFields(), + ...masterTaggingPublicKey.toFields(), + partialAddress, + ].map(toACVMField); } async getNotes( diff --git a/yarn-project/simulator/src/acvm/oracle/typed_oracle.ts b/yarn-project/simulator/src/acvm/oracle/typed_oracle.ts index 171ccb4d757..231d8cd99d1 100644 --- a/yarn-project/simulator/src/acvm/oracle/typed_oracle.ts +++ b/yarn-project/simulator/src/acvm/oracle/typed_oracle.ts @@ -17,7 +17,7 @@ import { } from '@aztec/circuits.js'; import { type FunctionSelector } from '@aztec/foundation/abi'; import { type AztecAddress } from '@aztec/foundation/aztec-address'; -import { Fr, type Point } from '@aztec/foundation/fields'; +import { Fr } from '@aztec/foundation/fields'; import { type ContractInstance } from '@aztec/types/contracts'; /** Nullifier keys which both correspond to the same master nullifier secret key. */ @@ -93,10 +93,6 @@ export abstract class TypedOracle { throw new OracleMethodNotAvailableError('getNullifierKeys'); } - getPublicKeyAndPartialAddress(_address: AztecAddress): Promise { - throw new OracleMethodNotAvailableError('getPublicKeyAndPartialAddress'); - } - getContractInstance(_address: AztecAddress): Promise { throw new OracleMethodNotAvailableError('getContractInstance'); } @@ -140,10 +136,6 @@ export abstract class TypedOracle { throw new OracleMethodNotAvailableError('popCapsule'); } - getPublicKeysForAddress(_address: AztecAddress): Promise { - throw new OracleMethodNotAvailableError('getPublicKeysForAddress'); - } - getNotes( _storageSlot: Fr, _numSelects: number, diff --git a/yarn-project/simulator/src/avm/avm_execution_environment.test.ts b/yarn-project/simulator/src/avm/avm_execution_environment.test.ts index 6aad1e2ea25..68bde3962fb 100644 --- a/yarn-project/simulator/src/avm/avm_execution_environment.test.ts +++ b/yarn-project/simulator/src/avm/avm_execution_environment.test.ts @@ -1,3 +1,4 @@ +import { FunctionSelector } from '@aztec/circuits.js'; import { Fr } from '@aztec/foundation/fields'; import { allSameExcept, anyAvmContextInputs, initExecutionEnvironment } from './fixtures/index.js'; @@ -5,10 +6,11 @@ import { allSameExcept, anyAvmContextInputs, initExecutionEnvironment } from './ describe('Execution Environment', () => { const newAddress = new Fr(123456n); const calldata = [new Fr(1n), new Fr(2n), new Fr(3n)]; + const selector = FunctionSelector.empty(); it('New call should fork execution environment correctly', () => { const executionEnvironment = initExecutionEnvironment(); - const newExecutionEnvironment = executionEnvironment.deriveEnvironmentForNestedCall(newAddress, calldata); + const newExecutionEnvironment = executionEnvironment.deriveEnvironmentForNestedCall(newAddress, calldata, selector); expect(newExecutionEnvironment).toEqual( allSameExcept(executionEnvironment, { @@ -20,9 +22,10 @@ describe('Execution Environment', () => { ); }); - it('New delegate call should fork execution environment correctly', () => { + // Delegate calls not supported. + it.skip('New delegate call should fork execution environment correctly', () => { const executionEnvironment = initExecutionEnvironment(); - const newExecutionEnvironment = executionEnvironment.newDelegateCall(newAddress, calldata); + const newExecutionEnvironment = executionEnvironment.newDelegateCall(newAddress, calldata, selector); expect(newExecutionEnvironment).toEqual( allSameExcept(executionEnvironment, { @@ -36,7 +39,11 @@ describe('Execution Environment', () => { it('New static call call should fork execution environment correctly', () => { const executionEnvironment = initExecutionEnvironment(); - const newExecutionEnvironment = executionEnvironment.deriveEnvironmentForNestedStaticCall(newAddress, calldata); + const newExecutionEnvironment = executionEnvironment.deriveEnvironmentForNestedStaticCall( + newAddress, + calldata, + selector, + ); expect(newExecutionEnvironment).toEqual( allSameExcept(executionEnvironment, { diff --git a/yarn-project/simulator/src/avm/avm_execution_environment.ts b/yarn-project/simulator/src/avm/avm_execution_environment.ts index 4a38171103d..e57f94eecda 100644 --- a/yarn-project/simulator/src/avm/avm_execution_environment.ts +++ b/yarn-project/simulator/src/avm/avm_execution_environment.ts @@ -45,72 +45,64 @@ export class AvmExecutionEnvironment { this.calldata = [...inputs.toFields(), ...calldata]; } - public deriveEnvironmentForNestedCall( + private deriveEnvironmentForNestedCallInternal( targetAddress: AztecAddress, calldata: Fr[], - temporaryFunctionSelector: FunctionSelector = FunctionSelector.empty(), - ): AvmExecutionEnvironment { + functionSelector: FunctionSelector, + isStaticCall: boolean, + isDelegateCall: boolean, + ) { return new AvmExecutionEnvironment( - targetAddress, + /*address=*/ targetAddress, /*storageAddress=*/ targetAddress, - this.address, + /*sender=*/ this.address, this.feePerL2Gas, this.feePerDaGas, this.contractCallDepth, this.header, this.globals, - this.isStaticCall, - this.isDelegateCall, + isStaticCall, + isDelegateCall, calldata, this.gasSettings, this.transactionFee, - temporaryFunctionSelector, + functionSelector, ); } - public deriveEnvironmentForNestedStaticCall( - address: AztecAddress, + public deriveEnvironmentForNestedCall( + targetAddress: AztecAddress, calldata: Fr[], - temporaryFunctionSelector: FunctionSelector = FunctionSelector.empty(), + functionSelector: FunctionSelector = FunctionSelector.empty(), ): AvmExecutionEnvironment { - return new AvmExecutionEnvironment( - address, - /*storageAddress=*/ address, - this.sender, - this.feePerL2Gas, - this.feePerDaGas, - this.contractCallDepth, - this.header, - this.globals, - /*isStaticCall=*/ true, - this.isDelegateCall, + return this.deriveEnvironmentForNestedCallInternal( + targetAddress, calldata, - this.gasSettings, - this.transactionFee, - temporaryFunctionSelector, + functionSelector, + /*isStaticCall=*/ false, + /*isDelegateCall=*/ false, ); } - public newDelegateCall( - address: AztecAddress, + public deriveEnvironmentForNestedStaticCall( + targetAddress: AztecAddress, calldata: Fr[], - temporaryFunctionSelector: FunctionSelector = FunctionSelector.empty(), + functionSelector: FunctionSelector, ): AvmExecutionEnvironment { - return new AvmExecutionEnvironment( - address, - this.storageAddress, - this.sender, - this.feePerL2Gas, - this.feePerDaGas, - this.contractCallDepth, - this.header, - this.globals, - this.isStaticCall, - /*isDelegateCall=*/ true, + return this.deriveEnvironmentForNestedCallInternal( + targetAddress, calldata, - this.gasSettings, - this.transactionFee, - temporaryFunctionSelector, + functionSelector, + /*isStaticCall=*/ true, + /*isDelegateCall=*/ false, ); } + + public newDelegateCall( + _targetAddress: AztecAddress, + _calldata: Fr[], + _functionSelector: FunctionSelector, + ): AvmExecutionEnvironment { + throw new Error('Delegate calls not supported!'); + } } diff --git a/yarn-project/simulator/src/avm/avm_simulator.test.ts b/yarn-project/simulator/src/avm/avm_simulator.test.ts index a18b4c05e43..80ea7a60e29 100644 --- a/yarn-project/simulator/src/avm/avm_simulator.test.ts +++ b/yarn-project/simulator/src/avm/avm_simulator.test.ts @@ -879,6 +879,20 @@ describe('AVM simulator: transpiled Noir contracts', () => { expect(results.revertReason?.message).toEqual('Assertion failed: Values are not equal'); }); }); + + it('conversions', async () => { + const calldata: Fr[] = [new Fr(0b1011101010100)]; + const context = initContext({ env: initExecutionEnvironment({ calldata }) }); + + const bytecode = getAvmTestContractBytecode('to_radix_le'); + const results = await new AvmSimulator(context).executeBytecode(bytecode); + + expect(results.reverted).toBe(false); + const expectedResults = Buffer.concat('0010101011'.split('').map(c => new Fr(Number(c)).toBuffer())); + const resultBuffer = Buffer.concat(results.output.map(f => f.toBuffer())); + + expect(resultBuffer.equals(expectedResults)).toBe(true); + }); }); function getAvmTestContractBytecode(functionName: string): Buffer { diff --git a/yarn-project/simulator/src/avm/opcodes/conversion.ts b/yarn-project/simulator/src/avm/opcodes/conversion.ts index dc9884d9aab..e07165b8ec6 100644 --- a/yarn-project/simulator/src/avm/opcodes/conversion.ts +++ b/yarn-project/simulator/src/avm/opcodes/conversion.ts @@ -1,4 +1,5 @@ -import { assert } from '../../../../foundation/src/json-rpc/js_utils.js'; +import { strict as assert } from 'assert'; + import { type AvmContext } from '../avm_context.js'; import { TypeTag, Uint8 } from '../avm_memory_types.js'; import { Opcode, OperandType } from '../serialization/instruction_serialization.js'; diff --git a/yarn-project/simulator/src/avm/serialization/bytecode_serialization.ts b/yarn-project/simulator/src/avm/serialization/bytecode_serialization.ts index 0d2866ee965..c6158e6bd76 100644 --- a/yarn-project/simulator/src/avm/serialization/bytecode_serialization.ts +++ b/yarn-project/simulator/src/avm/serialization/bytecode_serialization.ts @@ -1,4 +1,5 @@ import { DAGasLeft, L2GasLeft } from '../opcodes/context_getters.js'; +import { ToRadixLE } from '../opcodes/conversion.js'; import { Keccak, Pedersen, Poseidon2, Sha256 } from '../opcodes/hashing.js'; import type { Instruction } from '../opcodes/index.js'; import { @@ -136,6 +137,8 @@ const INSTRUCTION_SET = () => [Poseidon2.opcode, Poseidon2], [Sha256.opcode, Sha256], [Pedersen.opcode, Pedersen], + // Conversions + [ToRadixLE.opcode, ToRadixLE], ]); interface Serializable { diff --git a/yarn-project/simulator/src/client/db_oracle.ts b/yarn-project/simulator/src/client/db_oracle.ts index a7e78619eb1..0fbdbd64364 100644 --- a/yarn-project/simulator/src/client/db_oracle.ts +++ b/yarn-project/simulator/src/client/db_oracle.ts @@ -8,7 +8,7 @@ import { import { type CompleteAddress, type Header } from '@aztec/circuits.js'; import { type FunctionArtifact, type FunctionSelector } from '@aztec/foundation/abi'; import { type AztecAddress } from '@aztec/foundation/aztec-address'; -import { type Fr, type Point } from '@aztec/foundation/fields'; +import { type Fr } from '@aztec/foundation/fields'; import { type ContractInstance } from '@aztec/types/contracts'; import { type NoteData, type NullifierKeys } from '../acvm/index.js'; @@ -64,14 +64,6 @@ export interface DBOracle extends CommitmentsDB { */ popCapsule(): Promise; - /** - * Gets public keys for an address. - * @param The address to look up - * @returns The public keys for a specific address - * TODO(#5834): Replace with `getCompleteAddress`. - */ - getPublicKeysForAddress(address: AztecAddress): Promise; - /** * Retrieve nullifier keys associated with a specific account and app/contract address. * diff --git a/yarn-project/simulator/src/client/private_execution.test.ts b/yarn-project/simulator/src/client/private_execution.test.ts index 1037f15109e..cdcc07bd5a5 100644 --- a/yarn-project/simulator/src/client/private_execution.test.ts +++ b/yarn-project/simulator/src/client/private_execution.test.ts @@ -4,6 +4,8 @@ import { type L1ToL2Message, Note, PackedValues, + PublicDataWitness, + SiblingPath, TxExecutionRequest, } from '@aztec/circuit-types'; import { @@ -17,9 +19,10 @@ import { Header, L1_TO_L2_MSG_TREE_HEIGHT, NOTE_HASH_TREE_HEIGHT, + PUBLIC_DATA_TREE_HEIGHT, PartialStateReference, PublicCallRequest, - type PublicKey, + PublicDataTreeLeafPreimage, StateReference, TxContext, computeAppNullifierSecretKey, @@ -39,7 +42,7 @@ import { Fr } from '@aztec/foundation/fields'; import { type DebugLogger, createDebugLogger } from '@aztec/foundation/log'; import { type FieldsOf } from '@aztec/foundation/types'; import { openTmpStore } from '@aztec/kv-store/utils'; -import { type AppendOnlyTree, Pedersen, StandardTree, newTree } from '@aztec/merkle-tree'; +import { type AppendOnlyTree, INITIAL_LEAF, Pedersen, StandardTree, newTree } from '@aztec/merkle-tree'; import { ChildContractArtifact, ImportTestContractArtifact, @@ -79,14 +82,13 @@ describe('Private Execution test suite', () => { let ownerCompleteAddress: CompleteAddress; let recipientCompleteAddress: CompleteAddress; - let ownerMasterNullifierPublicKey: PublicKey; - let recipientMasterNullifierPublicKey: PublicKey; let ownerMasterNullifierSecretKey: GrumpkinPrivateKey; let recipientMasterNullifierSecretKey: GrumpkinPrivateKey; const treeHeights: { [name: string]: number } = { noteHash: NOTE_HASH_TREE_HEIGHT, l1ToL2Messages: L1_TO_L2_MSG_TREE_HEIGHT, + publicData: PUBLIC_DATA_TREE_HEIGHT, }; let trees: { [name: keyof typeof treeHeights]: AppendOnlyTree } = {}; @@ -139,7 +141,7 @@ describe('Private Execution test suite', () => { // Create a new snapshot. const newSnap = new AppendOnlyTreeSnapshot(Fr.fromBuffer(tree.getRoot(true)), Number(tree.getNumLeaves(true))); - if (name === 'noteHash' || name === 'l1ToL2Messages') { + if (name === 'noteHash' || name === 'l1ToL2Messages' || 'publicData') { header = new Header( header.lastArchive, header.contentCommitment, @@ -148,7 +150,7 @@ describe('Private Execution test suite', () => { new PartialStateReference( name === 'noteHash' ? newSnap : header.state.partial.noteHashTree, header.state.partial.nullifierTree, - header.state.partial.publicDataTree, + name === 'publicData' ? newSnap : header.state.partial.publicDataTree, ), ), header.globalVariables, @@ -175,42 +177,60 @@ describe('Private Execution test suite', () => { const ownerPartialAddress = Fr.random(); ownerCompleteAddress = CompleteAddress.fromSecretKeyAndPartialAddress(ownerSk, ownerPartialAddress); - - const allOwnerKeys = deriveKeys(ownerSk); - ownerMasterNullifierPublicKey = allOwnerKeys.masterNullifierPublicKey; - ownerMasterNullifierSecretKey = allOwnerKeys.masterNullifierSecretKey; + ownerMasterNullifierSecretKey = deriveKeys(ownerSk).masterNullifierSecretKey; const recipientPartialAddress = Fr.random(); recipientCompleteAddress = CompleteAddress.fromSecretKeyAndPartialAddress(recipientSk, recipientPartialAddress); - - const allRecipientKeys = deriveKeys(recipientSk); - recipientMasterNullifierPublicKey = allRecipientKeys.masterNullifierPublicKey; - recipientMasterNullifierSecretKey = allRecipientKeys.masterNullifierSecretKey; + recipientMasterNullifierSecretKey = deriveKeys(recipientSk).masterNullifierSecretKey; owner = ownerCompleteAddress.address; recipient = recipientCompleteAddress.address; }); - beforeEach(() => { + beforeEach(async () => { trees = {}; oracle = mock(); oracle.getNullifierKeys.mockImplementation((accountAddress: AztecAddress, contractAddress: AztecAddress) => { if (accountAddress.equals(ownerCompleteAddress.address)) { return Promise.resolve({ - masterNullifierPublicKey: ownerMasterNullifierPublicKey, + masterNullifierPublicKey: ownerCompleteAddress.masterNullifierPublicKey, appNullifierSecretKey: computeAppNullifierSecretKey(ownerMasterNullifierSecretKey, contractAddress), }); } if (accountAddress.equals(recipientCompleteAddress.address)) { return Promise.resolve({ - masterNullifierPublicKey: recipientMasterNullifierPublicKey, + masterNullifierPublicKey: recipientCompleteAddress.masterNullifierPublicKey, appNullifierSecretKey: computeAppNullifierSecretKey(recipientMasterNullifierSecretKey, contractAddress), }); } throw new Error(`Unknown address ${accountAddress}`); }); + + // We call insertLeaves here with no leaves to populate empty public data tree root --> this is necessary to be + // able to get ivpk_m during execution + await insertLeaves([], 'publicData'); oracle.getHeader.mockResolvedValue(header); + oracle.getCompleteAddress.mockImplementation((address: AztecAddress) => { + if (address.equals(owner)) { + return Promise.resolve(ownerCompleteAddress); + } + if (address.equals(recipient)) { + return Promise.resolve(recipientCompleteAddress); + } + throw new Error(`Unknown address ${address}`); + }); + // This oracle gets called when reading ivpk_m from key registry --> we return zero witness indicating that + // the keys were not registered. This triggers non-registered keys flow in which getCompleteAddress oracle + // gets called and we constrain the result by hashing address preimage and checking it matches. + oracle.getPublicDataTreeWitness.mockResolvedValue( + new PublicDataWitness( + 0n, + PublicDataTreeLeafPreimage.empty(), + SiblingPath.ZERO(PUBLIC_DATA_TREE_HEIGHT, INITIAL_LEAF, new Pedersen()), + ), + ); + acirSimulator = new AcirSimulator(oracle, node); }); @@ -286,16 +306,6 @@ describe('Private Execution test suite', () => { }; beforeEach(() => { - oracle.getCompleteAddress.mockImplementation((address: AztecAddress) => { - if (address.equals(owner)) { - return Promise.resolve(ownerCompleteAddress); - } - if (address.equals(recipient)) { - return Promise.resolve(recipientCompleteAddress); - } - throw new Error(`Unknown address ${address}`); - }); - oracle.getFunctionArtifactByName.mockImplementation((_, functionName: string) => Promise.resolve(getFunctionArtifact(StatefulTestContractArtifact, functionName)), ); @@ -555,15 +565,6 @@ describe('Private Execution test suite', () => { describe('consuming messages', () => { const contractAddress = defaultContractAddress; - beforeEach(() => { - oracle.getCompleteAddress.mockImplementation((address: AztecAddress) => { - if (address.equals(recipient)) { - return Promise.resolve(recipientCompleteAddress); - } - throw new Error(`Unknown address ${address}`); - }); - }); - describe('L1 to L2', () => { const artifact = getFunctionArtifact(TestContractArtifact, 'consume_mint_private_message'); let bridgedAmount = 100n; @@ -865,15 +866,6 @@ describe('Private Execution test suite', () => { }); describe('pending note hashes contract', () => { - beforeEach(() => { - oracle.getCompleteAddress.mockImplementation((address: AztecAddress) => { - if (address.equals(owner)) { - return Promise.resolve(ownerCompleteAddress); - } - throw new Error(`Unknown address ${address}`); - }); - }); - beforeEach(() => { oracle.getFunctionArtifact.mockImplementation((_, selector) => Promise.resolve(getFunctionArtifact(PendingNoteHashesContractArtifact, selector)), @@ -1045,15 +1037,15 @@ describe('Private Execution test suite', () => { }); }); - describe('get public key', () => { + describe('get master incoming viewing public key', () => { it('gets the public key for an address', async () => { // Tweak the contract artifact so we can extract return values - const artifact = getFunctionArtifact(TestContractArtifact, 'get_public_key'); + const artifact = getFunctionArtifact(TestContractArtifact, 'get_master_incoming_viewing_public_key'); // Generate a partial address, pubkey, and resulting address const completeAddress = CompleteAddress.random(); const args = [completeAddress.address]; - const pubKey = completeAddress.publicKey; + const pubKey = completeAddress.masterIncomingViewingPublicKey; oracle.getCompleteAddress.mockResolvedValue(completeAddress); const result = await runSimulator({ artifact, args }); diff --git a/yarn-project/simulator/src/client/view_data_oracle.ts b/yarn-project/simulator/src/client/view_data_oracle.ts index 50dc2552c25..b4c02039175 100644 --- a/yarn-project/simulator/src/client/view_data_oracle.ts +++ b/yarn-project/simulator/src/client/view_data_oracle.ts @@ -166,16 +166,6 @@ export class ViewDataOracle extends TypedOracle { return this.db.popCapsule(); } - /** - * Gets public keys for an address. - * @param The address to look up - * @returns The public keys for a specific address - * TODO(#5834): Replace with `getCompleteAddress`. - */ - public override getPublicKeysForAddress(address: AztecAddress) { - return this.db.getPublicKeysForAddress(address); - } - /** * Gets some notes for a contract address and storage slot. * Returns a flattened array containing filtered notes.