diff --git a/src/neural-network.ts b/src/neural-network.ts index 8d334a0f..b341cb4e 100644 --- a/src/neural-network.ts +++ b/src/neural-network.ts @@ -768,7 +768,12 @@ export class NeuralNetwork< let error = 0; if (layer === this.outputLayer) { - if (typeof this._lossFunction === "function") error = this._lossFunction.call({thread: {x: node}}, output, target[node], input, this.lossState); + if (typeof this._lossFunction === "function") { + const kernelFunctionThis = { thread: { x: node, y: layer, z: 0 } }; + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.lossState); + } else error = target[node] - output; } else { const deltas = this.deltas[layer + 1]; @@ -796,7 +801,12 @@ export class NeuralNetwork< let error = 0; if (layer === this.outputLayer) { - if (typeof this._lossFunction === "function") error = this._lossFunction.call({thread: {x: node}}, output, target[node], input, this.lossState); + if (typeof this._lossFunction === "function") { + const kernelFunctionThis = { thread: { x: node, y: layer, z: 0 } }; + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.lossState); + } else error = target[node] - output; } else { for (let k = 0; k < nextDeltas.length; k++) { @@ -824,7 +834,12 @@ export class NeuralNetwork< let error = 0; if (layer === this.outputLayer) { - if (typeof this._lossFunction === "function") error = this._lossFunction.call({thread: {x: node}}, output, target[node], input, this.lossState); + if (typeof this._lossFunction === "function") { + const kernelFunctionThis = { thread: { x: node, y: layer, z: 0 } }; + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.lossState); + } else error = target[node] - output; } else { for (let k = 0; k < nextDeltas.length; k++) { @@ -851,7 +866,12 @@ export class NeuralNetwork< let error = 0; if (layer === this.outputLayer) { - if (typeof this._lossFunction === "function") error = this._lossFunction.call({thread: {x: node}}, output, target[node], input, this.lossState); + if (typeof this._lossFunction === "function") { + const kernelFunctionThis = { thread: { x: node, y: layer, z: 0 } }; + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.lossState); + } else error = target[node] - output; } else { for (let k = 0; k < nextDeltas.length; k++) { diff --git a/src/utilities/loss.ts b/src/utilities/loss.ts index 3e07072f..56a4da66 100644 --- a/src/utilities/loss.ts +++ b/src/utilities/loss.ts @@ -1,8 +1,11 @@ +import { IKernelFunctionThis } from "gpu.js"; + export type LossFunctionInputs = number[] | number[][] | number[][][] | Float32Array | Float32Array[] | Float32Array[][]; export type LossFunctionState = number[][][] | Float32Array[][]; export type LossFunction = ( + this: IKernelFunctionThis, actual: number, expected: number, inputs: LossFunctionInputs,