Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Fixes token type ids for folded sequences (#5149)
Browse files Browse the repository at this point in the history
* Fixes token type ids for folded sequences

* Changelog

* Save memory on the GitHub test runners

* Tensors have to be on the same device
  • Loading branch information
dirkgr committed May 10, 2021
1 parent 402bc78 commit a6a600d
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 10 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.sanity_checks.task_checklists` module.


## Unreleased

### Fixed

- When `PretrainedTransformerIndexer` folds long sequences, it no longer loses the information from token type ids.


## [v2.4.0](https://github.com/allenai/allennlp/releases/tag/v2.4.0) - 2021-04-22

### Added
Expand Down
29 changes: 28 additions & 1 deletion allennlp/data/token_indexers/pretrained_transformer_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,25 +154,52 @@ def _postprocess_output(self, output: IndexedTokenList) -> IndexedTokenList:
# TODO(zhaofengw): we aren't respecting word boundaries when segmenting wordpieces.

indices = output["token_ids"]
type_ids = output.get("type_ids", [0] * len(indices))

# Strips original special tokens
indices = indices[
self._num_added_start_tokens : len(indices) - self._num_added_end_tokens
]
type_ids = type_ids[
self._num_added_start_tokens : len(type_ids) - self._num_added_end_tokens
]

# Folds indices
folded_indices = [
indices[i : i + self._effective_max_length]
for i in range(0, len(indices), self._effective_max_length)
]
folded_type_ids = [
type_ids[i : i + self._effective_max_length]
for i in range(0, len(type_ids), self._effective_max_length)
]

# Adds special tokens to each segment
folded_indices = [
self._tokenizer.build_inputs_with_special_tokens(segment)
for segment in folded_indices
]
single_sequence_start_type_ids = [
t.type_id for t in self._allennlp_tokenizer.single_sequence_start_tokens
]
single_sequence_end_type_ids = [
t.type_id for t in self._allennlp_tokenizer.single_sequence_end_tokens
]
folded_type_ids = [
single_sequence_start_type_ids + segment + single_sequence_end_type_ids
for segment in folded_type_ids
]
assert all(
len(segment_indices) == len(segment_type_ids)
for segment_indices, segment_type_ids in zip(folded_indices, folded_type_ids)
)

# Flattens
indices = [i for segment in folded_indices for i in segment]
type_ids = [i for segment in folded_type_ids for i in segment]

output["token_ids"] = indices
output["type_ids"] = [0] * len(indices)
output["type_ids"] = type_ids
output["segment_concat_mask"] = [True] * len(indices)

return output
Expand Down
16 changes: 16 additions & 0 deletions tests/data/token_indexers/pretrained_transformer_indexer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,22 @@ def test_long_sequence_splitting(self):
assert indexed["segment_concat_mask"] == [True] * len(expected_ids)
assert indexed["mask"] == [True] * 7 # original length

def test_type_ids_when_folding(self):
allennlp_tokenizer = PretrainedTransformerTokenizer(
"bert-base-uncased", add_special_tokens=False
)
indexer = PretrainedTransformerIndexer(model_name="bert-base-uncased", max_length=6)
first_string = "How do trees get online?"
second_string = "They log in!"

tokens = allennlp_tokenizer.add_special_tokens(
allennlp_tokenizer.tokenize(first_string), allennlp_tokenizer.tokenize(second_string)
)
vocab = Vocabulary()
indexed = indexer.tokens_to_indices(tokens, vocab)
assert min(indexed["type_ids"]) == 0
assert max(indexed["type_ids"]) == 1

@staticmethod
def _assert_tokens_equal(expected_tokens, actual_tokens):
for expected, actual in zip(expected_tokens, actual_tokens):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from allennlp.common import Params
from allennlp.common.testing import AllenNlpTestCase
from allennlp.common.testing import AllenNlpTestCase, requires_gpu
from allennlp.data import Vocabulary
from allennlp.data.batch import Batch
from allennlp.data.fields import TextField
Expand All @@ -15,14 +15,15 @@


class TestPretrainedTransformerEmbedder(AllenNlpTestCase):
@requires_gpu
def test_forward_runs_when_initialized_from_params(self):
# This code just passes things off to `transformers`, so we only have a very simple
# test.
params = Params({"model_name": "bert-base-uncased"})
embedder = PretrainedTransformerEmbedder.from_params(params)
embedder = PretrainedTransformerEmbedder.from_params(params).cuda()
token_ids = torch.randint(0, 100, (1, 4))
mask = torch.randint(0, 2, (1, 4)).bool()
output = embedder(token_ids=token_ids, mask=mask)
output = embedder(token_ids=token_ids.cuda(), mask=mask.cuda())
assert tuple(output.size()) == (1, 4, 768)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -169,22 +170,24 @@ def test_end_to_end_t5(
assert bert_vectors.size() == (2, 8, 64)
assert bert_vectors.requires_grad == (train_parameters or not last_layer_only)

@requires_gpu
def test_big_token_type_ids(self):
token_embedder = PretrainedTransformerEmbedder("roberta-base")
token_embedder = PretrainedTransformerEmbedder("roberta-base").cuda()
token_ids = torch.LongTensor([[1, 2, 3], [2, 3, 4]])
mask = torch.ones_like(token_ids).bool()
type_ids = torch.zeros_like(token_ids)
type_ids[1, 1] = 1
with pytest.raises(ValueError):
token_embedder(token_ids, mask, type_ids)
token_embedder(token_ids.cuda(), mask.cuda(), type_ids.cuda())

@requires_gpu
def test_xlnet_token_type_ids(self):
token_embedder = PretrainedTransformerEmbedder("xlnet-base-cased")
token_embedder = PretrainedTransformerEmbedder("xlnet-base-cased").cuda()
token_ids = torch.LongTensor([[1, 2, 3], [2, 3, 4]])
mask = torch.ones_like(token_ids).bool()
type_ids = torch.zeros_like(token_ids)
type_ids[1, 1] = 1
token_embedder(token_ids, mask, type_ids)
token_embedder(token_ids.cuda(), mask.cuda(), type_ids.cuda())

def test_long_sequence_splitting_end_to_end(self):
# Mostly the same as the end_to_end test (except for adding max_length=4),
Expand Down Expand Up @@ -310,11 +313,14 @@ def test_unfold_long_sequences(self):
)
assert (unfolded_embeddings_out == unfolded_embeddings).all()

@requires_gpu
def test_encoder_decoder_model(self):
token_embedder = PretrainedTransformerEmbedder("facebook/bart-large", sub_module="encoder")
token_embedder = PretrainedTransformerEmbedder(
"facebook/bart-large", sub_module="encoder"
).cuda()
token_ids = torch.LongTensor([[1, 2, 3], [2, 3, 4]])
mask = torch.ones_like(token_ids).bool()
token_embedder(token_ids, mask)
token_embedder(token_ids.cuda(), mask.cuda())

def test_embeddings_resize(self):
regular_token_embedder = PretrainedTransformerEmbedder("bert-base-cased")
Expand Down

0 comments on commit a6a600d

Please sign in to comment.