Skip to content

Commit

Permalink
[js/webgpu] Support uniforms for conv, conv transpose, conv grouped
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Dec 7, 2023
1 parent 9479ba5 commit 7e77cc2
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 184 deletions.
10 changes: 5 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import {biasAdd} from './ops/bias-add';
import {biasSplitGelu} from './ops/bias-split-gelu';
import * as binaryOps from './ops/binary-op';
import {concat, parseConcatAttributes} from './ops/concat';
import {conv, parseConvAttributes} from './ops/conv';
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
import {conv} from './ops/conv';
import {convTranspose} from './ops/conv-transpose';
import {cumsum, parseCumSumAttributes} from './ops/cumsum';
import {einsum, parseEinsumAttributes} from './ops/einsum';
import {expand} from './ops/expand';
Expand Down Expand Up @@ -60,8 +60,8 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Ceil', [unaryOps.ceil]],
['Clip', [unaryOps.clip]],
['Concat', [concat, parseConcatAttributes]],
['Conv', [conv, parseConvAttributes]],
['ConvTranspose', [convTranspose, parseConvTransposeAttributes]],
['Conv', [conv]],
['ConvTranspose', [convTranspose]],
['Cos', [unaryOps.cos]],
['Cosh', [unaryOps.cosh]],
['CumSum', [cumsum, parseCumSumAttributes]],
Expand All @@ -73,7 +73,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Exp', [unaryOps.exp]],
['Expand', [expand]],
['Floor', [unaryOps.floor]],
['FusedConv', [conv, parseConvAttributes]],
['FusedConv', [conv]],
['Gather', [gather, parseGatherAttributes]],
['GatherElements', [gatherElements, parseGatherElementsAttributes]],
['Gelu', [unaryOps.gelu]],
Expand Down
42 changes: 22 additions & 20 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {ConvAttributes} from '../conv';
import {getActivationSnippet} from '../fuse-utils';

Expand Down Expand Up @@ -88,10 +88,10 @@ const conv2dCommonSnippet =
let outRow = ${row} / outWidth;
let outCol = ${row} % outWidth;
let WRow = ${col} / (filterDims[1] * inChannels);
let WCol = ${col} / inChannels % filterDims[1];
let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0];
let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1];
let WRow = ${col} / (i32(uniforms.w_shape[1]) * inChannels);
let WCol = ${col} / inChannels % i32(uniforms.w_shape[1]);
let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0];
let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1];
let xCh = ${col} % inChannels;
var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0);
// The bounds checking is always needed since we use it to pad zero for
Expand Down Expand Up @@ -195,15 +195,18 @@ export const createConv2DMatMulProgramInfo =

// TODO: support component 2, 3.
const components = isVec4 ? 4 : 1;
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
const programUniforms: ProgramUniform[] = [
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
{type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides},
{type: 'int32', data: attributes.dilations}
];
const x =
inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize);
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
const inputVariables = [x, w];

programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];

let declareFunctions = `
fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
Expand All @@ -218,6 +221,7 @@ export const createConv2DMatMulProgramInfo =
inputVariables.push(bias);

programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
inputDependencies.push('rank');

declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? `vec4<${t}>` : t} {
Expand All @@ -226,9 +230,15 @@ export const createConv2DMatMulProgramInfo =
}
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
programUniforms.push(...createTensorShapeVariables(outputShape));

const uniforms: UniformsArrayType = [
{name: 'dimAOuter', type: 'i32'}, {name: 'dimBOuter', type: 'i32'}, {name: 'dimInner', type: 'i32'},
{name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2},
{name: 'dilation', type: 'i32', length: 2}
];
return {
name: 'Conv2DMatMul',
shaderCache: {hint: attributes.cacheKey},
shaderCache: {hint: `${attributes.format}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
Expand All @@ -239,15 +249,7 @@ export const createConv2DMatMulProgramInfo =
//struct Uniforms { xShape : vec4<i32>, wShape : vec4<i32>, outShape : vec4<i32>,
// outShapeStrides: vec3<i32>, filterDims : vec2<i32>, pad : vec2<i32>, stride : vec2<i32>,
// dilation : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32 };
${
shaderHelper.registerUniform('dimAOuter', 'i32')
.registerUniform('dimBOuter', 'i32')
.registerUniform('dimInner', 'i32')
.declareVariables(...inputVariables, output)}
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]});
const pad : vec2<i32> = vec2<i32>(${attributes.pads[0]}, ${attributes.pads[1]});
const stride : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${declareFunctions}
${
conv2dCommonSnippet(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common';
import {ConvTransposeAttributes} from '../conv-transpose';
import {getActivationSnippet} from '../fuse-utils';

Expand Down Expand Up @@ -74,21 +74,21 @@ const conv2dTransposeCommonSnippet =
col % outWidth);
`;

const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]';
const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]';
const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])';
const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])';
const row = isChannelsLast ? 'row' : 'col';
const col = isChannelsLast ? 'col' : 'row';

const readASnippet = `
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'};
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
let outRow = ${row} / outWidth;
let outCol = ${row} % outWidth;
let WRow = ${col} / (filterDims[1] * inChannels);
let WCol = ${col} / inChannels % filterDims[1];
let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]);
let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]);
let WRow = ${col} / (uniforms.filterDims[1] * inChannels);
let WCol = ${col} / inChannels % uniforms.filterDims[1];
let xR = f32(outRow - uniforms.pads[0] + uniforms.dilations[0] * WRow) / f32(uniforms.strides[0]);
let xC = f32(outCol - uniforms.pads[1] + uniforms.dilations[1] * WCol) / f32(uniforms.strides[1]);
if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) {
return ${type}(0.0);
}
Expand Down Expand Up @@ -116,9 +116,9 @@ const conv2dTransposeCommonSnippet =

