Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Add option to use scheduled sampling in CopyNet #309

Merged
merged 11 commits into from
Dec 13, 2021

Conversation

JohnGiorgi
Copy link
Contributor

@JohnGiorgi JohnGiorgi commented Nov 22, 2021

This PR adds the ability to use scheduled sampling in CopyNetSeq2Seq by supplying an argument for scheduled_sampling_ratio that's greater than zero. It is essentially a copy/paste from SimpleSeq2Seq.

This helps reduce the differences in the SimpleSeq2Seq and CopyNetSeq2Seq model arguments. It is also backwards compatible with a default value of 0 (no scheduled sampling i.e. teacher forcing).

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @JohnGiorgi! I think this is a good addition. I just have a couple of suggestions:

  • I think we should have a test
  • I think the default for scheduled_sampling_ratio should be None. And then when it is None we shouldn't call torch.rand(). That way there is no performance penalty for this feature.

@JohnGiorgi
Copy link
Contributor Author

JohnGiorgi commented Dec 10, 2021

@epwalsh Awesome, thanks for the feedback.

  • Updated the if statement so that if scheduled_sampling_ratio is falsey, torch.rand is not called. You could set the default to None, but this makes the default value for this parameter different than simple_seq2seq.
  • Added a simple test that the model can train with 0 < scheduled_sampling_ratio < 1. I wasn't really sure how to write a more significant test due to the randomness. Let me know if you have a different test in mind!

I could also update simple_seq2seq so that torch.rand is not called when scheduled_sampling_ratio is falsey, to save a tiny bit of performance there! It is currently calling torch.rand for the default scheduled_sampling_ratio of 0.0.

# Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
# during training.
# shape: (batch_size,)
input_choices = last_predictions
Copy link
Contributor Author

@JohnGiorgi JohnGiorgi Dec 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I realized this implementation doesn't work, because last_predictions is never updated. I would have had to take the index of the token with the highest probability for this timestep under the model. Something like:

last_predictions = torch.max(torch.cat((generation_scores, copy_scores), -1), -1)

@epwalsh does this make sense?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yeup, good catch. To avoid duplicate computation you could use all_scores from the _get_ll_contrib() method. And note that you will need to take into account this mask. So I suggest returning all_scores and mask from _get_ll_contrib so you can use them here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. Could I just return log_probs from _get_ll_contrib()? Its computed like: log_probs = util.masked_log_softmax(all_scores, mask).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, just pushed that change.

allennlp_models/generation/models/copynet_seq2seq.py Outdated Show resolved Hide resolved
# Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
# during training.
# shape: (batch_size,)
input_choices = last_predictions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yeup, good catch. To avoid duplicate computation you could use all_scores from the _get_ll_contrib() method. And note that you will need to take into account this mask. So I suggest returning all_scores and mask from _get_ll_contrib so you can use them here.

tests/generation/models/copynet_test.py Show resolved Hide resolved
Comment on lines +49 to +55
def test_model_can_train_with_scheduled_sampling_ratio(self):
train_model_from_file(
self.param_file,
self.TEST_DIR,
overrides="{'model.scheduled_sampling_ratio':0.5}",
)

Copy link
Contributor Author

@JohnGiorgi JohnGiorgi Dec 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@epwalsh Added the same test for scheduled sampling to simple_seq2seq.

Comment on lines +370 to +374
if (
self.training
and self._scheduled_sampling_ratio > 0.0
and torch.rand(1).item() < self._scheduled_sampling_ratio
):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@epwalsh Added a similar condition to simple_seq2seq to avoid the call to torch.rand when _scheduled_sampling_ratio is 0.0.

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM! Can you just update the CHANGELOG? Then I think this is good to go.

@JohnGiorgi
Copy link
Contributor Author

Cool! Changelog updated 👍

@epwalsh epwalsh enabled auto-merge (squash) December 13, 2021 17:57
@epwalsh epwalsh merged commit 4866862 into allenai:main Dec 13, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants