Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Make BeamSearch Registrable (#5231)
Browse files Browse the repository at this point in the history
* Make BeamSearch Registrable

* Update changelog

* Remove unused import

* Update CHANGELOG.md

Co-authored-by: Pete <[email protected]>
Co-authored-by: Pete <[email protected]>
  • Loading branch information
3 people authored Jun 1, 2021
1 parent c014232 commit 39d7e5a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0).
- Trainer callbacks can now store and restore state in case a training run gets interrupted.
- VilBERT backbone now rolls and unrolls extra dimensions to handle input with > 3 dimensions.
- `BeamSearch` is now a `Registrable` class.

### Added

Expand Down
9 changes: 7 additions & 2 deletions allennlp/nn/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from overrides import overrides
import torch

from allennlp.common import FromParams, Registrable
from allennlp.common import Registrable
from allennlp.common.checks import ConfigurationError
from allennlp.nn.util import min_value_of_dtype

Expand Down Expand Up @@ -683,7 +683,7 @@ def _update_state(
return state


class BeamSearch(FromParams):
class BeamSearch(Registrable):
"""
Implements the beam search algorithm for decoding the most likely sequences.
Expand Down Expand Up @@ -731,6 +731,8 @@ class BeamSearch(FromParams):
provided, no constraints will be enforced.
"""

default_implementation = "beam_search"

def __init__(
self,
end_index: int,
Expand Down Expand Up @@ -1180,3 +1182,6 @@ def _update_state(self, state: StateType, backpointer: torch.Tensor):
.gather(1, expanded_backpointer)
.reshape(batch_size * self.beam_size, *last_dims)
)


BeamSearch.register("beam_search")(BeamSearch)

0 comments on commit 39d7e5a

Please sign in to comment.