From b533733a6d4662fe3898406b4dfd1e8887b45400 Mon Sep 17 00:00:00 2001 From: Xin Zhang Date: Tue, 4 May 2021 02:43:50 +0800 Subject: [PATCH] Refactor span extractors and unify forward. (#5160) * Refactor span extractors * add SpanExtractorWithSpanWidthEmbedding * update changelog * fix blank lines Co-authored-by: Dirk Groeneveld --- CHANGELOG.md | 2 +- .../bidirectional_endpoint_span_extractor.py | 50 ++---- .../endpoint_span_extractor.py | 51 ++----- .../self_attentive_span_extractor.py | 45 ++++-- ...pan_extractor_with_span_width_embedding.py | 144 ++++++++++++++++++ .../self_attentive_span_extractor_test.py | 40 ++++- 6 files changed, 235 insertions(+), 97 deletions(-) create mode 100644 allennlp/modules/span_extractors/span_extractor_with_span_width_embedding.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 94f046c2ee8..6310c926cd2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,8 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 You can do this by setting the parameter `load_weights` to `False`. See [PR #5172](https://github.com/allenai/allennlp/pull/5172) for more details. +- Added `SpanExtractorWithSpanWidthEmbedding`, putting specific span embedding computations into the `_embed_spans` method and leaving the common code in `SpanExtractorWithSpanWidthEmbedding` to unify the arguments, and modified `BidirectionalEndpointSpanExtractor`, `EndpointSpanExtractor` and `SelfAttentiveSpanExtractor` accordingly. Now, `SelfAttentiveSpanExtractor` can also embed span widths. -## Unreleased ### Fixed diff --git a/allennlp/modules/span_extractors/bidirectional_endpoint_span_extractor.py b/allennlp/modules/span_extractors/bidirectional_endpoint_span_extractor.py index e63a19b53aa..464039ede6e 100644 --- a/allennlp/modules/span_extractors/bidirectional_endpoint_span_extractor.py +++ b/allennlp/modules/span_extractors/bidirectional_endpoint_span_extractor.py @@ -1,17 +1,16 @@ -from typing import Optional - import torch -from overrides import overrides from torch.nn.parameter import Parameter from allennlp.common.checks import ConfigurationError from allennlp.modules.span_extractors.span_extractor import SpanExtractor -from allennlp.modules.token_embedders.embedding import Embedding +from allennlp.modules.span_extractors.span_extractor_with_span_width_embedding import ( + SpanExtractorWithSpanWidthEmbedding, +) from allennlp.nn import util @SpanExtractor.register("bidirectional_endpoint") -class BidirectionalEndpointSpanExtractor(SpanExtractor): +class BidirectionalEndpointSpanExtractor(SpanExtractorWithSpanWidthEmbedding): """ Represents spans from a bidirectional encoder as a concatenation of two different representations of the span endpoints, one for the forward direction of the encoder @@ -79,12 +78,14 @@ def __init__( bucket_widths: bool = False, use_sentinels: bool = True, ) -> None: - super().__init__() - self._input_dim = input_dim + super().__init__( + input_dim=input_dim, + num_width_embeddings=num_width_embeddings, + span_width_embedding_dim=span_width_embedding_dim, + bucket_widths=bucket_widths, + ) self._forward_combination = forward_combination self._backward_combination = backward_combination - self._num_width_embeddings = num_width_embeddings - self._bucket_widths = bucket_widths if self._input_dim % 2 != 0: raise ConfigurationError( @@ -93,25 +94,11 @@ def __init__( "is bidirectional (and hence divisible by 2)." ) - self._span_width_embedding: Optional[Embedding] = None - if num_width_embeddings is not None and span_width_embedding_dim is not None: - self._span_width_embedding = Embedding( - num_embeddings=num_width_embeddings, embedding_dim=span_width_embedding_dim - ) - elif num_width_embeddings is not None or span_width_embedding_dim is not None: - raise ConfigurationError( - "To use a span width embedding representation, you must" - "specify both num_width_buckets and span_width_embedding_dim." - ) - self._use_sentinels = use_sentinels if use_sentinels: self._start_sentinel = Parameter(torch.randn([1, 1, int(input_dim / 2)])) self._end_sentinel = Parameter(torch.randn([1, 1, int(input_dim / 2)])) - def get_input_dim(self) -> int: - return self._input_dim - def get_output_dim(self) -> int: unidirectional_dim = int(self._input_dim / 2) forward_combined_dim = util.get_combined_dim( @@ -128,8 +115,7 @@ def get_output_dim(self) -> int: ) return forward_combined_dim + backward_combined_dim - @overrides - def forward( + def _embed_spans( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, @@ -238,18 +224,4 @@ def forward( # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim) span_embeddings = torch.cat([forward_spans, backward_spans], -1) - if self._span_width_embedding is not None: - # Embed the span widths and concatenate to the rest of the representations. - if self._bucket_widths: - span_widths = util.bucket_values( - span_ends - span_starts, num_total_buckets=self._num_width_embeddings # type: ignore - ) - else: - span_widths = span_ends - span_starts - - span_width_embeddings = self._span_width_embedding(span_widths) - return torch.cat([span_embeddings, span_width_embeddings], -1) - - if span_indices_mask is not None: - return span_embeddings * span_indices_mask.unsqueeze(-1) return span_embeddings diff --git a/allennlp/modules/span_extractors/endpoint_span_extractor.py b/allennlp/modules/span_extractors/endpoint_span_extractor.py index 86b19cb4a7e..fa229e00929 100644 --- a/allennlp/modules/span_extractors/endpoint_span_extractor.py +++ b/allennlp/modules/span_extractors/endpoint_span_extractor.py @@ -1,17 +1,15 @@ -from typing import Optional - import torch from torch.nn.parameter import Parameter -from overrides import overrides from allennlp.modules.span_extractors.span_extractor import SpanExtractor -from allennlp.modules.token_embedders.embedding import Embedding +from allennlp.modules.span_extractors.span_extractor_with_span_width_embedding import ( + SpanExtractorWithSpanWidthEmbedding, +) from allennlp.nn import util -from allennlp.common.checks import ConfigurationError @SpanExtractor.register("endpoint") -class EndpointSpanExtractor(SpanExtractor): +class EndpointSpanExtractor(SpanExtractorWithSpanWidthEmbedding): """ Represents spans as a combination of the embeddings of their endpoints. Additionally, the width of the spans can be embedded and concatenated on to the final combination. @@ -61,38 +59,25 @@ def __init__( bucket_widths: bool = False, use_exclusive_start_indices: bool = False, ) -> None: - super().__init__() - self._input_dim = input_dim + super().__init__( + input_dim=input_dim, + num_width_embeddings=num_width_embeddings, + span_width_embedding_dim=span_width_embedding_dim, + bucket_widths=bucket_widths, + ) self._combination = combination - self._num_width_embeddings = num_width_embeddings - self._bucket_widths = bucket_widths self._use_exclusive_start_indices = use_exclusive_start_indices if use_exclusive_start_indices: self._start_sentinel = Parameter(torch.randn([1, 1, int(input_dim)])) - self._span_width_embedding: Optional[Embedding] = None - if num_width_embeddings is not None and span_width_embedding_dim is not None: - self._span_width_embedding = Embedding( - num_embeddings=num_width_embeddings, embedding_dim=span_width_embedding_dim - ) - elif num_width_embeddings is not None or span_width_embedding_dim is not None: - raise ConfigurationError( - "To use a span width embedding representation, you must" - "specify both num_width_buckets and span_width_embedding_dim." - ) - - def get_input_dim(self) -> int: - return self._input_dim - def get_output_dim(self) -> int: combined_dim = util.get_combined_dim(self._combination, [self._input_dim, self._input_dim]) if self._span_width_embedding is not None: return combined_dim + self._span_width_embedding.get_output_dim() return combined_dim - @overrides - def forward( + def _embed_spans( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, @@ -148,19 +133,5 @@ def forward( combined_tensors = util.combine_tensors( self._combination, [start_embeddings, end_embeddings] ) - if self._span_width_embedding is not None: - # Embed the span widths and concatenate to the rest of the representations. - if self._bucket_widths: - span_widths = util.bucket_values( - span_ends - span_starts, num_total_buckets=self._num_width_embeddings # type: ignore - ) - else: - span_widths = span_ends - span_starts - - span_width_embeddings = self._span_width_embedding(span_widths) - combined_tensors = torch.cat([combined_tensors, span_width_embeddings], -1) - - if span_indices_mask is not None: - return combined_tensors * span_indices_mask.unsqueeze(-1) return combined_tensors diff --git a/allennlp/modules/span_extractors/self_attentive_span_extractor.py b/allennlp/modules/span_extractors/self_attentive_span_extractor.py index 28d68a308d9..b05aeaf0da8 100644 --- a/allennlp/modules/span_extractors/self_attentive_span_extractor.py +++ b/allennlp/modules/span_extractors/self_attentive_span_extractor.py @@ -1,13 +1,15 @@ import torch -from overrides import overrides from allennlp.modules.span_extractors.span_extractor import SpanExtractor +from allennlp.modules.span_extractors.span_extractor_with_span_width_embedding import ( + SpanExtractorWithSpanWidthEmbedding, +) from allennlp.modules.time_distributed import TimeDistributed from allennlp.nn import util @SpanExtractor.register("self_attentive") -class SelfAttentiveSpanExtractor(SpanExtractor): +class SelfAttentiveSpanExtractor(SpanExtractorWithSpanWidthEmbedding): """ Computes span representations by generating an unnormalized attention score for each word in the document. Spans representations are computed with respect to these @@ -23,6 +25,14 @@ class SelfAttentiveSpanExtractor(SpanExtractor): input_dim : `int`, required. The final dimension of the `sequence_tensor`. + num_width_embeddings : `int`, optional (default = `None`). + Specifies the number of buckets to use when representing + span width features. + span_width_embedding_dim : `int`, optional (default = `None`). + The embedding size for the span_width features. + bucket_widths : `bool`, optional (default = `False`). + Whether to bucket the span widths into log-space buckets. If `False`, + the raw span widths are used. # Returns @@ -33,22 +43,31 @@ class SelfAttentiveSpanExtractor(SpanExtractor): over which they are normalized. """ - def __init__(self, input_dim: int) -> None: - super().__init__() - self._input_dim = input_dim + def __init__( + self, + input_dim: int, + num_width_embeddings: int = None, + span_width_embedding_dim: int = None, + bucket_widths: bool = False, + ) -> None: + super().__init__( + input_dim=input_dim, + num_width_embeddings=num_width_embeddings, + span_width_embedding_dim=span_width_embedding_dim, + bucket_widths=bucket_widths, + ) self._global_attention = TimeDistributed(torch.nn.Linear(input_dim, 1)) - def get_input_dim(self) -> int: - return self._input_dim - def get_output_dim(self) -> int: + if self._span_width_embedding is not None: + return self._input_dim + self._span_width_embedding.get_output_dim() return self._input_dim - @overrides - def forward( + def _embed_spans( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, + sequence_mask: torch.BoolTensor = None, span_indices_mask: torch.BoolTensor = None, ) -> torch.FloatTensor: # shape (batch_size, sequence_length, 1) @@ -72,10 +91,4 @@ def forward( # Shape: (batch_size, num_spans, embedding_dim) attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights) - if span_indices_mask is not None: - # Above we were masking the widths of spans with respect to the max - # span width in the batch. Here we are masking the spans which were - # originally passed in as padding. - return attended_text_embeddings * span_indices_mask.unsqueeze(-1) - return attended_text_embeddings diff --git a/allennlp/modules/span_extractors/span_extractor_with_span_width_embedding.py b/allennlp/modules/span_extractors/span_extractor_with_span_width_embedding.py new file mode 100644 index 00000000000..de98059226f --- /dev/null +++ b/allennlp/modules/span_extractors/span_extractor_with_span_width_embedding.py @@ -0,0 +1,144 @@ +from typing import Optional +from overrides import overrides + +import torch + +from allennlp.common.checks import ConfigurationError +from allennlp.modules.span_extractors.span_extractor import SpanExtractor +from allennlp.modules.token_embedders.embedding import Embedding +from allennlp.nn import util + + +class SpanExtractorWithSpanWidthEmbedding(SpanExtractor): + """ + `SpanExtractorWithSpanWidthEmbedding` implements some common code for span + extractors which will need to embed span width. + + Specifically, we initiate the span width embedding matrix and other + attributes in `__init__`, leave an `_embed_spans` method that can be + implemented to compute span embeddings in different ways, and in `forward` + we concatenate span embeddings returned by `_embed_spans` with span width + embeddings to form the final span representations. + + We keep SpanExtractor as a purely abstract base class, just in case someone + wants to build a totally different span extractor. + + # Parameters + + input_dim : `int`, required. + The final dimension of the `sequence_tensor`. + num_width_embeddings : `int`, optional (default = `None`). + Specifies the number of buckets to use when representing + span width features. + span_width_embedding_dim : `int`, optional (default = `None`). + The embedding size for the span_width features. + bucket_widths : `bool`, optional (default = `False`). + Whether to bucket the span widths into log-space buckets. If `False`, + the raw span widths are used. + + # Returns + + span_embeddings : `torch.FloatTensor`. + A tensor of shape `(batch_size, num_spans, embedded_span_size)`, + where `embedded_span_size` depends on the way spans are represented. + """ + + def __init__( + self, + input_dim: int, + num_width_embeddings: int = None, + span_width_embedding_dim: int = None, + bucket_widths: bool = False, + ) -> None: + super().__init__() + self._input_dim = input_dim + self._num_width_embeddings = num_width_embeddings + self._bucket_widths = bucket_widths + + self._span_width_embedding: Optional[Embedding] = None + if num_width_embeddings is not None and span_width_embedding_dim is not None: + self._span_width_embedding = Embedding( + num_embeddings=num_width_embeddings, embedding_dim=span_width_embedding_dim + ) + elif num_width_embeddings is not None or span_width_embedding_dim is not None: + raise ConfigurationError( + "To use a span width embedding representation, you must" + "specify both num_width_embeddings and span_width_embedding_dim." + ) + + @overrides + def forward( + self, + sequence_tensor: torch.FloatTensor, + span_indices: torch.LongTensor, + sequence_mask: torch.BoolTensor = None, + span_indices_mask: torch.BoolTensor = None, + ): + """ + Given a sequence tensor, extract spans, concatenate width embeddings + when need and return representations of them. + + # Parameters + + sequence_tensor : `torch.FloatTensor`, required. + A tensor of shape (batch_size, sequence_length, embedding_size) + representing an embedded sequence of words. + span_indices : `torch.LongTensor`, required. + A tensor of shape `(batch_size, num_spans, 2)`, where the last + dimension represents the inclusive start and end indices of the + span to be extracted from the `sequence_tensor`. + sequence_mask : `torch.BoolTensor`, optional (default = `None`). + A tensor of shape (batch_size, sequence_length) representing padded + elements of the sequence. + span_indices_mask : `torch.BoolTensor`, optional (default = `None`). + A tensor of shape (batch_size, num_spans) representing the valid + spans in the `indices` tensor. This mask is optional because + sometimes it's easier to worry about masking after calling this + function, rather than passing a mask directly. + + # Returns + + A tensor of shape `(batch_size, num_spans, embedded_span_size)`, + where `embedded_span_size` depends on the way spans are represented. + """ + # shape (batch_size, num_spans, embedding_dim) + span_embeddings = self._embed_spans( + sequence_tensor, span_indices, sequence_mask, span_indices_mask + ) + if self._span_width_embedding is not None: + # width = end_index - start_index + 1 since `SpanField` use inclusive indices. + # But here we do not add 1 beacuse we often initiate the span width + # embedding matrix with `num_width_embeddings = max_span_width` + # shape (batch_size, num_spans) + widths_minus_one = span_indices[..., 1] - span_indices[..., 0] + + if self._bucket_widths: + widths_minus_one = util.bucket_values( + widths_minus_one, num_total_buckets=self._num_width_embeddings # type: ignore + ) + + # Embed the span widths and concatenate to the rest of the representations. + span_width_embeddings = self._span_width_embedding(widths_minus_one) + span_embeddings = torch.cat([span_embeddings, span_width_embeddings], -1) + + if span_indices_mask is not None: + # Here we are masking the spans which were originally passed in as padding. + return span_embeddings * span_indices_mask.unsqueeze(-1) + + return span_embeddings + + @overrides + def get_input_dim(self) -> int: + return self._input_dim + + def _embed_spans( + self, + sequence_tensor: torch.FloatTensor, + span_indices: torch.LongTensor, + sequence_mask: torch.BoolTensor = None, + span_indices_mask: torch.BoolTensor = None, + ) -> torch.Tensor: + """ + Returns the span embeddings computed in many different ways. + """ + raise NotImplementedError diff --git a/tests/modules/span_extractors/self_attentive_span_extractor_test.py b/tests/modules/span_extractors/self_attentive_span_extractor_test.py index d66c878c09c..4e5af1348d7 100644 --- a/tests/modules/span_extractors/self_attentive_span_extractor_test.py +++ b/tests/modules/span_extractors/self_attentive_span_extractor_test.py @@ -7,9 +7,17 @@ class TestSelfAttentiveSpanExtractor: def test_locally_normalised_span_extractor_can_build_from_params(self): - params = Params({"type": "self_attentive", "input_dim": 5}) + params = Params( + { + "type": "self_attentive", + "input_dim": 7, + "num_width_embeddings": 5, + "span_width_embedding_dim": 3, + } + ) extractor = SpanExtractor.from_params(params) assert isinstance(extractor, SelfAttentiveSpanExtractor) + assert extractor.get_output_dim() == 10 # input_dim + span_width_embedding_dim def test_attention_is_normalised_correctly(self): input_dim = 7 @@ -70,3 +78,33 @@ def test_attention_is_normalised_correctly(self): numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy()) # Second span was masked, so should be completely zero. numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), numpy.zeros([input_dim])) + + def test_widths_are_embedded_correctly(self): + input_dim = 7 + max_span_width = 5 + span_width_embedding_dim = 3 + output_dim = input_dim + span_width_embedding_dim + extractor = SelfAttentiveSpanExtractor( + input_dim=input_dim, + num_width_embeddings=max_span_width, + span_width_embedding_dim=span_width_embedding_dim, + ) + assert extractor.get_output_dim() == output_dim + assert extractor.get_input_dim() == input_dim + + sequence_tensor = torch.randn([2, max_span_width, input_dim]) + indices = torch.LongTensor( + [[[1, 3], [0, 4], [0, 0]], [[0, 2], [1, 4], [2, 2]]] + ) # smaller span tests masking. + span_representations = extractor(sequence_tensor, indices) + assert list(span_representations.size()) == [2, 3, output_dim] + + width_embeddings = extractor._span_width_embedding.weight.data.numpy() + widths_minus_one = indices[..., 1] - indices[..., 0] + for element in range(indices.size(0)): + for span in range(indices.size(1)): + width = widths_minus_one[element, span].item() + width_embedding = span_representations[element, span, input_dim:] + numpy.testing.assert_array_almost_equal( + width_embedding.data.numpy(), width_embeddings[width] + )