Skip to content

Commit

Permalink
refactor: toFields()/fromFields(...) methods in more classes (#4335)
Browse files Browse the repository at this point in the history
Partially fixes #4332

The goal of this PR was to progress on nuking
yarn-project/acir-simulator/src/acvm/serialize.ts and move it to the
individual classes to then be able to easily test it in the classes
tests and improve readability.

I didn't fully tackle the issue in this one because it was all getting
too big.
  • Loading branch information
benesjan authored Feb 1, 2024
1 parent e7db0da commit 433b9eb
Show file tree
Hide file tree
Showing 30 changed files with 434 additions and 340 deletions.
4 changes: 2 additions & 2 deletions yarn-project/acir-simulator/src/acvm/deserialize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ export function extractPrivateCircuitPublicInputs(
const encryptedLogPreimagesLength = witnessReader.readField();
const unencryptedLogPreimagesLength = witnessReader.readField();

const header = Header.fromFieldArray(witnessReader.readFieldArray(HEADER_LENGTH));
const header = Header.fromFields(witnessReader.readFieldArray(HEADER_LENGTH));

const contractDeploymentData = new ContractDeploymentData(
new Point(witnessReader.readField(), witnessReader.readField()),
Expand Down Expand Up @@ -176,7 +176,7 @@ export function extractPublicCircuitPublicInputs(partialWitness: ACVMWitness, ac
const unencryptedLogsHash = witnessReader.readFieldArray(NUM_FIELDS_PER_SHA256);
const unencryptedLogPreimagesLength = witnessReader.readField();

const header = Header.fromFieldArray(witnessReader.readFieldArray(HEADER_LENGTH));
const header = Header.fromFields(witnessReader.readFieldArray(HEADER_LENGTH));

const proverAddress = AztecAddress.fromField(witnessReader.readField());

Expand Down
16 changes: 5 additions & 11 deletions yarn-project/acir-simulator/src/acvm/oracle/oracle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@ import { createDebugLogger } from '@aztec/foundation/log';

import { ACVMField } from '../acvm_types.js';
import { frToNumber, fromACVMField } from '../deserialize.js';
import {
toACVMField,
toACVMHeader,
toAcvmCallPrivateStackItem,
toAcvmEnqueuePublicFunctionResult,
toAcvmL1ToL2MessageLoadOracleInputs,
} from '../serialize.js';
import { toACVMField, toAcvmEnqueuePublicFunctionResult } from '../serialize.js';
import { acvmFieldMessageToString, oracleDebugCallToFormattedStr } from './debug.js';
import { TypedOracle } from './typed_oracle.js';

Expand Down Expand Up @@ -132,7 +126,7 @@ export class Oracle {
if (!header) {
throw new Error(`Block header not found for block ${parsedBlockNumber}.`);
}
return toACVMHeader(header);
return header.toFields().map(toACVMField);
}

async getAuthWitness([messageHash]: ACVMField[]): Promise<ACVMField[]> {
Expand Down Expand Up @@ -226,8 +220,8 @@ export class Oracle {
}

async getL1ToL2Message([msgKey]: ACVMField[]): Promise<ACVMField[]> {
const { ...message } = await this.typedOracle.getL1ToL2Message(fromACVMField(msgKey));
return toAcvmL1ToL2MessageLoadOracleInputs(message);
const message = await this.typedOracle.getL1ToL2Message(fromACVMField(msgKey));
return message.toFields().map(toACVMField);
}

async getPortalContractAddress([aztecAddress]: ACVMField[]): Promise<ACVMField> {
Expand Down Expand Up @@ -297,7 +291,7 @@ export class Oracle {
fromACVMField(argsHash),
frToNumber(fromACVMField(sideffectCounter)),
);
return toAcvmCallPrivateStackItem(callStackItem);
return callStackItem.toFields().map(toACVMField);
}

async callPublicFunction(
Expand Down
33 changes: 16 additions & 17 deletions yarn-project/acir-simulator/src/acvm/oracle/typed_oracle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,22 @@ export interface NoteData {
index?: bigint;
}

/**
* The data for L1 to L2 Messages provided by other data sources.
*/
export interface MessageLoadOracleInputs {
/**
* An collapsed array of fields containing all of the l1 to l2 message components.
* `l1ToL2Message.toFieldArray()` -\> [sender, chainId, recipient, version, content, secretHash, deadline, fee]
*/
message: Fr[];
/**
* The path in the merkle tree to the message.
*/
siblingPath: Fr[];
/**
* The index of the message commitment in the merkle tree.
*/
index: bigint;
export class MessageLoadOracleInputs {
constructor(
/**
* An collapsed array of fields containing all of the l1 to l2 message components.
* `l1ToL2Message.toFieldArray()` -\> [sender, chainId, recipient, version, content, secretHash, deadline, fee]
*/
public message: Fr[],
/** The index of the message commitment in the merkle tree. */
public index: bigint,
/** The path in the merkle tree to the message. */
public siblingPath: Fr[],
) {}

toFields(): Fr[] {
return [...this.message, new Fr(this.index), ...this.siblingPath];
}
}

/**
Expand Down
154 changes: 6 additions & 148 deletions yarn-project/acir-simulator/src/acvm/serialize.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,9 @@
import {
CallContext,
ContractDeploymentData,
FunctionData,
GlobalVariables,
Header,
PrivateCallStackItem,
PrivateCircuitPublicInputs,
PublicCallRequest,
} from '@aztec/circuits.js';
import { PublicCallRequest } from '@aztec/circuits.js';
import { AztecAddress } from '@aztec/foundation/aztec-address';
import { EthAddress } from '@aztec/foundation/eth-address';
import { Fr } from '@aztec/foundation/fields';

import { ACVMField } from './acvm_types.js';
import { MessageLoadOracleInputs } from './oracle/typed_oracle.js';

/**
* Adapts the buffer to the field size.
Expand Down Expand Up @@ -53,125 +43,6 @@ export function toACVMField(
// Utilities to write TS classes to ACVM Field arrays
// In the order that the ACVM expects them

/**
* Converts a function data to ACVM fields.
* @param functionData - The function data to convert.
* @returns The ACVM fields.
*/
export function toACVMFunctionData(functionData: FunctionData): ACVMField[] {
return [
toACVMField(functionData.selector.toBuffer()),
toACVMField(functionData.isInternal),
toACVMField(functionData.isPrivate),
toACVMField(functionData.isConstructor),
];
}

/**
* Converts a call context to ACVM fields.
* @param callContext - The call context to convert.
* @returns The ACVM fields.
*/
export function toACVMCallContext(callContext: CallContext): ACVMField[] {
return [
toACVMField(callContext.msgSender),
toACVMField(callContext.storageContractAddress),
toACVMField(callContext.portalContractAddress),
toACVMField(callContext.functionSelector.toField()),
toACVMField(callContext.isDelegateCall),
toACVMField(callContext.isStaticCall),
toACVMField(callContext.isContractDeployment),
toACVMField(callContext.startSideEffectCounter),
];
}

/**
* Converts a contract deployment data to ACVM fields.
* @param contractDeploymentData - The contract deployment data to convert.
* @returns The ACVM fields.
*/
export function toACVMContractDeploymentData(contractDeploymentData: ContractDeploymentData): ACVMField[] {
return [
toACVMField(contractDeploymentData.publicKey.x),
toACVMField(contractDeploymentData.publicKey.y),
toACVMField(contractDeploymentData.initializationHash),
toACVMField(contractDeploymentData.contractClassId),
toACVMField(contractDeploymentData.contractAddressSalt),
toACVMField(contractDeploymentData.portalContractAddress),
];
}

/**
* Converts a block header into ACVM fields.
* @param header - The block header object to convert.
* @returns The ACVM fields.
*/
export function toACVMHeader(header: Header): ACVMField[] {
return header.toFieldArray().map(toACVMField);
}

/**
* Converts global variables into ACVM fields
* @param globalVariables - The global variables object to convert.
* @returns The ACVM fields
*/
export function toACVMGlobalVariables(globalVariables: GlobalVariables): ACVMField[] {
return [
toACVMField(globalVariables.chainId),
toACVMField(globalVariables.version),
toACVMField(globalVariables.blockNumber),
toACVMField(globalVariables.timestamp),
];
}

/**
* Converts the public inputs structure to ACVM fields.
* @param publicInputs - The public inputs to convert.
* @returns The ACVM fields.
*/
export function toACVMPublicInputs(publicInputs: PrivateCircuitPublicInputs): ACVMField[] {
return [
...toACVMCallContext(publicInputs.callContext),
toACVMField(publicInputs.argsHash),

...publicInputs.returnValues.map(toACVMField),
...publicInputs.readRequests.flatMap(x => x.toFields()).map(toACVMField),
...publicInputs.nullifierKeyValidationRequests.flatMap(x => x.toFields()).map(toACVMField),
...publicInputs.newCommitments.flatMap(x => x.toFields()).map(toACVMField),
...publicInputs.newNullifiers.flatMap(x => x.toFields()).map(toACVMField),
...publicInputs.privateCallStackHashes.map(toACVMField),
...publicInputs.publicCallStackHashes.map(toACVMField),
...publicInputs.newL2ToL1Msgs.map(toACVMField),
toACVMField(publicInputs.endSideEffectCounter),
...publicInputs.encryptedLogsHash.map(toACVMField),
...publicInputs.unencryptedLogsHash.map(toACVMField),

toACVMField(publicInputs.encryptedLogPreimagesLength),
toACVMField(publicInputs.unencryptedLogPreimagesLength),

...toACVMHeader(publicInputs.historicalHeader),

...toACVMContractDeploymentData(publicInputs.contractDeploymentData),

toACVMField(publicInputs.chainId),
toACVMField(publicInputs.version),
];
}

/**
* Converts a private call stack item to ACVM fields.
* @param item - The private call stack item to convert.
* @returns The ACVM fields.
*/
export function toAcvmCallPrivateStackItem(item: PrivateCallStackItem): ACVMField[] {
return [
toACVMField(item.contractAddress),
...toACVMFunctionData(item.functionData),
...toACVMPublicInputs(item.publicInputs),
toACVMField(item.isExecutionRequest),
];
}

/**
* Converts a public call stack item with the request for executing a public function to
* a set of ACVM fields accepted by the enqueue_public_function_call_oracle Aztec.nr function.
Expand All @@ -182,24 +53,11 @@ export function toAcvmCallPrivateStackItem(item: PrivateCallStackItem): ACVMFiel
*/
export function toAcvmEnqueuePublicFunctionResult(item: PublicCallRequest): ACVMField[] {
return [
toACVMField(item.contractAddress),
...toACVMFunctionData(item.functionData),
...toACVMCallContext(item.callContext),
toACVMField(item.getArgsHash()),
];
}

/**
* Converts the result of loading messages to ACVM fields.
* @param messageLoadOracleInputs - The result of loading messages to convert.
* @returns The Message Oracle Fields.
*/
export function toAcvmL1ToL2MessageLoadOracleInputs(messageLoadOracleInputs: MessageLoadOracleInputs): ACVMField[] {
return [
...messageLoadOracleInputs.message.map(f => toACVMField(f)),
toACVMField(messageLoadOracleInputs.index),
...messageLoadOracleInputs.siblingPath.map(f => toACVMField(f)),
];
item.contractAddress.toField(),
...item.functionData.toFields(),
...item.callContext.toFields(),
item.getArgsHash(),
].map(toACVMField);
}

/**
Expand Down
14 changes: 4 additions & 10 deletions yarn-project/acir-simulator/src/client/client_execution_context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,7 @@ import { AztecAddress } from '@aztec/foundation/aztec-address';
import { Fr, Point } from '@aztec/foundation/fields';
import { createDebugLogger } from '@aztec/foundation/log';

import {
NoteData,
toACVMCallContext,
toACVMContractDeploymentData,
toACVMHeader,
toACVMWitness,
} from '../acvm/index.js';
import { NoteData, toACVMWitness } from '../acvm/index.js';
import { PackedArgsCache } from '../common/packed_args_cache.js';
import { DBOracle } from './db_oracle.js';
import { ExecutionNoteCache } from './execution_note_cache.js';
Expand Down Expand Up @@ -96,9 +90,9 @@ export class ClientExecutionContext extends ViewDataOracle {
}

const fields = [
...toACVMCallContext(this.callContext),
...toACVMHeader(this.historicalHeader),
...toACVMContractDeploymentData(contractDeploymentData),
...this.callContext.toFields(),
...this.historicalHeader.toFields(),
...contractDeploymentData.toFields(),

this.txContext.chainId,
this.txContext.version,
Expand Down
14 changes: 8 additions & 6 deletions yarn-project/acir-simulator/src/client/private_execution.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ import { jest } from '@jest/globals';
import { MockProxy, mock } from 'jest-mock-extended';
import { getFunctionSelector } from 'viem';

import { KeyPair } from '../acvm/index.js';
import { KeyPair, MessageLoadOracleInputs } from '../acvm/index.js';
import { buildL1ToL2Message } from '../test/utils.js';
import { computeSlotForMapping } from '../utils.js';
import { DBOracle } from './db_oracle.js';
Expand Down Expand Up @@ -544,11 +544,13 @@ describe('Private Execution test suite', () => {
const mockOracles = async () => {
const tree = await insertLeaves([messageKey ?? preimage.hash()], 'l1ToL2Messages');
oracle.getL1ToL2Message.mockImplementation(async () => {
return Promise.resolve({
message: preimage.toFieldArray(),
index: 0n,
siblingPath: (await tree.getSiblingPath(0n, false)).toFieldArray(),
});
return Promise.resolve(
new MessageLoadOracleInputs(
preimage.toFieldArray(),
0n,
(await tree.getSiblingPath(0n, false)).toFieldArray(),
),
);
});
};

Expand Down
9 changes: 3 additions & 6 deletions yarn-project/acir-simulator/src/public/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { MockProxy, mock } from 'jest-mock-extended';
import { type MemDown, default as memdown } from 'memdown';
import { getFunctionSelector } from 'viem';

import { MessageLoadOracleInputs } from '../index.js';
import { buildL1ToL2Message } from '../test/utils.js';
import { computeSlotForMapping } from '../utils.js';
import { CommitmentsDB, PublicContractsDB, PublicStateDB } from './db.js';
Expand Down Expand Up @@ -459,12 +460,8 @@ describe('ACIR public execution simulator', () => {
for (const sibling of siblingPath) {
root = Fr.fromBuffer(pedersenHash([root.toBuffer(), sibling.toBuffer()]));
}
commitmentsDb.getL1ToL2Message.mockImplementation(async () => {
return await Promise.resolve({
message: preimage.toFieldArray(),
index: 0n,
siblingPath,
});
commitmentsDb.getL1ToL2Message.mockImplementation(() => {
return Promise.resolve(new MessageLoadOracleInputs(preimage.toFieldArray(), 0n, siblingPath));
});

return new AppendOnlyTreeSnapshot(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { EthAddress } from '@aztec/foundation/eth-address';
import { Fr } from '@aztec/foundation/fields';
import { createDebugLogger } from '@aztec/foundation/log';

import { TypedOracle, toACVMCallContext, toACVMGlobalVariables, toACVMHeader, toACVMWitness } from '../acvm/index.js';
import { TypedOracle, toACVMWitness } from '../acvm/index.js';
import { PackedArgsCache, SideEffectCounter } from '../common/index.js';
import { CommitmentsDB, PublicContractsDB, PublicStateDB } from './db.js';
import { PublicExecution, PublicExecutionResult } from './execution.js';
Expand Down Expand Up @@ -49,13 +49,7 @@ export class PublicExecutionContext extends TypedOracle {
*/
public getInitialWitness(witnessStartIndex = 0) {
const { callContext, args } = this.execution;
const fields = [
...toACVMCallContext(callContext),
...toACVMHeader(this.header),
...toACVMGlobalVariables(this.globalVariables),

...args,
];
const fields = [...callContext.toFields(), ...this.header.toFields(), ...this.globalVariables.toFields(), ...args];

return toACVMWitness(witnessStartIndex, fields);
}
Expand Down
4 changes: 2 additions & 2 deletions yarn-project/circuits.js/src/abis/abis.ts
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ function computePrivateInputsHash(input: PrivateCircuitPublicInputs) {
...input.unencryptedLogsHash.map(fr => fr.toBuffer()),
input.encryptedLogPreimagesLength.toBuffer(),
input.unencryptedLogPreimagesLength.toBuffer(),
...(input.historicalHeader.toFieldArray().map(fr => fr.toBuffer()) as Buffer[]),
...(input.historicalHeader.toFields().map(fr => fr.toBuffer()) as Buffer[]),
computeContractDeploymentDataHash(input.contractDeploymentData).toBuffer(),
input.chainId.toBuffer(),
input.version.toBuffer(),
Expand Down Expand Up @@ -463,7 +463,7 @@ export function computePublicInputsHash(input: PublicCircuitPublicInputs) {
...input.newL2ToL1Msgs.map(fr => fr.toBuffer()),
...input.unencryptedLogsHash.map(fr => fr.toBuffer()),
input.unencryptedLogPreimagesLength.toBuffer(),
...input.historicalHeader.toFieldArray().map(fr => fr.toBuffer()),
...input.historicalHeader.toFields().map(fr => fr.toBuffer()),
input.proverAddress.toBuffer(),
];
if (toHash.length != PUBLIC_CIRCUIT_PUBLIC_INPUTS_HASH_INPUT_LENGTH) {
Expand Down
Loading

0 comments on commit 433b9eb

Please sign in to comment.