Skip to content

Commit

Permalink
Add more tests to dropout and pass flag to computeTransformer to disa…
Browse files Browse the repository at this point in the history
…ble dropout during evaluation.
  • Loading branch information
aliciafmachado committed Sep 8, 2024
1 parent 70d0040 commit cf25062
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 20 deletions.
44 changes: 41 additions & 3 deletions animated-transformer/src/lib/transformer/dropout.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,53 @@ describe('dropout', () => {

it('Basic dropout', () => {
const beforeDropout =
gtensor.makeRange('input', 0, 4, 1, 'float32');
const afterDropout = dropout(0.5, beforeDropout, 0);
gtensor.makeRange('input', 1, 5, 1, 'float32');
const afterDropout = dropout(0.5, beforeDropout, false, 0);
afterDropout.tensor.print();

tf.test_util.expectArraysClose(afterDropout.tensor.dataSync(),
[0, 2, 0, 6]);
[2, 4, 0, 8]);
expect(afterDropout.dimNames).toEqual(
['input']);
});

it('Deterministic output', () => {
const beforeDropout =
gtensor.makeRange('input', 1, 5, 1, 'float32');
const afterDropout = dropout(0.5, beforeDropout, true, 0);
afterDropout.tensor.print();

tf.test_util.expectArraysClose(afterDropout.tensor.dataSync(),
[1, 2, 3, 4]);
expect(afterDropout.dimNames).toEqual(
['input']);
});

it('Dropout with noise shape ', () => {
const beforeDropout = new gtensor.GTensor(
tf.tensor([
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
]
]),
['batch', 'pos', 'inputRep']
);
const afterDropout = dropout(0.5, beforeDropout, false, 1, ['pos']);
afterDropout.tensor.print();

tf.test_util.expectArraysClose(afterDropout.tensor.dataSync(),
[
[
[2, 4, 6, 8],
[0, 0, 0, 0],
[0, 0, 0, 0],
]
]);
expect(afterDropout.dimNames).toEqual(
['batch', 'pos', 'inputRep']);
});

