Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Implementing abstraction to score final sequences in BeamSearch #5208

Merged
merged 7 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
See [PR #5172](https://github.com/allenai/allennlp/pull/5172) for more details.
- Added `SpanExtractorWithSpanWidthEmbedding`, putting specific span embedding computations into the `_embed_spans` method and leaving the common code in `SpanExtractorWithSpanWidthEmbedding` to unify the arguments, and modified `BidirectionalEndpointSpanExtractor`, `EndpointSpanExtractor` and `SelfAttentiveSpanExtractor` accordingly. Now, `SelfAttentiveSpanExtractor` can also embed span widths.
- Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences.
- Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`.

### Fixed

Expand Down
120 changes: 117 additions & 3 deletions allennlp/nn/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,99 @@ def gumbel_with_max(self, phi, T) -> torch.Tensor:
return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs()))


class FinalSequenceScorer(Registrable):
"""
An abstract class that can be used to score the final generated sequences found
by beam search. Given the predicted sequences and the corresponding log probabilities of
those sequences, the class calculates and returns the final score of the sequences.

The default implementation scores the sequences using the sum of the log probabilities of
the sequence, which is passed as input.
"""

default_implementation = "sequence-log-prob"

def score(
self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int
) -> torch.Tensor:
"""
Score the final predictions found by beam search.

# Parameters

predictions : `torch.Tensor`
A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`.

log_probabilities : `torch.Tensor`
A tensor containing the log probabilities of the sequence, defined as the sum
of the log probabilities per token, with shape `(batch_size, beam_size)`.

end_index : `int`
The index of the end symbol.

# Returns

`torch.Tensor`
A tensor of the final sequence scores of shape `(batch_size, beam_size)`.
"""
raise NotImplementedError


@FinalSequenceScorer.register("sequence-log-prob")
class SequenceLogProbabilityScorer(FinalSequenceScorer):
"""
A `FinalSequenceScorer` which scores the sequences by the sum of the log probabilities
across the sequence's tokens.
"""

@overrides
def score(
self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int
) -> torch.Tensor:
# The sum of the sequence log probabilities is the input parameter, so just
# return it.
return log_probabilities


@FinalSequenceScorer.register("length-normalized-sequence-log-prob")
class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
"""
A `FinalSequenceScorer` which scores the sequences by the average log probability of the
tokens in the sequence. It optionally includes a length penalty which promotes
or demotes sequences based on their lengths. The final score for a sequence will
be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length
here includes the end token.

# Parameters

length_penalty : `float`, optional (default = `1.0`)
The length penalty to use. A value of 1.0 means no length penalty is used.
A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.
"""

def __init__(self, length_penalty: float = 1.0):
super().__init__()
self.length_penalty = length_penalty

@overrides
def score(
self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int
) -> torch.Tensor:
# shape: (batch_size, beam_size)
lengths = (predictions != end_index).long().sum(dim=2)

# If the sequence ended during beam search, the `log_probabilities` will include
# the transition to the end token. Therefore, in such situations, `lengths` is
# actually off by 1. This corrects for that.
# shape: (batch_size, beam_size)
is_end_token = predictions[:, :, -1] == end_index
lengths += is_end_token.long()

# shape: (batch_size, beam_size)
average_log_probs = log_probabilities / (lengths ** self.length_penalty)
return average_log_probs


class BeamSearch(FromParams):
"""
Implements the beam search algorithm for decoding the most likely sequences.
Expand Down Expand Up @@ -467,6 +560,12 @@ class BeamSearch(FromParams):
The minimum number of decoding steps to take, i.e. the minimum length of
the predicted sequences. This does not include the start or end tokens. If `None`,
no minimum is enforced.

final_sequence_scorer : `FinalSequenceScorer`, optional (default = `None`)
An optional `FinalSequenceScorer` which is used to score the final generated sequences.
The output from this module is what is returned by the `search` method. If not
specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences
by the sum of the token log probabilities.
"""

def __init__(
Expand All @@ -477,6 +576,7 @@ def __init__(
per_node_beam_size: int = None,
sampler: Sampler = None,
min_steps: Optional[int] = None,
final_sequence_scorer: FinalSequenceScorer = None,
) -> None:
if not max_steps > 0:
raise ValueError("max_steps must be positive")
Expand All @@ -496,6 +596,7 @@ def __init__(
self.per_node_beam_size = per_node_beam_size or beam_size
self.sampler = sampler or DeterministicSampler()
self.min_steps = min_steps or 0
self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer()

@staticmethod
def _reconstruct_sequences(predictions, backpointers):
Expand Down Expand Up @@ -580,8 +681,8 @@ def search(
# Returns

`Tuple[torch.Tensor, torch.Tensor]`
Tuple of `(predictions, log_probabilities)`, where `predictions`
has shape `(batch_size, beam_size, max_steps)` and `log_probabilities`
Tuple of `(predictions, final_scores)`, where `predictions`
has shape `(batch_size, beam_size, max_steps)` and `final_scores`
has shape `(batch_size, beam_size)`.
"""
step_signature = signature(step)
Expand Down Expand Up @@ -786,7 +887,20 @@ def _search(
# shape: (batch_size, beam_size, max_steps)
all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)

return all_predictions, last_log_probabilities
# Calculate the final sequence scores
# shape: (batch_size, beam_size)
final_scores = self.final_sequence_scorer.score(
all_predictions, last_log_probabilities, self._end_index
)
danieldeutsch marked this conversation as resolved.
Show resolved Hide resolved

# Sort the sequences based on the final scores so the best scoring
# sequence is at index 0
sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True)
sorted_all_predictions = torch.gather(
all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions)
)