const sampleW = `
let col = colIn * ${innerElementSize};
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels);
let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1];
let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'};
let coordX = uniforms.filterDims[0] - 1 - row / (uniforms.filterDims[1] * inChannels);
let coordY = uniforms.filterDims[1] - 1 - (row / inChannels) % uniforms.filterDims[1];
if (${
isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' :
'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) {
Expand Down Expand Up @@ -186,20 +186,33 @@ export const createConv2DTransposeMatMulProgramInfo =
const innerElementSize = isVec4 ? 4 : 1;
const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);
const components = isVec4 ? 4 : 1;
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
const filterDims0 = attributes.kernelShape[isChannelsLast ? 1 : 2];
const filterDims1 = attributes.kernelShape[isChannelsLast ? 2 : 3];
const effectiveFilterDims0 =
filterDims0 + (attributes.dilations[0] <= 1 ? 0 : (filterDims0 - 1) * (attributes.dilations[0] - 1));
const effectiveFilterDims1 =
filterDims1 + (attributes.dilations[1] <= 1 ? 0 : (filterDims1 - 1) * (attributes.dilations[1] - 1));
const pads0 = effectiveFilterDims0 - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2);
const pads1 = effectiveFilterDims1 - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2);
const programUniforms: ProgramUniform[] = [
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
{type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations},
{type: 'int32', data: [filterDims0, filterDims1]}, {type: 'int32', data: [pads0, pads1]}
];
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components);
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1);
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
const inputVariables = [x, w];
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
programUniforms.push(
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));

const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
let declareFunctions = '';
if (hasBias) {
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
inputVariables.push(bias);
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
inputDependencies.push('rank');

declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
Expand All @@ -209,42 +222,23 @@ export const createConv2DTransposeMatMulProgramInfo =

programUniforms.push(...createTensorShapeVariables(outputShape));

const uniforms: UniformsArrayType = [
{name: 'dimAOuter', type: 'i32'}, {name: 'dimBOuter', type: 'i32'}, {name: 'dimInner', type: 'i32'},
{name: 'strides', type: 'i32', length: 2}, {name: 'dilations', type: 'i32', length: 2},
{name: 'filterDims', type: 'i32', length: 2}, {name: 'pads', type: 'i32', length: 2}
];

return {
name: 'Conv2DTransposeMatMul',
shaderCache: {hint: attributes.cacheKey},
shaderCache: {hint: `${attributes.format}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
programUniforms
}),
getShaderSource: (shaderHelper: ShaderHelper) => `
${utilFunctions('uniforms.result_strides')}
${
shaderHelper.registerUniform('dimAOuter', 'i32')
.registerUniform('dimBOuter', 'i32')
.registerUniform('dimInner', 'i32')
.declareVariables(...inputVariables, output)};
const outBackprop : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${
attributes.kernelShape[isChannelsLast ? 2 : 3]});
const effectiveFilterDims : vec2<i32> = filterDims + vec2<i32>(
${
attributes.dilations[0] <= 1 ?
0 :
(attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)},
${
attributes.dilations[1] <= 1 ?
0 :
(attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)});
const pads : vec2<i32> = vec2<i32>(i32(effectiveFilterDims[0]) - 1 - (${
attributes.pads[0] + attributes.pads[2]})/2,
i32(effectiveFilterDims[1]) - 1 - (${
attributes.pads[1] + attributes.pads[3]})/2);
const strides : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
const dimAOuter : i32 = ${dimAOuter};
const dimBOuter : i32 = ${dimBOuter};
const dimInner : i32 = ${dimInner};
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)};
${declareFunctions}
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)}
${
Expand Down
Loading

0 comments on commit 7e77cc2

Please sign in to comment.