diff --git a/docs/classes/models/mistral.rst b/docs/classes/models/mistral.rst new file mode 100644 index 0000000000..5aee051ff7 --- /dev/null +++ b/docs/classes/models/mistral.rst @@ -0,0 +1,31 @@ +Mistral +----------------------------------------------------------------------------------------------------------------------- + +The Mistral model was proposed in `Mistral 7B `__ by +Albert Q. Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, +Florian Bressand, Gianna Lengyel, Guillaume Lample, Lucile Saulnier, Lélio Renard Lavaud, Marie-Anne Lachaux, +Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed. +It is a foundation language model with 7.3B parameters. + +The abstract from the paper is the following: + +*We introduce Mistral 7B, a 7-billion-parameter language model engineered for +superior performance and efficiency. Mistral 7B outperforms the best open 13B +model (Llama 2) across all evaluated benchmarks, and the best released 34B +model (Llama 1) in reasoning, mathematics, and code generation. Our model +leverages grouped-query attention (GQA) for faster inference, coupled with sliding +window attention (SWA) to effectively handle sequences of arbitrary length with a +reduced inference cost. We also provide a model fine-tuned to follow instructions, +Mistral 7B - Instruct, that surpasses Llama 2 13B - chat model both on human and +automated benchmarks. Our models are released under the Apache 2.0 license.* + +Code: https://github.com/mistralai/mistral-src +Webpage: https://mistral.ai/news/announcing-mistral-7b/ + + +MistralAdapterModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: adapters.MistralAdapterModel + :members: + :inherited-members: MistralPreTrainedModel diff --git a/docs/contributing/adding_adapters_to_a_model.md b/docs/contributing/adding_adapters_to_a_model.md index 0c9164ff22..fbae1d3512 100644 --- a/docs/contributing/adding_adapters_to_a_model.md +++ b/docs/contributing/adding_adapters_to_a_model.md @@ -47,7 +47,7 @@ Now that we have discussed the purpose of every file in `src/adapters/models/AdapterModel` to the `ADAPTER_MODEL_MAPPING_NAMES` mapping in `src/adapters/models/auto/adapter_model.py` and to `src/adapters/__init__.py`. - Define the classes to be added to Python's import structure in `src/adapters/models//__init__.py`. This will likely only be the `AdapterModel`. 6. **Adapt the config classes:** - - Adapt the config class to the requirements of adapters in `src/transformers/adapters/wrappers/configuration.py`. + - Adapt the config class to the requirements of adapters in `src/adapters/wrappers/configuration.py`. - There are some naming differences in the config attributes of different model architectures. The adapter implementation requires some additional attributes with a specific name to be available. These currently are `num_attention_heads`, `hidden_size`, `hidden_dropout_prob` and `attention_probs_dropout_prob` as in the `BertConfig` class. If your model config does not provide these, add corresponding mappings to `CONFIG_CLASS_KEYS_MAPPING`. diff --git a/docs/model_overview.md b/docs/model_overview.md index a5ba7c4e8c..ca4688c93e 100644 --- a/docs/model_overview.md +++ b/docs/model_overview.md @@ -10,28 +10,29 @@ The table below further shows which model architectures support which adaptation E.g., for BERT, this means adapters provides a ``BertAdapterModel`` class, but you can also use ``BertModel``, ``BertForSequenceClassification`` etc. together with adapters. ``` -| Model | (Bottleneck)
Adapters | Prefix
Tuning | LoRA | Compacter | Adapter
Fusion | Invertible
Adapters | Parallel
block | Prompt
Tuning | -| --------------------------------------- | -| - | - | - | - | - | - |- | -| [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | -| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | +| Model | (Bottleneck)
Adapters | Prefix
Tuning | LoRA | Compacter | Adapter
Fusion | Invertible
Adapters | Parallel
block | Prompt
Tuning | +|--------------------------------------------------------| -| - | - | - | - | - | - |- | +| [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | | [BERT-Generation](classes/models/bert-generation.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [CLIP](classes/models/clip.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | -| [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [DistilBERT](classes/models/distilbert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [Electra](classes/models/electra.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | | | -| [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | -| [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | -| [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | -| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | -| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | -| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [CLIP](classes/models/clip.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | +| [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [DistilBERT](classes/models/distilbert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [Electra](classes/models/electra.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | | | +| [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [Mistral](classes/models/mistral.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | (*) If the used encoder and decoder model class are supported. diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index 942617e1d3..9687acba02 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -105,6 +105,7 @@ "models.gptj": ["GPTJAdapterModel"], "models.llama": ["LlamaAdapterModel"], "models.mbart": ["MBartAdapterModel"], + "models.mistral": ["MistralAdapterModel"], "models.roberta": ["RobertaAdapterModel"], "models.t5": ["T5AdapterModel"], "models.vit": ["ViTAdapterModel"], @@ -207,6 +208,7 @@ from .models.gptj import GPTJAdapterModel from .models.llama import LlamaAdapterModel from .models.mbart import MBartAdapterModel + from .models.mistral import MistralAdapterModel from .models.roberta import RobertaAdapterModel from .models.t5 import T5AdapterModel from .models.vit import ViTAdapterModel diff --git a/src/adapters/composition.py b/src/adapters/composition.py index 937fae2685..92bc720f5c 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -134,6 +134,7 @@ def __init__( "xlm-roberta", "bert-generation", "llama", + "mistral", "electra", "xmod", ], diff --git a/src/adapters/head_utils.py b/src/adapters/head_utils.py index 7673857adc..ac58949e92 100644 --- a/src/adapters/head_utils.py +++ b/src/adapters/head_utils.py @@ -612,6 +612,23 @@ }, "layers": ["lm_head"], }, + # Mistral + "MistralForSequenceClassification": { + "config": { + "head_type": "classification", + "layers": 1, + "dropout_prob": 0, + "activation_function": None, + "bias": False, + }, + "layers": [None, "score"], + }, + "MistralForCausalLM": { + "config": { + "head_type": "causal_lm", + }, + "layers": ["lm_head"], + }, "ElectraForTokenClassification": { "config": { "head_type": "tagging", diff --git a/src/adapters/models/__init__.py b/src/adapters/models/__init__.py index dd48552d23..e1314c8bf4 100644 --- a/src/adapters/models/__init__.py +++ b/src/adapters/models/__init__.py @@ -18,6 +18,7 @@ from .gpt2.mixin_gpt2 import GPT2ModelAdapterMixin from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin from .llama.mixin_llama import LlamaModelAdapterMixin +from .mistral.mixin_mistral import MistralModelAdapterMixin from .t5.mixin_t5 import ( T5BlockAdaptersMixin, T5ForCondiditionalGenerationWithHeadsMixin, @@ -78,4 +79,5 @@ "BertGenerationEncoder": BertModelAdaptersMixin, "BertGenerationLayer": BertLayerAdaptersMixin, "LlamaModel": LlamaModelAdapterMixin, + "MistralModel": MistralModelAdapterMixin, } diff --git a/src/adapters/models/auto/adapter_model.py b/src/adapters/models/auto/adapter_model.py index 5ff84de483..a168ee3177 100644 --- a/src/adapters/models/auto/adapter_model.py +++ b/src/adapters/models/auto/adapter_model.py @@ -23,6 +23,7 @@ ("gptj", "GPTJAdapterModel"), ("llama", "LlamaAdapterModel"), ("mbart", "MBartAdapterModel"), + ("mistral", "MistralAdapterModel"), ("roberta", "RobertaAdapterModel"), ("t5", "T5AdapterModel"), ("vit", "ViTAdapterModel"), diff --git a/src/adapters/models/mistral/__init__.py b/src/adapters/models/mistral/__init__.py new file mode 100644 index 0000000000..912e5efe5d --- /dev/null +++ b/src/adapters/models/mistral/__init__.py @@ -0,0 +1,39 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The Adapter-Hub Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "adapter_model": ["MistralAdapterModel"], +} + + +if TYPE_CHECKING: + from .adapter_model import MistralAdapterModel + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + ) diff --git a/src/adapters/models/mistral/adapter_model.py b/src/adapters/models/mistral/adapter_model.py new file mode 100644 index 0000000000..3b4ab30c54 --- /dev/null +++ b/src/adapters/models/mistral/adapter_model.py @@ -0,0 +1,214 @@ +import logging + +import torch + +from transformers.models.mistral.modeling_mistral import MISTRAL_START_DOCSTRING, MistralModel, MistralPreTrainedModel +from transformers.utils import add_start_docstrings + +from ...composition import adjust_tensors_for_parallel +from ...heads import ( + CausalLMHead, + ClassificationHead, + ModelWithFlexibleHeadsAdaptersMixin, + MultiLabelClassificationHead, + TaggingHead, +) +from ...model_mixin import EmbeddingAdaptersWrapperMixin +from ...wrappers import init + + +logger = logging.getLogger(__name__) + + +@add_start_docstrings( + """ +The Mistal Model that allows the loading of different heads dor different tasks. This enables a flexible use of the +models and adpters. Since this class does classification on the last token, it requires to know the position of the +last token. If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding +token in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since +it cannot guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same +(take the last value in each row of the batch). +""", + MISTRAL_START_DOCSTRING, +) +class MistralAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MistralPreTrainedModel): + _tied_weights_keys = [] # needs to be empty since Mistral does not yet support prompt tuning + + def __init__(self, config): + super().__init__(config) + self.model = MistralModel(config) + init(self.model) + + self._init_head_modules() + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + self.post_init() + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, + **kwargs + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs, context = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + return_dict=return_dict, + output_hidden_states=output_hidden_states, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, + ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context + + batch_size = outputs[0].shape[0] + + if self.config.pad_token_id is None: + # TODO-AH: this may result in unexpected behavior for classification. Find a better way to do this? + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + (sequence_lengths,) = adjust_tensors_for_parallel(outputs[0], sequence_lengths) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + cls_logits = outputs[0][range(batch_size), sequence_lengths] + + outputs = self.forward_head( + outputs, + head_name=head, + cls_output=cls_logits, + attention_mask=attention_mask, + return_dict=return_dict, + **kwargs, + ) + + return outputs + + head_types = { + "causal_lm": CausalLMHead, + "tagging": TaggingHead, + "classification": ClassificationHead, + } + + def add_causal_lm_head(self, head_name, overwrite_ok=False): + """ + Adds a causal language modeling head on top of the model. + + Args: + head_name (str): The name of the head. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = CausalLMHead(self, head_name) + self.add_prediction_head(head, overwrite_ok=overwrite_ok) + + def add_classification_head( + self, + head_name, + num_labels=2, + layers=2, + activation_function="tanh", + overwrite_ok=False, + multilabel=False, + id2label=None, + use_pooler=False, + dropout_prob=0, + ): + """ + Adds a sequence classification head on top of the model. + + Args: + head_name (str): The name of the head. + num_labels (int, optional): Number of classification labels. Defaults to 2. + layers (int, optional): Number of layers. Defaults to 2. + activation_function (str, optional): Activation function. Defaults to 'tanh'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + multilabel (bool, optional): Enable multilabel classification setup. Defaults to False. + """ + + if multilabel: + head = MultiLabelClassificationHead( + self, + head_name, + num_labels, + layers, + activation_function, + id2label, + use_pooler, + dropout_prob=dropout_prob, + ) + else: + head = ClassificationHead( + self, + head_name, + num_labels, + layers, + activation_function, + id2label, + use_pooler, + dropout_prob=dropout_prob, + ) + self.add_prediction_head(head, overwrite_ok) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False), + } + ) + return model_inputs diff --git a/src/adapters/models/mistral/mixin_mistral.py b/src/adapters/models/mistral/mixin_mistral.py new file mode 100644 index 0000000000..9acd17995b --- /dev/null +++ b/src/adapters/models/mistral/mixin_mistral.py @@ -0,0 +1,46 @@ +from typing import Iterable, Tuple + +import torch.nn as nn + +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer +from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin + + +class MistralAttentionMixin: + def init_adapters(self, model_config, adapters_config): + self.q_proj = LoRALinear.wrap(self.q_proj, "selfattn", model_config, adapters_config, attn_key="q") + self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k") + self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v") + + self.prefix_tuning = PrefixTuningLayer("self_prefix", model_config, adapters_config) + + +class MistralDecoderLayerMixin: + def init_adapters(self, model_config, adapters_config): + # Wrap layers for LoRA + self.mlp.down_proj = LoRALinear.wrap(self.mlp.down_proj, "intermediate", model_config, adapters_config) + self.mlp.up_proj = LoRALinear.wrap(self.mlp.up_proj, "output", model_config, adapters_config) + + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") + + +class MistralModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): + support_prompt_tuning = False + + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + + # Register hook for post embedding forward + self.embed_tokens.register_forward_hook(self.post_embedding_forward) + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + for i, layer in enumerate(self.layers): + yield i, layer + + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output diff --git a/src/adapters/models/mistral/modeling_mistral.py b/src/adapters/models/mistral/modeling_mistral.py new file mode 100644 index 0000000000..1ab5c269e4 --- /dev/null +++ b/src/adapters/models/mistral/modeling_mistral.py @@ -0,0 +1,179 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Mistral model.""" +import math +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from adapters.composition import ( + adjust_tensors_for_parallel, + adjust_tensors_for_parallel_, + match_attn_matrices_for_parallel, +) +from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb +from transformers.utils import logging + +from .mixin_mistral import MistralAttentionMixin, MistralDecoderLayerMixin + + +logger = logging.get_logger(__name__) + + +class MistralAttentionWithAdapters(nn.Module, MistralAttentionMixin): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + # Make adjustments since (parallel) prefix tuning changes the attention mask + kv_seq_len = key_states.shape[-2] + bsz = attention_mask.shape[0] + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MistralDecoderLayerWithAdapters(nn.Module, MistralDecoderLayerMixin): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + adjust_tensors_for_parallel_(hidden_states, attention_mask, position_ids) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = self.attention_adapters(hidden_states, residual, None) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.output_adapters(hidden_states, residual, None) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py new file mode 100644 index 0000000000..be66648c37 --- /dev/null +++ b/tests/models/test_mistral.py @@ -0,0 +1,12 @@ +# flake8: noqa: F403,F405 +from adapters import MistralAdapterModel +from hf_transformers.tests.models.mistral.test_modeling_mistral import * +from transformers.testing_utils import require_torch + +from .base import AdapterModelTesterMixin + + +@require_torch +class MistralAdapterModelTest(AdapterModelTesterMixin, MistralModelTest): + all_model_classes = (MistralAdapterModel,) + fx_compatible = False diff --git a/tests/test_mistral.py b/tests/test_mistral.py new file mode 100644 index 0000000000..8c79c755d0 --- /dev/null +++ b/tests/test_mistral.py @@ -0,0 +1,64 @@ +import unittest + +from transformers.models.mistral.configuration_mistral import MistralConfig +from transformers.testing_utils import require_torch + +from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, +) +from .test_adapter import AdapterTestBase, make_config +from .test_adapter_backward_compability import CompabilityTestMixin +from .test_adapter_conversion import ModelClassConversionTestMixin +from .test_adapter_embeddings import EmbeddingTestMixin +from .test_adapter_fusion_common import AdapterFusionModelTestMixin +from .test_adapter_heads import PredictionHeadModelTestMixin + + +class MistralAdapterTestBase(AdapterTestBase): + config_class = MistralConfig + config = make_config( + MistralConfig, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=8, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + pad_token_id=0, + ) + tokenizer_name = "mistralai/Mistral-7B-v0.1" + + +@require_torch +class MistralAdapterTest( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, + EmbeddingTestMixin, + AdapterFusionModelTestMixin, + CompabilityTestMixin, + PredictionHeadModelTestMixin, + ParallelAdapterInferenceTestMixin, + ParallelTrainingMixin, + MistralAdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class BertClassConversionTest( + ModelClassConversionTestMixin, + MistralAdapterTestBase, + unittest.TestCase, +): + pass