diff --git a/src/transformers/adapters/context.py b/src/transformers/adapters/context.py index 05261351a..2357f552b 100644 --- a/src/transformers/adapters/context.py +++ b/src/transformers/adapters/context.py @@ -78,7 +78,7 @@ class ForwardContext: # thread-local storage that holds a stack of active contexts storage = threading.local() - context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions"] + context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions", "adapter_input_parallelized"] def __init__(self, model, *args, **kwargs): # If the model has a method ``forward_context()``, use it to create the context. diff --git a/src/transformers/adapters/heads/base.py b/src/transformers/adapters/heads/base.py index dcea9bc98..83f5e3295 100644 --- a/src/transformers/adapters/heads/base.py +++ b/src/transformers/adapters/heads/base.py @@ -31,6 +31,10 @@ class MultiHeadOutput(ModelOutput): head_outputs: List[ModelOutput] = None loss: Optional[torch.FloatTensor] = None + @property + def logits(self): + return torch.vstack([outputs["logits"] for outputs in self.head_outputs]) + def __getitem__(self, k): # with number indices the head output at that position is accessed # e.g output[1] is equivalent to output.head_outputs[1] diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index 95b07ac84..9e247c60c 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -779,6 +779,11 @@ def forward_context(self, context: ForwardContext, *args, **kwargs): return context.adapters_parallelized = False + # Check if already parallelized in encoder + adapter_input_parallelized = kwargs.pop("adapter_input_parallelized", None) + if adapter_input_parallelized: + if active_adapters.parallel_channels > 1: + context.adapters_parallelized = True # Add the shared parameters for the active adapters to the context context.shared_parameters = { name: param for name, param in self.shared_parameters.items() if name in active_adapters.flatten() diff --git a/src/transformers/adapters/models/bart/adapter_model.py b/src/transformers/adapters/models/bart/adapter_model.py index a6ce6c827..a696cba8d 100644 --- a/src/transformers/adapters/models/bart/adapter_model.py +++ b/src/transformers/adapters/models/bart/adapter_model.py @@ -89,6 +89,7 @@ def forward( past_key_values=past_key_values, output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), ) # sequence classification based on last token in sequence x = outputs[0] # last hidden state @@ -139,6 +140,7 @@ def prepare_inputs_for_generation( "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + "adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False), } # Copied from BartForConditionalGeneration diff --git a/src/transformers/adapters/models/beit/adapter_model.py b/src/transformers/adapters/models/beit/adapter_model.py index c3c1260f7..bfc702953 100644 --- a/src/transformers/adapters/models/beit/adapter_model.py +++ b/src/transformers/adapters/models/beit/adapter_model.py @@ -48,6 +48,7 @@ def forward( return_dict=return_dict, output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), ) # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads diff --git a/src/transformers/adapters/models/bert/adapter_model.py b/src/transformers/adapters/models/bert/adapter_model.py index 0f6d92143..1a14fdc99 100644 --- a/src/transformers/adapters/models/bert/adapter_model.py +++ b/src/transformers/adapters/models/bert/adapter_model.py @@ -72,6 +72,7 @@ def forward( return_dict=return_dict, output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), ) # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: @@ -94,6 +95,24 @@ def forward( # in case no head is used just return the output of the base model (including pooler output) return outputs + # Copied from BertLMHeadModel + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False), + } + head_types = { "classification": ClassificationHead, "multilabel_classification": MultiLabelClassificationHead, diff --git a/src/transformers/adapters/models/deberta/adapter_model.py b/src/transformers/adapters/models/deberta/adapter_model.py index 49e791be3..a05bf9d37 100644 --- a/src/transformers/adapters/models/deberta/adapter_model.py +++ b/src/transformers/adapters/models/deberta/adapter_model.py @@ -65,6 +65,7 @@ def forward( return_dict=return_dict, output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), ) # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: diff --git a/src/transformers/adapters/models/debertaV2/adapter_model.py b/src/transformers/adapters/models/debertaV2/adapter_model.py index 6c19c44a6..2c5395c74 100644 --- a/src/transformers/adapters/models/debertaV2/adapter_model.py +++ b/src/transformers/adapters/models/debertaV2/adapter_model.py @@ -68,6 +68,7 @@ def forward( return_dict=return_dict, output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), ) # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: diff --git a/src/transformers/adapters/models/distilbert/adapter_model.py b/src/transformers/adapters/models/distilbert/adapter_model.py index 4f8c9fa7b..8bfcd9ff5 100644 --- a/src/transformers/adapters/models/distilbert/adapter_model.py +++ b/src/transformers/adapters/models/distilbert/adapter_model.py @@ -94,6 +94,7 @@ def forward( return_dict=return_dict, output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), ) outputs = self.forward_head( @@ -102,6 +103,24 @@ def forward( return outputs + # Copied from RobertaForCausalLM + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False), + } + head_types = { "classification": ClassificationHead, "multilabel_classification": MultiLabelClassificationHead, diff --git a/src/transformers/adapters/models/gpt2/adapter_model.py b/src/transformers/adapters/models/gpt2/adapter_model.py index 7cd9680a3..f75afd7c1 100644 --- a/src/transformers/adapters/models/gpt2/adapter_model.py +++ b/src/transformers/adapters/models/gpt2/adapter_model.py @@ -81,6 +81,7 @@ def forward( return_dict=return_dict, output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), ) batch_size = outputs[0].shape[0] @@ -139,6 +140,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): "position_ids": position_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, + "adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False), } head_types = { diff --git a/src/transformers/adapters/models/gptj/adapter_model.py b/src/transformers/adapters/models/gptj/adapter_model.py index f281e95f6..7d37f94bd 100644 --- a/src/transformers/adapters/models/gptj/adapter_model.py +++ b/src/transformers/adapters/models/gptj/adapter_model.py @@ -77,6 +77,7 @@ def forward( return_dict=return_dict, output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), ) batch_size = outputs[0].shape[0] @@ -135,6 +136,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): "position_ids": position_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, + "adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False), } head_types = { diff --git a/src/transformers/adapters/models/mbart/adapter_model.py b/src/transformers/adapters/models/mbart/adapter_model.py index ca1a2f3fb..a9b05182d 100644 --- a/src/transformers/adapters/models/mbart/adapter_model.py +++ b/src/transformers/adapters/models/mbart/adapter_model.py @@ -89,6 +89,7 @@ def forward( past_key_values=past_key_values, output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), ) # sequence classification based on last token in sequence x = outputs[0] # last hidden state @@ -139,6 +140,7 @@ def prepare_inputs_for_generation( "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + "adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False), } # Copied from MBartForConditionalGeneration diff --git a/src/transformers/adapters/models/roberta/adapter_model.py b/src/transformers/adapters/models/roberta/adapter_model.py index 70ef64378..a21699d07 100644 --- a/src/transformers/adapters/models/roberta/adapter_model.py +++ b/src/transformers/adapters/models/roberta/adapter_model.py @@ -77,6 +77,7 @@ def forward( return_dict=return_dict, output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), ) # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: @@ -99,6 +100,24 @@ def forward( # in case no head is used just return the output of the base model (including pooler output) return outputs + # Copied from RobertaForCausalLM + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False), + } + head_types = { "classification": ClassificationHead, "multilabel_classification": MultiLabelClassificationHead, diff --git a/src/transformers/adapters/models/t5/adapter_model.py b/src/transformers/adapters/models/t5/adapter_model.py index c47eb36ad..87492c81e 100644 --- a/src/transformers/adapters/models/t5/adapter_model.py +++ b/src/transformers/adapters/models/t5/adapter_model.py @@ -80,6 +80,7 @@ def forward( return_dict=return_dict, output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), ) sequence_output = model_output[0] # ToDo move head to device for parallel forward pass @@ -118,7 +119,6 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs ): - # cut decoder_input_ids if past is used if past is not None: input_ids = input_ids[:, -1:] @@ -132,6 +132,7 @@ def prepare_inputs_for_generation( "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, + "adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False), } # Copied from T5ForConditionalGeneration diff --git a/src/transformers/adapters/models/vit/adapter_model.py b/src/transformers/adapters/models/vit/adapter_model.py index 6a5692dc1..6b85df39e 100644 --- a/src/transformers/adapters/models/vit/adapter_model.py +++ b/src/transformers/adapters/models/vit/adapter_model.py @@ -48,6 +48,7 @@ def forward( return_dict=return_dict, output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), ) # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index f9d8f8808..635c27c8f 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -22,7 +22,7 @@ "fugashi": "fugashi>=1.0", "GitPython": "GitPython<3.1.19", "hf-doc-builder": "hf-doc-builder>=0.3.0", - "huggingface-hub": "huggingface-hub>=0.1.0,<0.8.0", + "huggingface-hub": "huggingface-hub>=0.1.0,<1.0", "importlib_metadata": "importlib_metadata", "ipadic": "ipadic>=1.0.0,<2.0", "isort": "isort>=5.5.4", diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index c3c0d63b7..11b130556 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -23,6 +23,7 @@ import torch.distributed as dist from torch import nn +from .adapters.composition import adjust_tensors_for_parallel from .adapters.context import ForwardContext from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer @@ -1198,6 +1199,17 @@ def generate( # if decoder-only then inputs_tensor has to be `input_ids` input_ids = inputs_tensor + # Pre-replicate inputs for parallel adapters to avoid issues within generation code + if ( + hasattr(self.config, "adapters") + and self.config.adapters.active_setup + and self.config.adapters.active_setup.parallel_channels > 1 + ): + input_ids = input_ids.repeat(self.config.adapters.active_setup.parallel_channels, 1) + model_kwargs["adapter_input_parallelized"] = True + (attention_mask,) = adjust_tensors_for_parallel(input_ids, model_kwargs["attention_mask"]) + model_kwargs["attention_mask"] = attention_mask + # 5. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] if max_length is None and max_new_tokens is None: diff --git a/tests_adapters/test_adapter_composition.py b/tests_adapters/test_adapter_composition.py index a5f0f0a14..ea772bd75 100644 --- a/tests_adapters/test_adapter_composition.py +++ b/tests_adapters/test_adapter_composition.py @@ -15,6 +15,7 @@ Trainer, TrainingArguments, ) +from transformers.adapters import ADAPTER_MODEL_MAPPING from transformers.adapters.composition import BatchSplit, Fuse, Parallel, Split, Stack, parse_composition from transformers.testing_utils import require_torch, torch_device @@ -245,6 +246,33 @@ def test_batch_split_with_heads(self): ) ) + def test_parallel_generate(self): + if self.config_class not in ADAPTER_MODEL_MAPPING or ( + not hasattr(ADAPTER_MODEL_MAPPING[self.config_class], "add_seq2seq_lm_head") + and not hasattr(ADAPTER_MODEL_MAPPING[self.config_class], "add_causal_lm_head") + ): + self.skipTest("No seq2seq or causal language model head") + + model1 = AutoAdapterModel.from_config(self.config()) + model1.add_adapter("adapter1") + model1.add_adapter("adapter2") + if hasattr(model1, "add_seq2seq_lm_head"): + model1.add_seq2seq_lm_head("adapter1") + model1.add_seq2seq_lm_head("adapter2") + else: + model1.add_causal_lm_head("adapter1") + model1.add_causal_lm_head("adapter2") + model1.set_active_adapters(Parallel("adapter1", "adapter2")) + model1.to(torch_device) + + seq_output_length = 32 + + # Finally, also check if generation works properly + input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] + input_ids = input_ids.to(torch_device) + generated = model1.generate(input_ids, max_length=seq_output_length) + self.assertLessEqual(generated.shape, (2, seq_output_length)) + class ParallelTrainingMixin: def create_twin_adapters(self, model, name):