From 8bba728012ac357c8b115d862edf02f4896d9880 Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Thu, 7 Jul 2022 09:53:26 +0200 Subject: [PATCH 1/3] Update tests --- tests_adapters/test_adapter_embeddings.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests_adapters/test_adapter_embeddings.py b/tests_adapters/test_adapter_embeddings.py index 01632c479..3c0c5daf7 100644 --- a/tests_adapters/test_adapter_embeddings.py +++ b/tests_adapters/test_adapter_embeddings.py @@ -27,6 +27,14 @@ def test_add_embeddings(self): model.add_embeddings("test", tokenizer) self.assertEqual(model.active_embeddings, "test") + def test_add_embedding_tokens(self): + model = self.get_model() + tokenizer = AutoTokenizer.from_pretrained("tests_adapters/fixtures/SiBERT") + self.assertEqual(tokenizer.vocab_size, 10000) + tokenizer.add_tokens(["test_token"]) + model.add_embeddings("test", tokenizer) + self.assertEqual(model.get_input_embeddings().num_embeddings, 10001) + def test_delete_embeddings(self): model = self.get_model() tokenizer = AutoTokenizer.from_pretrained("tests_adapters/fixtures/SiBERT") @@ -73,7 +81,12 @@ def test_back_to_default(self): self.assertTrue(torch.equal(output1[0], output2[0])) def test_training_embedding(self): + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token model = AutoAdapterModel.from_config(self.config()) + model.add_embeddings("test", tokenizer) + self.assertEqual(model.active_embeddings, "test") model.add_adapter("test") self.add_head(model, "test") model.train_adapter("test", train_embeddings=True) @@ -85,7 +98,7 @@ def test_training_embedding(self): state_dict_pre = copy.deepcopy(model.state_dict()) - train_dataset = self.dataset() + train_dataset = self.dataset(tokenizer=tokenizer) training_args = TrainingArguments( output_dir="./examples", do_train=True, From fa4d3fdd35a76846b260cb79a93b628b3e146f49 Mon Sep 17 00:00:00 2001 From: calpt Date: Thu, 7 Jul 2022 12:46:17 +0200 Subject: [PATCH 2/3] Add a new `EmbeddingAdaptersWrapperMixin`. Fixes issues with embedding method accessability. Fixes issue with embedding training. Fixes issue with embedding size. --- src/transformers/adapters/mixins/bart.py | 12 +++- src/transformers/adapters/mixins/bert.py | 12 +++- .../adapters/mixins/distilbert.py | 12 +++- src/transformers/adapters/mixins/gpt2.py | 12 +++- src/transformers/adapters/mixins/t5.py | 12 +++- src/transformers/adapters/model_mixin.py | 63 +++++++++---------- .../adapters/models/bart/adapter_model.py | 3 +- .../adapters/models/bert/adapter_model.py | 3 +- .../adapters/models/deberta/adapter_model.py | 3 +- .../models/debertaV2/adapter_model.py | 5 +- .../models/distilbert/adapter_model.py | 5 +- .../adapters/models/gpt2/adapter_model.py | 3 +- .../adapters/models/mbart/adapter_model.py | 3 +- .../adapters/models/roberta/adapter_model.py | 3 +- .../adapters/models/t5/adapter_model.py | 3 +- src/transformers/models/bart/modeling_bart.py | 11 ++-- src/transformers/models/bert/modeling_bert.py | 24 ++++--- .../models/deberta/modeling_deberta.py | 16 +++-- .../models/deberta_v2/modeling_deberta_v2.py | 18 +++--- .../models/distilbert/modeling_distilbert.py | 17 ++--- src/transformers/models/gpt2/modeling_gpt2.py | 15 +++-- .../models/mbart/modeling_mbart.py | 11 ++-- .../models/roberta/modeling_roberta.py | 20 +++--- src/transformers/models/t5/modeling_t5.py | 5 +- tests_adapters/test_adapter_embeddings.py | 4 ++ 25 files changed, 191 insertions(+), 104 deletions(-) diff --git a/src/transformers/adapters/mixins/bart.py b/src/transformers/adapters/mixins/bart.py index f098dfdd1..6bf4e5576 100644 --- a/src/transformers/adapters/mixins/bart.py +++ b/src/transformers/adapters/mixins/bart.py @@ -3,7 +3,13 @@ import torch.nn as nn from ..layer import AdapterLayer -from ..model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin +from ..model_mixin import ( + EmbeddingAdaptersMixin, + EmbeddingAdaptersWrapperMixin, + InvertibleAdaptersMixin, + ModelAdaptersMixin, + ModelWithHeadsAdaptersMixin, +) class BartEncoderLayerAdaptersMixin: @@ -48,3 +54,7 @@ def _init_adapter_modules(self): self.enable_invertible_adapters = self.encoder.enable_invertible_adapters self.invertible_adapters_forward = self.encoder.invertible_adapters_forward super()._init_adapter_modules() + + +class BartModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin): + pass diff --git a/src/transformers/adapters/mixins/bert.py b/src/transformers/adapters/mixins/bert.py index 21b2a203c..618802917 100644 --- a/src/transformers/adapters/mixins/bert.py +++ b/src/transformers/adapters/mixins/bert.py @@ -4,7 +4,13 @@ import torch.nn as nn from ..layer import AdapterLayer -from ..model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin +from ..model_mixin import ( + EmbeddingAdaptersMixin, + EmbeddingAdaptersWrapperMixin, + InvertibleAdaptersMixin, + ModelAdaptersMixin, + ModelWithHeadsAdaptersMixin, +) logger = logging.getLogger(__name__) @@ -32,3 +38,7 @@ class BertModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, Mo def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.encoder.layer): yield i, layer + + +class BertModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin): + pass diff --git a/src/transformers/adapters/mixins/distilbert.py b/src/transformers/adapters/mixins/distilbert.py index 8713a682d..60a9d0333 100644 --- a/src/transformers/adapters/mixins/distilbert.py +++ b/src/transformers/adapters/mixins/distilbert.py @@ -3,7 +3,13 @@ import torch.nn as nn from ..layer import AdapterLayer -from ..model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin +from ..model_mixin import ( + EmbeddingAdaptersMixin, + EmbeddingAdaptersWrapperMixin, + InvertibleAdaptersMixin, + ModelAdaptersMixin, + ModelWithHeadsAdaptersMixin, +) class DistilBertTransfomerBlockAdaptersMixin: @@ -22,3 +28,7 @@ class DistilBertModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMix def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.transformer.layer): yield i, layer + + +class DistilBertModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin): + pass diff --git a/src/transformers/adapters/mixins/gpt2.py b/src/transformers/adapters/mixins/gpt2.py index 50c724e2a..236b73bb4 100644 --- a/src/transformers/adapters/mixins/gpt2.py +++ b/src/transformers/adapters/mixins/gpt2.py @@ -3,7 +3,13 @@ import torch.nn as nn from ..layer import AdapterLayer -from ..model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin +from ..model_mixin import ( + EmbeddingAdaptersMixin, + EmbeddingAdaptersWrapperMixin, + InvertibleAdaptersMixin, + ModelAdaptersMixin, + ModelWithHeadsAdaptersMixin, +) class GPT2DecoderBlockAdaptersMixin: @@ -20,3 +26,7 @@ class GPT2ModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, Mod def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.base_model.h): yield i, layer + + +class GPT2ModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin): + pass diff --git a/src/transformers/adapters/mixins/t5.py b/src/transformers/adapters/mixins/t5.py index 49117d827..20250a9eb 100644 --- a/src/transformers/adapters/mixins/t5.py +++ b/src/transformers/adapters/mixins/t5.py @@ -3,7 +3,13 @@ import torch.nn as nn from ..layer import AdapterLayer -from ..model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin +from ..model_mixin import ( + EmbeddingAdaptersMixin, + EmbeddingAdaptersWrapperMixin, + InvertibleAdaptersMixin, + ModelAdaptersMixin, + ModelWithHeadsAdaptersMixin, +) class T5SelfAttentionLayerAdaptersMixin(AdapterLayer): @@ -45,3 +51,7 @@ def _init_adapter_modules(self): self.invertible_adapters_forward = self.encoder.invertible_adapters_forward self.delete_invertible_adapter = self.encoder.delete_invertible_adapter super()._init_adapter_modules() + + +class T5ModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin): + pass diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index 815b2526e..f82ee23e1 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -152,7 +152,7 @@ def add_embeddings(self, name, tokenizer, reference_embedding=None, reference_to raise ValueError("An embedding with the name {} already exists".format(name)) if embedding_dim is None: embedding_dim = self.config.hidden_size - embedding = nn.Embedding(tokenizer.vocab_size, embedding_dim) + embedding = nn.Embedding(len(tokenizer), embedding_dim) embedding.requires_grad_(False) if (reference_embedding is not None and reference_tokenizer is None) or ( reference_tokenizer is not None and reference_embedding is None @@ -167,7 +167,9 @@ def add_embeddings(self, name, tokenizer, reference_embedding=None, reference_to for t in tokens: idx_reference = reference_vocab[t] idx = vocab[t] - embedding.weight[idx] = self.loaded_embeddings[reference_embedding].weight[idx_reference].clone() + embedding.weight[idx] = ( + self.loaded_embeddings[reference_embedding].weight[idx_reference].detach().clone() + ) embedding.train(False) self.loaded_embeddings[name] = embedding self.set_active_embeddings(name) @@ -224,6 +226,31 @@ def active_embeddings(self): return self._active_embedding +class EmbeddingAdaptersWrapperMixin: + def load_embeddings(self, path: str, name: str): + return self.base_model.load_embeddings(path, name) + + def add_embeddings(self, name, tokenizer, reference_embedding=None, reference_tokenizer=None, embedding_dim=None): + return self.base_model.add_embeddings(name, tokenizer, reference_embedding, reference_tokenizer, embedding_dim) + + def delete_embeddings(self, name): + return self.base_model.delete_embeddings(name) + + def save_embeddings(self, path, name, tokenizer=None): + return self.base_model.save_embeddings(path, name, tokenizer) + + def set_active_embeddings(self, name): + return self.base_model.set_active_embeddings(name) + + @property + def active_embeddings(self): + return self.base_model.active_embeddings + + @property + def loaded_embeddings(self): + return self.base_model.loaded_embeddings + + class ModelAdaptersMixin(PushAdapterToHubMixin, ABC): """Mixin for transformer models adding support for loading/ saving adapters.""" @@ -1128,35 +1155,3 @@ def get_adapter(self, name): return super().get_adapter(name) else: return self.base_model.get_adapter(name) - - def load_embeddings(self, path: str, name: str): - if self.base_model is self: - return super().load_embeddings(path, name) - else: - return self.base_model.load_embeddings(path, name) - - def save_embeddings(self, path, name, tokenizer=None): - if self.base_model is self: - return super().save_embeddings(path, name, tokenizer) - else: - return self.base_model.save_embeddings(path, name, tokenizer) - - def add_embeddings(self, name, tokenizer, reference_embedding=None, reference_tokenizer=None, embedding_dim=None): - if self.base_model is None: - return super().add_embeddings(name, tokenizer, reference_embedding, reference_tokenizer, embedding_dim) - else: - return self.base_model.add_embeddings( - name, tokenizer, reference_embedding, reference_tokenizer, embedding_dim - ) - - def set_active_embeddings(self, name): - if self.base_model is None: - return super().set_active_embeddings(name) - else: - return self.base_model.set_active_embeddings(name) - - def delete_embeddings(self, name): - if self.base_model is None: - return super().delete_embeddings(name) - else: - return self.base_model.delete_embeddings(name) diff --git a/src/transformers/adapters/models/bart/adapter_model.py b/src/transformers/adapters/models/bart/adapter_model.py index 8dc3abeb1..596adb2bf 100644 --- a/src/transformers/adapters/models/bart/adapter_model.py +++ b/src/transformers/adapters/models/bart/adapter_model.py @@ -19,12 +19,13 @@ QuestionAnsweringHead, Seq2SeqLMHead, ) +from ...model_mixin import EmbeddingAdaptersWrapperMixin @add_start_docstrings( "BART Model with the option to add multiple flexible prediction heads on top.", BART_START_DOCSTRING ) -class BartAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, BartPretrainedModel): +class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BartPretrainedModel): def __init__(self, config: BartConfig, **kwargs): super().__init__(config, **kwargs) self.model = BartModel(config) diff --git a/src/transformers/adapters/models/bert/adapter_model.py b/src/transformers/adapters/models/bert/adapter_model.py index fc285c91d..567201910 100644 --- a/src/transformers/adapters/models/bert/adapter_model.py +++ b/src/transformers/adapters/models/bert/adapter_model.py @@ -14,13 +14,14 @@ QuestionAnsweringHead, TaggingHead, ) +from ...model_mixin import EmbeddingAdaptersWrapperMixin @add_start_docstrings( """Bert Model transformer with the option to add multiple flexible heads on top.""", BERT_START_DOCSTRING, ) -class BertAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, BertPreTrainedModel): +class BertAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BertPreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/adapters/models/deberta/adapter_model.py b/src/transformers/adapters/models/deberta/adapter_model.py index 1fc2c0586..310bad4de 100644 --- a/src/transformers/adapters/models/deberta/adapter_model.py +++ b/src/transformers/adapters/models/deberta/adapter_model.py @@ -9,12 +9,13 @@ QuestionAnsweringHead, TaggingHead, ) +from ...model_mixin import EmbeddingAdaptersWrapperMixin @add_start_docstrings( """Deberta Model transformer with the option to add multiple flexible heads on top.""", ) -class DebertaAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, DebertaPreTrainedModel): +class DebertaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, DebertaPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"cls.predictions.bias"] def __init__(self, config): diff --git a/src/transformers/adapters/models/debertaV2/adapter_model.py b/src/transformers/adapters/models/debertaV2/adapter_model.py index 18e1c4abe..bca062036 100644 --- a/src/transformers/adapters/models/debertaV2/adapter_model.py +++ b/src/transformers/adapters/models/debertaV2/adapter_model.py @@ -10,12 +10,15 @@ QuestionAnsweringHead, TaggingHead, ) +from ...model_mixin import EmbeddingAdaptersWrapperMixin @add_start_docstrings( """Deberta v2 Model transformer with the option to add multiple flexible heads on top.""", ) -class DebertaV2AdapterModel(ModelWithFlexibleHeadsAdaptersMixin, DebertaV2PreTrainedModel): +class DebertaV2AdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, DebertaV2PreTrainedModel +): _keys_to_ignore_on_load_unexpected = [r"cls.predictions.bias"] def __init__(self, config): diff --git a/src/transformers/adapters/models/distilbert/adapter_model.py b/src/transformers/adapters/models/distilbert/adapter_model.py index 7c135487e..5c47f8ba2 100644 --- a/src/transformers/adapters/models/distilbert/adapter_model.py +++ b/src/transformers/adapters/models/distilbert/adapter_model.py @@ -20,13 +20,16 @@ QuestionAnsweringHead, TaggingHead, ) +from ...model_mixin import EmbeddingAdaptersWrapperMixin @add_start_docstrings( """DistilBert Model transformer with the option to add multiple flexible heads on top.""", DISTILBERT_START_DOCSTRING, ) -class DistilBertAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, DistilBertPreTrainedModel): +class DistilBertAdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, DistilBertPreTrainedModel +): def __init__(self, config): super().__init__(config) self.distilbert = DistilBertModel(config) diff --git a/src/transformers/adapters/models/gpt2/adapter_model.py b/src/transformers/adapters/models/gpt2/adapter_model.py index 3d5628ee3..f0c588d27 100644 --- a/src/transformers/adapters/models/gpt2/adapter_model.py +++ b/src/transformers/adapters/models/gpt2/adapter_model.py @@ -13,6 +13,7 @@ MultiLabelClassificationHead, TaggingHead, ) +from ...model_mixin import EmbeddingAdaptersWrapperMixin logger = logging.getLogger(__name__) @@ -29,7 +30,7 @@ """, GPT2_START_DOCSTRING, ) -class GPT2AdapterModel(ModelWithFlexibleHeadsAdaptersMixin, GPT2PreTrainedModel): +class GPT2AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPT2PreTrainedModel): def __init__(self, config): super().__init__(config) self.transformer = GPT2Model(config) diff --git a/src/transformers/adapters/models/mbart/adapter_model.py b/src/transformers/adapters/models/mbart/adapter_model.py index bf01f7444..d513a5a20 100644 --- a/src/transformers/adapters/models/mbart/adapter_model.py +++ b/src/transformers/adapters/models/mbart/adapter_model.py @@ -19,12 +19,13 @@ QuestionAnsweringHead, Seq2SeqLMHead, ) +from ...model_mixin import EmbeddingAdaptersWrapperMixin @add_start_docstrings( "MBART Model with the option to add multiple flexible prediction heads on top.", MBART_START_DOCSTRING ) -class MBartAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, MBartPreTrainedModel): +class MBartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MBartPreTrainedModel): def __init__(self, config: MBartConfig, **kwargs): super().__init__(config, **kwargs) self.model = MBartModel(config) diff --git a/src/transformers/adapters/models/roberta/adapter_model.py b/src/transformers/adapters/models/roberta/adapter_model.py index 84c0b8083..b8bfff31d 100644 --- a/src/transformers/adapters/models/roberta/adapter_model.py +++ b/src/transformers/adapters/models/roberta/adapter_model.py @@ -19,13 +19,14 @@ QuestionAnsweringHead, TaggingHead, ) +from ...model_mixin import EmbeddingAdaptersWrapperMixin @add_start_docstrings( """Roberta Model transformer with the option to add multiple flexible heads on top.""", ROBERTA_START_DOCSTRING, ) -class RobertaAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, RobertaPreTrainedModel): +class RobertaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, RobertaPreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/adapters/models/t5/adapter_model.py b/src/transformers/adapters/models/t5/adapter_model.py index 6f4b9bc45..80ef3741b 100644 --- a/src/transformers/adapters/models/t5/adapter_model.py +++ b/src/transformers/adapters/models/t5/adapter_model.py @@ -6,13 +6,14 @@ from ....models.t5.modeling_t5 import T5_INPUTS_DOCSTRING, T5_START_DOCSTRING, T5Model, T5PreTrainedModel from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...heads import ModelWithFlexibleHeadsAdaptersMixin, Seq2SeqLMHead +from ...model_mixin import EmbeddingAdaptersWrapperMixin logger = logging.getLogger(__name__) @add_start_docstrings("T5 Model with the option to add multiple flexible prediction heads on top.", T5_START_DOCSTRING) -class T5AdapterModel(ModelWithFlexibleHeadsAdaptersMixin, T5PreTrainedModel): +class T5AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, T5PreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index bf5f973a0..10276da75 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -32,8 +32,9 @@ BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin, BartModelAdaptersMixin, + BartModelWithHeadsAdaptersMixin, ) -from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin +from ...adapters.model_mixin import InvertibleAdaptersMixin from ...adapters.prefix_tuning import PrefixTuningShim from ...modeling_outputs import ( BaseModelOutput, @@ -1302,7 +1303,7 @@ def forward( @add_start_docstrings( "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING ) -class BartForConditionalGeneration(ModelWithHeadsAdaptersMixin, BartPretrainedModel): +class BartForConditionalGeneration(BartModelWithHeadsAdaptersMixin, BartPretrainedModel): base_model_prefix = "model" _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"] @@ -1472,7 +1473,7 @@ def _reorder_cache(past, beam_idx): """, BART_START_DOCSTRING, ) -class BartForSequenceClassification(ModelWithHeadsAdaptersMixin, BartPretrainedModel): +class BartForSequenceClassification(BartModelWithHeadsAdaptersMixin, BartPretrainedModel): def __init__(self, config: BartConfig, **kwargs): super().__init__(config, **kwargs) self.model = BartModel(config) @@ -1600,7 +1601,7 @@ def forward( """, BART_START_DOCSTRING, ) -class BartForQuestionAnswering(ModelWithHeadsAdaptersMixin, BartPretrainedModel): +class BartForQuestionAnswering(BartModelWithHeadsAdaptersMixin, BartPretrainedModel): def __init__(self, config): super().__init__(config) @@ -1737,7 +1738,7 @@ def get_input_embeddings(self): return self.decoder.get_input_embeddings() -class BartForCausalLM(ModelWithHeadsAdaptersMixin, BartPretrainedModel): +class BartForCausalLM(BartModelWithHeadsAdaptersMixin, BartPretrainedModel): def __init__(self, config): super().__init__(config) decoder_config = copy.deepcopy(config) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 66efbd548..1bdd42ae5 100644 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -32,8 +32,12 @@ from ...adapters.composition import adjust_tensors_for_parallel from ...adapters.context import ForwardContext from ...adapters.lora import Linear as LoRALinear -from ...adapters.mixins.bert import BertModelAdaptersMixin, BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin -from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin +from ...adapters.mixins.bert import ( + BertModelAdaptersMixin, + BertModelWithHeadsAdaptersMixin, + BertOutputAdaptersMixin, + BertSelfOutputAdaptersMixin, +) from ...adapters.prefix_tuning import PrefixTuningShim from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -1076,7 +1080,7 @@ def forward( """, BERT_START_DOCSTRING, ) -class BertForPreTraining(ModelWithHeadsAdaptersMixin, BertPreTrainedModel): +class BertForPreTraining(BertModelWithHeadsAdaptersMixin, BertPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1184,7 +1188,7 @@ def forward( @add_start_docstrings( """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING ) -class BertLMHeadModel(ModelWithHeadsAdaptersMixin, BertPreTrainedModel): +class BertLMHeadModel(BertModelWithHeadsAdaptersMixin, BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] @@ -1322,7 +1326,7 @@ def _reorder_cache(self, past, beam_idx): @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING) -class BertForMaskedLM(ModelWithHeadsAdaptersMixin, BertPreTrainedModel): +class BertForMaskedLM(BertModelWithHeadsAdaptersMixin, BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] @@ -1438,7 +1442,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ """Bert Model with a `next sentence prediction (classification)` head on top.""", BERT_START_DOCSTRING, ) -class BertForNextSentencePrediction(ModelWithHeadsAdaptersMixin, BertPreTrainedModel): +class BertForNextSentencePrediction(BertModelWithHeadsAdaptersMixin, BertPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1542,7 +1546,7 @@ def forward( """, BERT_START_DOCSTRING, ) -class BertForSequenceClassification(ModelWithHeadsAdaptersMixin, BertPreTrainedModel): +class BertForSequenceClassification(BertModelWithHeadsAdaptersMixin, BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels @@ -1646,7 +1650,7 @@ def forward( """, BERT_START_DOCSTRING, ) -class BertForMultipleChoice(ModelWithHeadsAdaptersMixin, BertPreTrainedModel): +class BertForMultipleChoice(BertModelWithHeadsAdaptersMixin, BertPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1741,7 +1745,7 @@ def forward( """, BERT_START_DOCSTRING, ) -class BertForTokenClassification(ModelWithHeadsAdaptersMixin, BertPreTrainedModel): +class BertForTokenClassification(BertModelWithHeadsAdaptersMixin, BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] @@ -1828,7 +1832,7 @@ def forward( """, BERT_START_DOCSTRING, ) -class BertForQuestionAnswering(ModelWithHeadsAdaptersMixin, BertPreTrainedModel): +class BertForQuestionAnswering(BertModelWithHeadsAdaptersMixin, BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index a3a291c5d..abecea8e7 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -27,8 +27,12 @@ from ...adapters.context import ForwardContext from ...adapters.lora import Linear as LoRALinear from ...adapters.lora import MergedLinear as LoRAMergedLinear -from ...adapters.mixins.bert import BertModelAdaptersMixin, BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin -from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin +from ...adapters.mixins.bert import ( + BertModelAdaptersMixin, + BertModelWithHeadsAdaptersMixin, + BertOutputAdaptersMixin, + BertSelfOutputAdaptersMixin, +) from ...adapters.prefix_tuning import PrefixTuningShim from ...modeling_outputs import ( BaseModelOutput, @@ -1017,7 +1021,7 @@ def forward( @add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING) -class DebertaForMaskedLM(ModelWithHeadsAdaptersMixin, DebertaPreTrainedModel): +class DebertaForMaskedLM(BertModelWithHeadsAdaptersMixin, DebertaPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] @@ -1165,7 +1169,7 @@ def forward(self, sequence_output, **kwargs): """, DEBERTA_START_DOCSTRING, ) -class DebertaForSequenceClassification(ModelWithHeadsAdaptersMixin, DebertaPreTrainedModel): +class DebertaForSequenceClassification(BertModelWithHeadsAdaptersMixin, DebertaPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1284,7 +1288,7 @@ def forward( """, DEBERTA_START_DOCSTRING, ) -class DebertaForTokenClassification(ModelWithHeadsAdaptersMixin, DebertaPreTrainedModel): +class DebertaForTokenClassification(BertModelWithHeadsAdaptersMixin, DebertaPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): @@ -1360,7 +1364,7 @@ def forward( """, DEBERTA_START_DOCSTRING, ) -class DebertaForQuestionAnswering(ModelWithHeadsAdaptersMixin, DebertaPreTrainedModel): +class DebertaForQuestionAnswering(BertModelWithHeadsAdaptersMixin, DebertaPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 3b1fb2188..c2f1a130a 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -27,8 +27,12 @@ from ...adapters.composition import adjust_tensors_for_parallel from ...adapters.context import ForwardContext from ...adapters.lora import Linear as LoRALinear -from ...adapters.mixins.bert import BertModelAdaptersMixin, BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin -from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin +from ...adapters.mixins.bert import ( + BertModelAdaptersMixin, + BertModelWithHeadsAdaptersMixin, + BertOutputAdaptersMixin, + BertSelfOutputAdaptersMixin, +) from ...adapters.prefix_tuning import PrefixTuningShim from ...modeling_outputs import ( BaseModelOutput, @@ -1115,7 +1119,7 @@ def forward( @add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING) # Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2 -class DebertaV2ForMaskedLM(ModelWithHeadsAdaptersMixin, DebertaV2PreTrainedModel): +class DebertaV2ForMaskedLM(BertModelWithHeadsAdaptersMixin, DebertaV2PreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] @@ -1265,7 +1269,7 @@ def forward(self, sequence_output, **kwargs): DEBERTA_START_DOCSTRING, ) # Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 -class DebertaV2ForSequenceClassification(ModelWithHeadsAdaptersMixin, DebertaV2PreTrainedModel): +class DebertaV2ForSequenceClassification(BertModelWithHeadsAdaptersMixin, DebertaV2PreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1385,7 +1389,7 @@ def forward( DEBERTA_START_DOCSTRING, ) # Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2 -class DebertaV2ForTokenClassification(ModelWithHeadsAdaptersMixin, DebertaV2PreTrainedModel): +class DebertaV2ForTokenClassification(BertModelWithHeadsAdaptersMixin, DebertaV2PreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): @@ -1462,7 +1466,7 @@ def forward( DEBERTA_START_DOCSTRING, ) # Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering with Deberta->DebertaV2 -class DebertaV2ForQuestionAnswering(ModelWithHeadsAdaptersMixin, DebertaV2PreTrainedModel): +class DebertaV2ForQuestionAnswering(BertModelWithHeadsAdaptersMixin, DebertaV2PreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): @@ -1562,7 +1566,7 @@ def forward( """, DEBERTA_START_DOCSTRING, ) -class DebertaV2ForMultipleChoice(ModelWithHeadsAdaptersMixin, DebertaV2PreTrainedModel): +class DebertaV2ForMultipleChoice(BertModelWithHeadsAdaptersMixin, DebertaV2PreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 5697c0d1a..ca324228b 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -33,8 +33,11 @@ from ...adapters.composition import adjust_tensors_for_parallel from ...adapters.context import ForwardContext from ...adapters.lora import Linear as LoRALinear -from ...adapters.mixins.distilbert import DistilBertModelAdaptersMixin, DistilBertTransfomerBlockAdaptersMixin -from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin +from ...adapters.mixins.distilbert import ( + DistilBertModelAdaptersMixin, + DistilBertModelWithHeadsAdaptersMixin, + DistilBertTransfomerBlockAdaptersMixin, +) from ...adapters.prefix_tuning import PrefixTuningShim from ...deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( @@ -597,7 +600,7 @@ def forward( """DistilBert Model with a `masked language modeling` head on top.""", DISTILBERT_START_DOCSTRING, ) -class DistilBertForMaskedLM(ModelWithHeadsAdaptersMixin, DistilBertPreTrainedModel): +class DistilBertForMaskedLM(DistilBertModelWithHeadsAdaptersMixin, DistilBertPreTrainedModel): def __init__(self, config: PretrainedConfig): super().__init__(config) @@ -704,7 +707,7 @@ def forward( """, DISTILBERT_START_DOCSTRING, ) -class DistilBertForSequenceClassification(ModelWithHeadsAdaptersMixin, DistilBertPreTrainedModel): +class DistilBertForSequenceClassification(DistilBertModelWithHeadsAdaptersMixin, DistilBertPreTrainedModel): def __init__(self, config: PretrainedConfig): super().__init__(config) self.num_labels = config.num_labels @@ -822,7 +825,7 @@ def forward( """, DISTILBERT_START_DOCSTRING, ) -class DistilBertForQuestionAnswering(ModelWithHeadsAdaptersMixin, DistilBertPreTrainedModel): +class DistilBertForQuestionAnswering(DistilBertModelWithHeadsAdaptersMixin, DistilBertPreTrainedModel): def __init__(self, config: PretrainedConfig): super().__init__(config) @@ -939,7 +942,7 @@ def forward( """, DISTILBERT_START_DOCSTRING, ) -class DistilBertForTokenClassification(ModelWithHeadsAdaptersMixin, DistilBertPreTrainedModel): +class DistilBertForTokenClassification(DistilBertModelWithHeadsAdaptersMixin, DistilBertPreTrainedModel): def __init__(self, config: PretrainedConfig): super().__init__(config) self.num_labels = config.num_labels @@ -1034,7 +1037,7 @@ def forward( """, DISTILBERT_START_DOCSTRING, ) -class DistilBertForMultipleChoice(ModelWithHeadsAdaptersMixin, DistilBertPreTrainedModel): +class DistilBertForMultipleChoice(DistilBertModelWithHeadsAdaptersMixin, DistilBertPreTrainedModel): def __init__(self, config: PretrainedConfig): super().__init__(config) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 87d0cba0e..86f848433 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -38,8 +38,11 @@ from ...adapters.context import ForwardContext from ...adapters.lora import Linear as LoRALinear from ...adapters.lora import MergedLinear as LoRAMergedLinear -from ...adapters.mixins.gpt2 import GPT2DecoderBlockAdaptersMixin, GPT2ModelAdapterMixin -from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin +from ...adapters.mixins.gpt2 import ( + GPT2DecoderBlockAdaptersMixin, + GPT2ModelAdapterMixin, + GPT2ModelWithHeadsAdaptersMixin, +) from ...adapters.prefix_tuning import PrefixTuningShim from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -972,7 +975,7 @@ def custom_forward(*inputs): """, GPT2_START_DOCSTRING, ) -class GPT2LMHeadModel(ModelWithHeadsAdaptersMixin, GPT2PreTrainedModel): +class GPT2LMHeadModel(GPT2ModelWithHeadsAdaptersMixin, GPT2PreTrainedModel): _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] def __init__(self, config): @@ -1143,7 +1146,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> """, GPT2_START_DOCSTRING, ) -class GPT2DoubleHeadsModel(ModelWithHeadsAdaptersMixin, GPT2PreTrainedModel): +class GPT2DoubleHeadsModel(GPT2ModelWithHeadsAdaptersMixin, GPT2PreTrainedModel): _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] def __init__(self, config): @@ -1357,7 +1360,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> """, GPT2_START_DOCSTRING, ) -class GPT2ForSequenceClassification(ModelWithHeadsAdaptersMixin, GPT2PreTrainedModel): +class GPT2ForSequenceClassification(GPT2ModelWithHeadsAdaptersMixin, GPT2PreTrainedModel): _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] def __init__(self, config): @@ -1485,7 +1488,7 @@ def forward( """, GPT2_START_DOCSTRING, ) -class GPT2ForTokenClassification(ModelWithHeadsAdaptersMixin, GPT2PreTrainedModel): +class GPT2ForTokenClassification(GPT2ModelWithHeadsAdaptersMixin, GPT2PreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 364b89bb6..be896ee64 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -31,8 +31,9 @@ BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin, BartModelAdaptersMixin, + BartModelWithHeadsAdaptersMixin, ) -from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin +from ...adapters.model_mixin import InvertibleAdaptersMixin from ...adapters.prefix_tuning import PrefixTuningShim from ...modeling_outputs import ( BaseModelOutput, @@ -1296,7 +1297,7 @@ def forward( @add_start_docstrings( "The MBART Model with a language modeling head. Can be used for summarization.", MBART_START_DOCSTRING ) -class MBartForConditionalGeneration(ModelWithHeadsAdaptersMixin, MBartPreTrainedModel): +class MBartForConditionalGeneration(BartModelWithHeadsAdaptersMixin, MBartPreTrainedModel): base_model_prefix = "model" _keys_to_ignore_on_load_missing = [ r"final_logits_bias", @@ -1470,7 +1471,7 @@ def _reorder_cache(past, beam_idx): """, MBART_START_DOCSTRING, ) -class MBartForSequenceClassification(ModelWithHeadsAdaptersMixin, MBartPreTrainedModel): +class MBartForSequenceClassification(BartModelWithHeadsAdaptersMixin, MBartPreTrainedModel): def __init__(self, config: MBartConfig, **kwargs): super().__init__(config, **kwargs) self.model = MBartModel(config) @@ -1599,7 +1600,7 @@ def forward( """, MBART_START_DOCSTRING, ) -class MBartForQuestionAnswering(ModelWithHeadsAdaptersMixin, MBartPreTrainedModel): +class MBartForQuestionAnswering(BartModelWithHeadsAdaptersMixin, MBartPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1739,7 +1740,7 @@ def get_input_embeddings(self) -> nn.Module: # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25 -class MBartForCausalLM(ModelWithHeadsAdaptersMixin, MBartPreTrainedModel): +class MBartForCausalLM(BartModelWithHeadsAdaptersMixin, MBartPreTrainedModel): def __init__(self, config): super().__init__(config) decoder_config = copy.deepcopy(config) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index fdd0b16b1..82a4d3154 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -28,8 +28,12 @@ from ...adapters.composition import adjust_tensors_for_parallel from ...adapters.context import ForwardContext from ...adapters.lora import Linear as LoRALinear -from ...adapters.mixins.bert import BertModelAdaptersMixin, BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin -from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin +from ...adapters.mixins.bert import ( + BertModelAdaptersMixin, + BertModelWithHeadsAdaptersMixin, + BertOutputAdaptersMixin, + BertSelfOutputAdaptersMixin, +) from ...adapters.prefix_tuning import PrefixTuningShim from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -900,7 +904,7 @@ def forward( @add_start_docstrings( """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.""", ROBERTA_START_DOCSTRING ) -class RobertaForCausalLM(ModelWithHeadsAdaptersMixin, RobertaPreTrainedModel): +class RobertaForCausalLM(BertModelWithHeadsAdaptersMixin, RobertaPreTrainedModel): _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_unexpected = [r"pooler"] @@ -1055,7 +1059,7 @@ def _reorder_cache(self, past, beam_idx): @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING) -class RobertaForMaskedLM(ModelWithHeadsAdaptersMixin, RobertaPreTrainedModel): +class RobertaForMaskedLM(BertModelWithHeadsAdaptersMixin, RobertaPreTrainedModel): _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_unexpected = [r"pooler"] @@ -1192,7 +1196,7 @@ def _tie_weights(self): """, ROBERTA_START_DOCSTRING, ) -class RobertaForSequenceClassification(ModelWithHeadsAdaptersMixin, RobertaPreTrainedModel): +class RobertaForSequenceClassification(BertModelWithHeadsAdaptersMixin, RobertaPreTrainedModel): _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): @@ -1292,7 +1296,7 @@ def forward( """, ROBERTA_START_DOCSTRING, ) -class RobertaForMultipleChoice(ModelWithHeadsAdaptersMixin, RobertaPreTrainedModel): +class RobertaForMultipleChoice(BertModelWithHeadsAdaptersMixin, RobertaPreTrainedModel): _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): @@ -1385,7 +1389,7 @@ def forward( """, ROBERTA_START_DOCSTRING, ) -class RobertaForTokenClassification(ModelWithHeadsAdaptersMixin, RobertaPreTrainedModel): +class RobertaForTokenClassification(BertModelWithHeadsAdaptersMixin, RobertaPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids"] @@ -1494,7 +1498,7 @@ def forward(self, features, **kwargs): """, ROBERTA_START_DOCSTRING, ) -class RobertaForQuestionAnswering(ModelWithHeadsAdaptersMixin, RobertaPreTrainedModel): +class RobertaForQuestionAnswering(BertModelWithHeadsAdaptersMixin, RobertaPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids"] diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 470392015..5f638e226 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -34,9 +34,10 @@ T5CrossAttentionLayerAdaptersMixin, T5FFLayerAdaptersMixin, T5ModelAdaptersMixin, + T5ModelWithHeadsAdaptersMixin, T5SelfAttentionLayerAdaptersMixin, ) -from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin +from ...adapters.model_mixin import InvertibleAdaptersMixin from ...adapters.prefix_tuning import PrefixTuningShim from ...modeling_outputs import ( BaseModelOutput, @@ -1507,7 +1508,7 @@ def forward( @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) -class T5ForConditionalGeneration(ModelWithHeadsAdaptersMixin, T5ModelAdaptersMixin, T5PreTrainedModel): +class T5ForConditionalGeneration(T5ModelWithHeadsAdaptersMixin, T5ModelAdaptersMixin, T5PreTrainedModel): _keys_to_ignore_on_load_missing = [ r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", diff --git a/tests_adapters/test_adapter_embeddings.py b/tests_adapters/test_adapter_embeddings.py index 3c0c5daf7..d88b7d203 100644 --- a/tests_adapters/test_adapter_embeddings.py +++ b/tests_adapters/test_adapter_embeddings.py @@ -160,3 +160,7 @@ def test_reference_embedding(self): test = test_embedding(input_test) self.assertTrue(torch.equal(default, test)) + + # activate for training + model.add_adapter("test") + model.train_adapter("test", train_embeddings=True) From aa2860b817ba9d00500ae9f1c5eeb1b4e32d34ec Mon Sep 17 00:00:00 2001 From: calpt Date: Thu, 7 Jul 2022 14:26:02 +0200 Subject: [PATCH 3/3] T5 fix --- src/transformers/adapters/mixins/t5.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/adapters/mixins/t5.py b/src/transformers/adapters/mixins/t5.py index 20250a9eb..e00f7b612 100644 --- a/src/transformers/adapters/mixins/t5.py +++ b/src/transformers/adapters/mixins/t5.py @@ -5,7 +5,6 @@ from ..layer import AdapterLayer from ..model_mixin import ( EmbeddingAdaptersMixin, - EmbeddingAdaptersWrapperMixin, InvertibleAdaptersMixin, ModelAdaptersMixin, ModelWithHeadsAdaptersMixin, @@ -53,5 +52,6 @@ def _init_adapter_modules(self): super()._init_adapter_modules() -class T5ModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin): +# EmbeddingAdaptersWrapperMixin not required here as base and heads model are identical +class T5ModelWithHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin): pass