diff --git a/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.test.ts b/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.test.ts new file mode 100644 index 00000000000..76a802231a3 --- /dev/null +++ b/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.test.ts @@ -0,0 +1,110 @@ +import { Fr } from '@aztec/foundation/fields'; + +import { mock } from 'jest-mock-extended'; + +import { AvmMachineState } from '../avm_machine_state.js'; +import { AvmStateManager } from '../avm_state_manager.js'; +import { Add, Div, Mul, Sub } from './arithmetic.js'; + +describe('Arithmetic Instructions', () => { + let machineState: AvmMachineState; + let stateManager = mock(); + + beforeEach(() => { + machineState = new AvmMachineState([]); + stateManager = mock(); + }); + + describe('Add', () => { + it('Should add correctly over Fr type', () => { + const a = new Fr(1n); + const b = new Fr(2n); + + machineState.writeMemory(0, a); + machineState.writeMemory(1, b); + + new Add(0, 1, 2).execute(machineState, stateManager); + + const expected = new Fr(3n); + const actual = machineState.readMemory(2); + expect(actual).toEqual(expected); + }); + + it('Should wrap around on addition', () => { + const a = new Fr(1n); + const b = new Fr(Fr.MODULUS - 1n); + + machineState.writeMemory(0, a); + machineState.writeMemory(1, b); + + new Add(0, 1, 2).execute(machineState, stateManager); + + const expected = new Fr(0n); + const actual = machineState.readMemory(3); + expect(actual).toEqual(expected); + }); + }); + + describe('Sub', () => { + it('Should subtract correctly over Fr type', () => { + const a = new Fr(1n); + const b = new Fr(2n); + + machineState.writeMemory(0, a); + machineState.writeMemory(1, b); + + new Sub(0, 1, 2).execute(machineState, stateManager); + + const expected = new Fr(Fr.MODULUS - 1n); + const actual = machineState.readMemory(2); + expect(actual).toEqual(expected); + }); + }); + + describe('Mul', () => { + it('Should multiply correctly over Fr type', () => { + const a = new Fr(2n); + const b = new Fr(3n); + + machineState.writeMemory(0, a); + machineState.writeMemory(1, b); + + new Mul(0, 1, 2).execute(machineState, stateManager); + + const expected = new Fr(6n); + const actual = machineState.readMemory(2); + expect(actual).toEqual(expected); + }); + + it('Should wrap around on multiplication', () => { + const a = new Fr(2n); + const b = new Fr(Fr.MODULUS / 2n - 1n); + + machineState.writeMemory(0, a); + machineState.writeMemory(1, b); + + new Mul(0, 1, 2).execute(machineState, stateManager); + + const expected = new Fr(Fr.MODULUS - 3n); + const actual = machineState.readMemory(2); + expect(actual).toEqual(expected); + }); + }); + + describe('Div', () => { + it('Should perform field division', () => { + const a = new Fr(2n); + const b = new Fr(3n); + + machineState.writeMemory(0, a); + machineState.writeMemory(1, b); + + new Div(0, 1, 2).execute(machineState, stateManager); + + // Note + const actual = machineState.readMemory(2); + const recovered = actual.mul(b); + expect(recovered).toEqual(a); + }); + }); +}); diff --git a/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.ts b/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.ts index 7f2fb3db052..94df9713c83 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.ts @@ -17,7 +17,7 @@ export class Add extends Instruction { const a = machineState.readMemory(this.aOffset); const b = machineState.readMemory(this.bOffset); - const dest = new Fr((a.toBigInt() + b.toBigInt()) % Fr.MODULUS); + const dest = a.add(b); machineState.writeMemory(this.destOffset, dest); this.incrementPc(machineState); @@ -37,7 +37,7 @@ export class Sub extends Instruction { const a = machineState.readMemory(this.aOffset); const b = machineState.readMemory(this.bOffset); - const dest = new Fr((a.toBigInt() - b.toBigInt()) % Fr.MODULUS); + const dest = a.sub(b); machineState.writeMemory(this.destOffset, dest); this.incrementPc(machineState); @@ -57,7 +57,7 @@ export class Mul extends Instruction { const a: Fr = machineState.readMemory(this.aOffset); const b: Fr = machineState.readMemory(this.bOffset); - const dest = new Fr((a.toBigInt() * b.toBigInt()) % Fr.MODULUS); + const dest = a.mul(b); machineState.writeMemory(this.destOffset, dest); this.incrementPc(machineState); @@ -77,8 +77,7 @@ export class Div extends Instruction { const a: Fr = machineState.readMemory(this.aOffset); const b: Fr = machineState.readMemory(this.bOffset); - // TODO(https://github.com/AztecProtocol/aztec-packages/issues/3993): proper field division - const dest = new Fr(a.toBigInt() / b.toBigInt()); + const dest = a.div(b); machineState.writeMemory(this.destOffset, dest); this.incrementPc(machineState); diff --git a/yarn-project/acir-simulator/src/avm/opcodes/bitwise.ts b/yarn-project/acir-simulator/src/avm/opcodes/bitwise.ts index bc5c3a211cb..4950c771f7c 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/bitwise.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/bitwise.ts @@ -76,8 +76,10 @@ export class Not extends Instruction { execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { const a: Fr = machineState.readMemory(this.aOffset); - // TODO: hack -> until proper field arithmetic is implemented + // TODO: hack -> Bitwise operations should not occur over field elements + // It should only work over integers const result = ~a.toBigInt(); + const dest = new Fr(result < 0 ? Fr.MODULUS + /* using a + as result is -ve*/ result : result); machineState.writeMemory(this.destOffset, dest); diff --git a/yarn-project/acir-simulator/src/avm/opcodes/instruction_set.ts b/yarn-project/acir-simulator/src/avm/opcodes/instruction_set.ts index 1211205b979..4395ba46d53 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/instruction_set.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/instruction_set.ts @@ -1,9 +1,4 @@ -import { - Add, - /*Div,*/ - Mul, - Sub, -} from './arithmetic.js'; +import { Add, Div, Mul, Sub } from './arithmetic.js'; //import { And, Not, Or, Shl, Shr, Xor } from './bitwise.js'; //import { Eq, Lt, Lte } from './comparators.js'; import { InternalCall, InternalReturn, Jump, JumpI, Return } from './control_flow.js'; @@ -29,7 +24,7 @@ export const INSTRUCTION_SET: Map = ne [Opcode.ADD, Add], [Opcode.SUB, Sub], [Opcode.MUL, Mul], - //[Opcode.DIV, Div], + [Opcode.DIV, Div], //// Compute - Comparators //[Opcode.EQ, Eq], //[Opcode.LT, Lt], diff --git a/yarn-project/foundation/src/fields/fields.ts b/yarn-project/foundation/src/fields/fields.ts index 90600643cfe..bbf2f3b9a20 100644 --- a/yarn-project/foundation/src/fields/fields.ts +++ b/yarn-project/foundation/src/fields/fields.ts @@ -192,6 +192,30 @@ export class Fr extends BaseField { static fromString(buf: string) { return fromString(buf, Fr); } + + /** Arithmetic */ + + add(rhs: Fr) { + return new Fr((this.toBigInt() + rhs.toBigInt()) % Fr.MODULUS); + } + + sub(rhs: Fr) { + const result = this.toBigInt() - rhs.toBigInt(); + return new Fr(result < 0 ? result + Fr.MODULUS : result); + } + + mul(rhs: Fr) { + return new Fr((this.toBigInt() * rhs.toBigInt()) % Fr.MODULUS); + } + + div(rhs: Fr) { + if (rhs.isZero()) { + throw new Error('Division by zero'); + } + + const bInv = modInverse(rhs.toBigInt()); + return this.mul(bInv); + } } /** @@ -252,6 +276,33 @@ export class Fq extends BaseField { } } +// Beware: Performance bottleneck below + +/** + * Find the modular inverse of a given element, for BN254 Fr. + */ +function modInverse(b: bigint) { + const [gcd, x, _] = extendedEuclidean(b, Fr.MODULUS); + if (gcd != 1n) { + throw Error('Inverse does not exist'); + } + // Add modulus to ensure positive + return new Fr(x + Fr.MODULUS); +} + +/** + * The extended Euclidean algorithm can be used to find the multiplicative inverse of a field element + * This is used to perform field division. + */ +function extendedEuclidean(a: bigint, modulus: bigint): [bigint, bigint, bigint] { + if (a == 0n) { + return [modulus, 0n, 1n]; + } else { + const [gcd, x, y] = extendedEuclidean(modulus % a, a); + return [gcd, y - (modulus / a) * x, x]; + } +} + /** * GrumpkinScalar is an Fq. * @remarks Called GrumpkinScalar because it is used to represent elements in Grumpkin's scalar field as defined in