Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to dropout. #29

Merged
merged 8 commits into from
Oct 16, 2024
Merged

Conversation

aliciafmachado
Copy link
Collaborator

@aliciafmachado aliciafmachado commented Sep 8, 2024

  • Add dropout.
  • Add basic tests for dropout and trainer with dropout enabled.
  • Add example in app.
  • Add evalMode flag so that dropout can be disabled during eval.

Intended to resolve Issue: #1

@aliciafmachado aliciafmachado changed the title Add more tests to dropout and pass flag to computeTransformer to disable dropout during evaluation. Add support to dropout. Sep 8, 2024
Copy link
Collaborator Author

@aliciafmachado aliciafmachado left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some commits going back and forth on some things that I did not fully understood at first, so feel free to squash the commits before merging to main to avoid confusion. Otherwise I can recreate the pull request and fix the commit history.

I also have a few questions / discussion topics:

  1. I added support to dropout but we need something to manage random seeds so that we can seed properly. Should we create an issue for that?
  2. I tried to add dropout based on T5 architecture, but I decided to not add it after the FF layer and in the output. For the FF, I don't think it makes sense since we have a single layer and we apply dropout before the residual connection after the FF network. For the output, I don't see any additional computations after getting out of the stack, so I think it would only increase noise if we were to add another dropout there (I also did not see an additional dropout on the output for the haiku implementation linked in the issue to add dropout).

@aliciafmachado aliciafmachado marked this pull request as ready for review September 8, 2024 16:06
Copy link
Collaborator

@iislucas iislucas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, a few small things.

@@ -225,13 +229,20 @@ function gelu(x: tf.Tensor) {
export function computeAttnHead(
spec: AttnHeadComputeSpec,
params: AttnHeadParams<TensorKind>,
seqInput: GTensor<'batch' | 'pos' | 'inputRep'>
seqInput: GTensor<'batch' | 'pos' | 'inputRep'>,
evalMode: boolean = false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets drop evalMode flag, and just depend on the spec having dropoutRate set different at eval vs inference time.

export function dropout<G extends string, D extends G>(
dropoutRate: number,
g: GTensor<G>,
deterministic: boolean,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets remove deterministic, and just check if rate is 0.

let unNormedSeqOuput = inputToFF
.contract(ff.w, ['inputRepToFF'])
.pointwiseAdd(ff.bIn)
.applyPointWiseTfFn(gelu)
.pointwiseAdd(ff.bOut);

// Dropout before layer norm and residual connection.
let unNormedSeqOuputAfterDropout = unNormedSeqOuput;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets use this: https://github.com/Shivanandroy/simpleT5 as the reference for where to put it for T5. And maybe name this function computeT5AttnHead, and then later we can make a gpt2 one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg!

const layerSpec: transformer.TransformerParamLayerSpec = {
nHeads: 1,
hasPosEncoding: true,
computeSpec: { residuals: true, dropoutRate: 0.1 },
Copy link
Collaborator

@iislucas iislucas Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add one test also for dropout rate of 1, and then test that loss doesn't decrease.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@aliciafmachado
Copy link
Collaborator Author

Will rebase once #36 is submitted and then pass a generator so that the dropout is reproducible, and then you can take a second look @iislucas.

Copy link
Collaborator

@iislucas iislucas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, just a couple of minor things.

@aliciafmachado
Copy link
Collaborator Author

I added the generator everywhere to be able to seed dropout. I also fixed your comments and small bug. Please take a look.
Additionally, I realized two things:

  1. I don't know if it makes sense to have a test with 0.99 dropout because one single SGD step is highly sensitive and even for the case without dropout, if you change the seed, the loss does not necessarily go down. So I think that we can either remove the test with dropout or just test running a sgd step on it to make sure nothing is broken.
  2. If the dropout is very high (0.999) sometimes you get a NaN loss, I think it probably has to do with everything being dropped out - but I need to investigate a bit more. I don't think we will use such high dropout in any case. So I think I can create an issue to investigate this corner case, and we can unblock the PR if you agree.

@aliciafmachado
Copy link
Collaborator Author

There's one additional thing: ideally we should just pass one dropout rate and use it everywhere, but the way it's setup we have to pass in each layer and even outside. I can try to think of a way to make it clean and just pass it once, but I can also open the issue and fix it in a future PR.

Copy link
Collaborator

@iislucas iislucas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM. We can discuss improvements more later, but this is great.

@aliciafmachado aliciafmachado merged commit d2ccc7b into PAIR-code:main Oct 16, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants