Skip to content

Commit

Permalink
[Web] Implement linear congruential generator, make runtime seedable (#…
Browse files Browse the repository at this point in the history
…16722)

This PR implements `LinearCongruentialGenerator` in TVMjs,
following the C++ counterpart in #8642.
The motivation is that we want to seed autoregressive generation
to make results reproducible, supporting the OpenAI field `seed`.
The main function is `nextInt()`, which generates a number
`(0, 2^32 - 1)` non-inclusive.

Subsequently, we change all `Math.random()` in `runtime.ts` to
`this.rng.randomFloat()`, exposing API `Instance.setSeed()`.

Unit tests are added for `LinearCongruentialGenerator` for testing
seed and coverage.
  • Loading branch information
CharlieFRuan authored Mar 15, 2024
1 parent 94866f7 commit 45df124
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 5 deletions.
2 changes: 1 addition & 1 deletion web/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export {
} from "./runtime";
export { Disposable, LibraryProvider } from "./types";
export { RPCServer } from "./rpc_server";
export { wasmPath } from "./support";
export { wasmPath, LinearCongruentialGenerator } from "./support";
export { detectGPUDevice, GPUDeviceDetectOutput } from "./webgpu";
export { assert } from "./support";
export { createPolyfillWASI } from "./compact";
17 changes: 13 additions & 4 deletions web/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import { Pointer, PtrOffset, SizeOf, ArgTypeCode } from "./ctypes";
import { Disposable } from "./types";
import { Memory, CachedCallStack } from "./memory";
import { assert, StringToUint8Array } from "./support";
import { assert, StringToUint8Array, LinearCongruentialGenerator } from "./support";
import { Environment } from "./environment";
import { AsyncifyHandler } from "./asyncify";
import { FunctionInfo, WebGPUContext } from "./webgpu";
Expand Down Expand Up @@ -1079,6 +1079,7 @@ export class Instance implements Disposable {
private ctx: RuntimeContext;
private asyncifyHandler: AsyncifyHandler;
private initProgressCallback: Array<InitProgressCallback> = [];
private rng: LinearCongruentialGenerator;

/**
* Internal function(registered by the runtime)
Expand Down Expand Up @@ -1131,6 +1132,7 @@ export class Instance implements Disposable {
);
this.registerEnvGlobalPackedFuncs();
this.registerObjectFactoryFuncs();
this.rng = new LinearCongruentialGenerator();
}

/**
Expand Down Expand Up @@ -1811,11 +1813,18 @@ export class Instance implements Disposable {
const scale = high - low;
const input = new Float32Array(size);
for (let i = 0; i < input.length; ++i) {
input[i] = low + Math.random() * scale;
input[i] = low + this.rng.randomFloat() * scale;
}
return ret.copyFrom(input);
}

/**
* Set the seed of the internal LinearCongruentialGenerator.
*/
setSeed(seed: number): void {
this.rng.setSeed(seed);
}

/**
* Sample index via top-p sampling.
*
Expand All @@ -1825,7 +1834,7 @@ export class Instance implements Disposable {
* @returns The sampled index.
*/
sampleTopPFromLogits(logits: NDArray, temperature: number, top_p: number): number {
return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, Math.random());
return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, this.rng.randomFloat());
}

/**
Expand All @@ -1836,7 +1845,7 @@ export class Instance implements Disposable {
* @returns The sampled index.
*/
sampleTopPFromProb(prob: NDArray, top_p: number): number {
return this.ctx.sampleTopPFromProb(prob, top_p, Math.random());
return this.ctx.sampleTopPFromProb(prob, top_p, this.rng.randomFloat());
}

/**
Expand Down
76 changes: 76 additions & 0 deletions web/src/support.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,79 @@ export function assert(condition: boolean, msg?: string): asserts condition {
export function wasmPath(): string {
return __dirname + "/wasm";
}

/**
* Linear congruential generator for random number generating that can be seeded.
*
* Follows the implementation of `include/tvm/support/random_engine.h`, which follows the
* sepcification in https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine.
*
* Note `Number.MAX_SAFE_INTEGER = 2^53 - 1`, and our intermediates are strictly less than 2^48.
*/

