Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor utility function #387

Merged
merged 10 commits into from
Jul 24, 2022
5 changes: 3 additions & 2 deletions packages/upscalerjs/src/image.browser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ export const isHTMLImageElement = (pixels: GetImageAsTensorInput): pixels is HTM
}
};

export const tensorAsBase64 = async (tensor: tf.Tensor3D) => {
const arr = await tensorAsClampedArray(tensor);
export const tensorAsBase64 = (tensor: tf.Tensor3D): string => {
const arr = tensorAsClampedArray(tensor);
tensor.dispose();
const [height, width, ] = tensor.shape;
const imageData = new ImageData(width, height);
imageData.data.set(arr);
Expand Down
5 changes: 3 additions & 2 deletions packages/upscalerjs/src/image.node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ export const getImageAsTensor = async (
throw getInvalidTensorError(tensor);
};

export const tensorAsBase64 = async (tensor: tf.Tensor3D) => {
const arr = await tensorAsClampedArray(tensor);
export const tensorAsBase64 = (tensor: tf.Tensor3D): string => {
const arr = tensorAsClampedArray(tensor);
tensor.dispose();
return Buffer.from(arr).toString('base64');
};

9 changes: 4 additions & 5 deletions packages/upscalerjs/src/upscale.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { tf } from './dependencies.generated';
// import { tf } from './dependencies.generated';
import * as tf from '@tensorflow/tfjs-node';
import {
AbortError,
predict,
Expand Down Expand Up @@ -32,8 +33,6 @@ jest.mock('./image.generated', () => ({
const mockedImage = image as jest.Mocked<typeof image>;
// const mockedTensorAsBase = tensorAsBase as jest.Mocked<typeof tensorAsBase>;

console.log(mockedImage);

describe('concatTensors', () => {
beforeEach(() => {
try {
Expand Down Expand Up @@ -1218,7 +1217,7 @@ describe('predict', () => {
it('should invoke progress callback with percent and slice', async () => {
console.warn = jest.fn();
const mockResponse = 'foobarbaz';
(mockedImage as any).default.tensorAsBase64 = async() => mockResponse;
(mockedImage as any).default.tensorAsBase64 = () => mockResponse;
const img: tf.Tensor4D = tf.ones([4, 2, 3,]).expandDims(0);
const scale = 2;
const patchSize = 2;
Expand Down Expand Up @@ -1543,7 +1542,7 @@ describe('upscale', () => {
const model = {
predict: jest.fn(() => tf.ones([1, 2, 2, 3,])),
} as unknown as tf.LayersModel;
(mockedImage as any).default.tensorAsBase64 = async () => 'foobarbaz';
(mockedImage as any).default.tensorAsBase64 = () => 'foobarbaz';
const result = await wrapGenerator(upscale(img, {}, {
model,
modelDefinition: { scale: 2, } as ModelDefinition,
Expand Down
7 changes: 5 additions & 2 deletions packages/upscalerjs/src/upscale.ts
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ export async function* predict<P extends Progress<O, PO>, O extends ResultFormat
modelDefinition,
}: UpscaleInternalArgs
): AsyncGenerator<YieldedIntermediaryValue, tf.Tensor3D> {
// TODO: Remove this
await Promise.resolve();
const scale = modelDefinition.scale;

if (originalPatchSize && padding === undefined) {
Expand Down Expand Up @@ -328,7 +330,8 @@ export async function* predict<P extends Progress<O, PO>, O extends ResultFormat
(<MultiArgProgress<'tensor'>>progress)(percent, squeezedTensor);
} else {
// because we are returning a string, we can safely dispose of our tensor
const src = await tensorAsBase64(squeezedTensor);
const src = tensorAsBase64(squeezedTensor);
console.log('what is src', src);
squeezedTensor.dispose();
(<MultiArgProgress<'src'>>progress)(percent, src);
}
Expand Down Expand Up @@ -431,7 +434,7 @@ export async function* upscale<P extends Progress<O, PO>, O extends ResultFormat
return <UpscaleResponse<O>>postprocessedPixels;
}

const base64Src = await tensorAsBase64(postprocessedPixels);
const base64Src = tensorAsBase64(postprocessedPixels);
postprocessedPixels.dispose();
return <UpscaleResponse<O>>base64Src;
}
Expand Down
10 changes: 5 additions & 5 deletions packages/upscalerjs/src/utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,13 @@ describe('isTensor', () => {
});

describe('tensorAsClampedArray', () => {
it('returns an array', async () => {
const result = await tensorAsClampedArray(tf.tensor([[[2, 2, 3], [2, 1, 4], [5,5,5],[6,6,6]]]))
expect(Array.from(result)).toEqual([2,2,3,255,2,1,4,255,5,5,5,255,6,6,6,255]);
it('returns an array', () => {
const result = tensorAsClampedArray(tf.tensor([[[2, 2, 3], [2, 1, 4], [5,5,5],[6,6,6], [7,7,7],[8,8,8]]]))
expect(Array.from(result)).toEqual([2,2,3,255,2,1,4,255,5,5,5,255,6,6,6,255,7,7,7,255,8,8,8,255]);
});

it('returns a clamped array', async () => {
const result = await tensorAsClampedArray(tf.tensor([[[-100, 2, 3], [256, 1, 4], [500,5,5],[6,6,6]]]))
it('returns a clamped array', () => {
const result = tensorAsClampedArray(tf.tensor([[[-100, 2, 3], [256, 1, 4], [500,5,5],[6,6,6]]]))
expect(Array.from(result)).toEqual([0,2,3,255,255,1,4,255,255,5,5,255,6,6,6,255]);
});
});
22 changes: 5 additions & 17 deletions packages/upscalerjs/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,8 @@ export async function wrapGenerator<T = unknown, TReturn = any, TNext = unknown>

export function isModelDefinitionFn (modelDefinition: ModelDefinitionObjectOrFn): modelDefinition is ModelDefinitionFn { return typeof modelDefinition === 'function'; }

export const tensorAsClampedArray = async (tensor: tf.Tensor3D) => {
const [height, width, ] = tensor.shape;
const arr = new Uint8ClampedArray(width * height * 4);
const data = await tensor.data();
let i = 0;
for (let y = 0; y < height; y++) {
for (let x = 0; x < width; x++) {
const pos = (y * width + x) * 4;
arr[pos] = data[i]; // R
arr[pos + 1] = data[i + 1]; // G
arr[pos + 2] = data[i + 2]; // B
arr[pos + 3] = 255; // Alpha
i += 3;
}
}
return arr;
};
export const tensorAsClampedArray = (tensor: tf.Tensor3D) => tf.tidy(() => {
const [height, width,] = tensor.shape;
const fill = tf.fill([height, width,], 255).expandDims(2);
return tensor.clipByValue(0, 255).concat([fill,], 2).dataSync();
});