From e89c992a0709fca414246fd3ec741a342ed899cf Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Fri, 23 Apr 2021 12:02:53 -0700 Subject: [PATCH 1/4] Fixes token type ids for folded sequences --- .../pretrained_transformer_indexer.py | 29 ++++++++++++++++++- .../pretrained_transformer_indexer_test.py | 16 ++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/allennlp/data/token_indexers/pretrained_transformer_indexer.py b/allennlp/data/token_indexers/pretrained_transformer_indexer.py index ae90663062c..ee88a153f45 100644 --- a/allennlp/data/token_indexers/pretrained_transformer_indexer.py +++ b/allennlp/data/token_indexers/pretrained_transformer_indexer.py @@ -154,25 +154,52 @@ def _postprocess_output(self, output: IndexedTokenList) -> IndexedTokenList: # TODO(zhaofengw): we aren't respecting word boundaries when segmenting wordpieces. indices = output["token_ids"] + type_ids = output.get("type_ids", [0] * len(indices)) + # Strips original special tokens indices = indices[ self._num_added_start_tokens : len(indices) - self._num_added_end_tokens ] + type_ids = type_ids[ + self._num_added_start_tokens : len(type_ids) - self._num_added_end_tokens + ] + # Folds indices folded_indices = [ indices[i : i + self._effective_max_length] for i in range(0, len(indices), self._effective_max_length) ] + folded_type_ids = [ + type_ids[i : i + self._effective_max_length] + for i in range(0, len(type_ids), self._effective_max_length) + ] + # Adds special tokens to each segment folded_indices = [ self._tokenizer.build_inputs_with_special_tokens(segment) for segment in folded_indices ] + single_sequence_start_type_ids = [ + t.type_id for t in self._allennlp_tokenizer.single_sequence_start_tokens + ] + single_sequence_end_type_ids = [ + t.type_id for t in self._allennlp_tokenizer.single_sequence_end_tokens + ] + folded_type_ids = [ + single_sequence_start_type_ids + segment + single_sequence_end_type_ids + for segment in folded_type_ids + ] + assert all( + len(segment_indices) == len(segment_type_ids) + for segment_indices, segment_type_ids in zip(folded_indices, folded_type_ids) + ) + # Flattens indices = [i for segment in folded_indices for i in segment] + type_ids = [i for segment in folded_type_ids for i in segment] output["token_ids"] = indices - output["type_ids"] = [0] * len(indices) + output["type_ids"] = type_ids output["segment_concat_mask"] = [True] * len(indices) return output diff --git a/tests/data/token_indexers/pretrained_transformer_indexer_test.py b/tests/data/token_indexers/pretrained_transformer_indexer_test.py index d817af9b392..84021394cf3 100644 --- a/tests/data/token_indexers/pretrained_transformer_indexer_test.py +++ b/tests/data/token_indexers/pretrained_transformer_indexer_test.py @@ -163,6 +163,22 @@ def test_long_sequence_splitting(self): assert indexed["segment_concat_mask"] == [True] * len(expected_ids) assert indexed["mask"] == [True] * 7 # original length + def test_type_ids_when_folding(self): + allennlp_tokenizer = PretrainedTransformerTokenizer( + "bert-base-uncased", add_special_tokens=False + ) + indexer = PretrainedTransformerIndexer(model_name="bert-base-uncased", max_length=6) + first_string = "How do trees get online?" + second_string = "They log in!" + + tokens = allennlp_tokenizer.add_special_tokens( + allennlp_tokenizer.tokenize(first_string), allennlp_tokenizer.tokenize(second_string) + ) + vocab = Vocabulary() + indexed = indexer.tokens_to_indices(tokens, vocab) + assert min(indexed["type_ids"]) == 0 + assert max(indexed["type_ids"]) == 1 + @staticmethod def _assert_tokens_equal(expected_tokens, actual_tokens): for expected, actual in zip(expected_tokens, actual_tokens): From a1d94524d5068bb2c80692d6e949c6947238e472 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Fri, 23 Apr 2021 12:07:14 -0700 Subject: [PATCH 2/4] Changelog --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7adc2ad183d..29652f5a2d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased + +### Fixed + +- When `PretrainedTransformerIndexer` folds long sequences, it no longer loses the information from token type ids. + + ## [v2.4.0](https://github.com/allenai/allennlp/releases/tag/v2.4.0) - 2021-04-22 ### Added From ad26dd6ca35da7c464124c876a822d59d38fba45 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 26 Apr 2021 15:33:13 -0700 Subject: [PATCH 3/4] Save memory on the GitHub test runners --- .../pretrained_transformer_embedder_test.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py index 00356d80ae8..60f0d34e8d1 100644 --- a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py +++ b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py @@ -3,7 +3,7 @@ import torch from allennlp.common import Params -from allennlp.common.testing import AllenNlpTestCase +from allennlp.common.testing import AllenNlpTestCase, requires_gpu from allennlp.data import Vocabulary from allennlp.data.batch import Batch from allennlp.data.fields import TextField @@ -15,11 +15,12 @@ class TestPretrainedTransformerEmbedder(AllenNlpTestCase): + @requires_gpu def test_forward_runs_when_initialized_from_params(self): # This code just passes things off to `transformers`, so we only have a very simple # test. params = Params({"model_name": "bert-base-uncased"}) - embedder = PretrainedTransformerEmbedder.from_params(params) + embedder = PretrainedTransformerEmbedder.from_params(params).cuda() token_ids = torch.randint(0, 100, (1, 4)) mask = torch.randint(0, 2, (1, 4)).bool() output = embedder(token_ids=token_ids, mask=mask) @@ -169,8 +170,9 @@ def test_end_to_end_t5( assert bert_vectors.size() == (2, 8, 64) assert bert_vectors.requires_grad == (train_parameters or not last_layer_only) + @requires_gpu def test_big_token_type_ids(self): - token_embedder = PretrainedTransformerEmbedder("roberta-base") + token_embedder = PretrainedTransformerEmbedder("roberta-base").cuda() token_ids = torch.LongTensor([[1, 2, 3], [2, 3, 4]]) mask = torch.ones_like(token_ids).bool() type_ids = torch.zeros_like(token_ids) @@ -178,8 +180,9 @@ def test_big_token_type_ids(self): with pytest.raises(ValueError): token_embedder(token_ids, mask, type_ids) + @requires_gpu def test_xlnet_token_type_ids(self): - token_embedder = PretrainedTransformerEmbedder("xlnet-base-cased") + token_embedder = PretrainedTransformerEmbedder("xlnet-base-cased").cuda() token_ids = torch.LongTensor([[1, 2, 3], [2, 3, 4]]) mask = torch.ones_like(token_ids).bool() type_ids = torch.zeros_like(token_ids) @@ -310,8 +313,11 @@ def test_unfold_long_sequences(self): ) assert (unfolded_embeddings_out == unfolded_embeddings).all() + @requires_gpu def test_encoder_decoder_model(self): - token_embedder = PretrainedTransformerEmbedder("facebook/bart-large", sub_module="encoder") + token_embedder = PretrainedTransformerEmbedder( + "facebook/bart-large", sub_module="encoder" + ).cuda() token_ids = torch.LongTensor([[1, 2, 3], [2, 3, 4]]) mask = torch.ones_like(token_ids).bool() token_embedder(token_ids, mask) From 3d565517ab1fde2f28e3b4f7b66eeac4d293fd34 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 26 Apr 2021 15:41:43 -0700 Subject: [PATCH 4/4] Tensors have to be on the same device --- .../pretrained_transformer_embedder_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py index 60f0d34e8d1..a3343c07700 100644 --- a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py +++ b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py @@ -23,7 +23,7 @@ def test_forward_runs_when_initialized_from_params(self): embedder = PretrainedTransformerEmbedder.from_params(params).cuda() token_ids = torch.randint(0, 100, (1, 4)) mask = torch.randint(0, 2, (1, 4)).bool() - output = embedder(token_ids=token_ids, mask=mask) + output = embedder(token_ids=token_ids.cuda(), mask=mask.cuda()) assert tuple(output.size()) == (1, 4, 768) @pytest.mark.parametrize( @@ -178,7 +178,7 @@ def test_big_token_type_ids(self): type_ids = torch.zeros_like(token_ids) type_ids[1, 1] = 1 with pytest.raises(ValueError): - token_embedder(token_ids, mask, type_ids) + token_embedder(token_ids.cuda(), mask.cuda(), type_ids.cuda()) @requires_gpu def test_xlnet_token_type_ids(self): @@ -187,7 +187,7 @@ def test_xlnet_token_type_ids(self): mask = torch.ones_like(token_ids).bool() type_ids = torch.zeros_like(token_ids) type_ids[1, 1] = 1 - token_embedder(token_ids, mask, type_ids) + token_embedder(token_ids.cuda(), mask.cuda(), type_ids.cuda()) def test_long_sequence_splitting_end_to_end(self): # Mostly the same as the end_to_end test (except for adding max_length=4), @@ -320,7 +320,7 @@ def test_encoder_decoder_model(self): ).cuda() token_ids = torch.LongTensor([[1, 2, 3], [2, 3, 4]]) mask = torch.ones_like(token_ids).bool() - token_embedder(token_ids, mask) + token_embedder(token_ids.cuda(), mask.cuda()) def test_embeddings_resize(self): regular_token_embedder = PretrainedTransformerEmbedder("bert-base-cased")