Skip to content

Commit

Permalink
feat(avm): better field arithmetic (#4142)
Browse files Browse the repository at this point in the history
fixes: #3993
  • Loading branch information
Maddiaa0 authored Jan 19, 2024
1 parent fa4d919 commit 7308e31
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 13 deletions.
110 changes: 110 additions & 0 deletions yarn-project/acir-simulator/src/avm/opcodes/arithmetic.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import { Fr } from '@aztec/foundation/fields';

import { mock } from 'jest-mock-extended';

import { AvmMachineState } from '../avm_machine_state.js';
import { AvmStateManager } from '../avm_state_manager.js';
import { Add, Div, Mul, Sub } from './arithmetic.js';

describe('Arithmetic Instructions', () => {
let machineState: AvmMachineState;
let stateManager = mock<AvmStateManager>();

beforeEach(() => {
machineState = new AvmMachineState([]);
stateManager = mock<AvmStateManager>();
});

describe('Add', () => {
it('Should add correctly over Fr type', () => {
const a = new Fr(1n);
const b = new Fr(2n);

machineState.writeMemory(0, a);
machineState.writeMemory(1, b);

new Add(0, 1, 2).execute(machineState, stateManager);

const expected = new Fr(3n);
const actual = machineState.readMemory(2);
expect(actual).toEqual(expected);
});

it('Should wrap around on addition', () => {
const a = new Fr(1n);
const b = new Fr(Fr.MODULUS - 1n);

machineState.writeMemory(0, a);
machineState.writeMemory(1, b);

new Add(0, 1, 2).execute(machineState, stateManager);

const expected = new Fr(0n);
const actual = machineState.readMemory(3);
expect(actual).toEqual(expected);
});
});

describe('Sub', () => {
it('Should subtract correctly over Fr type', () => {
const a = new Fr(1n);
const b = new Fr(2n);

machineState.writeMemory(0, a);
machineState.writeMemory(1, b);

new Sub(0, 1, 2).execute(machineState, stateManager);

const expected = new Fr(Fr.MODULUS - 1n);
const actual = machineState.readMemory(2);
expect(actual).toEqual(expected);
});
});

describe('Mul', () => {
it('Should multiply correctly over Fr type', () => {
const a = new Fr(2n);
const b = new Fr(3n);

machineState.writeMemory(0, a);
machineState.writeMemory(1, b);

new Mul(0, 1, 2).execute(machineState, stateManager);

const expected = new Fr(6n);
const actual = machineState.readMemory(2);
expect(actual).toEqual(expected);
});

it('Should wrap around on multiplication', () => {
const a = new Fr(2n);
const b = new Fr(Fr.MODULUS / 2n - 1n);

machineState.writeMemory(0, a);
machineState.writeMemory(1, b);

new Mul(0, 1, 2).execute(machineState, stateManager);

const expected = new Fr(Fr.MODULUS - 3n);
const actual = machineState.readMemory(2);
expect(actual).toEqual(expected);
});
});

describe('Div', () => {
it('Should perform field division', () => {
const a = new Fr(2n);
const b = new Fr(3n);

machineState.writeMemory(0, a);
machineState.writeMemory(1, b);

new Div(0, 1, 2).execute(machineState, stateManager);

// Note
const actual = machineState.readMemory(2);
const recovered = actual.mul(b);
expect(recovered).toEqual(a);
});
});
});
9 changes: 4 additions & 5 deletions yarn-project/acir-simulator/src/avm/opcodes/arithmetic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export class Add extends Instruction {
const a = machineState.readMemory(this.aOffset);
const b = machineState.readMemory(this.bOffset);

const dest = new Fr((a.toBigInt() + b.toBigInt()) % Fr.MODULUS);
const dest = a.add(b);
machineState.writeMemory(this.destOffset, dest);

this.incrementPc(machineState);
Expand All @@ -37,7 +37,7 @@ export class Sub extends Instruction {
const a = machineState.readMemory(this.aOffset);
const b = machineState.readMemory(this.bOffset);

const dest = new Fr((a.toBigInt() - b.toBigInt()) % Fr.MODULUS);
const dest = a.sub(b);
machineState.writeMemory(this.destOffset, dest);

this.incrementPc(machineState);
Expand All @@ -57,7 +57,7 @@ export class Mul extends Instruction {
const a: Fr = machineState.readMemory(this.aOffset);
const b: Fr = machineState.readMemory(this.bOffset);

const dest = new Fr((a.toBigInt() * b.toBigInt()) % Fr.MODULUS);
const dest = a.mul(b);
machineState.writeMemory(this.destOffset, dest);

this.incrementPc(machineState);
Expand All @@ -77,8 +77,7 @@ export class Div extends Instruction {
const a: Fr = machineState.readMemory(this.aOffset);
const b: Fr = machineState.readMemory(this.bOffset);

// TODO(https://github.com/AztecProtocol/aztec-packages/issues/3993): proper field division
const dest = new Fr(a.toBigInt() / b.toBigInt());
const dest = a.div(b);
machineState.writeMemory(this.destOffset, dest);

this.incrementPc(machineState);
Expand Down
4 changes: 3 additions & 1 deletion yarn-project/acir-simulator/src/avm/opcodes/bitwise.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ export class Not extends Instruction {
execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void {
const a: Fr = machineState.readMemory(this.aOffset);

// TODO: hack -> until proper field arithmetic is implemented
// TODO: hack -> Bitwise operations should not occur over field elements
// It should only work over integers
const result = ~a.toBigInt();

const dest = new Fr(result < 0 ? Fr.MODULUS + /* using a + as result is -ve*/ result : result);
machineState.writeMemory(this.destOffset, dest);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import {
Add,
/*Div,*/
Mul,
Sub,
} from './arithmetic.js';
import { Add, Div, Mul, Sub } from './arithmetic.js';
//import { And, Not, Or, Shl, Shr, Xor } from './bitwise.js';
//import { Eq, Lt, Lte } from './comparators.js';
import { InternalCall, InternalReturn, Jump, JumpI, Return } from './control_flow.js';
Expand All @@ -29,7 +24,7 @@ export const INSTRUCTION_SET: Map<Opcode, InstructionConstructorAndMembers> = ne
[Opcode.ADD, Add],
[Opcode.SUB, Sub],
[Opcode.MUL, Mul],
//[Opcode.DIV, Div],
[Opcode.DIV, Div],
//// Compute - Comparators
//[Opcode.EQ, Eq],
//[Opcode.LT, Lt],
Expand Down
51 changes: 51 additions & 0 deletions yarn-project/foundation/src/fields/fields.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,30 @@ export class Fr extends BaseField {
static fromString(buf: string) {
return fromString(buf, Fr);
}

/** Arithmetic */

add(rhs: Fr) {
return new Fr((this.toBigInt() + rhs.toBigInt()) % Fr.MODULUS);
}

sub(rhs: Fr) {
const result = this.toBigInt() - rhs.toBigInt();
return new Fr(result < 0 ? result + Fr.MODULUS : result);
}

mul(rhs: Fr) {
return new Fr((this.toBigInt() * rhs.toBigInt()) % Fr.MODULUS);
}

div(rhs: Fr) {
if (rhs.isZero()) {
throw new Error('Division by zero');
}

const bInv = modInverse(rhs.toBigInt());
return this.mul(bInv);
}
}

/**
Expand Down Expand Up @@ -252,6 +276,33 @@ export class Fq extends BaseField {
}
}

// Beware: Performance bottleneck below

/**
* Find the modular inverse of a given element, for BN254 Fr.
*/
function modInverse(b: bigint) {
const [gcd, x, _] = extendedEuclidean(b, Fr.MODULUS);
if (gcd != 1n) {
throw Error('Inverse does not exist');
}
// Add modulus to ensure positive
return new Fr(x + Fr.MODULUS);
}

/**
* The extended Euclidean algorithm can be used to find the multiplicative inverse of a field element
* This is used to perform field division.
*/
function extendedEuclidean(a: bigint, modulus: bigint): [bigint, bigint, bigint] {
if (a == 0n) {
return [modulus, 0n, 1n];
} else {
const [gcd, x, y] = extendedEuclidean(modulus % a, a);
return [gcd, y - (modulus / a) * x, x];
}
}

/**
* GrumpkinScalar is an Fq.
* @remarks Called GrumpkinScalar because it is used to represent elements in Grumpkin's scalar field as defined in
Expand Down

0 comments on commit 7308e31

Please sign in to comment.