diff --git a/README.md b/README.md index b96f307547..92b8d1622b 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Thus, most files in this repository are direct copies from the HuggingFace Trans ## Installation -`adapter-transformers` currently supports **Python 3.6+** and **PyTorch 1.3.1+**. +`adapter-transformers` currently supports **Python 3.7+** and **PyTorch 1.3.1+**. After [installing PyTorch](https://pytorch.org/get-started/locally/), you can install `adapter-transformers` from PyPI ... ``` @@ -74,10 +74,11 @@ Currently, adapter-transformers integrates all architectures and methods listed | AdapterDrop | [Rücklé et al. (2021)](https://arxiv.org/pdf/2010.11918.pdf) | [Notebook](https://colab.research.google.com/github/Adapter-Hub/adapter-transformers/blob/master/notebooks/05_Adapter_Drop_Training.ipynb) | | MAD-X 2.0,
Embedding training | [Pfeiffer et al. (2021)](https://arxiv.org/pdf/2012.15562.pdf) | [Docs: Embeddings](https://docs.adapterhub.ml/embeddings.html), [Notebook](https://colab.research.google.com/github/Adapter-Hub/adapter-transformers/blob/master/notebooks/08_NER_Wikiann.ipynb) | | Prefix Tuning | [Li and Liang (2021)](https://arxiv.org/pdf/2101.00190.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#prefix-tuning) | -| Parallel adapters,
Mix-and-Match adapters | [He et al. (2021)](https://arxiv.org/pdf/2110.04366.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#combinations-mix-and-match-adapters) | +| Parallel adapters,
Mix-and-Match adapters | [He et al. (2021)](https://arxiv.org/pdf/2110.04366.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#mix-and-match-adapters) | | Compacter | [Mahabadi et al. (2021)](https://arxiv.org/pdf/2106.04647.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#compacter) | | LoRA | [Hu et al. (2021)](https://arxiv.org/pdf/2106.09685.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#lora) | -| (IA)^3 | [Liu et al. (2022)](https://arxiv.org/pdf/2205.05638.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#ia3) | +| (IA)^3 | [Liu et al. (2022)](https://arxiv.org/pdf/2205.05638.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#ia-3) | +| UniPELT | [Mao et al. (2022)](https://arxiv.org/pdf/2110.07577.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#unipelt) | ## Supported Models diff --git a/adapter_docs/adapter_composition.md b/adapter_docs/adapter_composition.md index 2bd472658c..906c504436 100644 --- a/adapter_docs/adapter_composition.md +++ b/adapter_docs/adapter_composition.md @@ -109,6 +109,32 @@ To learn how training an _AdapterFusion_ layer works, check out [this Colab note In v1.x of `adapter-transformers`, fusing adapters was done using a nested list of adapter names, i.e. the example from above would be defined as `[["d", "e", "f"]]`. For backwards compatibility, you can still do this, although it is recommended to use the new syntax. +#### Retrieving AdapterFusion attentions + +Finally, it is possible to retrieve the attention scores computed by each fusion layer in a forward pass of the model. +These scores can be used for analyzing the fused adapter blocks and can serve as the basis for visualizations similar to those in the AdapterFusion paper. +You can collect the fusion attention scores by passing `output_adapter_fusion_attentions=True` to the model forward call. +The scores for each layer will then be saved in the `adapter_fusion_attentions` attribute of the output: + +```python +outputs = model(**inputs, output_adapter_fusion_attentions=True) +attention_scores = outputs.adapter_fusion_attentions +``` +Note that this parameter is only available to base model classes and [AdapterModel classes](prediction_heads.md#adaptermodel-classes). +In the example, `attention_scores` holds a dictionary of the following form: +``` +{ + '': { + : { + '': np.array([...]), + ... + }, + ... + }, + ... +} +``` + ## `Split` ```{eval-rst} diff --git a/adapter_docs/classes/adapter_config.rst b/adapter_docs/classes/adapter_config.rst index 75da8973c5..e08be42807 100644 --- a/adapter_docs/classes/adapter_config.rst +++ b/adapter_docs/classes/adapter_config.rst @@ -65,6 +65,9 @@ Combined configurations .. autoclass:: transformers.MAMConfig :members: +.. autoclass:: transformers.UniPELTConfig + :members: + Adapter Fusion ~~~~~~~~~~~~~~~ diff --git a/adapter_docs/conf.py b/adapter_docs/conf.py index 3f1a200632..a24f18b66d 100644 --- a/adapter_docs/conf.py +++ b/adapter_docs/conf.py @@ -26,7 +26,7 @@ docs_versions = [ "adapters1.1.1", "adapters2.3.0", - "adapters3.0.1", + "adapters3.1.0", ] diff --git a/adapter_docs/img/compacter.png b/adapter_docs/img/compacter.png new file mode 100644 index 0000000000..a7c2786e96 Binary files /dev/null and b/adapter_docs/img/compacter.png differ diff --git a/adapter_docs/img/ia3.png b/adapter_docs/img/ia3.png new file mode 100644 index 0000000000..f335d132be Binary files /dev/null and b/adapter_docs/img/ia3.png differ diff --git a/adapter_docs/img/lora.png b/adapter_docs/img/lora.png new file mode 100644 index 0000000000..310420fb10 Binary files /dev/null and b/adapter_docs/img/lora.png differ diff --git a/adapter_docs/img/prefix.png b/adapter_docs/img/prefix.png new file mode 100644 index 0000000000..59d2909933 Binary files /dev/null and b/adapter_docs/img/prefix.png differ diff --git a/adapter_docs/img/unipelt.png b/adapter_docs/img/unipelt.png new file mode 100644 index 0000000000..110a19ead9 Binary files /dev/null and b/adapter_docs/img/unipelt.png differ diff --git a/adapter_docs/installation.md b/adapter_docs/installation.md index 5364bcee95..ac6c72dfa0 100644 --- a/adapter_docs/installation.md +++ b/adapter_docs/installation.md @@ -1,7 +1,7 @@ # Installation Our *adapter-transformers* package is a drop-in replacement for Huggingface's *transformers* library. -It currently supports Python 3.6+ and PyTorch 1.3.1+. You will have to [install PyTorch](https://pytorch.org/get-started/locally/) first. +It currently supports Python 3.7+ and PyTorch 1.3.1+. You will have to [install PyTorch](https://pytorch.org/get-started/locally/) first. ```{eval-rst} .. important:: diff --git a/adapter_docs/overview.md b/adapter_docs/overview.md index ae0e3e04b1..8ba89a8fc0 100644 --- a/adapter_docs/overview.md +++ b/adapter_docs/overview.md @@ -130,6 +130,15 @@ _Papers:_ _Configuration class_: [`PrefixTuningConfig`](transformers.PrefixTuningConfig) +```{eval-rst} +.. figure:: img/prefix.png + :height: 300 + :align: center + :alt: Illustration of Prefix Tuning. + + Illustration of the Prefix Tuning method within one Transformer layer. Trained components are colored in shades of magenta. +``` + Prefix Tuning ([Li and Liang, 2021](https://aclanthology.org/2021.acl-long.353.pdf)) introduces new parameters in the multi-head attention blocks in each Transformer layer. More, specifically, it prepends trainable prefix vectors $P^K$ and $P^V$ to the keys and values of the attention head input, each of a configurable prefix length $l$ (`prefix_length` attribute): @@ -162,6 +171,15 @@ _Papers:_ _Configuration class_: [`CompacterConfig`](transformers.CompacterConfig), [`CompacterPlusPlusConfig`](transformers.CompacterPlusPlusConfig) +```{eval-rst} +.. figure:: img/compacter.png + :height: 300 + :align: center + :alt: Illustration of Compacter. + + Illustration of the Compacter method within one Transformer layer. Trained components are colored in shades of magenta. +``` + The Compacter architecture proposed by [Mahabadi et al., 2021](https://arxiv.org/pdf/2106.04647.pdf) is similar to the bottleneck adapter architecture. It only exchanges the linear down- and up-projection with a PHM layer. Unlike the linear layer, the PHM layer constructs its weight matrix from two smaller matrices, which reduces the number of parameters. @@ -187,6 +205,15 @@ _Papers:_ _Configuration class_: [`LoRAConfig`](transformers.LoRAConfig) +```{eval-rst} +.. figure:: img/lora.png + :height: 300 + :align: center + :alt: Illustration of LoRA. + + Illustration of the LoRA method within one Transformer layer. Trained components are colored in shades of magenta. +``` + Low-Rank Adaptation (LoRA) is an efficient fine-tuning technique proposed by [Hu et al. (2021)](https://arxiv.org/pdf/2106.09685.pdf). LoRA injects trainable low-rank decomposition matrices into the layers of a pre-trained model. For any model layer expressed as a matrix multiplication of the form $h = W_0 x$, it therefore performs a reparameterization, such that: @@ -229,6 +256,15 @@ _Papers:_ _Configuration class_: [`IA3Config`](transformers.IA3Config) +```{eval-rst} +.. figure:: img/ia3.png + :height: 300 + :align: center + :alt: Illustration of (IA)^3. + + Illustration of the (IA)^3 method within one Transformer layer. Trained components are colored in shades of magenta. +``` + _Infused Adapter by Inhibiting and Amplifying Inner Activations ((IA)^3)_ is an efficient fine-tuning method proposed within the _T-Few_ fine-tuning approach by [Liu et al. (2022)](https://arxiv.org/pdf/2205.05638.pdf). (IA)^3 introduces trainable vectors $l_W$ into different components of a Transformer model which perform element-wise rescaling of inner model activations. For any model layer expressed as a matrix multiplication of the form $h = W x$, it therefore performs an element-wise multiplication with $l_W$, such that: @@ -271,7 +307,7 @@ model.reset_adapter("ia3_adapter") _Papers:_ - [Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning](https://arxiv.org/pdf/2205.05638.pdf) (Liu et al., 2022) -## Combinations - Mix-and-Match Adapters +## Method Combinations _Configuration class_: [`ConfigUnion`](transformers.ConfigUnion) @@ -290,6 +326,10 @@ config = ConfigUnion( model.add_adapter("union_adapter", config=config) ``` +### Mix-and-Match Adapters + +_Configuration class_: [`MAMConfig`](transformers.MAMConfig) + [He et al. (2021)](https://arxiv.org/pdf/2110.04366.pdf) study various variants and combinations of efficient fine-tuning methods. Among others, they propose _Mix-and-Match Adapters_ as a combination of Prefix Tuning and parallel bottleneck adapters. This configuration is supported by adapter-transformers out-of-the-box: @@ -315,3 +355,79 @@ model.add_adapter("mam_adapter", config=config) _Papers:_ - [Towards a Unified View of Parameter-Efficient Transfer Learning](https://arxiv.org/pdf/2110.04366.pdf) (He et al., 2021) + +### UniPELT + +_Configuration class_: [`UniPELTConfig`](transformers.UniPELTConfig) + +```{eval-rst} +.. figure:: img/unipelt.png + :height: 300 + :align: center + :alt: Illustration of UniPELT. + + Illustration of the UniPELT method within one Transformer layer. Trained components are colored in shades of magenta. +``` + +An approach similar to the work of [He et al. (2021)](https://arxiv.org/pdf/2110.04366.pdf) is taken by [Mao et al. (2022)](https://arxiv.org/pdf/2110.07577.pdf) in their _UniPELT_ framework. +They, too, combine multiple efficient fine-tuning methods, namely LoRA, Prefix Tuning and bottleneck adapters, in a single unified setup. +_UniPELT_ additionally introduces a gating mechanism that controls the activation of the different submodules. + +Concretely, for each adapted module $m$, UniPELT adds a trainable gating value $\mathcal{G}_m \in (0, 1)$ that is computed via a feed-forward network ($W_{\mathcal{G}_m}$) and sigmoid activation ($\sigma$) from the Transformer layer input states ($x$): + +$$\mathcal{G}_m \leftarrow \sigma(W_{\mathcal{G}_m} \cdot x)$$ + +These gating values are then used to scale the output activations of the injected adapter modules, e.g. for a LoRA layer: + +$$ +h \leftarrow W_0 x + \mathcal{G}_{LoRA} B A x +$$ + +In the configuration classes of `adapter-transformers`, these gating mechanisms can be activated via `use_gating=True`. +The full UniPELT setup can be instantiated using `UniPELTConfig`[^unipelt]: + +[^unipelt]: Note that the implementation of UniPELT in `adapter-transformers` follows the implementation in the original code, which is slighlty different from the description in the paper. See [here](https://github.com/morningmoni/UniPELT/issues/1) for more. + +```python +from transformers.adapters import UniPELTConfig + +config = UniPELTConfig() +model.add_adapter("unipelt", config=config) +``` + +which is identical to the following `ConfigUnion`: + +```python +from transformers.adapters import ConfigUnion, LoRAConfig, PrefixTuningConfig, PfeifferConfig + +config = ConfigUnion( + LoRAConfig(r=8, use_gating=True), + PrefixTuningConfig(prefix_length=10, use_gating=True), + PfeifferConfig(reduction_factor=16, use_gating=True), +) +model.add_adapter("unipelt", config=config) +``` + +Finally, as the gating values for each adapter module might provide interesting insights for analysis, `adapter-transformers` comes with an integrated mechanism of returning all gating values computed during a model forward pass via the `output_adapter_gating_scores` parameter: + +```python +outputs = model(**inputs, output_adapter_gating_scores=True) +gating_scores = outputs.adapter_gating_scores +``` +Note that this parameter is only available to base model classes and [AdapterModel classes](prediction_heads.md#adaptermodel-classes). +In the example, `gating_scores` holds a dictionary of the following form: +``` +{ + '': { + : { + '': np.array([...]), + ... + }, + ... + }, + ... +} +``` + +_Papers:_ +- [UNIPELT: A Unified Framework for Parameter-Efficient Language Model Tuning](https://arxiv.org/pdf/2110.07577.pdf) (Mao et al., 2022) diff --git a/setup.py b/setup.py index 5f46998a2c..7ba7d41b29 100644 --- a/setup.py +++ b/setup.py @@ -116,7 +116,7 @@ "fugashi>=1.0", "GitPython<3.1.19", "hf-doc-builder>=0.3.0", - "huggingface-hub>=0.1.0,<0.8.0", + "huggingface-hub>=0.1.0,<1.0", "importlib_metadata", "ipadic>=1.0.0,<2.0", "isort>=5.5.4", @@ -417,7 +417,7 @@ def run(self): setup( name="adapter-transformers", - version="3.1.0a1", + version="3.1.0", author="Jonas Pfeiffer, Andreas Rücklé, Clifton Poth, Hannah Sterz, based on work by the HuggingFace team and community", author_email="pfeiffer@ukp.tu-darmstadt.de", description="A friendly fork of HuggingFace's Transformers, adding Adapters to PyTorch language models", diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d03da16858..f28c722ba5 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -22,7 +22,7 @@ # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names # in the namespace without actually importing anything (and especially none of the backends). -__version__ = "4.21.2" +__version__ = "4.21.3" from typing import TYPE_CHECKING @@ -2063,6 +2063,7 @@ "StaticAdapterFusionConfig", "T5AdapterModel", "T5ModelWithHeads", + "UniPELTConfig", "ViTAdapterModel", "XLMRobertaAdapterModel", "XLMRobertaModelWithHeads", @@ -4602,6 +4603,7 @@ StaticAdapterFusionConfig, T5AdapterModel, T5ModelWithHeads, + UniPELTConfig, ViTAdapterModel, XLMRobertaAdapterModel, XLMRobertaModelWithHeads, diff --git a/src/transformers/adapters/__init__.py b/src/transformers/adapters/__init__.py index a4327fc981..9df65b7ff3 100644 --- a/src/transformers/adapters/__init__.py +++ b/src/transformers/adapters/__init__.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "3.1.0a1" +__version__ = "3.1.0" from typing import TYPE_CHECKING @@ -57,6 +57,7 @@ "PfeifferInvConfig", "PrefixTuningConfig", "StaticAdapterFusionConfig", + "UniPELTConfig", ], "context": [ "AdapterSetup", @@ -176,6 +177,7 @@ PfeifferInvConfig, PrefixTuningConfig, StaticAdapterFusionConfig, + UniPELTConfig, ) from .context import AdapterSetup, ForwardContext from .heads import ( diff --git a/src/transformers/adapters/configuration.py b/src/transformers/adapters/configuration.py index 16013d9350..d1179d28e9 100644 --- a/src/transformers/adapters/configuration.py +++ b/src/transformers/adapters/configuration.py @@ -162,6 +162,9 @@ class AdapterConfig(AdapterConfigBase): Scaling factor to use for scaled addition of adapter outputs as done by He et al. (2021). Can bei either a constant factor (float) or the string "learned", in which case the scaling factor is learned. Defaults to 1.0. + use_gating (:obj:`bool`, optional): + Place a trainable gating module besides the added parameter module to control module activation. This is + e.g. used for UniPELT. Defaults to False. residual_before_ln (:obj:`bool`, optional): If True, take the residual connection around the adapter bottleneck before the layer normalization. Only applicable if :obj:`original_ln_before` is True. @@ -224,6 +227,7 @@ class AdapterConfig(AdapterConfigBase): init_weights: str = "bert" is_parallel: bool = False scaling: Union[float, str] = 1.0 + use_gating: bool = False residual_before_ln: bool = True adapter_residual_before_ln: bool = False inv_adapter: Optional[str] = None @@ -362,6 +366,12 @@ class PrefixTuningConfig(AdapterConfigBase): non_linearity (str): If flat=False, the non-linearity used in the bottleneck MLP. dropout (float): The dropout rate used in the prefix tuning layer. leave_out (List[int]): The IDs of the layers (starting at 0) where NO prefix should be added. + use_gating (:obj:`bool`, optional): + Place a trainable gating module besides the added parameter module to control module activation. This is + e.g. used for UniPELT. Defaults to False. + shared_gating (: + obj:`bool`, optional): Whether to use a shared gate for the prefixes of all attention matrices. Only + applicable if `use_gating=True`. Defaults to True. """ architecture: Optional[str] = "prefix_tuning" @@ -375,6 +385,8 @@ class PrefixTuningConfig(AdapterConfigBase): bottleneck_size: int = 512 non_linearity: str = "tanh" dropout: float = 0.0 + use_gating: bool = False + shared_gating: bool = True @dataclass(eq=False) @@ -402,6 +414,10 @@ class LoRAConfig(AdapterConfigBase): (IA)^3). "scale" can only be used together with r=1. Defaults to "add". init_weights (:obj:`str`, optional): Initialization method for the weights of the LoRA modules. Currently, this can be either "lora" (default) or "bert". + use_gating (:obj:`bool`, optional): + Place a trainable gating module besides the added parameter module to control module activation. This is + e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using + `merge_adapter()`. """ architecture: Optional[str] = "lora" @@ -416,6 +432,7 @@ class LoRAConfig(AdapterConfigBase): attn_matrices: List[str] = field(default_factory=lambda: ["q", "v"]) composition_mode: str = "add" init_weights: str = "lora" + use_gating: bool = False @dataclass(eq=False) @@ -427,8 +444,8 @@ class IA3Config(LoRAConfig): """ selfattn_lora: bool = True - intermediate_lora: bool = False - output_lora: bool = True + intermediate_lora: bool = True + output_lora: bool = False r: int = 1 alpha: int = 1 @@ -436,6 +453,7 @@ class IA3Config(LoRAConfig): attn_matrices: List[str] = field(default_factory=lambda: ["k", "v"]) composition_mode: str = "scale" init_weights: str = "ia3" + use_gating: bool = False class ConfigUnion(AdapterConfigBase): @@ -548,6 +566,26 @@ def adapter(self): return self[1] +class UniPELTConfig(ConfigUnion): + """ + The UniPELT adapter architecture proposed by Mao et al. (2022). See https://arxiv.org/pdf/2110.07577.pdf. + """ + + def __init__( + self, + prefix_tuning: Optional[PrefixTuningConfig] = None, + adapter: Optional[AdapterConfig] = None, + lora: Optional[LoRAConfig] = None, + ): + components = [ + prefix_tuning or PrefixTuningConfig(prefix_length=10), + adapter or PfeifferConfig(reduction_factor=16), + lora or LoRAConfig(r=8), + ] + + super().__init__(*[c.replace(use_gating=True) for c in components]) + + ADAPTER_CONFIG_MAP = { "pfeiffer": PfeifferConfig(), "houlsby": HoulsbyConfig(), @@ -562,6 +600,7 @@ def adapter(self): "lora": LoRAConfig(), "ia3": IA3Config(), "mam": MAMConfig(), + "unipelt": UniPELTConfig(), } DEFAULT_ADAPTER_CONFIG = "pfeiffer" diff --git a/src/transformers/adapters/context.py b/src/transformers/adapters/context.py index 6503487ccd..05261351ac 100644 --- a/src/transformers/adapters/context.py +++ b/src/transformers/adapters/context.py @@ -78,6 +78,8 @@ class ForwardContext: # thread-local storage that holds a stack of active contexts storage = threading.local() + context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions"] + def __init__(self, model, *args, **kwargs): # If the model has a method ``forward_context()``, use it to create the context. if hasattr(model, "forward_context"): @@ -99,8 +101,21 @@ def wrap(cls, f): @functools.wraps(f) def wrapper_func(self, *args, **kwargs): if self.config.adapters is not None: - with cls(self, *args, **kwargs): + with cls(self, *args, **kwargs) as ctx: + kwargs = { + k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes + } results = f(self, *args, **kwargs) + + # append output attributes + if isinstance(results, tuple): + for attr in cls.context_attributes: + if getattr(ctx, "output_" + attr, False): + results = results + (dict(getattr(ctx, attr)),) + else: + for attr in cls.context_attributes: + if getattr(ctx, "output_" + attr, False): + results[attr] = dict(getattr(ctx, attr)) return results else: return f(self, *args, **kwargs) diff --git a/src/transformers/adapters/heads/base.py b/src/transformers/adapters/heads/base.py index 25df781f91..dcea9bc984 100644 --- a/src/transformers/adapters/heads/base.py +++ b/src/transformers/adapters/heads/base.py @@ -18,7 +18,7 @@ ) from ...utils import ModelOutput from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, parse_heads_from_composition -from ..context import AdapterSetup +from ..context import AdapterSetup, ForwardContext from ..model_mixin import ModelWithHeadsAdaptersMixin from ..modeling import Activation_Function_Class @@ -790,7 +790,7 @@ def _get_head_input(outputs, cls_out, batch): if all("loss" in out and out["loss"] is not None for out in head_outputs) else None ) - return MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss) + return_output = MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss) elif self.has_parallel_adapters or isinstance(self.active_head, Parallel): if len(self.active_head) != self.config.adapters.active_setup.parallel_channels: raise ValueError("The number of parallel adapters and the number of active heads must match.") @@ -807,16 +807,22 @@ def _get_head_input(outputs, cls_out, batch): if all("loss" in out and out["loss"] is not None for out in head_outputs) else None ) - return MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss) + return_output = MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss) elif len(used_heads) > 1: head_outputs = [] for head in used_heads: head_module = self.heads[head] head_outputs.append(head_module(all_outputs, cls_output, attention_mask, return_dict, **kwargs)) - return head_outputs + return_output = MultiHeadOutput(head_outputs=head_outputs) else: head_module = self.heads[used_heads[0]] - return head_module(all_outputs, cls_output, attention_mask, return_dict, **kwargs) + return_output = head_module(all_outputs, cls_output, attention_mask, return_dict, **kwargs) + + if isinstance(return_output, ModelOutput): + for attr in ForwardContext.context_attributes: + if attr not in return_output and attr in all_outputs: + return_output[attr] = all_outputs[attr] + return return_output def get_labels_dict(self, head_name=None): """ diff --git a/src/transformers/adapters/layer.py b/src/transformers/adapters/layer.py index 4499f71e43..63a3c81013 100644 --- a/src/transformers/adapters/layer.py +++ b/src/transformers/adapters/layer.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import List, Mapping, Union +import numpy as np import torch from torch import nn @@ -43,6 +44,31 @@ def get_active_setup(self, module_dict): else: return None + def _store_gating_score(self, adapter_name, gating_score): + context = ForwardContext.get_context() + if context.output_adapter_gating_scores: + gating_cache = context.adapter_gating_scores + if self.layer_idx not in gating_cache[adapter_name]: + gating_cache[adapter_name][self.layer_idx] = {} + gating_score = gating_score.detach().squeeze().cpu().numpy() + if len(gating_score.shape) == 0: + gating_score = np.expand_dims(gating_score, axis=0) + cache_score = gating_cache[adapter_name][self.layer_idx].get(self.location_key, None) + if cache_score is not None: + gating_cache[adapter_name][self.layer_idx][self.location_key] = np.column_stack( + (cache_score, gating_score) + ) + else: + gating_cache[adapter_name][self.layer_idx][self.location_key] = gating_score + + def _store_fusion_attentions(self, fusion_name, attentions): + context = ForwardContext.get_context() + if context.output_adapter_fusion_attentions: + attention_cache = context.adapter_fusion_attentions + if self.layer_idx not in attention_cache[fusion_name]: + attention_cache[fusion_name][self.layer_idx] = {} + attention_cache[fusion_name][self.layer_idx][self.location_key] = attentions + @abstractmethod def add_adapter(self, adapter_name: str, layer_idx: int): raise NotImplementedError() @@ -202,7 +228,12 @@ def adapter_stack(self, adapter_setup: Stack, hidden_states, input_tensor, layer elif adapter_stack_layer in self.adapters: adapter_layer = self.adapters[adapter_stack_layer] hidden_states, _, residual = adapter_layer.pre_forward(hidden_states, input_tensor, layer_norm) - hidden_states, _, up = adapter_layer(hidden_states, residual_input=residual) + context = ForwardContext.get_context() + layer_output = adapter_layer( + hidden_states, residual_input=residual, output_gating=context.output_adapter_gating_scores + ) + hidden_states, up = layer_output[0], layer_output[2] + self._store_gating_score(adapter_stack_layer, layer_output[-1]) # as this stack might be part of a fusion block, return the adapter up-projection output here # together with the final output (with potential residuals & norms) if we reached the last block of the stack if i == len(adapter_setup) - 1: @@ -217,6 +248,8 @@ def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, layer """ Performs adapter fusion with the given adapters for the given input. """ + context = ForwardContext.get_context() + # config of _last_ fused adapter is significant fusion_config = self.config.adapters.get_fusion(adapter_setup.name) last_adapter = self.adapters[adapter_setup.last()] @@ -235,7 +268,11 @@ def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, layer # Case 2: We have a single adapter which is part of this module -> forward pass elif adapter_block in self.adapters: adapter_layer = self.adapters[adapter_block] - _, _, up = adapter_layer(hidden_states, residual_input=residual) + layer_output = adapter_layer( + hidden_states, residual_input=residual, output_gating=context.output_adapter_gating_scores + ) + up = layer_output[2] + self._store_gating_score(adapter_block, layer_output[-1]) up_list.append(up) # Case 3: nesting other composition blocks is invalid elif isinstance(adapter_block, AdapterCompositionBlock): @@ -250,12 +287,18 @@ def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, layer up_list = torch.stack(up_list) up_list = up_list.permute(1, 2, 0, 3) - hidden_states = self.adapter_fusion_layer[adapter_setup.name]( + fusion_output = self.adapter_fusion_layer[adapter_setup.name]( query, up_list, up_list, residual, + output_attentions=context.output_adapter_fusion_attentions, ) + if context.output_adapter_fusion_attentions: + hidden_states = fusion_output[0] + self._store_fusion_attentions(adapter_setup.name, fusion_output[-1]) + else: + hidden_states = fusion_output return hidden_states @@ -300,7 +343,14 @@ def adapter_split(self, adapter_setup: Split, hidden_states, input_tensor, layer # Case 4: We have a single adapter which is part of this module -> forward pass elif adapter_block in self.adapters: adapter_layer = self.adapters[adapter_block] - split_hidden_states[i], _, _ = adapter_layer(split_hidden_states[i], residual_input=split_residual[i]) + context = ForwardContext.get_context() + layer_output = adapter_layer( + split_hidden_states[i], + residual_input=split_residual[i], + output_gating=context.output_adapter_gating_scores, + ) + split_hidden_states[i] = layer_output[0] + self._store_gating_score(adapter_block, layer_output[-1]) # Case 5: nesting other composition blocks is invalid elif isinstance(adapter_block, AdapterCompositionBlock): raise ValueError( @@ -365,10 +415,14 @@ def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor, # Case 3: We have a single adapter which is part of this module -> forward pass elif child in self.adapters: adapter_layer = self.adapters[child] - child_hidden_states, _, _ = adapter_layer( + context = ForwardContext.get_context() + layer_output = adapter_layer( hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size], residual_input=residual[i * orig_batch_size : (i + 1) * orig_batch_size], + output_gating=context.output_adapter_gating_scores, ) + child_hidden_states = layer_output[0] + self._store_gating_score(child, layer_output[-1]) children_hidden.append(child_hidden_states) # Case 4: nesting other composition blocks is invalid elif isinstance(child, AdapterCompositionBlock): @@ -436,10 +490,14 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten elif adapter_block in self.adapters: adapter_layer = self.adapters[adapter_block] - child, _, _ = adapter_layer( - hidden_states[batch_idx[0] : batch_idx[1]], residual_input=residual[batch_idx[0] : batch_idx[1]] + context = ForwardContext.get_context() + layer_output = adapter_layer( + hidden_states[batch_idx[0] : batch_idx[1]], + residual_input=residual[batch_idx[0] : batch_idx[1]], + output_gating=context.output_adapter_gating_scores, ) - children_hidden.append(child) + children_hidden.append(layer_output[0]) + self._store_gating_score(adapter_block, layer_output[-1]) # Case 5: nesting other composition blocks is invalid elif isinstance(adapter_block, AdapterCompositionBlock): raise ValueError( diff --git a/src/transformers/adapters/loading.py b/src/transformers/adapters/loading.py index a66f98b8a0..0eaf58dae7 100644 --- a/src/transformers/adapters/loading.py +++ b/src/transformers/adapters/loading.py @@ -305,6 +305,7 @@ def filter_func(self, adapter_name): lambda x: "_adapters.{}.".format(adapter_name) in x or ".adapters.{}.".format(adapter_name) in x or ".prefix_tunings.{}.".format(adapter_name) in x + or ".prefix_gates.{}.".format(adapter_name) in x or ".loras.{}.".format(adapter_name) in x ) @@ -348,6 +349,7 @@ def rename_func(self, old_name, new_name): lambda k: self._rename_legacy_weights(k) .replace("adapters.{}.".format(old_name), "adapters.{}.".format(new_name)) .replace(".prefix_tunings.{}.".format(old_name), ".prefix_tunings.{}.".format(new_name)) + .replace(".prefix_gates.{}.".format(old_name), ".prefix_gates.{}.".format(new_name)) .replace(".loras.{}.".format(old_name), ".loras.{}.".format(new_name)) ) diff --git a/src/transformers/adapters/lora.py b/src/transformers/adapters/lora.py index 587a57b332..657cde095b 100644 --- a/src/transformers/adapters/lora.py +++ b/src/transformers/adapters/lora.py @@ -22,12 +22,14 @@ def __init__( lora_A_shape, lora_B_shape, config: LoRAConfig, + gating_heads: int = 1, ): super().__init__() self.r = config.r self.lora_alpha = config.alpha self.composition_mode = config.composition_mode self.attn_matrices = config.attn_matrices + self.use_gating = config.use_gating # Optional dropout if config.dropout > 0.0: self.lora_dropout = nn.Dropout(p=config.dropout) @@ -43,28 +45,39 @@ def __init__( self.lora_B = nn.Parameter(torch.zeros(lora_B_shape)) self.scaling = self.lora_alpha / self.r + if self.use_gating: + self.gate = nn.Linear(lora_A_shape[-1], gating_heads) + if config.init_weights == "lora": # initialize A the same way as the default for nn.Linear and B to zero if self.composition_mode == "add": nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) + if self.use_gating: + nn.init.normal_(self.gate.weight, std=0.02) elif config.init_weights == "bert": if self.composition_mode == "add": nn.init.normal_(self.lora_A, std=0.02) nn.init.normal_(self.lora_B, std=0.02) + if self.use_gating: + nn.init.normal_(self.gate.weight, std=0.02) elif config.init_weights == "ia3": if self.composition_mode == "add": nn.init.ones_(self.lora_A) nn.init.ones_(self.lora_B) + if self.use_gating: + nn.init.normal_(self.gate.weight, std=0.02) else: raise ValueError("Unknown init_weights type: {}".format(config.init_weights)) - def com(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor: + def com(self, weights: torch.Tensor, added: torch.Tensor, scaling=None) -> torch.Tensor: """Performs the composition operation between existing and injected weights.""" + if scaling is None: + scaling = self.scaling if self.composition_mode == "add": - return weights + added * self.scaling + return weights + added * scaling elif self.composition_mode == "scale": - return weights * (added * self.scaling) + return weights * (added * scaling) else: raise ValueError("Invalid composition mode.") @@ -105,7 +118,11 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: location_key=self.location_key, ) if lora_config is not None and self._check_lora_location(lora_config): - lora = LoRA(*self._get_lora_shapes(lora_config), lora_config) + lora = LoRA( + *self._get_lora_shapes(lora_config), + lora_config, + gating_heads=self.get_n_heads(lora_config), + ) lora.train(self.training) self.loras[adapter_name] = lora return True @@ -176,7 +193,7 @@ def T(w): self.weight.data = lora.com_inv(self.weight.data, delta_w) self.merged = None - def _compute_adapted_weight(self, lora): + def _compute_adapted_weight(self, lora, scaling=None): def T(w): return torch.t(w) if self.fan_in_fan_out else w @@ -187,7 +204,7 @@ def T(w): delta_w = T(lora.lora_B) else: delta_w = T(lora.lora_B @ lora.lora_A) - weight = lora.com(weight, delta_w) + weight = lora.com(weight, delta_w, scaling=scaling) return weight @@ -197,6 +214,8 @@ def merge_adapter(self, name: str): return # already merged elif not self.merged: lora = self.loras[name] + if lora.use_gating: + raise ValueError("Cannot merge LoRA layer with gating.") self.weight.data = self._compute_adapted_weight(lora) self.merged = name elif self.merged != name: @@ -204,7 +223,7 @@ def merge_adapter(self, name: str): def forward(self, x: torch.Tensor): def T(w): - return torch.t(w) if self.fan_in_fan_out else w + return torch.transpose(w, -2, -1) if self.fan_in_fan_out else w if not self.merged: adapter_setup = self.get_active_setup(self.loras) @@ -218,7 +237,13 @@ def T(w): delta_w = lora.lora_B.view(1, 1, -1) else: delta_w = lora.lora_dropout(x) @ torch.t(lora.lora_A) @ torch.t(lora.lora_B) - result = lora.com(result, delta_w) + if lora.use_gating: + gate = torch.sigmoid(lora.gate(x)) + gate = torch.mean(gate, dim=1).unsqueeze(-1) + self._store_gating_score(adapter_setup[0], gate) + else: + gate = None + result = lora.com(result, delta_w, scaling=gate) return result else: raise ValueError(f"Invalid adapter setup. Cannot use {adapter_setup} with LoRA.") @@ -307,7 +332,7 @@ def T(w): self.weight.data = lora.com_inv(self.weight.data, T(self.pad(delta_w, lora))) self.merged = None - def _compute_adapted_weight(self, lora): + def _compute_adapted_weight(self, name, lora): def T(w): return w if self.fan_in_fan_out else torch.t(w) @@ -331,7 +356,9 @@ def merge_adapter(self, name: str): return # already merged elif not self.merged: lora = self.loras[name] - self.weight.data = self._compute_adapted_weight(lora) + if lora.use_gating: + raise ValueError("Cannot merge LoRA layer with gating.") + self.weight.data = self._compute_adapted_weight(name, lora) self.merged = name elif self.merged != name: raise ValueError("LoRALayer already has a merged LoRA module. Please reset it first.") @@ -355,8 +382,17 @@ def T(w): after_A.transpose(-2, -1), lora.lora_B.unsqueeze(-1), groups=sum(lora.enable_lora) ).transpose(-2, -1) delta_w = after_B - # result shape: x x - result = lora.com(result, self.pad(delta_w, lora)) + if lora.use_gating: + gate = torch.sigmoid(lora.gate(x)) + gate = torch.mean(gate, dim=1) + self._store_gating_score(adapter_setup[0], gate) + gate = self.pad( + gate.repeat_interleave(self.out_features // 3, dim=-1), lora, fill_value=1 + ).unsqueeze(1) + else: + gate = None + # result = (batch_size, seq_len, head_dim * 3) + result = lora.com(result, self.pad(delta_w, lora), scaling=gate) return result else: raise ValueError(f"Invalid adapter setup. Cannot use {adapter_setup} with LoRA.") diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index 260b679f8e..95b07ac849 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -785,6 +785,11 @@ def forward_context(self, context: ForwardContext, *args, **kwargs): } context.prefix_states = self.base_model.prefix_tuning(*args, **kwargs) + # Adapter gating and attention outputs + context.output_adapter_gating_scores = kwargs.get("output_adapter_gating_scores", False) + context.output_adapter_fusion_attentions = kwargs.get("output_adapter_fusion_attentions", False) + context.adapter_gating_scores = defaultdict(dict) + context.adapter_fusion_attentions = defaultdict(dict) def get_fusion_regularization_loss(self): reg_loss = 0.0 diff --git a/src/transformers/adapters/modeling.py b/src/transformers/adapters/modeling.py index 11a3bf1aaa..2d6ac3b38d 100644 --- a/src/transformers/adapters/modeling.py +++ b/src/transformers/adapters/modeling.py @@ -46,6 +46,7 @@ def __init__( self.add_layer_norm_before = config["ln_before"] self.add_layer_norm_after = config["ln_after"] self.adapter_residual_before_ln = config["adapter_residual_before_ln"] + self.use_gating = config["use_gating"] # Params related to input & output of adapter self.residual_before_ln = config["residual_before_ln"] @@ -104,16 +105,23 @@ def __init__( if self.add_layer_norm_after: self.adapter_norm_after = nn.LayerNorm(self.input_size) + if self.use_gating: + self.gate = nn.Linear(self.input_size, 1) + # if we want to initialize with the bert strategy then this function is called for all the linear layers if config["init_weights"] == "bert": self.adapter_down.apply(self.init_bert_weights) self.adapter_up.apply(self.init_bert_weights) + if self.use_gating: + self.gate.apply(self.init_bert_weights) elif config["init_weights"] == "mam_adapter": with torch.no_grad(): nn.init.kaiming_uniform_(self.adapter_down[0].weight, a=math.sqrt(5)) nn.init.zeros_(self.adapter_up.weight) nn.init.zeros_(self.adapter_down[0].bias) nn.init.zeros_(self.adapter_up.bias) + if self.use_gating: + self.gate.apply(self.init_bert_weights) else: raise ValueError("Unknown init_weights type: {}".format(config["init_weights"])) @@ -157,13 +165,19 @@ def pre_forward( return hidden_states, query, residual - def forward(self, x, residual_input): # , residual_input=None): + def forward(self, x, residual_input, output_gating=False): down = self.adapter_down(x) up = self.adapter_up(down) up = up * self.scaling output = up + if self.use_gating: + # x.shape = (batch_size, seq_len, hidden_size) + gate = torch.sigmoid(self.gate(x)) + gate = torch.mean(gate, dim=1).unsqueeze(-1) + output = output * gate + # apply residual connection before layer norm if configured in this way if self.adapter_residual_before_ln: output = output + residual_input @@ -176,6 +190,8 @@ def forward(self, x, residual_input): # , residual_input=None): if not self.adapter_residual_before_ln: output = output + residual_input + if self.use_gating and output_gating: + return output, down, up, gate return output, down, up def post_forward(self, hidden_states, input_hidden_states, input_tensor, layer_norm): @@ -246,7 +262,7 @@ def pre_forward( query = input_tensor return input_tensor, query, input_tensor - def forward(self, x, residual_input): + def forward(self, x, residual_input, output_gating=False): down = self.adapter_down(x) up = self.adapter_up(down) @@ -254,10 +270,18 @@ def forward(self, x, residual_input): output = up + if self.use_gating: + # x.shape = (batch_size, seq_len, hidden_size) + gate = torch.sigmoid(self.gate(x)) + gate = torch.mean(gate, dim=1).unsqueeze(-1) + output = output * gate + # apply layer norm if available if self.add_layer_norm_after: output = self.adapter_norm_after(output) + if self.use_gating and output_gating: + return output, down, up, gate return output, down, up def post_forward(self, hidden_states, input_hidden_states, input_tensor, layer_norm): @@ -332,7 +356,7 @@ def __init__( self.T = 1.0 self.reduction = self.T / 1000.0 - def forward(self, query, key, value, residual): + def forward(self, query, key, value, residual, output_attentions: bool = False): if self.config["residual_before"]: value += residual[:, :, None, :].repeat(1, 1, value.size(2), 1) @@ -362,9 +386,6 @@ def forward(self, query, key, value, residual): attention_probs = nn.Softmax(dim=-1)(attention_scores / self.T) self.T = max(self.T - self.reduction, 1.0) - if not self.training: - self.recent_attention = attention_probs.detach().cpu().numpy() - context_layer = torch.squeeze(torch.matmul(attention_probs.unsqueeze(2), value_layer), dim=2) if self.config["value"] and not self.config["value_before_softmax"]: @@ -376,7 +397,11 @@ def forward(self, query, key, value, residual): if not self.config["residual_before"]: context_layer += residual - return context_layer + if output_attentions: + attention_probs = attention_probs.detach().cpu().numpy() + return context_layer, attention_probs + else: + return context_layer # Invertible Adapters diff --git a/src/transformers/adapters/models/bart/adapter_model.py b/src/transformers/adapters/models/bart/adapter_model.py index 596adb2bf6..a6ce6c827c 100644 --- a/src/transformers/adapters/models/bart/adapter_model.py +++ b/src/transformers/adapters/models/bart/adapter_model.py @@ -57,6 +57,8 @@ def forward( return_dict=None, past_key_values=None, head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, **kwargs ): r""" @@ -85,6 +87,8 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, past_key_values=past_key_values, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, ) # sequence classification based on last token in sequence x = outputs[0] # last hidden state diff --git a/src/transformers/adapters/models/bert/adapter_model.py b/src/transformers/adapters/models/bert/adapter_model.py index 5672019105..0f6d921435 100644 --- a/src/transformers/adapters/models/bert/adapter_model.py +++ b/src/transformers/adapters/models/bert/adapter_model.py @@ -44,6 +44,8 @@ def forward( output_hidden_states=None, return_dict=None, head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, **kwargs ): input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None @@ -68,6 +70,8 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, ) # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: diff --git a/src/transformers/adapters/models/deberta/adapter_model.py b/src/transformers/adapters/models/deberta/adapter_model.py index 310bad4dec..49e791be39 100644 --- a/src/transformers/adapters/models/deberta/adapter_model.py +++ b/src/transformers/adapters/models/deberta/adapter_model.py @@ -38,6 +38,8 @@ def forward( output_hidden_states=None, return_dict=None, head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, **kwargs ): input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None @@ -61,6 +63,8 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, ) # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: diff --git a/src/transformers/adapters/models/debertaV2/adapter_model.py b/src/transformers/adapters/models/debertaV2/adapter_model.py index bca062036b..6c19c44a69 100644 --- a/src/transformers/adapters/models/debertaV2/adapter_model.py +++ b/src/transformers/adapters/models/debertaV2/adapter_model.py @@ -41,6 +41,8 @@ def forward( output_hidden_states=None, return_dict=None, head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, **kwargs ): input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None @@ -64,6 +66,8 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, ) # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: diff --git a/src/transformers/adapters/models/distilbert/adapter_model.py b/src/transformers/adapters/models/distilbert/adapter_model.py index 5c47f8ba26..4f8c9fa7be 100644 --- a/src/transformers/adapters/models/distilbert/adapter_model.py +++ b/src/transformers/adapters/models/distilbert/adapter_model.py @@ -70,6 +70,8 @@ def forward( output_hidden_states=None, return_dict=None, head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, **kwargs ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -90,6 +92,8 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, ) outputs = self.forward_head( diff --git a/src/transformers/adapters/models/gpt2/adapter_model.py b/src/transformers/adapters/models/gpt2/adapter_model.py index 55a275d7eb..7cd9680a37 100644 --- a/src/transformers/adapters/models/gpt2/adapter_model.py +++ b/src/transformers/adapters/models/gpt2/adapter_model.py @@ -59,6 +59,8 @@ def forward( output_hidden_states=None, return_dict=None, head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, **kwargs ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -77,6 +79,8 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, ) batch_size = outputs[0].shape[0] diff --git a/src/transformers/adapters/models/mbart/adapter_model.py b/src/transformers/adapters/models/mbart/adapter_model.py index d513a5a209..ca1a2f3fb0 100644 --- a/src/transformers/adapters/models/mbart/adapter_model.py +++ b/src/transformers/adapters/models/mbart/adapter_model.py @@ -57,6 +57,8 @@ def forward( return_dict=None, past_key_values=None, head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, **kwargs ): r""" @@ -85,6 +87,8 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, past_key_values=past_key_values, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, ) # sequence classification based on last token in sequence x = outputs[0] # last hidden state diff --git a/src/transformers/adapters/models/roberta/adapter_model.py b/src/transformers/adapters/models/roberta/adapter_model.py index b8bfff31df..70ef643786 100644 --- a/src/transformers/adapters/models/roberta/adapter_model.py +++ b/src/transformers/adapters/models/roberta/adapter_model.py @@ -49,6 +49,8 @@ def forward( output_hidden_states=None, return_dict=None, head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, **kwargs ): input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None @@ -73,6 +75,8 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, ) # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: diff --git a/src/transformers/adapters/models/t5/adapter_model.py b/src/transformers/adapters/models/t5/adapter_model.py index 80ef3741b6..c47eb36ad9 100644 --- a/src/transformers/adapters/models/t5/adapter_model.py +++ b/src/transformers/adapters/models/t5/adapter_model.py @@ -53,6 +53,8 @@ def forward( output_hidden_states=None, return_dict=None, head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, **kwargs ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -76,6 +78,8 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, ) sequence_output = model_output[0] # ToDo move head to device for parallel forward pass diff --git a/src/transformers/adapters/models/vit/adapter_model.py b/src/transformers/adapters/models/vit/adapter_model.py index d2510909d6..6a5692dc16 100644 --- a/src/transformers/adapters/models/vit/adapter_model.py +++ b/src/transformers/adapters/models/vit/adapter_model.py @@ -33,6 +33,8 @@ def forward( interpolate_pos_encoding: Optional[bool] = None, return_dict: Optional[bool] = None, head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -44,6 +46,8 @@ def forward( output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, ) # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads diff --git a/src/transformers/adapters/prefix_tuning.py b/src/transformers/adapters/prefix_tuning.py index b98f938741..0dceffd697 100644 --- a/src/transformers/adapters/prefix_tuning.py +++ b/src/transformers/adapters/prefix_tuning.py @@ -237,6 +237,7 @@ def __init__(self, location_key: str, config): self.config = config self.location_key = location_key self.prefixes = {} + self.prefix_gates = nn.ModuleDict() def set_pool(self, pool: PrefixTuningPool): self.__setattr__("pool", pool) @@ -258,10 +259,18 @@ def add_adapter(self, adapter_name: str, layer_idx: int): prefix_id = self.pool.indicate_prefix(adapter_name, self.location_key) self.prefixes[adapter_name] = prefix_id + if prefix_tuning_config.use_gating: + gate_outputs = 1 if prefix_tuning_config.shared_gating else 2 + gate = nn.Linear(self.config.hidden_size, gate_outputs) + gate.weight.data.normal_(mean=0.0, std=0.02) + self.prefix_gates[adapter_name] = gate + def delete_adapter(self, adapter_name: str): self.pool.delete_prefix(adapter_name) if adapter_name in self.prefixes: del self.prefixes[adapter_name] + if adapter_name in self.prefix_gates: + del self.prefix_gates[adapter_name] def add_fusion_layer(self, adapter_names: Union[List, str]): pass # not applicable to prefix tuning @@ -273,17 +282,25 @@ def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapt if unfreeze_adapters: for prefix_tuning_name in adapter_setup.flatten(): self.pool.enable_prefix(prefix_tuning_name) + if prefix_tuning_name in self.prefix_gates: + for param in self.prefix_gates[prefix_tuning_name].parameters(): + param.requires_grad = unfreeze_adapters def get_adapter(self, adapter_name): + return_dict = nn.ModuleDict() # Make sure to only return params once if adapter_name in self.prefixes and self.prefixes[adapter_name] == 0: prefix_module = self.pool.get_prefix(adapter_name) if prefix_module is not None: - return prefix_module[self.location_key] + return_dict["prefix"] = prefix_module[self.location_key] + if adapter_name in self.prefix_gates: + return_dict["gate"] = self.prefix_gates[adapter_name] + if len(return_dict) > 0: + return return_dict return None - def forward(self, key_states, value_states, attention_mask=None, invert_mask=True): + def forward(self, key_states, value_states, residual_input, attention_mask=None, invert_mask=True): adapter_setup = self.get_active_setup(self.prefixes) if adapter_setup is not None: if len(adapter_setup) == 1: @@ -295,10 +312,20 @@ def forward(self, key_states, value_states, attention_mask=None, invert_mask=Tru # Retrieve pre-computed prefix states from context context = ForwardContext.get_context() + # batch_size x n_heads x prefix_length x n_embd_per_head prefix_keys, prefix_values = context.prefix_states[prefix_tuning_name][self.location_key][ prefix_id ] + if prefix_tuning_name in self.prefix_gates: + gate = self.prefix_gates[prefix_tuning_name] + gate_output = torch.mean(torch.sigmoid(gate(residual_input)), dim=1) + self._store_gating_score(prefix_tuning_name, gate_output) + gate_output_key = gate_output[:, 0].view(-1, 1, 1, 1) + gate_output_value = gate_output[:, -1].view(-1, 1, 1, 1) + key_states = key_states * gate_output_key + value_states = value_states * gate_output_value + key_states = torch.cat([prefix_keys, key_states], dim=2) value_states = torch.cat([prefix_values, value_states], dim=2) if attention_mask is not None: diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 3a9757ec6c..518978b69d 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -237,7 +237,9 @@ def forward( proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states, value_states, attention_mask = self.prefix_tuning(key_states, value_states, attention_mask) + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 1c367aad01..dd17f489b4 100644 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -335,7 +335,9 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_layer, value_layer) - key_layer, value_layer, attention_mask = self.prefix_tuning(key_layer, value_layer, attention_mask) + key_layer, value_layer, attention_mask = self.prefix_tuning( + key_layer, value_layer, hidden_states, attention_mask + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index e753dd85b7..cea1b033c8 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -655,7 +655,9 @@ def linear(w, b, x): query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :]) - key_layer, value_layer, attention_mask = self.prefix_tuning(key_layer, value_layer, attention_mask, False) + key_layer, value_layer, attention_mask = self.prefix_tuning( + key_layer, value_layer, hidden_states, attention_mask, False + ) rel_att = None # Take the dot product between "query" and "key" to get the raw attention scores. diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 071ef71f06..87ea225318 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -714,7 +714,7 @@ def forward( value_layer = self.transpose_for_scores_extended(self.value_proj(hidden_states), self.num_attention_heads) key_layer, value_layer, attention_mask = self.prefix_tuning( - key_layer, value_layer, attention_mask, False + key_layer, value_layer, hidden_states, attention_mask, False ) # [:, 0, :, 0]) key_layer = key_layer.contiguous().view(-1, key_layer.size(2), key_layer.size(-1)) diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 231778fa13..ae2d9ec7b1 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -216,7 +216,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor: k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) - k, v, mask = self.prefix_tuning(k, v, mask, invert_mask=False) + k, v, mask = self.prefix_tuning(k, v, value, mask, invert_mask=False) mask_reshp = (bs, 1, 1, k.size(2)) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 6dfdf0f088..ae3669fe4f 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -355,7 +355,7 @@ def forward( else: present = None - key, value, attention_mask = self.prefix_tuning(key, value, attention_mask) + key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask) if self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 44dc64a35c..e817ad073f 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -244,7 +244,9 @@ def forward( proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states, value_states, attention_mask = self.prefix_tuning(key_states, value_states, attention_mask) + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 8d35c27bb3..4cb40cbe2b 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -246,7 +246,9 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_layer, value_layer) - key_layer, value_layer, attention_mask = self.prefix_tuning(key_layer, value_layer, attention_mask) + key_layer, value_layer, attention_mask = self.prefix_tuning( + key_layer, value_layer, hidden_states, attention_mask + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 2039a03708..bf5e995557 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -309,7 +309,7 @@ def forward(self, hidden_states): class T5DenseGatedActDense(nn.Module): def __init__(self, config: T5Config): super().__init__() - self.wi_0 = LoRALinear(config.d_model, config.d_ff, "intermediate", config, bias=False) + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) self.wi_1 = LoRALinear(config.d_model, config.d_ff, "intermediate", config, bias=False) self.wo = LoRALinear(config.d_ff, config.d_model, "output", config, bias=False) self.dropout = nn.Dropout(config.dropout_rate) @@ -523,7 +523,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - key_states, value_states, mask = self.prefix_tuning(key_states, value_states, mask) + key_states, value_states, mask = self.prefix_tuning(key_states, value_states, hidden_states, mask) key_length = key_states.size(2) # compute scores diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index e08378de3c..d291fe9c37 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -223,7 +223,7 @@ def forward( value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer, value_layer, _ = self.prefix_tuning(key_layer, value_layer) + key_layer, value_layer, _ = self.prefix_tuning(key_layer, value_layer, hidden_states) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/tests_adapters/methods/__init__.py b/tests_adapters/methods/__init__.py index 3559c930d3..f40a688e58 100644 --- a/tests_adapters/methods/__init__.py +++ b/tests_adapters/methods/__init__.py @@ -22,3 +22,4 @@ from .test_ia3 import IA3TestMixin from .test_lora import LoRATestMixin from .test_prefix_tuning import PrefixTuningTestMixin +from .test_unipelt import UniPELTTestMixin diff --git a/tests_adapters/methods/base.py b/tests_adapters/methods/base.py index 7eefc8686a..d3944114be 100644 --- a/tests_adapters/methods/base.py +++ b/tests_adapters/methods/base.py @@ -127,7 +127,7 @@ def run_forward_test(self, model, adapter_config): self.assertEqual(len(output_1), len(output_2)) self.assertTrue(torch.equal(output_1[0], output_2[0])) - self.assertEqual(len(output_1), len(base_output)) + self.assertGreaterEqual(len(output_1), len(base_output)) self.assertFalse(torch.equal(output_1[0], base_output[0])) def run_load_test(self, adapter_config): @@ -191,14 +191,14 @@ def run_full_model_load_test(self, adapter_config): self.assertEqual(len(output1), len(output2)) self.assertTrue(torch.equal(output1[0], output2[0])) - def trainings_run(self, model): + def trainings_run(self, model, lr=1.0, steps=20): # setup dataset train_dataset = self.dataset() training_args = TrainingArguments( output_dir="./examples", do_train=True, - learning_rate=1.0, - max_steps=20, + learning_rate=lr, + max_steps=steps, no_cuda=True, per_device_train_batch_size=2, remove_unused_columns=False, diff --git a/tests_adapters/methods/test_lora.py b/tests_adapters/methods/test_lora.py index ba4f047ecf..8d35579528 100644 --- a/tests_adapters/methods/test_lora.py +++ b/tests_adapters/methods/test_lora.py @@ -36,6 +36,6 @@ def test_train_lora(self): def test_merge_lora(self): self.run_merge_test(LoRAConfig(init_weights="bert")) - + def test_reset_lora(self): self.run_reset_test(LoRAConfig(init_weights="bert")) diff --git a/tests_adapters/methods/test_unipelt.py b/tests_adapters/methods/test_unipelt.py new file mode 100644 index 0000000000..507aa7dd1e --- /dev/null +++ b/tests_adapters/methods/test_unipelt.py @@ -0,0 +1,59 @@ +import torch + +from transformers.adapters import UniPELTConfig +from transformers.testing_utils import require_torch, torch_device + +from .base import AdapterMethodBaseTestMixin + + +@require_torch +class UniPELTTestMixin(AdapterMethodBaseTestMixin): + def test_add_unipelt(self): + model = self.get_model() + self.run_add_test(model, UniPELTConfig(), ["loras.{name}.", "adapters.{name}.", "prefix_tunings.{name}."]) + + def test_delete_unipelt(self): + model = self.get_model() + self.run_delete_test(model, UniPELTConfig(), ["loras.{name}.", "adapters.{name}.", "prefix_tunings.{name}."]) + + def test_get_unipelt(self): + model = self.get_model() + self.run_get_test(model, UniPELTConfig()) + + def test_forward_unipelt(self): + model = self.get_model() + self.run_forward_test(model, UniPELTConfig()) + + def test_load_unipelt(self): + self.run_load_test(UniPELTConfig()) + + def test_load_full_model_unipelt(self): + self.run_full_model_load_test(UniPELTConfig()) + + def test_train_unipelt(self): + self.run_train_test( + UniPELTConfig(), ["loras.{name}.", "adapters.{name}.", "prefix_tunings.{name}.", "prefix_gates.{name}."] + ) + + def test_output_adapter_gating_scores_unipelt(self): + model = self.get_model() + model.eval() + + adapter_config = UniPELTConfig() + name = adapter_config.__class__.__name__ + model.add_adapter(name, config=adapter_config) + model.to(torch_device) + + input_data = self.get_input_samples(config=model.config) + + model.set_active_adapters([name]) + output_1 = model(**input_data, output_adapter_gating_scores=True) + + self.assertEqual(len(output_1[0]), self.default_input_samples_shape[0]) + self.assertTrue(hasattr(output_1, "adapter_gating_scores")) + gating_scores = output_1.adapter_gating_scores[name] + self.assertEqual(len(list(model.iter_layers())), len(gating_scores)) + for k, per_layer_scores in gating_scores.items(): + self.assertGreaterEqual(len(per_layer_scores), 3) + for k, v in per_layer_scores.items(): + self.assertEqual(self.default_input_samples_shape[0], v.shape[0], k) diff --git a/tests_adapters/test_adapter_fusion_common.py b/tests_adapters/test_adapter_fusion_common.py index b1a8c79ef9..703d971cdd 100644 --- a/tests_adapters/test_adapter_fusion_common.py +++ b/tests_adapters/test_adapter_fusion_common.py @@ -195,3 +195,26 @@ def test_adapter_fusion_save_with_head(self): output2 = model2(**in_data) self.assertEqual(len(output1), len(output2)) self.assertTrue(torch.equal(output1[0], output2[0])) + + def test_output_adapter_fusion_attentions(self): + model = self.get_model() + model.eval() + + model.add_adapter("a") + model.add_adapter("b") + model.add_adapter_fusion(["a", "b"]) + model.to(torch_device) + + input_data = self.get_input_samples(config=model.config) + + model.set_active_adapters(Fuse("a", "b")) + output_1 = model(**input_data, output_adapter_fusion_attentions=True) + + self.assertEqual(len(output_1[0]), self.default_input_samples_shape[0]) + self.assertTrue(hasattr(output_1, "adapter_fusion_attentions")) + attention_scores = output_1.adapter_fusion_attentions["a,b"] + self.assertEqual(len(list(model.iter_layers())), len(attention_scores)) + for k, per_layer_scores in attention_scores.items(): + self.assertEqual(len(per_layer_scores), 1) + for k, v in per_layer_scores.items(): + self.assertEqual(self.default_input_samples_shape[0], v.shape[0], k) diff --git a/tests_adapters/test_bart.py b/tests_adapters/test_bart.py index d5c0abe848..41dd3beab5 100644 --- a/tests_adapters/test_bart.py +++ b/tests_adapters/test_bart.py @@ -4,7 +4,14 @@ from transformers import BartAdapterModel from transformers.testing_utils import require_torch -from .methods import BottleneckAdapterTestMixin, LoRATestMixin, CompacterTestMixin, PrefixTuningTestMixin, IA3TestMixin +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, +) from .test_adapter import AdapterTestBase, make_config from .test_adapter_backward_compability import CompabilityTestMixin from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin @@ -45,6 +52,7 @@ class BartAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, EmbeddingTestMixin, diff --git a/tests_adapters/test_bert.py b/tests_adapters/test_bert.py index e25d92be84..e79a53d0df 100644 --- a/tests_adapters/test_bert.py +++ b/tests_adapters/test_bert.py @@ -4,7 +4,14 @@ from transformers import BertAdapterModel from transformers.testing_utils import require_torch -from .methods import BottleneckAdapterTestMixin, CompacterTestMixin, LoRATestMixin, PrefixTuningTestMixin, IA3TestMixin +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, +) from .test_adapter import AdapterTestBase, make_config from .test_adapter_backward_compability import CompabilityTestMixin from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin @@ -42,6 +49,7 @@ class BertAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + UniPELTTestMixin, EmbeddingTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_deberta.py b/tests_adapters/test_deberta.py index 990ac0d77f..bdd64e31d6 100644 --- a/tests_adapters/test_deberta.py +++ b/tests_adapters/test_deberta.py @@ -4,7 +4,7 @@ from transformers import DebertaAdapterModel from transformers.testing_utils import require_torch -from .methods import BottleneckAdapterTestMixin, CompacterTestMixin, LoRATestMixin, PrefixTuningTestMixin, IA3TestMixin +from .methods import BottleneckAdapterTestMixin, UniPELTTestMixin, CompacterTestMixin, IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin from .test_adapter import AdapterTestBase, make_config from .test_adapter_backward_compability import CompabilityTestMixin from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin @@ -49,6 +49,7 @@ class DebertaAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + UniPELTTestMixin, EmbeddingTestMixin, ParallelTrainingMixin, diff --git a/tests_adapters/test_debertaV2.py b/tests_adapters/test_debertaV2.py index c146b1bcdd..26ab7cde68 100644 --- a/tests_adapters/test_debertaV2.py +++ b/tests_adapters/test_debertaV2.py @@ -4,7 +4,7 @@ from transformers import DebertaV2AdapterModel from transformers.testing_utils import require_torch -from .methods import BottleneckAdapterTestMixin, CompacterTestMixin, LoRATestMixin, PrefixTuningTestMixin, IA3TestMixin +from .methods import BottleneckAdapterTestMixin, UniPELTTestMixin, CompacterTestMixin, IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin from .test_adapter import AdapterTestBase, make_config from .test_adapter_backward_compability import CompabilityTestMixin from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin @@ -50,6 +50,7 @@ class DebertaV2AdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + UniPELTTestMixin, EmbeddingTestMixin, ParallelTrainingMixin, diff --git a/tests_adapters/test_distilbert.py b/tests_adapters/test_distilbert.py index 6748971587..759de3f80b 100644 --- a/tests_adapters/test_distilbert.py +++ b/tests_adapters/test_distilbert.py @@ -4,7 +4,7 @@ from transformers import DistilBertAdapterModel from transformers.testing_utils import require_torch -from .methods import BottleneckAdapterTestMixin, LoRATestMixin, CompacterTestMixin, PrefixTuningTestMixin, IA3TestMixin +from .methods import BottleneckAdapterTestMixin, UniPELTTestMixin, CompacterTestMixin, IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin from .test_adapter import AdapterTestBase, make_config from .test_adapter_backward_compability import CompabilityTestMixin from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin @@ -42,6 +42,7 @@ class DistilBertAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + UniPELTTestMixin, EmbeddingTestMixin, CompabilityTestMixin, AdapterFusionModelTestMixin, diff --git a/tests_adapters/test_encoder_decoder.py b/tests_adapters/test_encoder_decoder.py index 170127e690..f1a48cc631 100644 --- a/tests_adapters/test_encoder_decoder.py +++ b/tests_adapters/test_encoder_decoder.py @@ -1,7 +1,14 @@ from tests.models.encoder_decoder.test_modeling_encoder_decoder import * # Imported to execute model tests from transformers import AutoModelForSeq2SeqLM, BertConfig -from .methods import BottleneckAdapterTestMixin, CompacterTestMixin, LoRATestMixin, PrefixTuningTestMixin, IA3TestMixin +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, +) from .test_adapter import AdapterTestBase from .test_adapter_fusion_common import AdapterFusionModelTestMixin @@ -37,6 +44,7 @@ class EncoderDecoderAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + UniPELTTestMixin, AdapterFusionModelTestMixin, EncoderDecoderAdapterTestBase, unittest.TestCase, @@ -65,3 +73,11 @@ def forward_pre_hook(module, input): self.assertEqual((1, 128, model.config.decoder.vocab_size), out[0].shape) self.assertEqual(2, calls) + + def test_output_adapter_gating_scores_unipelt(self): + # TODO currently not supported + self.skipTest("Not implemented.") + + def test_output_adapter_fusion_attentions(self): + # TODO currently not supported + self.skipTest("Not implemented.") diff --git a/tests_adapters/test_gpt2.py b/tests_adapters/test_gpt2.py index 8f8fdd146f..2de9fb7a43 100644 --- a/tests_adapters/test_gpt2.py +++ b/tests_adapters/test_gpt2.py @@ -4,7 +4,14 @@ from transformers import GPT2AdapterModel from transformers.testing_utils import require_torch -from .methods import BottleneckAdapterTestMixin, CompacterTestMixin, LoRATestMixin, PrefixTuningTestMixin, IA3TestMixin +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, +) from .test_adapter import AdapterTestBase, make_config from .test_adapter_backward_compability import CompabilityTestMixin from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin @@ -43,6 +50,7 @@ class GPT2AdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + UniPELTTestMixin, EmbeddingTestMixin, CompabilityTestMixin, AdapterFusionModelTestMixin, diff --git a/tests_adapters/test_mbart.py b/tests_adapters/test_mbart.py index 0fbe71f629..0e4912d22b 100644 --- a/tests_adapters/test_mbart.py +++ b/tests_adapters/test_mbart.py @@ -4,7 +4,7 @@ from transformers import MBartAdapterModel from transformers.testing_utils import require_torch -from .methods import BottleneckAdapterTestMixin, LoRATestMixin, CompacterTestMixin, PrefixTuningTestMixin, IA3TestMixin +from .methods import BottleneckAdapterTestMixin, UniPELTTestMixin, CompacterTestMixin, IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin from .test_adapter import AdapterTestBase, make_config from .test_adapter_composition import ParallelAdapterInferenceTestMixin from .test_adapter_conversion import ModelClassConversionTestMixin @@ -44,6 +44,7 @@ class MBartAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + UniPELTTestMixin, AdapterFusionModelTestMixin, PredictionHeadModelTestMixin, ParallelAdapterInferenceTestMixin, diff --git a/tests_adapters/test_roberta.py b/tests_adapters/test_roberta.py index 8d7c2cb355..432e6eea31 100644 --- a/tests_adapters/test_roberta.py +++ b/tests_adapters/test_roberta.py @@ -4,7 +4,7 @@ from transformers import RobertaAdapterModel from transformers.testing_utils import require_torch -from .methods import BottleneckAdapterTestMixin, CompacterTestMixin, LoRATestMixin, PrefixTuningTestMixin, IA3TestMixin +from .methods import BottleneckAdapterTestMixin, UniPELTTestMixin, CompacterTestMixin, IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin from .test_adapter import AdapterTestBase, make_config from .test_adapter_backward_compability import CompabilityTestMixin from .test_adapter_composition import ParallelAdapterInferenceTestMixin @@ -42,6 +42,7 @@ class RobertaAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, PredictionHeadModelTestMixin, diff --git a/tests_adapters/test_t5.py b/tests_adapters/test_t5.py index 5ac59ddaee..0f5f134bc3 100644 --- a/tests_adapters/test_t5.py +++ b/tests_adapters/test_t5.py @@ -3,10 +3,17 @@ from datasets import load_dataset from tests.models.t5.test_modeling_t5 import * -from transformers import T5AdapterModel, AutoTokenizer +from transformers import AutoTokenizer, T5AdapterModel from transformers.testing_utils import require_torch -from .methods import BottleneckAdapterTestMixin, LoRATestMixin, CompacterTestMixin, PrefixTuningTestMixin, IA3TestMixin +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, +) from .test_adapter import AdapterTestBase, make_config from .test_adapter_backward_compability import CompabilityTestMixin from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin @@ -92,6 +99,7 @@ class T5AdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + UniPELTTestMixin, EmbeddingTestMixin, CompabilityTestMixin, ParallelAdapterInferenceTestMixin, diff --git a/tests_adapters/test_vit.py b/tests_adapters/test_vit.py index fc4c1c79a4..f5b0d9c04b 100644 --- a/tests_adapters/test_vit.py +++ b/tests_adapters/test_vit.py @@ -5,7 +5,14 @@ from transformers import ViTAdapterModel from transformers.testing_utils import require_torch -from .methods import BottleneckAdapterTestMixin, CompacterTestMixin, LoRATestMixin, PrefixTuningTestMixin, IA3TestMixin +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, +) from .test_adapter import VisionAdapterTestBase, make_config from .test_adapter_backward_compability import CompabilityTestMixin from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin @@ -43,6 +50,7 @@ class ViTAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, PredictionHeadModelTestMixin,