-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support to dropout. #29
Merged
Merged
Changes from 5 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
d8ea0d8
Add dropout skeleton.
aliciafmachado 53f72f7
Add test for trainer when there is a Dropout layer.
aliciafmachado 07a1711
Make dropout a spec and not a parameter, and add a test.
aliciafmachado 70d0040
Add flag to disable dropout during evaluation.
aliciafmachado cf25062
Add more tests to dropout and pass flag to computeTransformer to disa…
aliciafmachado fa9b3f6
Improve dropout setup, and fix where the dropout is used in the code …
aliciafmachado bbbbd98
Merge branch 'PAIR-code:main' into main
aliciafmachado d49fe1f
Pass generator through TrainState and transformer functions. Fixed bu…
aliciafmachado File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
/* Copyright 2023 Google LLC. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
/* Dropout */ | ||
import * as tf from '@tensorflow/tfjs'; | ||
import { gtensor } from '..'; | ||
import { dropout } from './dropout'; | ||
|
||
describe('dropout', () => { | ||
|
||
it('Basic dropout', () => { | ||
const beforeDropout = | ||
gtensor.makeRange('input', 1, 5, 1, 'float32'); | ||
const afterDropout = dropout(0.5, beforeDropout, false, 0); | ||
afterDropout.tensor.print(); | ||
|
||
tf.test_util.expectArraysClose(afterDropout.tensor.dataSync(), | ||
[2, 4, 0, 8]); | ||
expect(afterDropout.dimNames).toEqual( | ||
['input']); | ||
}); | ||
|
||
it('Deterministic output', () => { | ||
const beforeDropout = | ||
gtensor.makeRange('input', 1, 5, 1, 'float32'); | ||
const afterDropout = dropout(0.5, beforeDropout, true, 0); | ||
afterDropout.tensor.print(); | ||
|
||
tf.test_util.expectArraysClose(afterDropout.tensor.dataSync(), | ||
[1, 2, 3, 4]); | ||
expect(afterDropout.dimNames).toEqual( | ||
['input']); | ||
}); | ||
|
||
it('Dropout with noise shape ', () => { | ||
const beforeDropout = new gtensor.GTensor( | ||
tf.tensor([ | ||
[ | ||
[1, 2, 3, 4], | ||
[5, 6, 7, 8], | ||
[9, 10, 11, 12], | ||
] | ||
]), | ||
['batch', 'pos', 'inputRep'] | ||
); | ||
const afterDropout = dropout(0.5, beforeDropout, false, 1, ['pos']); | ||
afterDropout.tensor.print(); | ||
|
||
tf.test_util.expectArraysClose(afterDropout.tensor.dataSync(), | ||
[ | ||
[ | ||
[2, 4, 6, 8], | ||
[0, 0, 0, 0], | ||
[0, 0, 0, 0], | ||
] | ||
]); | ||
expect(afterDropout.dimNames).toEqual( | ||
['batch', 'pos', 'inputRep']); | ||
}); | ||
|
||
// TODO(@aliciafmachado): Test that grads are not applied on de-activated neurons. | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
/* Copyright 2023 Google LLC. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
import { | ||
GTensor, | ||
} from '../gtensor/gtensor'; | ||
|
||
import {dropout as tf_dropout} from '@tensorflow/tfjs'; | ||
|
||
// Wrapper for tf ts dropout. | ||
export function dropout<G extends string, D extends G>( | ||
dropoutRate: number, | ||
g: GTensor<G>, | ||
deterministic: boolean, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lets remove deterministic, and just check if rate is 0. |
||
seed?: number, | ||
dimNames?: string[], | ||
aliciafmachado marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): GTensor<G> { | ||
if (deterministic) { | ||
return g; | ||
} | ||
|
||
let dimensions: number[] = g.tensor.shape; | ||
if (dimNames) { | ||
dimensions = []; | ||
for (const d of g.dimNames) { | ||
if (dimNames.includes(d)) { | ||
dimensions = dimensions.concat(g.dim[d].size); | ||
} | ||
else { | ||
dimensions = dimensions.concat(1); | ||
} | ||
} | ||
} | ||
return new GTensor(tf_dropout(g.tensor, dropoutRate, dimensions, seed), g.dimNames); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add one test also for dropout rate of 1, and then test that loss doesn't decrease.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.