diff --git a/packages/upscalerjs/src/image.browser.ts b/packages/upscalerjs/src/image.browser.ts index 94a9e6814..c8d3e8b67 100644 --- a/packages/upscalerjs/src/image.browser.ts +++ b/packages/upscalerjs/src/image.browser.ts @@ -64,8 +64,9 @@ export const isHTMLImageElement = (pixels: GetImageAsTensorInput): pixels is HTM } }; -export const tensorAsBase64 = async (tensor: tf.Tensor3D) => { - const arr = await tensorAsClampedArray(tensor); +export const tensorAsBase64 = (tensor: tf.Tensor3D): string => { + const arr = tensorAsClampedArray(tensor); + tensor.dispose(); const [height, width, ] = tensor.shape; const imageData = new ImageData(width, height); imageData.data.set(arr); diff --git a/packages/upscalerjs/src/image.node.ts b/packages/upscalerjs/src/image.node.ts index 343ebc5be..492ac4365 100644 --- a/packages/upscalerjs/src/image.node.ts +++ b/packages/upscalerjs/src/image.node.ts @@ -67,8 +67,9 @@ export const getImageAsTensor = async ( throw getInvalidTensorError(tensor); }; -export const tensorAsBase64 = async (tensor: tf.Tensor3D) => { - const arr = await tensorAsClampedArray(tensor); +export const tensorAsBase64 = (tensor: tf.Tensor3D): string => { + const arr = tensorAsClampedArray(tensor); + tensor.dispose(); return Buffer.from(arr).toString('base64'); }; diff --git a/packages/upscalerjs/src/upscale.test.ts b/packages/upscalerjs/src/upscale.test.ts index 178039540..689a42573 100644 --- a/packages/upscalerjs/src/upscale.test.ts +++ b/packages/upscalerjs/src/upscale.test.ts @@ -1,4 +1,5 @@ -import { tf } from './dependencies.generated'; +// import { tf } from './dependencies.generated'; +import * as tf from '@tensorflow/tfjs-node'; import { AbortError, predict, @@ -32,8 +33,6 @@ jest.mock('./image.generated', () => ({ const mockedImage = image as jest.Mocked; // const mockedTensorAsBase = tensorAsBase as jest.Mocked; -console.log(mockedImage); - describe('concatTensors', () => { beforeEach(() => { try { @@ -1218,7 +1217,7 @@ describe('predict', () => { it('should invoke progress callback with percent and slice', async () => { console.warn = jest.fn(); const mockResponse = 'foobarbaz'; - (mockedImage as any).default.tensorAsBase64 = async() => mockResponse; + (mockedImage as any).default.tensorAsBase64 = () => mockResponse; const img: tf.Tensor4D = tf.ones([4, 2, 3,]).expandDims(0); const scale = 2; const patchSize = 2; @@ -1543,7 +1542,7 @@ describe('upscale', () => { const model = { predict: jest.fn(() => tf.ones([1, 2, 2, 3,])), } as unknown as tf.LayersModel; - (mockedImage as any).default.tensorAsBase64 = async () => 'foobarbaz'; + (mockedImage as any).default.tensorAsBase64 = () => 'foobarbaz'; const result = await wrapGenerator(upscale(img, {}, { model, modelDefinition: { scale: 2, } as ModelDefinition, diff --git a/packages/upscalerjs/src/upscale.ts b/packages/upscalerjs/src/upscale.ts index 92cbbe5d4..649b156b9 100644 --- a/packages/upscalerjs/src/upscale.ts +++ b/packages/upscalerjs/src/upscale.ts @@ -254,6 +254,8 @@ export async function* predict

, O extends ResultFormat modelDefinition, }: UpscaleInternalArgs ): AsyncGenerator { + // TODO: Remove this + await Promise.resolve(); const scale = modelDefinition.scale; if (originalPatchSize && padding === undefined) { @@ -328,7 +330,8 @@ export async function* predict

, O extends ResultFormat (>progress)(percent, squeezedTensor); } else { // because we are returning a string, we can safely dispose of our tensor - const src = await tensorAsBase64(squeezedTensor); + const src = tensorAsBase64(squeezedTensor); + console.log('what is src', src); squeezedTensor.dispose(); (>progress)(percent, src); } @@ -431,7 +434,7 @@ export async function* upscale

, O extends ResultFormat return >postprocessedPixels; } - const base64Src = await tensorAsBase64(postprocessedPixels); + const base64Src = tensorAsBase64(postprocessedPixels); postprocessedPixels.dispose(); return >base64Src; } diff --git a/packages/upscalerjs/src/utils.test.ts b/packages/upscalerjs/src/utils.test.ts index 6e9274d6e..4a1a03989 100644 --- a/packages/upscalerjs/src/utils.test.ts +++ b/packages/upscalerjs/src/utils.test.ts @@ -286,13 +286,13 @@ describe('isTensor', () => { }); describe('tensorAsClampedArray', () => { - it('returns an array', async () => { - const result = await tensorAsClampedArray(tf.tensor([[[2, 2, 3], [2, 1, 4], [5,5,5],[6,6,6]]])) - expect(Array.from(result)).toEqual([2,2,3,255,2,1,4,255,5,5,5,255,6,6,6,255]); + it('returns an array', () => { + const result = tensorAsClampedArray(tf.tensor([[[2, 2, 3], [2, 1, 4], [5,5,5],[6,6,6], [7,7,7],[8,8,8]]])) + expect(Array.from(result)).toEqual([2,2,3,255,2,1,4,255,5,5,5,255,6,6,6,255,7,7,7,255,8,8,8,255]); }); - it('returns a clamped array', async () => { - const result = await tensorAsClampedArray(tf.tensor([[[-100, 2, 3], [256, 1, 4], [500,5,5],[6,6,6]]])) + it('returns a clamped array', () => { + const result = tensorAsClampedArray(tf.tensor([[[-100, 2, 3], [256, 1, 4], [500,5,5],[6,6,6]]])) expect(Array.from(result)).toEqual([0,2,3,255,255,1,4,255,255,5,5,255,6,6,6,255]); }); }); diff --git a/packages/upscalerjs/src/utils.ts b/packages/upscalerjs/src/utils.ts index 021be6a25..c52ca8c84 100644 --- a/packages/upscalerjs/src/utils.ts +++ b/packages/upscalerjs/src/utils.ts @@ -86,20 +86,8 @@ export async function wrapGenerator export function isModelDefinitionFn (modelDefinition: ModelDefinitionObjectOrFn): modelDefinition is ModelDefinitionFn { return typeof modelDefinition === 'function'; } -export const tensorAsClampedArray = async (tensor: tf.Tensor3D) => { - const [height, width, ] = tensor.shape; - const arr = new Uint8ClampedArray(width * height * 4); - const data = await tensor.data(); - let i = 0; - for (let y = 0; y < height; y++) { - for (let x = 0; x < width; x++) { - const pos = (y * width + x) * 4; - arr[pos] = data[i]; // R - arr[pos + 1] = data[i + 1]; // G - arr[pos + 2] = data[i + 2]; // B - arr[pos + 3] = 255; // Alpha - i += 3; - } - } - return arr; -}; +export const tensorAsClampedArray = (tensor: tf.Tensor3D) => tf.tidy(() => { + const [height, width,] = tensor.shape; + const fill = tf.fill([height, width,], 255).expandDims(2); + return tensor.clipByValue(0, 255).concat([fill,], 2).dataSync(); +});