Skip to content

Commit

Permalink
Generalize loss function for both CPU and GPU
Browse files Browse the repository at this point in the history
Allow for using the same loss function between both CPU and GPU neural networks
  • Loading branch information
voidvoxel committed Jun 16, 2024
1 parent ba12f82 commit 7371a23
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
28 changes: 24 additions & 4 deletions src/neural-network.ts
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,12 @@ export class NeuralNetwork<

let error = 0;
if (layer === this.outputLayer) {
if (typeof this._lossFunction === "function") error = this._lossFunction.call({thread: {x: node}}, output, target[node], input, this.lossState);
if (typeof this._lossFunction === "function") {
const kernelFunctionThis = { thread: { x: node, y: layer, z: 0 } };
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.lossState);
}
else error = target[node] - output;
} else {
const deltas = this.deltas[layer + 1];
Expand Down Expand Up @@ -796,7 +801,12 @@ export class NeuralNetwork<

let error = 0;
if (layer === this.outputLayer) {
if (typeof this._lossFunction === "function") error = this._lossFunction.call({thread: {x: node}}, output, target[node], input, this.lossState);
if (typeof this._lossFunction === "function") {
const kernelFunctionThis = { thread: { x: node, y: layer, z: 0 } };
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.lossState);
}
else error = target[node] - output;
} else {
for (let k = 0; k < nextDeltas.length; k++) {
Expand Down Expand Up @@ -824,7 +834,12 @@ export class NeuralNetwork<

let error = 0;
if (layer === this.outputLayer) {
if (typeof this._lossFunction === "function") error = this._lossFunction.call({thread: {x: node}}, output, target[node], input, this.lossState);
if (typeof this._lossFunction === "function") {
const kernelFunctionThis = { thread: { x: node, y: layer, z: 0 } };
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.lossState);
}
else error = target[node] - output;
} else {
for (let k = 0; k < nextDeltas.length; k++) {
Expand All @@ -851,7 +866,12 @@ export class NeuralNetwork<

let error = 0;
if (layer === this.outputLayer) {
if (typeof this._lossFunction === "function") error = this._lossFunction.call({thread: {x: node}}, output, target[node], input, this.lossState);
if (typeof this._lossFunction === "function") {
const kernelFunctionThis = { thread: { x: node, y: layer, z: 0 } };
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.lossState);
}
else error = target[node] - output;
} else {
for (let k = 0; k < nextDeltas.length; k++) {
Expand Down
3 changes: 3 additions & 0 deletions src/utilities/loss.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import { IKernelFunctionThis } from "gpu.js";

export type LossFunctionInputs = number[] | number[][] | number[][][] | Float32Array | Float32Array[] | Float32Array[][];

export type LossFunctionState = number[][][] | Float32Array[][];

export type LossFunction = (
this: IKernelFunctionThis,
actual: number,
expected: number,
inputs: LossFunctionInputs,
Expand Down

0 comments on commit 7371a23

Please sign in to comment.