Skip to content

Commit

Permalink
[JS/WebGPU] Added Uniforms to SkipLayerNorm. (#18788)
Browse files Browse the repository at this point in the history
### Description
Added Uniforms to SkipLayerNorm



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Improve performance

---------

Co-authored-by: Yulong Wang <[email protected]>
  • Loading branch information
satyajandhyala and fs-eire authored Jan 24, 2024
1 parent a39ac4a commit a33b5bd
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 58 deletions.
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import * as pool from './ops/pool';
import {range} from './ops/range';
import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
import {parseResizeAttributes, resize} from './ops/resize';
import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm';
import {skipLayerNorm} from './ops/skip-layer-norm';
import {parseSliceAttributes, slice} from './ops/slice';
import {parseSoftmaxAttributes, softmax} from './ops/softmax';
import {parseSplitAttributes, split} from './ops/split';
Expand Down Expand Up @@ -116,7 +116,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Sin', [unaryOps.sin]],
['Sinh', [unaryOps.sinh]],
['Slice', [slice, parseSliceAttributes]],
['SkipLayerNormalization', [skipLayerNorm, parseSkipLayerNormAttributes]],
['SkipLayerNormalization', [skipLayerNorm]],
['Split', [split, parseSplitAttributes]],
['Sqrt', [unaryOps.sqrt]],
['Softmax', [softmax, parseSoftmaxAttributes]],
Expand Down
123 changes: 67 additions & 56 deletions js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
import {AttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common';
import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common';

export interface SkipLayerNormAttributes extends AttributeWithCacheKey {
epsilon: number;
Expand Down Expand Up @@ -86,60 +86,74 @@ const createSkipLayerNormProgramInfo =
const hasInputSkipBiasSumOutput = outputCount > 3;

const components = getMaxComponents(hiddenSize);
const variables = [
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
inputVariable('skip', inputs[1].dataType, inputs[1].dims, components),
inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components),
];
if (hasBetaInput) {
variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components));
}
if (hasBiasInput) {
variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components));
}
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
if (hasMeanOutput) {
variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim));
}
if (hasInvStdDevOutput) {
variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim));
}
if (hasInputSkipBiasSumOutput) {
variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components));
}
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const getShaderSource = (shaderHelper: ShaderHelper) => `
const hiddenSize: f32 = ${hiddenSize};
const hiddenSizeVectorized: u32 = ${hiddenSize / components};
const epsilon: f32 = ${attributes.epsilon};

${shaderHelper.declareVariables(...variables)}
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize},
{type: 'uint32', data: components},
{type: 'uint32', data: hiddenSize},
{type: 'float32', data: attributes.epsilon},
];
const getShaderSource = (shaderHelper: ShaderHelper) => {
const uniformsArray: UniformsArrayType = [
{name: 'output_size', type: 'u32'},
{name: 'components', type: 'u32'},
{name: 'hidden_size', type: 'u32'},
{name: 'epsilon', type: 'f32'},
];
const variables = [
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
inputVariable('skip', inputs[1].dataType, inputs[1].dims, components),
inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components),
];
if (hasBetaInput) {
variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components));
}
if (hasBiasInput) {
variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components));
}
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
if (hasMeanOutput) {
variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim));
}
if (hasInvStdDevOutput) {
variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim));
}
if (hasInputSkipBiasSumOutput) {
variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components));
}
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
return `
${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)}
let offset = global_idx * hiddenSizeVectorized;
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size / uniforms.hidden_size')}
let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;
let offset = global_idx * hidden_size_vectorized;
var sum = ${fillVector('f32', components)};
var squareSum = ${fillVector('f32', components)};
for (var i: u32 = 0; i < hiddenSizeVectorized; i++) {
let skipValue = skip[offset + i];
let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'};
let inputValue = x[offset + i];
let value = inputValue + skipValue + biasValue;
${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''}
for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
let skip_value = skip[offset + i];
let bias_value = ${hasBiasInput ? 'bias[i]' : '0.0'};
let input_value = x[offset + i];
let value = input_value + skip_value + bias_value;
${hasInputSkipBiasSumOutput ? 'input_skip_bias_sum[offset + i] = value;' : ''}
output[offset + i] = value;
let f32Value = ${castToF32(dataType, components, 'value')};
sum += f32Value;
squareSum += f32Value * f32Value;
let f32_value = ${castToF32(dataType, components, 'value')};
sum += f32_value;
squareSum += f32_value * f32_value;
}
let mean = ${sumVector('sum', components)} / hiddenSize;
let invStdDev = inverseSqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon);
${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''}
${hasInvStdDevOutput ? 'invStdOutput[global_idx] = invStdDev;' : ''}
for (var i: u32 = 0; i < hiddenSizeVectorized; i++) {
output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(invStdDev) * gamma[i]
+ ${hasBetaInput ? 'beta[i]' : '0.0'};
let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size);
let inv_std_dev = inverseSqrt(${
sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon);
${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''}
${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''}
for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(inv_std_dev) * gamma[i] + ${
hasBetaInput ? 'beta[i]' : '0.0'};
}
}`;
};
const outputs = [{dims: outputShape, dataType: inputs[0].dataType}];
if (outputCount > 1) {
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float});
Expand All @@ -150,12 +164,14 @@ const createSkipLayerNormProgramInfo =
if (outputCount > 3) {
outputs.push({dims: inputShape, dataType: inputs[0].dataType});
}

return {
name: 'SkipLayerNormalization',
shaderCache: {hint: attributes.cacheKey},
shaderCache: {
hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`,
inputDependencies: inputs.map((_input, _index) => 'type')
},
getShaderSource,
getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}),
getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}, programUniforms}),
};
};

Expand All @@ -178,8 +194,3 @@ export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNorm
context.compute(
createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), {outputs});
};

export const parseSkipLayerNormAttributes = (attributes: Record<string, unknown>): SkipLayerNormAttributes => {
const epsilon = attributes.epsilon as number;
return createAttributeWithCacheKey({epsilon});
};

0 comments on commit a33b5bd

Please sign in to comment.