Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to dropout. #29

Merged
merged 8 commits into from
Oct 16, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ export class ModelSpecAndData {
const layerSpec: TransformerParamLayerSpec = {
nHeads: 4,
hasPosEncoding: true,
computeSpec: { residuals: true },
computeSpec: { residuals: true, dropoutRate: 0.0 },
layerNormFF: false,
layerNormHeadsProjection: false,
addLayerNormBias: false,
};

const defaultConfig: ModelConfig = {
name: 'd=8 l=1 h=4, !layerN',
name: 'd=8 l=1 h=4, !layerN !dropout',
transformer: {
spec: {
inputRep: 8,
Expand All @@ -105,14 +105,14 @@ const defaultConfig: ModelConfig = {
const layerSpecWithNorm: TransformerParamLayerSpec = {
nHeads: 4,
hasPosEncoding: true,
computeSpec: { residuals: true },
computeSpec: { residuals: true, dropoutRate: 0.0 },
layerNormFF: true,
layerNormHeadsProjection: true,
addLayerNormBias: false,
};

const transWithLayerNormed: ModelConfig = {
name: 'd=8 l=1 h=4 +layerN',
name: 'd=8 l=1 h=4 +layerN !dropout',
transformer: {
spec: {
inputRep: 8,
Expand All @@ -127,15 +127,42 @@ const transWithLayerNormed: ModelConfig = {
},
};

const layerSpecWithNormAndDropout: TransformerParamLayerSpec = {
nHeads: 4,
hasPosEncoding: true,
computeSpec: { residuals: true, dropoutRate: 0.1 },
layerNormFF: true,
layerNormHeadsProjection: true,
addLayerNormBias: false,
};

const transWithLayerNormedAndDropout: ModelConfig = {
name: 'd=8 l=1 h=4 +layerN +dropout',
transformer: {
spec: {
inputRep: 8,
kqvRep: 8,
layers: [layerSpecWithNormAndDropout],
},
init: {
stddev: 0.5,
mean: 0,
seed: 96,
},
},
};

const simpleTransformer = new ModelSpecAndData('transformer', defaultConfig);

const simpleTransformerWithLayerNorm = new ModelSpecAndData('transformer', transWithLayerNormed);

const simpleTransformerWithLayerNormAndDropout = new ModelSpecAndData('transformer', transWithLayerNormedAndDropout)

export interface ModelUpdate {
model: ModelSpecAndData | null;
}

const initModels: ModelSpecAndData[] = [simpleTransformer, simpleTransformerWithLayerNorm];
const initModels: ModelSpecAndData[] = [simpleTransformer, simpleTransformerWithLayerNorm, simpleTransformerWithLayerNormAndDropout];
const initModelsMap: { [name: string]: ModelSpecAndData } = {};
initModels.forEach((m) => (initModelsMap[m.config.name] = m));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,70 @@ describe('basic_transformer_trainer', () => {
const layerSpec: transformer.TransformerParamLayerSpec = {
nHeads: 1,
hasPosEncoding: false,
computeSpec: { residuals: true },
computeSpec: { residuals: true, dropoutRate: 0},
// TODO: investigate: these make 0 gradients?
layerNormFF: false,
layerNormHeadsProjection: false,
addLayerNormBias: false,
};
const decoderConfig: transformer.TransformerConfig = {
spec: {
inputRep: 4,
kqvRep: 3,
layers: [layerSpec, layerSpec],
},
init: {
stddev: 0.5,
mean: 0,
seed: 1,
},
};
const taskConfig: BasicRandSeededTaskConfig = {
name: 'AorBisMaxTask',
maxInputLen: 4,
maxOutputLen: 4,
seed: 0,
};
const trainStateConfig: TrainStateConfig = {
learningRate: 0.5,
batchSize: 64,
maxInputlength: taskConfig.maxInputLen,
testSetSize: 0,
trainSetSize: 64,
};
const task = new abtask.AorBisMaxTask(taskConfig);
const tokenRep = prepareBasicTaskTokenRep(task.baseVocab);
const initParams = transformer.initDecoderParamsTree(tokenRep, decoderConfig);
console.log('initTransformerTrainState...');
const trainState = initTransformerTrainState(
task,
tokenRep,
strSeqPrepFn,
singleNextTokenIdxOutputPrepFn,
decoderConfig,
initParams,
trainStateConfig
);
// Taking a couple of steps...
const initLoss = trainState.batchMeanLoss;
expect(trainState.nSteps).toBe(0);
expect(trainState.nExamples).toBe(0);
const stillTraining = trySgdTrainStep(trainState);
expect(stillTraining).toBe(true);
const newLoss = trainState.batchMeanLoss;
expect(trainState.nSteps).toBe(1);
expect(trainState.nExamples).toBe(trainState.batchExamples.length);
expect(newLoss).toBeLessThan(initLoss);

// Memory cleanup
jstree.forEach((g: GTensor<any>) => g.dispose(), initParams);
trainState.dispose();
});
it('AorBisMaxTaskWithDropout training', async () => {
const layerSpec: transformer.TransformerParamLayerSpec = {
nHeads: 1,
hasPosEncoding: true,
computeSpec: { residuals: true, dropoutRate: 0.1 },
Copy link
Collaborator

@iislucas iislucas Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add one test also for dropout rate of 1, and then test that loss doesn't decrease.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// TODO: investigate: these make 0 gradients?
layerNormFF: false,
layerNormHeadsProjection: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ export function initTransformerTrainState(
}

export function computeMetrics(state: TransformerTrainState): TrainMetrics {
// Eval mode is passed inside computeStateBatchAccuracy and computeLossAndAccuracy to
// computeTransformer.
const trainBatchAcc: number = computeStateBatchAccuracy(state);
const testLossAndAcc = computeLossAndAccuracy(state, state.taskSplit.testSetExamples);
return {
Expand All @@ -128,7 +130,8 @@ export function computeStateBatchAccuracy(state: TransformerTrainState): number
const decoderComputation = transformer.computeTransformer(
state.spec,
state.params,
state.inputsVar
state.inputsVar,
true
);
meanAcc = transformerAccuracy(
decoderComputation,
Expand Down Expand Up @@ -156,7 +159,8 @@ export function computeLossAndAccuracy(
const decoderComputation = transformer.computeTransformer(
state.spec,
state.params,
state.inputsVar
state.inputsVar,
true
);
const batchAcc = transformerAccuracy(
decoderComputation,
Expand Down
74 changes: 74 additions & 0 deletions animated-transformer/src/lib/transformer/dropout.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/* Copyright 2023 Google LLC. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

/* Dropout */
import * as tf from '@tensorflow/tfjs';
import { gtensor } from '..';
import { dropout } from './dropout';

describe('dropout', () => {

it('Basic dropout', () => {
const beforeDropout =
gtensor.makeRange('input', 1, 5, 1, 'float32');
const afterDropout = dropout(0.5, beforeDropout, false, 0);
afterDropout.tensor.print();

tf.test_util.expectArraysClose(afterDropout.tensor.dataSync(),
[2, 4, 0, 8]);
expect(afterDropout.dimNames).toEqual(
['input']);
});

it('Deterministic output', () => {
const beforeDropout =
gtensor.makeRange('input', 1, 5, 1, 'float32');
const afterDropout = dropout(0.5, beforeDropout, true, 0);
afterDropout.tensor.print();

tf.test_util.expectArraysClose(afterDropout.tensor.dataSync(),
[1, 2, 3, 4]);
expect(afterDropout.dimNames).toEqual(
['input']);
});

it('Dropout with noise shape ', () => {
const beforeDropout = new gtensor.GTensor(
tf.tensor([
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
]
]),
['batch', 'pos', 'inputRep']
);
const afterDropout = dropout(0.5, beforeDropout, false, 1, ['pos']);
afterDropout.tensor.print();

tf.test_util.expectArraysClose(afterDropout.tensor.dataSync(),
[
[
[2, 4, 6, 8],
[0, 0, 0, 0],
[0, 0, 0, 0],
]
]);
expect(afterDropout.dimNames).toEqual(
['batch', 'pos', 'inputRep']);
});

// TODO(@aliciafmachado): Test that grads are not applied on de-activated neurons.
});
48 changes: 48 additions & 0 deletions animated-transformer/src/lib/transformer/dropout.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/* Copyright 2023 Google LLC. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

import {
GTensor,
} from '../gtensor/gtensor';

import {dropout as tf_dropout} from '@tensorflow/tfjs';

// Wrapper for tf ts dropout.
export function dropout<G extends string, D extends G>(
dropoutRate: number,
g: GTensor<G>,
deterministic: boolean,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets remove deterministic, and just check if rate is 0.

seed?: number,
dimNames?: string[],
aliciafmachado marked this conversation as resolved.
Show resolved Hide resolved
): GTensor<G> {
if (deterministic) {
return g;
}

let dimensions: number[] = g.tensor.shape;
if (dimNames) {
dimensions = [];
for (const d of g.dimNames) {
if (dimNames.includes(d)) {
dimensions = dimensions.concat(g.dim[d].size);
}
else {
dimensions = dimensions.concat(1);
}
}
}
return new GTensor(tf_dropout(g.tensor, dropoutRate, dimensions, seed), g.dimNames);
}

Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ describe('GTensor Transformers', () => {
it('basic transformer shapes', () => {
const spec: AttnHeadComputeSpec = {
residuals: true,
dropoutRate: 0.0
};
const paramSizes: AttnHeadParamSpec = {
inputRep: 2,
Expand Down
Loading