Skip to content

Commit

Permalink
Improve dropout setup, and fix where the dropout is used in the code …
Browse files Browse the repository at this point in the history
…to follow the T5X implementation.
  • Loading branch information
aliciafmachado committed Oct 2, 2024
1 parent cf25062 commit fa9b3f6
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ const defaultConfig: ModelConfig = {
inputRep: 8,
kqvRep: 8,
layers: [layerSpec],
dropoutRate: 0.0,
},
init: {
stddev: 0.5,
Expand All @@ -118,6 +119,7 @@ const transWithLayerNormed: ModelConfig = {
inputRep: 8,
kqvRep: 8,
layers: [layerSpecWithNorm],
dropoutRate: 0.0,
},
init: {
stddev: 0.5,
Expand All @@ -143,6 +145,7 @@ const transWithLayerNormedAndDropout: ModelConfig = {
inputRep: 8,
kqvRep: 8,
layers: [layerSpecWithNormAndDropout],
dropoutRate: 0.1,
},
init: {
stddev: 0.5,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ describe('basic_transformer_trainer', () => {
inputRep: 4,
kqvRep: 3,
layers: [layerSpec, layerSpec],
dropoutRate: 0,
},
init: {
stddev: 0.5,
Expand Down Expand Up @@ -77,6 +78,8 @@ describe('basic_transformer_trainer', () => {
);
// Taking a couple of steps...
const initLoss = trainState.batchMeanLoss;
console.log('initLoss');

This comment has been minimized.

Copy link
@iislucas

iislucas Oct 8, 2024

Collaborator

You can remove the logs now I think.

console.log(initLoss);
expect(trainState.nSteps).toBe(0);
expect(trainState.nExamples).toBe(0);
const stillTraining = trySgdTrainStep(trainState);
Expand All @@ -90,11 +93,14 @@ describe('basic_transformer_trainer', () => {
jstree.forEach((g: GTensor<any>) => g.dispose(), initParams);
trainState.dispose();
});
// TODO(@aliciafmachado): Make sure that this test is reproducible by adding seeding.
// Currently it's flaky since dropout and some other source of randomness is not properly

This comment has been minimized.

Copy link
@iislucas

iislucas Oct 8, 2024

Collaborator

This is fixed now?

// seeded yet.
it('AorBisMaxTaskWithDropout training', async () => {
const layerSpec: transformer.TransformerParamLayerSpec = {
nHeads: 1,
hasPosEncoding: true,
computeSpec: { residuals: true, dropoutRate: 0.1 },
computeSpec: { residuals: true, dropoutRate: 0.999 },

This comment has been minimized.

Copy link
@iislucas

iislucas Oct 8, 2024

Collaborator

curious why not 1.0 ?

// TODO: investigate: these make 0 gradients?
layerNormFF: false,
layerNormHeadsProjection: false,
Expand All @@ -105,6 +111,7 @@ describe('basic_transformer_trainer', () => {
inputRep: 4,
kqvRep: 3,
layers: [layerSpec, layerSpec],
dropoutRate: 0.999,
},
init: {
stddev: 0.5,
Expand Down Expand Up @@ -147,7 +154,8 @@ describe('basic_transformer_trainer', () => {
const newLoss = trainState.batchMeanLoss;
expect(trainState.nSteps).toBe(1);
expect(trainState.nExamples).toBe(trainState.batchExamples.length);
expect(newLoss).toBeLessThan(initLoss);
// Transformer with 99% dropout should not learn anything, so loss should not improve.
expect(newLoss).toBeGreaterThanOrEqual(initLoss);

// Memory cleanup
jstree.forEach((g: GTensor<any>) => g.dispose(), initParams);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ export function initTransformerTrainState(
}

export function computeMetrics(state: TransformerTrainState): TrainMetrics {
// Eval mode is passed inside computeStateBatchAccuracy and computeLossAndAccuracy to
// computeTransformer.
const trainBatchAcc: number = computeStateBatchAccuracy(state);
const testLossAndAcc = computeLossAndAccuracy(state, state.taskSplit.testSetExamples);
return {
Expand All @@ -130,8 +128,7 @@ export function computeStateBatchAccuracy(state: TransformerTrainState): number
const decoderComputation = transformer.computeTransformer(
state.spec,
state.params,
state.inputsVar,
true
state.inputsVar
);
meanAcc = transformerAccuracy(
decoderComputation,
Expand Down Expand Up @@ -159,8 +156,7 @@ export function computeLossAndAccuracy(
const decoderComputation = transformer.computeTransformer(
state.spec,
state.params,
state.inputsVar,
true
state.inputsVar
);
const batchAcc = transformerAccuracy(
decoderComputation,
Expand Down
6 changes: 3 additions & 3 deletions animated-transformer/src/lib/transformer/dropout.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ describe('dropout', () => {
it('Basic dropout', () => {
const beforeDropout =
gtensor.makeRange('input', 1, 5, 1, 'float32');
const afterDropout = dropout(0.5, beforeDropout, false, 0);
const afterDropout = dropout(0.5, beforeDropout, 0);
afterDropout.tensor.print();

tf.test_util.expectArraysClose(afterDropout.tensor.dataSync(),
Expand All @@ -35,7 +35,7 @@ describe('dropout', () => {
it('Deterministic output', () => {
const beforeDropout =
gtensor.makeRange('input', 1, 5, 1, 'float32');
const afterDropout = dropout(0.5, beforeDropout, true, 0);
const afterDropout = dropout(0, beforeDropout, 0);
afterDropout.tensor.print();

tf.test_util.expectArraysClose(afterDropout.tensor.dataSync(),
Expand All @@ -55,7 +55,7 @@ describe('dropout', () => {
]),
['batch', 'pos', 'inputRep']
);
const afterDropout = dropout(0.5, beforeDropout, false, 1, ['pos']);
const afterDropout = dropout(0.5, beforeDropout, 1, ['pos']);
afterDropout.tensor.print();

tf.test_util.expectArraysClose(afterDropout.tensor.dataSync(),
Expand Down
3 changes: 1 addition & 2 deletions animated-transformer/src/lib/transformer/dropout.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ import {
export function dropout<G extends string, D extends G>(
dropoutRate: number,
g: GTensor<G>,
deterministic: boolean,
seed?: number,
dimNames?: string[],
): GTensor<G> {
if (deterministic) {
if (dropoutRate = 0) {
return g;
}

Expand Down
71 changes: 22 additions & 49 deletions animated-transformer/src/lib/transformer/transformer_gtensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ export type TransformerParamSpec = {
inputRep: number;
kqvRep: number;
layers: TransformerParamLayerSpec[];
// Dropout rate on the input before going into the stack.
dropoutRate: number;
relPosEncodingSeqLength?: number;
};

Expand Down Expand Up @@ -229,20 +231,13 @@ function gelu(x: tf.Tensor) {
export function computeAttnHead(
spec: AttnHeadComputeSpec,
params: AttnHeadParams<TensorKind>,
seqInput: GTensor<'batch' | 'pos' | 'inputRep'>,
evalMode: boolean = false
seqInput: GTensor<'batch' | 'pos' | 'inputRep'>
): BatchAttnHeadCompututation {
const { queryM, keyM, valueM, headsToInputRepM, ff } = params;

// Dropout on the input of the stack.
let seqInputAfterDropout = seqInput;
if (spec.dropoutRate > 0){
seqInputAfterDropout = dropout(spec.dropoutRate, seqInput, evalMode);
}

const queries = seqInputAfterDropout.contract(queryM, ['inputRep']);
const keys = seqInputAfterDropout.contract(keyM, ['inputRep']);
const values = seqInputAfterDropout.contract(valueM, ['inputRep']);
const queries = seqInput.contract(queryM, ['inputRep']);
const keys = seqInput.contract(keyM, ['inputRep']);
const values = seqInput.contract(valueM, ['inputRep']);

let rawAttention = keys
.rename('pos', 'keyPos')
Expand All @@ -261,37 +256,25 @@ export function computeAttnHead(
.scalarDiv(makeScalar(Math.sqrt(seqInput.dim.inputRep.size), 'float32'));
}

// Dropout on attention weights.
if (spec.dropoutRate > 0) {
rawAttention = dropout(
spec.dropoutRate,
rawAttention,
evalMode,
);
}

const attention = rawAttention.softmax('queryPos');

// Dropout on the attention weights.
const attentionAfterDropout = dropout(spec.dropoutRate, attention);

const attendedValues = values
.contract(attention.rename('queryPos', 'pos'), ['pos'])
.contract(attentionAfterDropout.rename('queryPos', 'pos'), ['pos'])
.rename('keyPos', 'pos');

const headsReduction = attendedValues.contract(headsToInputRepM, ['value', 'heads']);

// Dropout before layer norm and residual connection.
let headsReductionAfterDropout = headsReduction;
if (spec.dropoutRate > 0) {
headsReductionAfterDropout = dropout(
spec.dropoutRate,
headsReduction,
evalMode,
);
}
let headsReductionAfterDropout = dropout(spec.dropoutRate, headsReduction);

let normedHeadReduction = headsReductionAfterDropout;
if (params.layerNormHeadsProjection) {
normedHeadReduction = layerNorm(
params.layerNormHeadsProjection,
headsReduction,
headsReductionAfterDropout,
'inputRepToFF'
);
}
Expand All @@ -304,34 +287,25 @@ export function computeAttnHead(
inputToFF = normedHeadReduction.pointwiseAdd(seqInput.rename('inputRep', 'inputRepToFF'));
}

// Skipped dropout in the FF, since the FF nn is a single layer.
// Skipped dropout in the FF, since the FF nn is a single layer with two biases.
let unNormedSeqOuput = inputToFF
.contract(ff.w, ['inputRepToFF'])
.pointwiseAdd(ff.bIn)
.applyPointWiseTfFn(gelu)
.pointwiseAdd(ff.bOut);

// Dropout before layer norm and residual connection.
let unNormedSeqOuputAfterDropout = unNormedSeqOuput;
if (spec.dropoutRate > 0) {
unNormedSeqOuputAfterDropout = dropout(
spec.dropoutRate,
unNormedSeqOuput,
evalMode,
);
}
const unNormedSeqOuputAfterDropout = dropout(spec.dropoutRate, unNormedSeqOuput);

if (spec.residuals) {
// FF residual.
unNormedSeqOuput = unNormedSeqOuputAfterDropout.pointwiseAdd(inputToFF.rename('inputRepToFF', 'inputRep'));
}

let seqOuput = unNormedSeqOuput;
if (params.layerNormPostFF) {
seqOuput = layerNorm(params.layerNormPostFF, unNormedSeqOuput, 'inputRep');
}

// Skipped dropout in the output, since results are being outputted directly without additional
// computations.

return {
seqInput,
Expand Down Expand Up @@ -388,21 +362,22 @@ export type TransformerComputation = {
export function computeTransformer(
spec: TransformerParamSpec,
params: TransformerParams,
seqInput: GTensor<'batch' | 'pos' | 'inputRep'>,
evalMode: boolean = false
seqInput: GTensor<'batch' | 'pos' | 'inputRep'>
): TransformerComputation {
const compute: TransformerComputation = { layers: [] };
let currentLayerInput = seqInput;
// Dropout on the input.
let currentLayerInput = dropout(spec.dropoutRate, seqInput);
params.layers.forEach((layerParams, i) => {
const layerCompute = computeAttnHead(
spec.layers[i].computeSpec,
layerParams,
currentLayerInput,
evalMode,
currentLayerInput
);
compute.layers.push(layerCompute);
currentLayerInput = layerCompute.seqOuput;
});
// TODO(@aliciafmachado): Skipped dropout on the output, since I am not sure how to integrate
// this in the TransformerComputation output.
return compute;
}

Expand Down Expand Up @@ -501,9 +476,7 @@ export function computeDecoder(
(max, curInput) => (max >= curInput.length ? max : curInput.length),
0
);
// Tokenization.
const gtensorInputs = inputPrepFn(tokenRep, params, maxInputLength, inputs);
// Transformer computation.
return computeTransformer(spec, params, gtensorInputs);
}

Expand Down

0 comments on commit fa9b3f6

Please sign in to comment.