Skip to content

Commit

Permalink
webgl: Add shapes uniforms to reduce shader compilation time
Browse files Browse the repository at this point in the history
PERF
Fix tensorflow#5205

This PR adds the shapes uniforms support and enables it for unary/binary
ops.
  • Loading branch information
qjia7 committed Jun 22, 2021
1 parent 2d16dc9 commit 8ac5bc2
Show file tree
Hide file tree
Showing 8 changed files with 781 additions and 83 deletions.
4 changes: 3 additions & 1 deletion tfjs-backend-webgl/src/binaryop_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import {backend_util} from '@tensorflow/tfjs-core';

import {GPGPUProgram} from './gpgpu_math';
import {GPGPUProgram, useShapeUniforms} from './gpgpu_math';

export const CHECK_NAN_SNIPPET = `
if (isnan(a)) return a;
Expand All @@ -29,9 +29,11 @@ export class BinaryOpProgram implements GPGPUProgram {
variableNames = ['A', 'B'];
outputShape: number[];
userCode: string;
enableShapeUniforms: boolean;

constructor(op: string, aShape: number[], bShape: number[]) {
this.outputShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
this.userCode = `
float binaryOperation(float a, float b) {
${op}
Expand Down
28 changes: 25 additions & 3 deletions tfjs-backend-webgl/src/binaryop_packed_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import {backend_util, util} from '@tensorflow/tfjs-core';

import {GPGPUProgram} from './gpgpu_math';
import {GPGPUProgram, useShapeUniforms} from './gpgpu_math';
import {getChannels} from './packing_util';
import {getCoordsDataType} from './shader_compiler';

Expand All @@ -44,12 +44,14 @@ export class BinaryOpPackedProgram implements GPGPUProgram {
supportsBroadcasting = true;
packedInputs = true;
packedOutput = true;
enableShapeUniforms: boolean;

constructor(
op: string, aShape: number[], bShape: number[],
checkOutOfBounds = false) {
this.outputShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
const rank = this.outputShape.length;
this.enableShapeUniforms = useShapeUniforms(rank);
let checkOutOfBoundsString = '';
if (checkOutOfBounds) {
if (rank === 0 || util.sizeFromShape(this.outputShape) === 1) {
Expand All @@ -64,14 +66,33 @@ export class BinaryOpPackedProgram implements GPGPUProgram {
${dtype} coords = getOutputCoords();
`;
if (rank === 1) {
checkOutOfBoundsString += `
if (this.enableShapeUniforms) {
checkOutOfBoundsString += `
result.y = (coords + 1) >= outShape ? 0. : result.y;
result.z = 0.;
result.w = 0.;
`;
} else {
checkOutOfBoundsString += `
result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;
result.z = 0.;
result.w = 0.;
`;
}
} else {
const channels = getChannels('coords', rank);
checkOutOfBoundsString += `
if (this.enableShapeUniforms) {
checkOutOfBoundsString += `
bool nextRowOutOfBounds =
(${channels[rank - 2]} + 1) >= outShape[${rank} - 2];
bool nextColOutOfBounds =
(${channels[rank - 1]} + 1) >= outShape[${rank} - 1];
result.y = nextColOutOfBounds ? 0. : result.y;
result.z = nextRowOutOfBounds ? 0. : result.z;
result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
`;
} else {
checkOutOfBoundsString += `
bool nextRowOutOfBounds =
(${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]};
bool nextColOutOfBounds =
Expand All @@ -80,6 +101,7 @@ export class BinaryOpPackedProgram implements GPGPUProgram {
result.z = nextRowOutOfBounds ? 0. : result.z;
result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
`;
}
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions tfjs-backend-webgl/src/flags_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,6 @@ ENV.registerFlag(
* Default value is 128.
*/
ENV.registerFlag('CPU_HANDOFF_SIZE_THRESHOLD', () => 128);

/** Whether we will use shapes uniforms. */
ENV.registerFlag('WEBGL_USE_SHAPES_UNIFORMS', () => true);
155 changes: 145 additions & 10 deletions tfjs-backend-webgl/src/gpgpu_math.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {env, Tensor, TypedArray, util} from '@tensorflow/tfjs-core';
import {backend_util, env, Tensor, TypedArray, util} from '@tensorflow/tfjs-core';

import {GPGPUContext} from './gpgpu_context';
import * as shader_compiler from './shader_compiler';
Expand All @@ -26,6 +26,7 @@ export interface GPGPUProgram {
variableNames: string[];
outputShape: number[];
userCode: string;
enableShapeUniforms?: boolean;
/** If true, this program expects packed input textures. Defaults to false. */
packedInputs?: boolean;
/** If true, this program produces a packed texture. Defaults to false. */
Expand All @@ -51,6 +52,11 @@ export interface GPGPUBinary {
outShapeInfo: ShapeInfo;
infLoc: WebGLUniformLocation;
nanLoc: WebGLUniformLocation;
inShapesLocations?: {[name: string]: WebGLUniformLocation};
inTexShapesLocations?: {[name: string]: WebGLUniformLocation};
outShapeLocation?: WebGLUniformLocation;
outShapeStridesLocation?: WebGLUniformLocation;
outTexShapeLocation?: WebGLUniformLocation;
}

export interface TensorData {
Expand All @@ -64,7 +70,6 @@ export interface TensorData {
export function compileProgram<T extends Tensor, K extends Tensor>(
gpgpu: GPGPUContext, program: GPGPUProgram, inputs: TensorData[],
output: TensorData): GPGPUBinary {
const userCode = program.userCode;
const inputInfos: InputInfo[] = inputs.map((input, i) => {
const shapeInfo: ShapeInfo = {
logicalShape: input.shape,
Expand All @@ -87,8 +92,7 @@ export function compileProgram<T extends Tensor, K extends Tensor>(
isPacked: output.texData.isPacked,
flatOffset: null
};
const source = shader_compiler.makeShader(
inputInfos, outShapeInfo, userCode, program.packedInputs);
const source = shader_compiler.makeShader(inputInfos, outShapeInfo, program);

const webGLProgram = gpgpu.createProgram(source);

Expand All @@ -100,14 +104,34 @@ export function compileProgram<T extends Tensor, K extends Tensor>(
}

// Add user-defined uniforms
const shouldThrow = false;
const uniformLocations: {[name: string]: WebGLUniformLocation} = {};
const inShapesLocations: {[name: string]: WebGLUniformLocation} = {};
const inTexShapesLocations: {[name: string]: WebGLUniformLocation} = {};
for (let i = 0; i < program.variableNames.length; i++) {
const varName = program.variableNames[i];
const shouldThrow = false;
uniformLocations[varName] =
gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow);
uniformLocations[`offset${varName}`] =
gpgpu.getUniformLocation(webGLProgram, `offset${varName}`, shouldThrow);
if (program.enableShapeUniforms) {
inShapesLocations[`${varName}Shape`] = gpgpu.getUniformLocation(
webGLProgram, `${varName}Shape`, shouldThrow);
inTexShapesLocations[`${varName}TexShape`] = gpgpu.getUniformLocation(
webGLProgram, `${varName}TexShape`, shouldThrow);
}
}

let outShapeLocation: WebGLUniformLocation;
let outTexShapeLocation: WebGLUniformLocation;
let outShapeStridesLocation: WebGLUniformLocation;
if (program.enableShapeUniforms) {
outShapeLocation =
gpgpu.getUniformLocation(webGLProgram, 'outShape', shouldThrow);
outShapeStridesLocation =
gpgpu.getUniformLocation(webGLProgram, 'outShapeStrides', shouldThrow);
outTexShapeLocation =
gpgpu.getUniformLocation(webGLProgram, 'outTexShape', shouldThrow);
}

return {
Expand All @@ -119,6 +143,11 @@ export function compileProgram<T extends Tensor, K extends Tensor>(
outShapeInfo,
infLoc,
nanLoc,
inShapesLocations,
inTexShapesLocations,
outShapeLocation,
outShapeStridesLocation,
outTexShapeLocation
};
}

Expand Down Expand Up @@ -160,8 +189,10 @@ export function runProgram<T extends Tensor, K extends Tensor>(
output: TensorData,
customSetup?: (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) =>
void): void {
validateBinaryAndProgram(binary.inShapeInfos, inputs);
validateBinaryAndProgram([binary.outShapeInfo], [output]);
if (!binary.program.enableShapeUniforms) {
validateBinaryAndProgram(binary.inShapeInfos, inputs);
validateBinaryAndProgram([binary.outShapeInfo], [output]);
}

const outTex = output.texData.texture;
const outTexShape = output.texData.texShape;
Expand All @@ -187,6 +218,33 @@ export function runProgram<T extends Tensor, K extends Tensor>(
const varName = binary.program.variableNames[i];
const varLoc = binary.uniformLocations[varName];
const varOffsetLoc = binary.uniformLocations[`offset${varName}`];
const varShapeLoc = binary.inShapesLocations[`${varName}Shape`];
const varTexShapeLoc = binary.inTexShapesLocations[`${varName}TexShape`];

if (varShapeLoc) {
const {uniformShape} = shader_compiler.getUniformInfoFromShape(
binary.program.packedInputs, input.shape, input.texData.texShape);
switch (uniformShape.length) {
case 1:
gpgpu.gl.uniform1iv(varShapeLoc, new Int32Array(uniformShape));
break;
case 2:
gpgpu.gl.uniform2iv(varShapeLoc, new Int32Array(uniformShape));
break;
case 3:
gpgpu.gl.uniform3iv(varShapeLoc, new Int32Array(uniformShape));
break;
case 4:
gpgpu.gl.uniform4iv(varShapeLoc, new Int32Array(uniformShape));
break;
default:
break;
}
}
if (varTexShapeLoc) {
gpgpu.gl.uniform2i(
varTexShapeLoc, input.texData.texShape[0], input.texData.texShape[1]);
}

if (varLoc == null) {
// The compiler inferred that this variable is not used in this shader.
Expand Down Expand Up @@ -215,6 +273,50 @@ export function runProgram<T extends Tensor, K extends Tensor>(
gpgpu.setInputMatrixTexture(input.texData.texture, varLoc, i);
});

const outShapeLoc = binary.outShapeLocation;
if (outShapeLoc) {
switch (output.shape.length) {
case 1:
gpgpu.gl.uniform1iv(outShapeLoc, new Int32Array(output.shape));
break;
case 2:
gpgpu.gl.uniform2iv(outShapeLoc, new Int32Array(output.shape));
break;
case 3:
gpgpu.gl.uniform3iv(outShapeLoc, new Int32Array(output.shape));
break;
case 4:
gpgpu.gl.uniform4iv(outShapeLoc, new Int32Array(output.shape));
break;
default:
break;
}
}
if (binary.outShapeStridesLocation) {
const strides = util.computeStrides(output.shape);
switch (output.shape.length) {
case 2:
gpgpu.gl.uniform1iv(
binary.outShapeStridesLocation, new Int32Array(strides));
break;
case 3:
gpgpu.gl.uniform2iv(
binary.outShapeStridesLocation, new Int32Array(strides));
break;
case 4:
gpgpu.gl.uniform3iv(
binary.outShapeStridesLocation, new Int32Array(strides));
break;
default:
break;
}
}
if (binary.outTexShapeLocation) {
gpgpu.gl.uniform2i(
binary.outTexShapeLocation, output.texData.texShape[0],
output.texData.texShape[1]);
}

if (customSetup != null) {
customSetup(gpgpu, binary.webGLProgram);
}
Expand All @@ -227,12 +329,45 @@ export function makeShaderKey(
inputs.concat(output).forEach(x => {
const hasOffset = x.texData != null && x.texData.slice != null &&
x.texData.slice.flatOffset > 0;
const texShape = x.isUniform ? 'uniform' : x.texData.texShape;
keyInputs += `${x.shape}_${texShape}_${hasOffset}`;
if (program.enableShapeUniforms && !x.isUniform) {
const {useSqueezeShape, uniformShape} =
shader_compiler.getUniformInfoFromShape(
program.packedInputs, x.shape, x.texData.texShape);
let rank1 = '', rank2 = '', rank34 = '';
if (uniformShape.length === 1 && program.packedInputs) {
const packedTexShape = [
Math.ceil(x.texData.texShape[0] / 2),
Math.ceil(x.texData.texShape[1] / 2)
];
rank1 = `${packedTexShape[0] > 1}_${packedTexShape[1] > 1}`;
} else if (uniformShape.length === 2) {
rank2 = `${uniformShape[0] > 1}_${uniformShape[1] > 1}`;
} else if (uniformShape.length > 2 && !program.packedInputs) {
const strides = util.computeStrides(uniformShape);
rank34 = `${strides[0] === x.texData.texShape[1]}_${
strides[strides.length - 1] === x.texData.texShape[1]}`;
}
const isScalar = util.sizeFromShape(x.shape) === 1;
const broadcastDims =
backend_util.getBroadcastDims(x.shape, output.shape);
keyInputs += `${x.shape.length}_${useSqueezeShape}_${
uniformShape.length}_${isScalar}_${broadcastDims}_${
util.arraysEqual(
x.shape, x.texData.texShape)}_${rank1}_${rank2}_${rank34}_${
x.texData.texShape[0] > 1}_${x.texData.texShape[1] > 1}_${hasOffset}`;
} else {
const texShape = x.isUniform ? 'uniform' : x.texData.texShape;
keyInputs += `${x.shape}_${texShape}_${hasOffset}`;
}
});
const keyUserCode = program.userCode;
let key = program.constructor.name;
// Fast string concat. See https://jsperf.com/string-concatenation/14.
key += '_' + keyInputs + '_' + keyUserCode;
key +=
'_' + keyInputs + '_' + keyUserCode + `env().getNumber('WEBGL_VERSION')`;
return key;
}

export function useShapeUniforms(rank: number) {
return env().getBool('WEBGL_USE_SHAPES_UNIFORMS') && rank <= 4;
}
Loading

0 comments on commit 8ac5bc2

Please sign in to comment.