-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support to dropout. #29
Conversation
…ble dropout during evaluation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some commits going back and forth on some things that I did not fully understood at first, so feel free to squash the commits before merging to main to avoid confusion. Otherwise I can recreate the pull request and fix the commit history.
I also have a few questions / discussion topics:
- I added support to dropout but we need something to manage random seeds so that we can seed properly. Should we create an issue for that?
- I tried to add dropout based on T5 architecture, but I decided to not add it after the FF layer and in the output. For the FF, I don't think it makes sense since we have a single layer and we apply dropout before the residual connection after the FF network. For the output, I don't see any additional computations after getting out of the stack, so I think it would only increase noise if we were to add another dropout there (I also did not see an additional dropout on the output for the haiku implementation linked in the issue to add dropout).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, a few small things.
@@ -225,13 +229,20 @@ function gelu(x: tf.Tensor) { | |||
export function computeAttnHead( | |||
spec: AttnHeadComputeSpec, | |||
params: AttnHeadParams<TensorKind>, | |||
seqInput: GTensor<'batch' | 'pos' | 'inputRep'> | |||
seqInput: GTensor<'batch' | 'pos' | 'inputRep'>, | |||
evalMode: boolean = false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets drop evalMode flag, and just depend on the spec having dropoutRate set different at eval vs inference time.
export function dropout<G extends string, D extends G>( | ||
dropoutRate: number, | ||
g: GTensor<G>, | ||
deterministic: boolean, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets remove deterministic, and just check if rate is 0.
let unNormedSeqOuput = inputToFF | ||
.contract(ff.w, ['inputRepToFF']) | ||
.pointwiseAdd(ff.bIn) | ||
.applyPointWiseTfFn(gelu) | ||
.pointwiseAdd(ff.bOut); | ||
|
||
// Dropout before layer norm and residual connection. | ||
let unNormedSeqOuputAfterDropout = unNormedSeqOuput; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets use this: https://github.com/Shivanandroy/simpleT5 as the reference for where to put it for T5. And maybe name this function computeT5AttnHead, and then later we can make a gpt2 one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest we use the T5 implementation in Jax (T5X): https://github.com/google-research/t5x/blob/705247b743d26a33d0c058b41c72ad030e51891b/t5x/examples/t5/network.py#L222
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg!
const layerSpec: transformer.TransformerParamLayerSpec = { | ||
nHeads: 1, | ||
hasPosEncoding: true, | ||
computeSpec: { residuals: true, dropoutRate: 0.1 }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add one test also for dropout rate of 1, and then test that loss doesn't decrease.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
…to follow the T5X implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, just a couple of minor things.
animated-transformer/src/lib/trainer/basic_transformer_trainer.spec.ts
Outdated
Show resolved
Hide resolved
animated-transformer/src/lib/trainer/basic_transformer_trainer.spec.ts
Outdated
Show resolved
Hide resolved
I added the generator everywhere to be able to seed dropout. I also fixed your comments and small bug. Please take a look.
|
There's one additional thing: ideally we should just pass one dropout rate and use it everywhere, but the way it's setup we have to pass in each layer and even outside. I can try to think of a way to make it clean and just pass it once, but I can also open the issue and fix it in a future PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! LGTM. We can discuss improvements more later, but this is great.
Intended to resolve Issue: #1