From 110605ee7eaac2917b5467adccee8732696afe80 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Fri, 12 Nov 2021 13:50:47 +0800 Subject: [PATCH] webgl: Support uniforms for pack/unpack programs BUG: 5205 --- tfjs-backend-webgl/src/pack_gpu.ts | 129 ++++++++++++++------------- tfjs-backend-webgl/src/unpack_gpu.ts | 4 +- 2 files changed, 70 insertions(+), 63 deletions(-) diff --git a/tfjs-backend-webgl/src/pack_gpu.ts b/tfjs-backend-webgl/src/pack_gpu.ts index cedc6ffb442..695d3ad664a 100644 --- a/tfjs-backend-webgl/src/pack_gpu.ts +++ b/tfjs-backend-webgl/src/pack_gpu.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {GPGPUProgram} from './gpgpu_math'; +import {GPGPUProgram, useShapeUniforms} from './gpgpu_math'; import {getChannels} from './packing_util'; import {getCoordsDataType} from './shader_compiler'; @@ -25,29 +25,30 @@ export class PackProgram implements GPGPUProgram { userCode: string; packedInputs = false; packedOutput = true; + enableShapeUniforms: boolean; + rank: number; constructor( outputShape: number[]) { // TODO(https://github.com/tensorflow/tfjs/issues/893): // Only input / output 3D tensors. this.outputShape = outputShape; - const rank = outputShape.length; + this.rank = outputShape.length; + this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); - if (rank === 0) { + if (this.rank === 0) { this.userCode = ` void main() { setOutput(vec4(getA(), 0., 0., 0.)); } `; } else { - const channels = getChannels('rc', rank); - const dtype = getCoordsDataType(rank); + const channels = getChannels('rc', this.rank); + const dtype = getCoordsDataType(this.rank); const outOfBoundsCondition = - getOutOfBoundsCondition(rank, outputShape, channels); - const setup = getSetup( - rank, outputShape[outputShape.length - 1], - outputShape[outputShape.length - 2], channels); - const output = getOutput(outputShape, channels); + this.getOutOfBoundsCondition(channels); + const setup = this.getSetup(channels); + const output = this.getOutput(channels); this.userCode = ` void main() { @@ -64,72 +65,76 @@ export class PackProgram implements GPGPUProgram { `; } } -} -function getSourceCoordsArr(rank: number, dims: string[]): string[] { - const coords = []; + private getSourceCoordsArr(dims: string[]): string[] { + const coords = []; - for (let row = 0; row <= 1; row++) { - for (let col = 0; col <= 1; col++) { - let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`; + for (let row = 0; row <= 1; row++) { + for (let col = 0; col <= 1; col++) { + let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`; - for (let d = 2; d < rank; d++) { - coord = `${dims[dims.length - 1 - d]},` + coord; - } + for (let d = 2; d < this.rank; d++) { + coord = `${dims[dims.length - 1 - d]},` + coord; + } - coords.push(coord); + coords.push(coord); + } } - } - return coords; -} - -function getOutOfBoundsCondition( - rank: number, shape: number[], dims: string[]): string { - if (rank === 1) { - return `rc > ${shape[0]}`; + return coords; } - let cond = ''; - for (let i = rank - 2; i < rank; i++) { - cond += `${dims[i]} >= ${shape[i]}`; - if (i < rank - 1) { - cond += '||'; + private getOutOfBoundsCondition(dims: string[]): string { + if (this.rank === 1) { + return `rc > ${this.enableShapeUniforms? 'outShape' : + this.outputShape[0]}`; } - } - return cond; -} + let cond = ''; + for (let i = this.rank - 2; i < this.rank; i++) { + cond += `${dims[i]} >= ${this.enableShapeUniforms? `outShape[${i}]` : + this.outputShape[i]}`; + if (i < this.rank - 1) { + cond += '||'; + } + } -function getSetup( - rank: number, cols: number, rows: number, dims: string[]): string { - if (rank === 1) { - return ''; + return cond; } - const innerDims = dims.slice(-2); + private getSetup(dims: string[]): string { + if (this.rank === 1) { + return ''; + } - return ` - int r = ${innerDims[0]}; - int c = ${innerDims[1]}; - int rp1 = r + 1; - int cp1 = c + 1; + const innerDims = dims.slice(-2); + const col = this.enableShapeUniforms? `outShape[${this.rank} - 1]` : + this.outputShape[this.rank - 1]; + const row = this.enableShapeUniforms? `outShape[${this.rank} - 2]` : + this.outputShape[this.rank - 2]; + + return ` + int r = ${innerDims[0]}; + int c = ${innerDims[1]}; + int rp1 = r + 1; + int cp1 = c + 1; + + bool cEdge = cp1 >= ${col}; + bool rEdge = rp1 >= ${row}; + `; + } - bool cEdge = cp1 >= ${cols}; - bool rEdge = rp1 >= ${rows}; - `; -} + private getOutput(dims: string[]): string { + const sourceCoords = this.getSourceCoordsArr(dims); + if (this.rank === 1) { + return `getA(rc), + rc + 1 >= ${this.enableShapeUniforms? 'outShape' : + this.outputShape[0]} ? 0. : getA(rc + 1), + 0, 0`; + } -function getOutput(shape: number[], dims: string[]): string { - const rank = shape.length; - const sourceCoords = getSourceCoordsArr(rank, dims); - if (rank === 1) { - return `getA(rc), - rc + 1 >= ${shape[0]} ? 0. : getA(rc + 1), - 0, 0`; + return `getA(${sourceCoords[0]}), + cEdge ? 0. : getA(${sourceCoords[1]}), + rEdge ? 0. : getA(${sourceCoords[2]}), + rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`; } - - return `getA(${sourceCoords[0]}), - cEdge ? 0. : getA(${sourceCoords[1]}), - rEdge ? 0. : getA(${sourceCoords[2]}), - rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`; } diff --git a/tfjs-backend-webgl/src/unpack_gpu.ts b/tfjs-backend-webgl/src/unpack_gpu.ts index b3dc4adf0e3..79bb3447578 100644 --- a/tfjs-backend-webgl/src/unpack_gpu.ts +++ b/tfjs-backend-webgl/src/unpack_gpu.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {GPGPUProgram} from './gpgpu_math'; +import {GPGPUProgram, useShapeUniforms} from './gpgpu_math'; import {getChannels, getSourceCoords} from './packing_util'; import {getCoordsDataType} from './shader_compiler'; @@ -25,9 +25,11 @@ export class UnpackProgram implements GPGPUProgram { packedOutput = false; outputShape: number[]; userCode: string; + enableShapeUniforms: boolean; constructor(outputShape: number[]) { this.outputShape = outputShape; + this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const rank = outputShape.length; const channels = getChannels('rc', rank);