Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Add option to use scheduled sampling in CopyNet #309

Merged
merged 11 commits into from
Dec 13, 2021
48 changes: 40 additions & 8 deletions allennlp_models/generation/models/copynet_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ class CopyNetSeq2Seq(Model):
This is used to during inference to select the tokens of the decoded output sequence.
target_embedding_dim : `int`, optional (default = `30`)
The size of the embeddings for the target vocabulary.
scheduled_sampling_ratio : `float`, optional (default = `0.`)
At each timestep during training, we sample a random number between 0 and 1, and if it is
not less than this value, we use the ground truth labels for the whole batch. Else, we use
the predictions from the previous time step for the whole batch. If this value is 0.0
(default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not
using target side ground truth labels. See the following paper for more information:
[Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al.,
2015](https://arxiv.org/abs/1506.03099).
copy_token : `str`, optional (default = `'@COPY@'`)
The token used to indicate that a target token was copied from the source.
If this token is not already in your target vocabulary, it will be added.
Expand Down Expand Up @@ -83,6 +91,7 @@ def __init__(
label_smoothing: float = None,
beam_search: Lazy[BeamSearch] = Lazy(BeamSearch),
target_embedding_dim: int = 30,
scheduled_sampling_ratio: float = 0.0,
JohnGiorgi marked this conversation as resolved.
Show resolved Hide resolved
copy_token: str = "@COPY@",
target_namespace: str = "target_tokens",
tensor_based_metric: Metric = None,
Expand All @@ -92,6 +101,7 @@ def __init__(
) -> None:
super().__init__(vocab)
self._target_namespace = target_namespace
self._scheduled_sampling_ratio = scheduled_sampling_ratio
self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)
self._oov_index = self.vocab.get_token_index(self.vocab._oov_token, self._target_namespace)
Expand Down Expand Up @@ -391,7 +401,7 @@ def _get_ll_contrib(
target_tokens: torch.Tensor,
target_to_source: torch.Tensor,
source_mask: torch.BoolTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get the log-likelihood contribution from a single timestep.

Expand Down Expand Up @@ -474,7 +484,7 @@ def _get_ll_contrib(
# shape: (batch_size,)
step_log_likelihood = util.logsumexp(combined_gen_and_copy)

return step_log_likelihood, selective_weights
return step_log_likelihood, selective_weights, log_probs

def _forward_loss(
self,
Expand Down Expand Up @@ -514,13 +524,33 @@ def _forward_loss(
(batch_size, self._target_vocab_size), fill_value=1.0, dtype=torch.bool
)

# Initialize target predictions with the start index.
# shape: (batch_size,)
last_predictions = source_mask.new_full(
(batch_size,), fill_value=self._start_index, dtype=torch.long
)

step_log_likelihoods = []
for timestep in range(num_decoding_steps):
# shape: (batch_size,)
input_choices = target_tokens["tokens"]["tokens"][:, timestep]
# Get mask tensor indicating which instances were copied.
# shape: (batch_size,)
copied = ((input_choices == self._oov_index) & (target_to_source.sum(-1) > 0)).long()
if (
self.training
and self._scheduled_sampling_ratio > 0.0
and torch.rand(1).item() < self._scheduled_sampling_ratio
):
# Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
# during training.
# shape: (batch_size,)
input_choices = last_predictions
Copy link
Contributor Author

@JohnGiorgi JohnGiorgi Dec 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I realized this implementation doesn't work, because last_predictions is never updated. I would have had to take the index of the token with the highest probability for this timestep under the model. Something like:

last_predictions = torch.max(torch.cat((generation_scores, copy_scores), -1), -1)

@epwalsh does this make sense?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yeup, good catch. To avoid duplicate computation you could use all_scores from the _get_ll_contrib() method. And note that you will need to take into account this mask. So I suggest returning all_scores and mask from _get_ll_contrib so you can use them here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. Could I just return log_probs from _get_ll_contrib()? Its computed like: log_probs = util.masked_log_softmax(all_scores, mask).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, just pushed that change.

# Get mask tensor indicating which instances were copied.
# shape: (batch_size,)
copied = (input_choices >= self._target_vocab_size).long()
else:
# shape: (batch_size,)
input_choices = target_tokens["tokens"]["tokens"][:, timestep]
# shape: (batch_size,)
copied = (
(input_choices == self._oov_index) & (target_to_source.sum(-1) > 0)
).long()
# shape: (batch_size,)
input_choices = input_choices * (1 - copied) + copy_input_choices * copied
# shape: (batch_size, source_sequence_length)
Expand All @@ -538,7 +568,7 @@ def _forward_loss(
copy_scores = self._get_copy_scores(state)
# shape: (batch_size,)
step_target_tokens = target_tokens["tokens"]["tokens"][:, timestep + 1]
step_log_likelihood, selective_weights = self._get_ll_contrib(
step_log_likelihood, selective_weights, log_probs = self._get_ll_contrib(
generation_scores,
generation_scores_mask,
copy_scores,
Expand All @@ -547,6 +577,8 @@ def _forward_loss(
source_mask,
)
step_log_likelihoods.append(step_log_likelihood.unsqueeze(1))
# shape (predicted_classes): (batch_size,)
_, last_predictions = torch.max(log_probs, 1)

# Gather step log-likelihoods.
# shape: (batch_size, num_decoding_steps = target_sequence_length - 1)
Expand Down
6 changes: 5 additions & 1 deletion allennlp_models/generation/models/simple_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,11 @@ def _forward_loop(
step_logits: List[torch.Tensor] = []
step_predictions: List[torch.Tensor] = []
for timestep in range(num_decoding_steps):
if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio:
if (
self.training
and self._scheduled_sampling_ratio > 0.0
and torch.rand(1).item() < self._scheduled_sampling_ratio
):
Comment on lines +370 to +374
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@epwalsh Added a similar condition to simple_seq2seq to avoid the call to torch.rand when _scheduled_sampling_ratio is 0.0.

# Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
# during training.
# shape: (batch_size,)
Expand Down
9 changes: 8 additions & 1 deletion tests/generation/models/copynet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def test_model_can_train_with_amp(self):
overrides="{'trainer.use_amp':true,'trainer.cuda_device':0}",
)

def test_model_can_train_with_scheduled_sampling_ratio(self):
train_model_from_file(
self.param_file,
self.TEST_DIR,
overrides="{'model.scheduled_sampling_ratio':0.5}",
)
JohnGiorgi marked this conversation as resolved.
Show resolved Hide resolved

def test_vocab(self):
vocab = self.model.vocab
assert vocab.get_vocab_size(self.model._target_namespace) == 8
Expand Down Expand Up @@ -133,7 +140,7 @@ def test_get_ll_contrib(self):
generation_scores_mask = generation_scores.new_full(
generation_scores.size(), True, dtype=torch.bool
)
ll_actual, selective_weights_actual = self.model._get_ll_contrib(
ll_actual, selective_weights_actual, _ = self.model._get_ll_contrib(
generation_scores,
generation_scores_mask,
copy_scores,
Expand Down
7 changes: 7 additions & 0 deletions tests/generation/models/simple_seq2seq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def test_model_can_train_with_amp(self):
overrides="{'trainer.use_amp':true,'trainer.cuda_device':0}",
)

def test_model_can_train_with_scheduled_sampling_ratio(self):
train_model_from_file(
self.param_file,
self.TEST_DIR,
overrides="{'model.scheduled_sampling_ratio':0.5}",
)

Comment on lines +49 to +55
Copy link
Contributor Author

@JohnGiorgi JohnGiorgi Dec 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@epwalsh Added the same test for scheduled sampling to simple_seq2seq.

def test_bidirectional_model_can_train_save_and_load(self):
param_overrides = json.dumps({"model.encoder.bidirectional": True})
self.ensure_model_can_train_save_and_load(
Expand Down