export class LinearCongruentialGenerator {
readonly modulus: number;
readonly multiplier: number;
readonly increment: number;
// Always within the range (0, 2^32 - 1) non-inclusive; if 0, will forever generate 0.
private rand_state: number;

/**
* Set modulus, multiplier, and increment. Initialize `rand_state` according to `Date.now()`.
*/
constructor() {
this.modulus = 2147483647; // 2^32 - 1
this.multiplier = 48271; // between 2^15 and 2^16
this.increment = 0;
this.setSeed(Date.now());
}

/**
* Sets `rand_state` after normalized with `modulus` to ensure that it is within range.
* @param seed Any integer. Used to set `rand_state` after normalized with `modulus`.
*
* Postcondition: pass `checkRandState()`, i.e. rand_state > 0 and is an integer.
*/
setSeed(seed: number) {
if (!Number.isInteger(seed)) {
throw new Error("Seed should be an integer.");
}
this.rand_state = seed % this.modulus;
if (this.rand_state == 0) {
this.rand_state = 1;
}
this.checkRandState();
}

/**
* Generate the next integer in the range (0, this.modulus) non-inclusive, updating `rand_state`.
*
* Postcondition: pass `checkRandState()`, i.e. rand_state > 0 and is an integer.
*/
nextInt(): number {
// `intermediate` is always < 2^48, hence less than `Number.MAX_SAFE_INTEGER` due to the
// invariants as commented in the constructor.
const intermediate = this.multiplier * this.rand_state + this.increment;
this.rand_state = intermediate % this.modulus;
this.checkRandState();
return this.rand_state;
}

/**
* Generates random float between (0, 1) non-inclusive, updating `rand_state`.
*
* Postcondition: pass `checkRandState()`, i.e. rand_state > 0 and is an integer.
*/
randomFloat(): number {
return this.nextInt() / this.modulus;
}

private checkRandState(): void {
if (this.rand_state <= 0) {
throw new Error("Random state is unexpectedly not strictly positive.");
}
if (!Number.isInteger(this.rand_state)) {
throw new Error("Random state is unexpectedly not an integer.");
}
}
}
71 changes: 71 additions & 0 deletions web/tests/node/test_random_generator.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/* eslint-disable no-undef */

const tvmjs = require("../../dist");

test("Test coverage of [0,100] inclusive", () => {
const covered = Array(100);
const rng = new tvmjs.LinearCongruentialGenerator();
for (let i = 0; i < 100000; i++) {
covered[rng.nextInt() % 100] = true;
}
const notCovered = [];
for (let i = 0; i < 100; i++) {
if (!covered[i]) {
notCovered.push(i);
}
}
expect(notCovered).toEqual([]);
});

test("Test whether the same seed make two RNGs generate same results", () => {
const rng1 = new tvmjs.LinearCongruentialGenerator();
const rng2 = new tvmjs.LinearCongruentialGenerator();
rng1.setSeed(42);
rng2.setSeed(42);

for (let i = 0; i < 100; i++) {
expect(rng1.randomFloat()).toBeCloseTo(rng2.randomFloat());
}
});

test("Test two RNGs with different seeds generate different results", () => {
const rng1 = new tvmjs.LinearCongruentialGenerator();
const rng2 = new tvmjs.LinearCongruentialGenerator();
rng1.setSeed(41);
rng2.setSeed(42);
let numSame = 0;
const numTest = 100;

// Generate `numTest` random numbers, make sure not all are the same.
for (let i = 0; i < numTest; i++) {
if (rng1.nextInt() === rng2.nextInt()) {
numSame += 1;
}
}
expect(numSame < numTest).toBe(true);
});

test('Illegal argument to `setSeed()`', () => {
expect(() => {
const rng1 = new tvmjs.LinearCongruentialGenerator();
rng1.setSeed(42.5);
}).toThrow("Seed should be an integer.");
});

0 comments on commit 45df124

Please sign in to comment.