// TODO(@aliciafmachado): Test that grads are not applied on de-activated neurons.
});
21 changes: 19 additions & 2 deletions animated-transformer/src/lib/transformer/dropout.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,26 @@ import {
export function dropout<G extends string, D extends G>(
dropoutRate: number,
g: GTensor<G>,
deterministic: boolean,
seed?: number,
dim?: number[],
dimNames?: string[],
): GTensor<G> {
return new GTensor(tf_dropout(g.tensor, dropoutRate, dim, seed), g.dimNames);
if (deterministic) {
return g;
}

let dimensions: number[] = g.tensor.shape;
if (dimNames) {
dimensions = [];
for (const d of g.dimNames) {
if (dimNames.includes(d)) {
dimensions = dimensions.concat(g.dim[d].size);
}
else {
dimensions = dimensions.concat(1);
}
}
}
return new GTensor(tf_dropout(g.tensor, dropoutRate, dimensions, seed), g.dimNames);
}

59 changes: 44 additions & 15 deletions animated-transformer/src/lib/transformer/transformer_gtensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,19 @@ export function computeAttnHead(
spec: AttnHeadComputeSpec,
params: AttnHeadParams<TensorKind>,
seqInput: GTensor<'batch' | 'pos' | 'inputRep'>,
eval_mode: boolean = false
evalMode: boolean = false
): BatchAttnHeadCompututation {
const { queryM, keyM, valueM, headsToInputRepM, ff } = params;

const queries = seqInput.contract(queryM, ['inputRep']);
const keys = seqInput.contract(keyM, ['inputRep']);
const values = seqInput.contract(valueM, ['inputRep']);
// Dropout on the input of the stack.
let seqInputAfterDropout = seqInput;
if (spec.dropoutRate > 0){
seqInputAfterDropout = dropout(spec.dropoutRate, seqInput, evalMode);
}

const queries = seqInputAfterDropout.contract(queryM, ['inputRep']);
const keys = seqInputAfterDropout.contract(keyM, ['inputRep']);
const values = seqInputAfterDropout.contract(valueM, ['inputRep']);

let rawAttention = keys
.rename('pos', 'keyPos')
Expand All @@ -255,25 +261,33 @@ export function computeAttnHead(
.scalarDiv(makeScalar(Math.sqrt(seqInput.dim.inputRep.size), 'float32'));
}

// Dropout on attention weights.
if (spec.dropoutRate > 0) {
rawAttention = dropout(
spec.dropoutRate,
rawAttention,
evalMode,
);
}

const attention = rawAttention.softmax('queryPos');
const attendedValues = values
.contract(attention.rename('queryPos', 'pos'), ['pos'])
.rename('keyPos', 'pos');

const headsReduction = attendedValues.contract(headsToInputRepM, ['value', 'heads']);

// Dropout after attention weights.
// TODO(@aliciafmachado): Add proper seeding to dropout so that results are reproducible.
let dropoutResult = headsReduction;
// Disable dropout when evaluating.
if (spec.dropoutRate != 0 && !eval_mode) {
dropoutResult = dropout(
// Dropout before layer norm and residual connection.
let headsReductionAfterDropout = headsReduction;
if (spec.dropoutRate > 0) {
headsReductionAfterDropout = dropout(
spec.dropoutRate,
headsReduction,
evalMode,
);
}

let normedHeadReduction = dropoutResult;
let normedHeadReduction = headsReductionAfterDropout;
if (params.layerNormHeadsProjection) {
normedHeadReduction = layerNorm(
params.layerNormHeadsProjection,
Expand All @@ -290,19 +304,34 @@ export function computeAttnHead(
inputToFF = normedHeadReduction.pointwiseAdd(seqInput.rename('inputRep', 'inputRepToFF'));
}

// Skipped dropout in the FF, since the FF nn is a single layer.
let unNormedSeqOuput = inputToFF
.contract(ff.w, ['inputRepToFF'])
.pointwiseAdd(ff.bIn)
.applyPointWiseTfFn(gelu)
.pointwiseAdd(ff.bOut);

// Dropout before layer norm and residual connection.
let unNormedSeqOuputAfterDropout = unNormedSeqOuput;
if (spec.dropoutRate > 0) {
unNormedSeqOuputAfterDropout = dropout(
spec.dropoutRate,
unNormedSeqOuput,
evalMode,
);
}

if (spec.residuals) {
// FF residual
unNormedSeqOuput = unNormedSeqOuput.pointwiseAdd(inputToFF.rename('inputRepToFF', 'inputRep'));
// FF residual.
unNormedSeqOuput = unNormedSeqOuputAfterDropout.pointwiseAdd(inputToFF.rename('inputRepToFF', 'inputRep'));
}
let seqOuput = unNormedSeqOuput;
if (params.layerNormPostFF) {
seqOuput = layerNorm(params.layerNormPostFF, unNormedSeqOuput, 'inputRep');
}

// Skipped dropout in the output, since results are being outputted directly without additional
// computations.

return {
seqInput,
Expand Down Expand Up @@ -360,7 +389,7 @@ export function computeTransformer(
spec: TransformerParamSpec,
params: TransformerParams,
seqInput: GTensor<'batch' | 'pos' | 'inputRep'>,
eval_mode: boolean = false
evalMode: boolean = false
): TransformerComputation {
const compute: TransformerComputation = { layers: [] };
let currentLayerInput = seqInput;
Expand All @@ -369,7 +398,7 @@ export function computeTransformer(
spec.layers[i].computeSpec,
layerParams,
currentLayerInput,
eval_mode,
evalMode,
);
compute.layers.push(layerCompute);
currentLayerInput = layerCompute.seqOuput;
Expand Down

0 comments on commit cf25062

Please sign in to comment.