diff --git a/CHANGELOG.md b/CHANGELOG.md index 22f351856..3d48d7dc1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added a configuration to train on the PIQA dataset with AllenNLP Tango. - Added a transformer classification model. - Added a configuration to train on the IMDB dataset with AllenNLP Tango. +- Added `scheduled_sampling_ratio` argument to `CopyNetSeq2Seq` to use scheduled sampling during training. ### Fixed diff --git a/allennlp_models/generation/models/copynet_seq2seq.py b/allennlp_models/generation/models/copynet_seq2seq.py index 758c47fe5..380523c7f 100644 --- a/allennlp_models/generation/models/copynet_seq2seq.py +++ b/allennlp_models/generation/models/copynet_seq2seq.py @@ -56,6 +56,14 @@ class CopyNetSeq2Seq(Model): This is used to during inference to select the tokens of the decoded output sequence. target_embedding_dim : `int`, optional (default = `30`) The size of the embeddings for the target vocabulary. + scheduled_sampling_ratio : `float`, optional (default = `0.`) + At each timestep during training, we sample a random number between 0 and 1, and if it is + not less than this value, we use the ground truth labels for the whole batch. Else, we use + the predictions from the previous time step for the whole batch. If this value is 0.0 + (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not + using target side ground truth labels. See the following paper for more information: + [Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., + 2015](https://arxiv.org/abs/1506.03099). copy_token : `str`, optional (default = `'@COPY@'`) The token used to indicate that a target token was copied from the source. If this token is not already in your target vocabulary, it will be added. @@ -83,6 +91,7 @@ def __init__( label_smoothing: float = None, beam_search: Lazy[BeamSearch] = Lazy(BeamSearch), target_embedding_dim: int = 30, + scheduled_sampling_ratio: float = 0.0, copy_token: str = "@COPY@", target_namespace: str = "target_tokens", tensor_based_metric: Metric = None, @@ -92,6 +101,7 @@ def __init__( ) -> None: super().__init__(vocab) self._target_namespace = target_namespace + self._scheduled_sampling_ratio = scheduled_sampling_ratio self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) self._oov_index = self.vocab.get_token_index(self.vocab._oov_token, self._target_namespace) @@ -391,7 +401,7 @@ def _get_ll_contrib( target_tokens: torch.Tensor, target_to_source: torch.Tensor, source_mask: torch.BoolTensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Get the log-likelihood contribution from a single timestep. @@ -474,7 +484,7 @@ def _get_ll_contrib( # shape: (batch_size,) step_log_likelihood = util.logsumexp(combined_gen_and_copy) - return step_log_likelihood, selective_weights + return step_log_likelihood, selective_weights, log_probs def _forward_loss( self, @@ -514,13 +524,33 @@ def _forward_loss( (batch_size, self._target_vocab_size), fill_value=1.0, dtype=torch.bool ) + # Initialize target predictions with the start index. + # shape: (batch_size,) + last_predictions = source_mask.new_full( + (batch_size,), fill_value=self._start_index, dtype=torch.long + ) + step_log_likelihoods = [] for timestep in range(num_decoding_steps): - # shape: (batch_size,) - input_choices = target_tokens["tokens"]["tokens"][:, timestep] - # Get mask tensor indicating which instances were copied. - # shape: (batch_size,) - copied = ((input_choices == self._oov_index) & (target_to_source.sum(-1) > 0)).long() + if ( + self.training + and self._scheduled_sampling_ratio > 0.0 + and torch.rand(1).item() < self._scheduled_sampling_ratio + ): + # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio + # during training. + # shape: (batch_size,) + input_choices = last_predictions + # Get mask tensor indicating which instances were copied. + # shape: (batch_size,) + copied = (input_choices >= self._target_vocab_size).long() + else: + # shape: (batch_size,) + input_choices = target_tokens["tokens"]["tokens"][:, timestep] + # shape: (batch_size,) + copied = ( + (input_choices == self._oov_index) & (target_to_source.sum(-1) > 0) + ).long() # shape: (batch_size,) input_choices = input_choices * (1 - copied) + copy_input_choices * copied # shape: (batch_size, source_sequence_length) @@ -538,7 +568,7 @@ def _forward_loss( copy_scores = self._get_copy_scores(state) # shape: (batch_size,) step_target_tokens = target_tokens["tokens"]["tokens"][:, timestep + 1] - step_log_likelihood, selective_weights = self._get_ll_contrib( + step_log_likelihood, selective_weights, log_probs = self._get_ll_contrib( generation_scores, generation_scores_mask, copy_scores, @@ -547,6 +577,8 @@ def _forward_loss( source_mask, ) step_log_likelihoods.append(step_log_likelihood.unsqueeze(1)) + # shape (predicted_classes): (batch_size,) + _, last_predictions = torch.max(log_probs, 1) # Gather step log-likelihoods. # shape: (batch_size, num_decoding_steps = target_sequence_length - 1) diff --git a/allennlp_models/generation/models/simple_seq2seq.py b/allennlp_models/generation/models/simple_seq2seq.py index cc73d4cfc..9458ad8ff 100644 --- a/allennlp_models/generation/models/simple_seq2seq.py +++ b/allennlp_models/generation/models/simple_seq2seq.py @@ -367,7 +367,11 @@ def _forward_loop( step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): - if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio: + if ( + self.training + and self._scheduled_sampling_ratio > 0.0 + and torch.rand(1).item() < self._scheduled_sampling_ratio + ): # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) diff --git a/tests/generation/models/copynet_test.py b/tests/generation/models/copynet_test.py index c857c2ac7..86b66b3ca 100644 --- a/tests/generation/models/copynet_test.py +++ b/tests/generation/models/copynet_test.py @@ -44,6 +44,13 @@ def test_model_can_train_with_amp(self): overrides="{'trainer.use_amp':true,'trainer.cuda_device':0}", ) + def test_model_can_train_with_scheduled_sampling_ratio(self): + train_model_from_file( + self.param_file, + self.TEST_DIR, + overrides="{'model.scheduled_sampling_ratio':0.5}", + ) + def test_vocab(self): vocab = self.model.vocab assert vocab.get_vocab_size(self.model._target_namespace) == 8 @@ -133,7 +140,7 @@ def test_get_ll_contrib(self): generation_scores_mask = generation_scores.new_full( generation_scores.size(), True, dtype=torch.bool ) - ll_actual, selective_weights_actual = self.model._get_ll_contrib( + ll_actual, selective_weights_actual, _ = self.model._get_ll_contrib( generation_scores, generation_scores_mask, copy_scores, diff --git a/tests/generation/models/simple_seq2seq_test.py b/tests/generation/models/simple_seq2seq_test.py index 7b48ff863..28aa8922b 100644 --- a/tests/generation/models/simple_seq2seq_test.py +++ b/tests/generation/models/simple_seq2seq_test.py @@ -46,6 +46,13 @@ def test_model_can_train_with_amp(self): overrides="{'trainer.use_amp':true,'trainer.cuda_device':0}", ) + def test_model_can_train_with_scheduled_sampling_ratio(self): + train_model_from_file( + self.param_file, + self.TEST_DIR, + overrides="{'model.scheduled_sampling_ratio':0.5}", + ) + def test_bidirectional_model_can_train_save_and_load(self): param_overrides = json.dumps({"model.encoder.bidirectional": True}) self.ensure_model_can_train_save_and_load(