return sorted_all_predictions, sorted_final_scores

@staticmethod
def _is_multilayer_rnn_decoder(key: str, state_tensor: torch.Tensor) -> bool:
Expand Down
58 changes: 58 additions & 0 deletions tests/nn/beam_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
TopKSampler,
TopPSampler,
GumbelSampler,
SequenceLogProbabilityScorer,
LengthNormalizedSequenceLogProbabilityScorer,
)
from allennlp.common.params import Params

Expand Down Expand Up @@ -538,3 +540,59 @@ def test_gumbel_sampler(self):

assert all([x >= 0 and x < 4 for x in indices[0]])
assert all([x > 1 and x <= 5 for x in indices[1]])

def test_sequence_log_prob_scorer(self):
# SequenceLogProbabilityScorer is the default, so manually setting the
# sequence scorer shouldn't actually change anything
self.beam_search.sequence_scorer = SequenceLogProbabilityScorer()

def test_length_normalized_sequence_log_prob_scorer(self):
"""
Tests to ensure the sequences are normalized by the correct values. The end token is
included in the length. The start token is not.
"""
self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer()
expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
length_normalization = np.array([5, 4, 3])
expected_scores = expected_log_probs / length_normalization
self._check_results(expected_log_probs=expected_scores)

# Introduce a length penalty
length_penalty = 2.0
self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
length_penalty=length_penalty
)
expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
length_normalization = np.array(
[5 ** length_penalty, 4 ** length_penalty, 3 ** length_penalty]
)
expected_scores = expected_log_probs / length_normalization
self._check_results(expected_log_probs=expected_scores)

# Pick a length penalty so extreme that the order of the sequences is reversed
length_penalty = -2.0
self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
length_penalty=length_penalty
)
expected_top_k = np.array([[3, 4, 5, 5, 5], [2, 3, 4, 5, 5], [1, 2, 3, 4, 5]])
expected_log_probs = np.log(np.array([0.2, 0.3, 0.4]))
length_normalization = np.array(
[3 ** length_penalty, 4 ** length_penalty, 5 ** length_penalty]
)
expected_scores = expected_log_probs / length_normalization
self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores)

# Here, we set the max_steps = 4. This prevents the first sequence from finishing,
# so its length does not include the end token, whereas the other sequences do.
length_penalty = 2.0
self.beam_search.max_steps = 4
self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
length_penalty=length_penalty
)
expected_top_k = np.array([[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 5]])
expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
length_normalization = np.array(
[4 ** length_penalty, 4 ** length_penalty, 3 ** length_penalty]
)
expected_scores = expected_log_probs / length_normalization
self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores)