diff --git a/README.md b/README.md
index b96f307547..92b8d1622b 100644
--- a/README.md
+++ b/README.md
@@ -36,7 +36,7 @@ Thus, most files in this repository are direct copies from the HuggingFace Trans
## Installation
-`adapter-transformers` currently supports **Python 3.6+** and **PyTorch 1.3.1+**.
+`adapter-transformers` currently supports **Python 3.7+** and **PyTorch 1.3.1+**.
After [installing PyTorch](https://pytorch.org/get-started/locally/), you can install `adapter-transformers` from PyPI ...
```
@@ -74,10 +74,11 @@ Currently, adapter-transformers integrates all architectures and methods listed
| AdapterDrop | [Rücklé et al. (2021)](https://arxiv.org/pdf/2010.11918.pdf) | [Notebook](https://colab.research.google.com/github/Adapter-Hub/adapter-transformers/blob/master/notebooks/05_Adapter_Drop_Training.ipynb) |
| MAD-X 2.0,
Embedding training | [Pfeiffer et al. (2021)](https://arxiv.org/pdf/2012.15562.pdf) | [Docs: Embeddings](https://docs.adapterhub.ml/embeddings.html), [Notebook](https://colab.research.google.com/github/Adapter-Hub/adapter-transformers/blob/master/notebooks/08_NER_Wikiann.ipynb) |
| Prefix Tuning | [Li and Liang (2021)](https://arxiv.org/pdf/2101.00190.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#prefix-tuning) |
-| Parallel adapters,
Mix-and-Match adapters | [He et al. (2021)](https://arxiv.org/pdf/2110.04366.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#combinations-mix-and-match-adapters) |
+| Parallel adapters,
Mix-and-Match adapters | [He et al. (2021)](https://arxiv.org/pdf/2110.04366.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#mix-and-match-adapters) |
| Compacter | [Mahabadi et al. (2021)](https://arxiv.org/pdf/2106.04647.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#compacter) |
| LoRA | [Hu et al. (2021)](https://arxiv.org/pdf/2106.09685.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#lora) |
-| (IA)^3 | [Liu et al. (2022)](https://arxiv.org/pdf/2205.05638.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#ia3) |
+| (IA)^3 | [Liu et al. (2022)](https://arxiv.org/pdf/2205.05638.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#ia-3) |
+| UniPELT | [Mao et al. (2022)](https://arxiv.org/pdf/2110.07577.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#unipelt) |
## Supported Models
diff --git a/adapter_docs/adapter_composition.md b/adapter_docs/adapter_composition.md
index 2bd472658c..906c504436 100644
--- a/adapter_docs/adapter_composition.md
+++ b/adapter_docs/adapter_composition.md
@@ -109,6 +109,32 @@ To learn how training an _AdapterFusion_ layer works, check out [this Colab note
In v1.x of `adapter-transformers`, fusing adapters was done using a nested list of adapter names, i.e. the example from above would be defined as `[["d", "e", "f"]]`.
For backwards compatibility, you can still do this, although it is recommended to use the new syntax.
+#### Retrieving AdapterFusion attentions
+
+Finally, it is possible to retrieve the attention scores computed by each fusion layer in a forward pass of the model.
+These scores can be used for analyzing the fused adapter blocks and can serve as the basis for visualizations similar to those in the AdapterFusion paper.
+You can collect the fusion attention scores by passing `output_adapter_fusion_attentions=True` to the model forward call.
+The scores for each layer will then be saved in the `adapter_fusion_attentions` attribute of the output:
+
+```python
+outputs = model(**inputs, output_adapter_fusion_attentions=True)
+attention_scores = outputs.adapter_fusion_attentions
+```
+Note that this parameter is only available to base model classes and [AdapterModel classes](prediction_heads.md#adaptermodel-classes).
+In the example, `attention_scores` holds a dictionary of the following form:
+```
+{
+ '': {
+ : {
+ '': np.array([...]),
+ ...
+ },
+ ...
+ },
+ ...
+}
+```
+
## `Split`
```{eval-rst}
diff --git a/adapter_docs/classes/adapter_config.rst b/adapter_docs/classes/adapter_config.rst
index 75da8973c5..e08be42807 100644
--- a/adapter_docs/classes/adapter_config.rst
+++ b/adapter_docs/classes/adapter_config.rst
@@ -65,6 +65,9 @@ Combined configurations
.. autoclass:: transformers.MAMConfig
:members:
+.. autoclass:: transformers.UniPELTConfig
+ :members:
+
Adapter Fusion
~~~~~~~~~~~~~~~
diff --git a/adapter_docs/conf.py b/adapter_docs/conf.py
index 3f1a200632..a24f18b66d 100644
--- a/adapter_docs/conf.py
+++ b/adapter_docs/conf.py
@@ -26,7 +26,7 @@
docs_versions = [
"adapters1.1.1",
"adapters2.3.0",
- "adapters3.0.1",
+ "adapters3.1.0",
]
diff --git a/adapter_docs/img/compacter.png b/adapter_docs/img/compacter.png
new file mode 100644
index 0000000000..a7c2786e96
Binary files /dev/null and b/adapter_docs/img/compacter.png differ
diff --git a/adapter_docs/img/ia3.png b/adapter_docs/img/ia3.png
new file mode 100644
index 0000000000..f335d132be
Binary files /dev/null and b/adapter_docs/img/ia3.png differ
diff --git a/adapter_docs/img/lora.png b/adapter_docs/img/lora.png
new file mode 100644
index 0000000000..310420fb10
Binary files /dev/null and b/adapter_docs/img/lora.png differ
diff --git a/adapter_docs/img/prefix.png b/adapter_docs/img/prefix.png
new file mode 100644
index 0000000000..59d2909933
Binary files /dev/null and b/adapter_docs/img/prefix.png differ
diff --git a/adapter_docs/img/unipelt.png b/adapter_docs/img/unipelt.png
new file mode 100644
index 0000000000..110a19ead9
Binary files /dev/null and b/adapter_docs/img/unipelt.png differ
diff --git a/adapter_docs/installation.md b/adapter_docs/installation.md
index 5364bcee95..ac6c72dfa0 100644
--- a/adapter_docs/installation.md
+++ b/adapter_docs/installation.md
@@ -1,7 +1,7 @@
# Installation
Our *adapter-transformers* package is a drop-in replacement for Huggingface's *transformers* library.
-It currently supports Python 3.6+ and PyTorch 1.3.1+. You will have to [install PyTorch](https://pytorch.org/get-started/locally/) first.
+It currently supports Python 3.7+ and PyTorch 1.3.1+. You will have to [install PyTorch](https://pytorch.org/get-started/locally/) first.
```{eval-rst}
.. important::
diff --git a/adapter_docs/overview.md b/adapter_docs/overview.md
index ae0e3e04b1..8ba89a8fc0 100644
--- a/adapter_docs/overview.md
+++ b/adapter_docs/overview.md
@@ -130,6 +130,15 @@ _Papers:_
_Configuration class_: [`PrefixTuningConfig`](transformers.PrefixTuningConfig)
+```{eval-rst}
+.. figure:: img/prefix.png
+ :height: 300
+ :align: center
+ :alt: Illustration of Prefix Tuning.
+
+ Illustration of the Prefix Tuning method within one Transformer layer. Trained components are colored in shades of magenta.
+```
+
Prefix Tuning ([Li and Liang, 2021](https://aclanthology.org/2021.acl-long.353.pdf)) introduces new parameters in the multi-head attention blocks in each Transformer layer.
More, specifically, it prepends trainable prefix vectors $P^K$ and $P^V$ to the keys and values of the attention head input, each of a configurable prefix length $l$ (`prefix_length` attribute):
@@ -162,6 +171,15 @@ _Papers:_
_Configuration class_: [`CompacterConfig`](transformers.CompacterConfig), [`CompacterPlusPlusConfig`](transformers.CompacterPlusPlusConfig)
+```{eval-rst}
+.. figure:: img/compacter.png
+ :height: 300
+ :align: center
+ :alt: Illustration of Compacter.
+
+ Illustration of the Compacter method within one Transformer layer. Trained components are colored in shades of magenta.
+```
+
The Compacter architecture proposed by [Mahabadi et al., 2021](https://arxiv.org/pdf/2106.04647.pdf)
is similar to the bottleneck adapter architecture. It only exchanges the linear down- and
up-projection with a PHM layer. Unlike the linear layer, the PHM layer constructs its weight matrix from two smaller matrices, which reduces the number of parameters.
@@ -187,6 +205,15 @@ _Papers:_
_Configuration class_: [`LoRAConfig`](transformers.LoRAConfig)
+```{eval-rst}
+.. figure:: img/lora.png
+ :height: 300
+ :align: center
+ :alt: Illustration of LoRA.
+
+ Illustration of the LoRA method within one Transformer layer. Trained components are colored in shades of magenta.
+```
+
Low-Rank Adaptation (LoRA) is an efficient fine-tuning technique proposed by [Hu et al. (2021)](https://arxiv.org/pdf/2106.09685.pdf).
LoRA injects trainable low-rank decomposition matrices into the layers of a pre-trained model.
For any model layer expressed as a matrix multiplication of the form $h = W_0 x$, it therefore performs a reparameterization, such that:
@@ -229,6 +256,15 @@ _Papers:_
_Configuration class_: [`IA3Config`](transformers.IA3Config)
+```{eval-rst}
+.. figure:: img/ia3.png
+ :height: 300
+ :align: center
+ :alt: Illustration of (IA)^3.
+
+ Illustration of the (IA)^3 method within one Transformer layer. Trained components are colored in shades of magenta.
+```
+
_Infused Adapter by Inhibiting and Amplifying Inner Activations ((IA)^3)_ is an efficient fine-tuning method proposed within the _T-Few_ fine-tuning approach by [Liu et al. (2022)](https://arxiv.org/pdf/2205.05638.pdf).
(IA)^3 introduces trainable vectors $l_W$ into different components of a Transformer model which perform element-wise rescaling of inner model activations.
For any model layer expressed as a matrix multiplication of the form $h = W x$, it therefore performs an element-wise multiplication with $l_W$, such that:
@@ -271,7 +307,7 @@ model.reset_adapter("ia3_adapter")
_Papers:_
- [Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning](https://arxiv.org/pdf/2205.05638.pdf) (Liu et al., 2022)
-## Combinations - Mix-and-Match Adapters
+## Method Combinations
_Configuration class_: [`ConfigUnion`](transformers.ConfigUnion)
@@ -290,6 +326,10 @@ config = ConfigUnion(
model.add_adapter("union_adapter", config=config)
```
+### Mix-and-Match Adapters
+
+_Configuration class_: [`MAMConfig`](transformers.MAMConfig)
+
[He et al. (2021)](https://arxiv.org/pdf/2110.04366.pdf) study various variants and combinations of efficient fine-tuning methods.
Among others, they propose _Mix-and-Match Adapters_ as a combination of Prefix Tuning and parallel bottleneck adapters.
This configuration is supported by adapter-transformers out-of-the-box:
@@ -315,3 +355,79 @@ model.add_adapter("mam_adapter", config=config)
_Papers:_
- [Towards a Unified View of Parameter-Efficient Transfer Learning](https://arxiv.org/pdf/2110.04366.pdf) (He et al., 2021)
+
+### UniPELT
+
+_Configuration class_: [`UniPELTConfig`](transformers.UniPELTConfig)
+
+```{eval-rst}
+.. figure:: img/unipelt.png
+ :height: 300
+ :align: center
+ :alt: Illustration of UniPELT.
+
+ Illustration of the UniPELT method within one Transformer layer. Trained components are colored in shades of magenta.
+```
+
+An approach similar to the work of [He et al. (2021)](https://arxiv.org/pdf/2110.04366.pdf) is taken by [Mao et al. (2022)](https://arxiv.org/pdf/2110.07577.pdf) in their _UniPELT_ framework.
+They, too, combine multiple efficient fine-tuning methods, namely LoRA, Prefix Tuning and bottleneck adapters, in a single unified setup.
+_UniPELT_ additionally introduces a gating mechanism that controls the activation of the different submodules.
+
+Concretely, for each adapted module $m$, UniPELT adds a trainable gating value $\mathcal{G}_m \in (0, 1)$ that is computed via a feed-forward network ($W_{\mathcal{G}_m}$) and sigmoid activation ($\sigma$) from the Transformer layer input states ($x$):
+
+$$\mathcal{G}_m \leftarrow \sigma(W_{\mathcal{G}_m} \cdot x)$$
+
+These gating values are then used to scale the output activations of the injected adapter modules, e.g. for a LoRA layer:
+
+$$
+h \leftarrow W_0 x + \mathcal{G}_{LoRA} B A x
+$$
+
+In the configuration classes of `adapter-transformers`, these gating mechanisms can be activated via `use_gating=True`.
+The full UniPELT setup can be instantiated using `UniPELTConfig`[^unipelt]:
+
+[^unipelt]: Note that the implementation of UniPELT in `adapter-transformers` follows the implementation in the original code, which is slighlty different from the description in the paper. See [here](https://github.com/morningmoni/UniPELT/issues/1) for more.
+
+```python
+from transformers.adapters import UniPELTConfig
+
+config = UniPELTConfig()
+model.add_adapter("unipelt", config=config)
+```
+
+which is identical to the following `ConfigUnion`:
+
+```python
+from transformers.adapters import ConfigUnion, LoRAConfig, PrefixTuningConfig, PfeifferConfig
+
+config = ConfigUnion(
+ LoRAConfig(r=8, use_gating=True),
+ PrefixTuningConfig(prefix_length=10, use_gating=True),
+ PfeifferConfig(reduction_factor=16, use_gating=True),
+)
+model.add_adapter("unipelt", config=config)
+```
+
+Finally, as the gating values for each adapter module might provide interesting insights for analysis, `adapter-transformers` comes with an integrated mechanism of returning all gating values computed during a model forward pass via the `output_adapter_gating_scores` parameter:
+
+```python
+outputs = model(**inputs, output_adapter_gating_scores=True)
+gating_scores = outputs.adapter_gating_scores
+```
+Note that this parameter is only available to base model classes and [AdapterModel classes](prediction_heads.md#adaptermodel-classes).
+In the example, `gating_scores` holds a dictionary of the following form:
+```
+{
+ '': {
+ : {
+ '': np.array([...]),
+ ...
+ },
+ ...
+ },
+ ...
+}
+```
+
+_Papers:_
+- [UNIPELT: A Unified Framework for Parameter-Efficient Language Model Tuning](https://arxiv.org/pdf/2110.07577.pdf) (Mao et al., 2022)
diff --git a/setup.py b/setup.py
index 5f46998a2c..7ba7d41b29 100644
--- a/setup.py
+++ b/setup.py
@@ -116,7 +116,7 @@
"fugashi>=1.0",
"GitPython<3.1.19",
"hf-doc-builder>=0.3.0",
- "huggingface-hub>=0.1.0,<0.8.0",
+ "huggingface-hub>=0.1.0,<1.0",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
@@ -417,7 +417,7 @@ def run(self):
setup(
name="adapter-transformers",
- version="3.1.0a1",
+ version="3.1.0",
author="Jonas Pfeiffer, Andreas Rücklé, Clifton Poth, Hannah Sterz, based on work by the HuggingFace team and community",
author_email="pfeiffer@ukp.tu-darmstadt.de",
description="A friendly fork of HuggingFace's Transformers, adding Adapters to PyTorch language models",
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index d03da16858..f28c722ba5 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -22,7 +22,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
-__version__ = "4.21.2"
+__version__ = "4.21.3"
from typing import TYPE_CHECKING
@@ -2063,6 +2063,7 @@
"StaticAdapterFusionConfig",
"T5AdapterModel",
"T5ModelWithHeads",
+ "UniPELTConfig",
"ViTAdapterModel",
"XLMRobertaAdapterModel",
"XLMRobertaModelWithHeads",
@@ -4602,6 +4603,7 @@
StaticAdapterFusionConfig,
T5AdapterModel,
T5ModelWithHeads,
+ UniPELTConfig,
ViTAdapterModel,
XLMRobertaAdapterModel,
XLMRobertaModelWithHeads,
diff --git a/src/transformers/adapters/__init__.py b/src/transformers/adapters/__init__.py
index a4327fc981..9df65b7ff3 100644
--- a/src/transformers/adapters/__init__.py
+++ b/src/transformers/adapters/__init__.py
@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "3.1.0a1"
+__version__ = "3.1.0"
from typing import TYPE_CHECKING
@@ -57,6 +57,7 @@
"PfeifferInvConfig",
"PrefixTuningConfig",
"StaticAdapterFusionConfig",
+ "UniPELTConfig",
],
"context": [
"AdapterSetup",
@@ -176,6 +177,7 @@
PfeifferInvConfig,
PrefixTuningConfig,
StaticAdapterFusionConfig,
+ UniPELTConfig,
)
from .context import AdapterSetup, ForwardContext
from .heads import (
diff --git a/src/transformers/adapters/configuration.py b/src/transformers/adapters/configuration.py
index 16013d9350..d1179d28e9 100644
--- a/src/transformers/adapters/configuration.py
+++ b/src/transformers/adapters/configuration.py
@@ -162,6 +162,9 @@ class AdapterConfig(AdapterConfigBase):
Scaling factor to use for scaled addition of adapter outputs as done by He et al. (2021). Can bei either a
constant factor (float) or the string "learned", in which case the scaling factor is learned. Defaults to
1.0.
+ use_gating (:obj:`bool`, optional):
+ Place a trainable gating module besides the added parameter module to control module activation. This is
+ e.g. used for UniPELT. Defaults to False.
residual_before_ln (:obj:`bool`, optional):
If True, take the residual connection around the adapter bottleneck before the layer normalization. Only
applicable if :obj:`original_ln_before` is True.
@@ -224,6 +227,7 @@ class AdapterConfig(AdapterConfigBase):
init_weights: str = "bert"
is_parallel: bool = False
scaling: Union[float, str] = 1.0
+ use_gating: bool = False
residual_before_ln: bool = True
adapter_residual_before_ln: bool = False
inv_adapter: Optional[str] = None
@@ -362,6 +366,12 @@ class PrefixTuningConfig(AdapterConfigBase):
non_linearity (str): If flat=False, the non-linearity used in the bottleneck MLP.
dropout (float): The dropout rate used in the prefix tuning layer.
leave_out (List[int]): The IDs of the layers (starting at 0) where NO prefix should be added.
+ use_gating (:obj:`bool`, optional):
+ Place a trainable gating module besides the added parameter module to control module activation. This is
+ e.g. used for UniPELT. Defaults to False.
+ shared_gating (:
+ obj:`bool`, optional): Whether to use a shared gate for the prefixes of all attention matrices. Only
+ applicable if `use_gating=True`. Defaults to True.
"""
architecture: Optional[str] = "prefix_tuning"
@@ -375,6 +385,8 @@ class PrefixTuningConfig(AdapterConfigBase):
bottleneck_size: int = 512
non_linearity: str = "tanh"
dropout: float = 0.0
+ use_gating: bool = False
+ shared_gating: bool = True
@dataclass(eq=False)
@@ -402,6 +414,10 @@ class LoRAConfig(AdapterConfigBase):
(IA)^3). "scale" can only be used together with r=1. Defaults to "add".
init_weights (:obj:`str`, optional): Initialization method for the weights of the LoRA modules.
Currently, this can be either "lora" (default) or "bert".
+ use_gating (:obj:`bool`, optional):
+ Place a trainable gating module besides the added parameter module to control module activation. This is
+ e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using
+ `merge_adapter()`.
"""
architecture: Optional[str] = "lora"
@@ -416,6 +432,7 @@ class LoRAConfig(AdapterConfigBase):
attn_matrices: List[str] = field(default_factory=lambda: ["q", "v"])
composition_mode: str = "add"
init_weights: str = "lora"
+ use_gating: bool = False
@dataclass(eq=False)
@@ -427,8 +444,8 @@ class IA3Config(LoRAConfig):
"""
selfattn_lora: bool = True
- intermediate_lora: bool = False
- output_lora: bool = True
+ intermediate_lora: bool = True
+ output_lora: bool = False
r: int = 1
alpha: int = 1
@@ -436,6 +453,7 @@ class IA3Config(LoRAConfig):
attn_matrices: List[str] = field(default_factory=lambda: ["k", "v"])
composition_mode: str = "scale"
init_weights: str = "ia3"
+ use_gating: bool = False
class ConfigUnion(AdapterConfigBase):
@@ -548,6 +566,26 @@ def adapter(self):
return self[1]
+class UniPELTConfig(ConfigUnion):
+ """
+ The UniPELT adapter architecture proposed by Mao et al. (2022). See https://arxiv.org/pdf/2110.07577.pdf.
+ """
+
+ def __init__(
+ self,
+ prefix_tuning: Optional[PrefixTuningConfig] = None,
+ adapter: Optional[AdapterConfig] = None,
+ lora: Optional[LoRAConfig] = None,
+ ):
+ components = [
+ prefix_tuning or PrefixTuningConfig(prefix_length=10),
+ adapter or PfeifferConfig(reduction_factor=16),
+ lora or LoRAConfig(r=8),
+ ]
+
+ super().__init__(*[c.replace(use_gating=True) for c in components])
+
+
ADAPTER_CONFIG_MAP = {
"pfeiffer": PfeifferConfig(),
"houlsby": HoulsbyConfig(),
@@ -562,6 +600,7 @@ def adapter(self):
"lora": LoRAConfig(),
"ia3": IA3Config(),
"mam": MAMConfig(),
+ "unipelt": UniPELTConfig(),
}
DEFAULT_ADAPTER_CONFIG = "pfeiffer"
diff --git a/src/transformers/adapters/context.py b/src/transformers/adapters/context.py
index 6503487ccd..05261351ac 100644
--- a/src/transformers/adapters/context.py
+++ b/src/transformers/adapters/context.py
@@ -78,6 +78,8 @@ class ForwardContext:
# thread-local storage that holds a stack of active contexts
storage = threading.local()
+ context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions"]
+
def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
if hasattr(model, "forward_context"):
@@ -99,8 +101,21 @@ def wrap(cls, f):
@functools.wraps(f)
def wrapper_func(self, *args, **kwargs):
if self.config.adapters is not None:
- with cls(self, *args, **kwargs):
+ with cls(self, *args, **kwargs) as ctx:
+ kwargs = {
+ k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes
+ }
results = f(self, *args, **kwargs)
+
+ # append output attributes
+ if isinstance(results, tuple):
+ for attr in cls.context_attributes:
+ if getattr(ctx, "output_" + attr, False):
+ results = results + (dict(getattr(ctx, attr)),)
+ else:
+ for attr in cls.context_attributes:
+ if getattr(ctx, "output_" + attr, False):
+ results[attr] = dict(getattr(ctx, attr))
return results
else:
return f(self, *args, **kwargs)
diff --git a/src/transformers/adapters/heads/base.py b/src/transformers/adapters/heads/base.py
index 25df781f91..dcea9bc984 100644
--- a/src/transformers/adapters/heads/base.py
+++ b/src/transformers/adapters/heads/base.py
@@ -18,7 +18,7 @@
)
from ...utils import ModelOutput
from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, parse_heads_from_composition
-from ..context import AdapterSetup
+from ..context import AdapterSetup, ForwardContext
from ..model_mixin import ModelWithHeadsAdaptersMixin
from ..modeling import Activation_Function_Class
@@ -790,7 +790,7 @@ def _get_head_input(outputs, cls_out, batch):
if all("loss" in out and out["loss"] is not None for out in head_outputs)
else None
)
- return MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss)
+ return_output = MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss)
elif self.has_parallel_adapters or isinstance(self.active_head, Parallel):
if len(self.active_head) != self.config.adapters.active_setup.parallel_channels:
raise ValueError("The number of parallel adapters and the number of active heads must match.")
@@ -807,16 +807,22 @@ def _get_head_input(outputs, cls_out, batch):
if all("loss" in out and out["loss"] is not None for out in head_outputs)
else None
)
- return MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss)
+ return_output = MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss)
elif len(used_heads) > 1:
head_outputs = []
for head in used_heads:
head_module = self.heads[head]
head_outputs.append(head_module(all_outputs, cls_output, attention_mask, return_dict, **kwargs))
- return head_outputs
+ return_output = MultiHeadOutput(head_outputs=head_outputs)
else:
head_module = self.heads[used_heads[0]]
- return head_module(all_outputs, cls_output, attention_mask, return_dict, **kwargs)
+ return_output = head_module(all_outputs, cls_output, attention_mask, return_dict, **kwargs)
+
+ if isinstance(return_output, ModelOutput):
+ for attr in ForwardContext.context_attributes:
+ if attr not in return_output and attr in all_outputs:
+ return_output[attr] = all_outputs[attr]
+ return return_output
def get_labels_dict(self, head_name=None):
"""
diff --git a/src/transformers/adapters/layer.py b/src/transformers/adapters/layer.py
index 4499f71e43..63a3c81013 100644
--- a/src/transformers/adapters/layer.py
+++ b/src/transformers/adapters/layer.py
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import List, Mapping, Union
+import numpy as np
import torch
from torch import nn
@@ -43,6 +44,31 @@ def get_active_setup(self, module_dict):
else:
return None
+ def _store_gating_score(self, adapter_name, gating_score):
+ context = ForwardContext.get_context()
+ if context.output_adapter_gating_scores:
+ gating_cache = context.adapter_gating_scores
+ if self.layer_idx not in gating_cache[adapter_name]:
+ gating_cache[adapter_name][self.layer_idx] = {}
+ gating_score = gating_score.detach().squeeze().cpu().numpy()
+ if len(gating_score.shape) == 0:
+ gating_score = np.expand_dims(gating_score, axis=0)
+ cache_score = gating_cache[adapter_name][self.layer_idx].get(self.location_key, None)
+ if cache_score is not None:
+ gating_cache[adapter_name][self.layer_idx][self.location_key] = np.column_stack(
+ (cache_score, gating_score)
+ )
+ else:
+ gating_cache[adapter_name][self.layer_idx][self.location_key] = gating_score
+
+ def _store_fusion_attentions(self, fusion_name, attentions):
+ context = ForwardContext.get_context()
+ if context.output_adapter_fusion_attentions:
+ attention_cache = context.adapter_fusion_attentions
+ if self.layer_idx not in attention_cache[fusion_name]:
+ attention_cache[fusion_name][self.layer_idx] = {}
+ attention_cache[fusion_name][self.layer_idx][self.location_key] = attentions
+
@abstractmethod
def add_adapter(self, adapter_name: str, layer_idx: int):
raise NotImplementedError()
@@ -202,7 +228,12 @@ def adapter_stack(self, adapter_setup: Stack, hidden_states, input_tensor, layer
elif adapter_stack_layer in self.adapters:
adapter_layer = self.adapters[adapter_stack_layer]
hidden_states, _, residual = adapter_layer.pre_forward(hidden_states, input_tensor, layer_norm)
- hidden_states, _, up = adapter_layer(hidden_states, residual_input=residual)
+ context = ForwardContext.get_context()
+ layer_output = adapter_layer(
+ hidden_states, residual_input=residual, output_gating=context.output_adapter_gating_scores
+ )
+ hidden_states, up = layer_output[0], layer_output[2]
+ self._store_gating_score(adapter_stack_layer, layer_output[-1])
# as this stack might be part of a fusion block, return the adapter up-projection output here
# together with the final output (with potential residuals & norms) if we reached the last block of the stack
if i == len(adapter_setup) - 1:
@@ -217,6 +248,8 @@ def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, layer
"""
Performs adapter fusion with the given adapters for the given input.
"""
+ context = ForwardContext.get_context()
+
# config of _last_ fused adapter is significant
fusion_config = self.config.adapters.get_fusion(adapter_setup.name)
last_adapter = self.adapters[adapter_setup.last()]
@@ -235,7 +268,11 @@ def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, layer
# Case 2: We have a single adapter which is part of this module -> forward pass
elif adapter_block in self.adapters:
adapter_layer = self.adapters[adapter_block]
- _, _, up = adapter_layer(hidden_states, residual_input=residual)
+ layer_output = adapter_layer(
+ hidden_states, residual_input=residual, output_gating=context.output_adapter_gating_scores
+ )
+ up = layer_output[2]
+ self._store_gating_score(adapter_block, layer_output[-1])
up_list.append(up)
# Case 3: nesting other composition blocks is invalid
elif isinstance(adapter_block, AdapterCompositionBlock):
@@ -250,12 +287,18 @@ def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, layer
up_list = torch.stack(up_list)
up_list = up_list.permute(1, 2, 0, 3)
- hidden_states = self.adapter_fusion_layer[adapter_setup.name](
+ fusion_output = self.adapter_fusion_layer[adapter_setup.name](
query,
up_list,
up_list,
residual,
+ output_attentions=context.output_adapter_fusion_attentions,
)
+ if context.output_adapter_fusion_attentions:
+ hidden_states = fusion_output[0]
+ self._store_fusion_attentions(adapter_setup.name, fusion_output[-1])
+ else:
+ hidden_states = fusion_output
return hidden_states
@@ -300,7 +343,14 @@ def adapter_split(self, adapter_setup: Split, hidden_states, input_tensor, layer
# Case 4: We have a single adapter which is part of this module -> forward pass
elif adapter_block in self.adapters:
adapter_layer = self.adapters[adapter_block]
- split_hidden_states[i], _, _ = adapter_layer(split_hidden_states[i], residual_input=split_residual[i])
+ context = ForwardContext.get_context()
+ layer_output = adapter_layer(
+ split_hidden_states[i],
+ residual_input=split_residual[i],
+ output_gating=context.output_adapter_gating_scores,
+ )
+ split_hidden_states[i] = layer_output[0]
+ self._store_gating_score(adapter_block, layer_output[-1])
# Case 5: nesting other composition blocks is invalid
elif isinstance(adapter_block, AdapterCompositionBlock):
raise ValueError(
@@ -365,10 +415,14 @@ def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor,
# Case 3: We have a single adapter which is part of this module -> forward pass
elif child in self.adapters:
adapter_layer = self.adapters[child]
- child_hidden_states, _, _ = adapter_layer(
+ context = ForwardContext.get_context()
+ layer_output = adapter_layer(
hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size],
residual_input=residual[i * orig_batch_size : (i + 1) * orig_batch_size],
+ output_gating=context.output_adapter_gating_scores,
)
+ child_hidden_states = layer_output[0]
+ self._store_gating_score(child, layer_output[-1])
children_hidden.append(child_hidden_states)
# Case 4: nesting other composition blocks is invalid
elif isinstance(child, AdapterCompositionBlock):
@@ -436,10 +490,14 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten
elif adapter_block in self.adapters:
adapter_layer = self.adapters[adapter_block]
- child, _, _ = adapter_layer(
- hidden_states[batch_idx[0] : batch_idx[1]], residual_input=residual[batch_idx[0] : batch_idx[1]]
+ context = ForwardContext.get_context()
+ layer_output = adapter_layer(
+ hidden_states[batch_idx[0] : batch_idx[1]],
+ residual_input=residual[batch_idx[0] : batch_idx[1]],
+ output_gating=context.output_adapter_gating_scores,
)
- children_hidden.append(child)
+ children_hidden.append(layer_output[0])
+ self._store_gating_score(adapter_block, layer_output[-1])
# Case 5: nesting other composition blocks is invalid
elif isinstance(adapter_block, AdapterCompositionBlock):
raise ValueError(
diff --git a/src/transformers/adapters/loading.py b/src/transformers/adapters/loading.py
index a66f98b8a0..0eaf58dae7 100644
--- a/src/transformers/adapters/loading.py
+++ b/src/transformers/adapters/loading.py
@@ -305,6 +305,7 @@ def filter_func(self, adapter_name):
lambda x: "_adapters.{}.".format(adapter_name) in x
or ".adapters.{}.".format(adapter_name) in x
or ".prefix_tunings.{}.".format(adapter_name) in x
+ or ".prefix_gates.{}.".format(adapter_name) in x
or ".loras.{}.".format(adapter_name) in x
)
@@ -348,6 +349,7 @@ def rename_func(self, old_name, new_name):
lambda k: self._rename_legacy_weights(k)
.replace("adapters.{}.".format(old_name), "adapters.{}.".format(new_name))
.replace(".prefix_tunings.{}.".format(old_name), ".prefix_tunings.{}.".format(new_name))
+ .replace(".prefix_gates.{}.".format(old_name), ".prefix_gates.{}.".format(new_name))
.replace(".loras.{}.".format(old_name), ".loras.{}.".format(new_name))
)
diff --git a/src/transformers/adapters/lora.py b/src/transformers/adapters/lora.py
index 587a57b332..657cde095b 100644
--- a/src/transformers/adapters/lora.py
+++ b/src/transformers/adapters/lora.py
@@ -22,12 +22,14 @@ def __init__(
lora_A_shape,
lora_B_shape,
config: LoRAConfig,
+ gating_heads: int = 1,
):
super().__init__()
self.r = config.r
self.lora_alpha = config.alpha
self.composition_mode = config.composition_mode
self.attn_matrices = config.attn_matrices
+ self.use_gating = config.use_gating
# Optional dropout
if config.dropout > 0.0:
self.lora_dropout = nn.Dropout(p=config.dropout)
@@ -43,28 +45,39 @@ def __init__(
self.lora_B = nn.Parameter(torch.zeros(lora_B_shape))
self.scaling = self.lora_alpha / self.r
+ if self.use_gating:
+ self.gate = nn.Linear(lora_A_shape[-1], gating_heads)
+
if config.init_weights == "lora":
# initialize A the same way as the default for nn.Linear and B to zero
if self.composition_mode == "add":
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
+ if self.use_gating:
+ nn.init.normal_(self.gate.weight, std=0.02)
elif config.init_weights == "bert":
if self.composition_mode == "add":
nn.init.normal_(self.lora_A, std=0.02)
nn.init.normal_(self.lora_B, std=0.02)
+ if self.use_gating:
+ nn.init.normal_(self.gate.weight, std=0.02)
elif config.init_weights == "ia3":
if self.composition_mode == "add":
nn.init.ones_(self.lora_A)
nn.init.ones_(self.lora_B)
+ if self.use_gating:
+ nn.init.normal_(self.gate.weight, std=0.02)
else:
raise ValueError("Unknown init_weights type: {}".format(config.init_weights))
- def com(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor:
+ def com(self, weights: torch.Tensor, added: torch.Tensor, scaling=None) -> torch.Tensor:
"""Performs the composition operation between existing and injected weights."""
+ if scaling is None:
+ scaling = self.scaling
if self.composition_mode == "add":
- return weights + added * self.scaling
+ return weights + added * scaling
elif self.composition_mode == "scale":
- return weights * (added * self.scaling)
+ return weights * (added * scaling)
else:
raise ValueError("Invalid composition mode.")
@@ -105,7 +118,11 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
location_key=self.location_key,
)
if lora_config is not None and self._check_lora_location(lora_config):
- lora = LoRA(*self._get_lora_shapes(lora_config), lora_config)
+ lora = LoRA(
+ *self._get_lora_shapes(lora_config),
+ lora_config,
+ gating_heads=self.get_n_heads(lora_config),
+ )
lora.train(self.training)
self.loras[adapter_name] = lora
return True
@@ -176,7 +193,7 @@ def T(w):
self.weight.data = lora.com_inv(self.weight.data, delta_w)
self.merged = None
- def _compute_adapted_weight(self, lora):
+ def _compute_adapted_weight(self, lora, scaling=None):
def T(w):
return torch.t(w) if self.fan_in_fan_out else w
@@ -187,7 +204,7 @@ def T(w):
delta_w = T(lora.lora_B)
else:
delta_w = T(lora.lora_B @ lora.lora_A)
- weight = lora.com(weight, delta_w)
+ weight = lora.com(weight, delta_w, scaling=scaling)
return weight
@@ -197,6 +214,8 @@ def merge_adapter(self, name: str):
return # already merged
elif not self.merged:
lora = self.loras[name]
+ if lora.use_gating:
+ raise ValueError("Cannot merge LoRA layer with gating.")
self.weight.data = self._compute_adapted_weight(lora)
self.merged = name
elif self.merged != name:
@@ -204,7 +223,7 @@ def merge_adapter(self, name: str):
def forward(self, x: torch.Tensor):
def T(w):
- return torch.t(w) if self.fan_in_fan_out else w
+ return torch.transpose(w, -2, -1) if self.fan_in_fan_out else w
if not self.merged:
adapter_setup = self.get_active_setup(self.loras)
@@ -218,7 +237,13 @@ def T(w):
delta_w = lora.lora_B.view(1, 1, -1)
else:
delta_w = lora.lora_dropout(x) @ torch.t(lora.lora_A) @ torch.t(lora.lora_B)
- result = lora.com(result, delta_w)
+ if lora.use_gating:
+ gate = torch.sigmoid(lora.gate(x))
+ gate = torch.mean(gate, dim=1).unsqueeze(-1)
+ self._store_gating_score(adapter_setup[0], gate)
+ else:
+ gate = None
+ result = lora.com(result, delta_w, scaling=gate)
return result
else:
raise ValueError(f"Invalid adapter setup. Cannot use {adapter_setup} with LoRA.")
@@ -307,7 +332,7 @@ def T(w):
self.weight.data = lora.com_inv(self.weight.data, T(self.pad(delta_w, lora)))
self.merged = None
- def _compute_adapted_weight(self, lora):
+ def _compute_adapted_weight(self, name, lora):
def T(w):
return w if self.fan_in_fan_out else torch.t(w)
@@ -331,7 +356,9 @@ def merge_adapter(self, name: str):
return # already merged
elif not self.merged:
lora = self.loras[name]
- self.weight.data = self._compute_adapted_weight(lora)
+ if lora.use_gating:
+ raise ValueError("Cannot merge LoRA layer with gating.")
+ self.weight.data = self._compute_adapted_weight(name, lora)
self.merged = name
elif self.merged != name:
raise ValueError("LoRALayer already has a merged LoRA module. Please reset it first.")
@@ -355,8 +382,17 @@ def T(w):
after_A.transpose(-2, -1), lora.lora_B.unsqueeze(-1), groups=sum(lora.enable_lora)
).transpose(-2, -1)
delta_w = after_B
- # result shape: x x
- result = lora.com(result, self.pad(delta_w, lora))
+ if lora.use_gating:
+ gate = torch.sigmoid(lora.gate(x))
+ gate = torch.mean(gate, dim=1)
+ self._store_gating_score(adapter_setup[0], gate)
+ gate = self.pad(
+ gate.repeat_interleave(self.out_features // 3, dim=-1), lora, fill_value=1
+ ).unsqueeze(1)
+ else:
+ gate = None
+ # result = (batch_size, seq_len, head_dim * 3)
+ result = lora.com(result, self.pad(delta_w, lora), scaling=gate)
return result
else:
raise ValueError(f"Invalid adapter setup. Cannot use {adapter_setup} with LoRA.")
diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py
index 260b679f8e..95b07ac849 100644
--- a/src/transformers/adapters/model_mixin.py
+++ b/src/transformers/adapters/model_mixin.py
@@ -785,6 +785,11 @@ def forward_context(self, context: ForwardContext, *args, **kwargs):
}
context.prefix_states = self.base_model.prefix_tuning(*args, **kwargs)
+ # Adapter gating and attention outputs
+ context.output_adapter_gating_scores = kwargs.get("output_adapter_gating_scores", False)
+ context.output_adapter_fusion_attentions = kwargs.get("output_adapter_fusion_attentions", False)
+ context.adapter_gating_scores = defaultdict(dict)
+ context.adapter_fusion_attentions = defaultdict(dict)
def get_fusion_regularization_loss(self):
reg_loss = 0.0
diff --git a/src/transformers/adapters/modeling.py b/src/transformers/adapters/modeling.py
index 11a3bf1aaa..2d6ac3b38d 100644
--- a/src/transformers/adapters/modeling.py
+++ b/src/transformers/adapters/modeling.py
@@ -46,6 +46,7 @@ def __init__(
self.add_layer_norm_before = config["ln_before"]
self.add_layer_norm_after = config["ln_after"]
self.adapter_residual_before_ln = config["adapter_residual_before_ln"]
+ self.use_gating = config["use_gating"]
# Params related to input & output of adapter
self.residual_before_ln = config["residual_before_ln"]
@@ -104,16 +105,23 @@ def __init__(
if self.add_layer_norm_after:
self.adapter_norm_after = nn.LayerNorm(self.input_size)
+ if self.use_gating:
+ self.gate = nn.Linear(self.input_size, 1)
+
# if we want to initialize with the bert strategy then this function is called for all the linear layers
if config["init_weights"] == "bert":
self.adapter_down.apply(self.init_bert_weights)
self.adapter_up.apply(self.init_bert_weights)
+ if self.use_gating:
+ self.gate.apply(self.init_bert_weights)
elif config["init_weights"] == "mam_adapter":
with torch.no_grad():
nn.init.kaiming_uniform_(self.adapter_down[0].weight, a=math.sqrt(5))
nn.init.zeros_(self.adapter_up.weight)
nn.init.zeros_(self.adapter_down[0].bias)
nn.init.zeros_(self.adapter_up.bias)
+ if self.use_gating:
+ self.gate.apply(self.init_bert_weights)
else:
raise ValueError("Unknown init_weights type: {}".format(config["init_weights"]))
@@ -157,13 +165,19 @@ def pre_forward(
return hidden_states, query, residual
- def forward(self, x, residual_input): # , residual_input=None):
+ def forward(self, x, residual_input, output_gating=False):
down = self.adapter_down(x)
up = self.adapter_up(down)
up = up * self.scaling
output = up
+ if self.use_gating:
+ # x.shape = (batch_size, seq_len, hidden_size)
+ gate = torch.sigmoid(self.gate(x))
+ gate = torch.mean(gate, dim=1).unsqueeze(-1)
+ output = output * gate
+
# apply residual connection before layer norm if configured in this way
if self.adapter_residual_before_ln:
output = output + residual_input
@@ -176,6 +190,8 @@ def forward(self, x, residual_input): # , residual_input=None):
if not self.adapter_residual_before_ln:
output = output + residual_input
+ if self.use_gating and output_gating:
+ return output, down, up, gate
return output, down, up
def post_forward(self, hidden_states, input_hidden_states, input_tensor, layer_norm):
@@ -246,7 +262,7 @@ def pre_forward(
query = input_tensor
return input_tensor, query, input_tensor
- def forward(self, x, residual_input):
+ def forward(self, x, residual_input, output_gating=False):
down = self.adapter_down(x)
up = self.adapter_up(down)
@@ -254,10 +270,18 @@ def forward(self, x, residual_input):
output = up
+ if self.use_gating:
+ # x.shape = (batch_size, seq_len, hidden_size)
+ gate = torch.sigmoid(self.gate(x))
+ gate = torch.mean(gate, dim=1).unsqueeze(-1)
+ output = output * gate
+
# apply layer norm if available
if self.add_layer_norm_after:
output = self.adapter_norm_after(output)
+ if self.use_gating and output_gating:
+ return output, down, up, gate
return output, down, up
def post_forward(self, hidden_states, input_hidden_states, input_tensor, layer_norm):
@@ -332,7 +356,7 @@ def __init__(
self.T = 1.0
self.reduction = self.T / 1000.0
- def forward(self, query, key, value, residual):
+ def forward(self, query, key, value, residual, output_attentions: bool = False):
if self.config["residual_before"]:
value += residual[:, :, None, :].repeat(1, 1, value.size(2), 1)
@@ -362,9 +386,6 @@ def forward(self, query, key, value, residual):
attention_probs = nn.Softmax(dim=-1)(attention_scores / self.T)
self.T = max(self.T - self.reduction, 1.0)
- if not self.training:
- self.recent_attention = attention_probs.detach().cpu().numpy()
-
context_layer = torch.squeeze(torch.matmul(attention_probs.unsqueeze(2), value_layer), dim=2)
if self.config["value"] and not self.config["value_before_softmax"]:
@@ -376,7 +397,11 @@ def forward(self, query, key, value, residual):
if not self.config["residual_before"]:
context_layer += residual
- return context_layer
+ if output_attentions:
+ attention_probs = attention_probs.detach().cpu().numpy()
+ return context_layer, attention_probs
+ else:
+ return context_layer
# Invertible Adapters
diff --git a/src/transformers/adapters/models/bart/adapter_model.py b/src/transformers/adapters/models/bart/adapter_model.py
index 596adb2bf6..a6ce6c827c 100644
--- a/src/transformers/adapters/models/bart/adapter_model.py
+++ b/src/transformers/adapters/models/bart/adapter_model.py
@@ -57,6 +57,8 @@ def forward(
return_dict=None,
past_key_values=None,
head=None,
+ output_adapter_gating_scores=False,
+ output_adapter_fusion_attentions=False,
**kwargs
):
r"""
@@ -85,6 +87,8 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
past_key_values=past_key_values,
+ output_adapter_gating_scores=output_adapter_gating_scores,
+ output_adapter_fusion_attentions=output_adapter_fusion_attentions,
)
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
diff --git a/src/transformers/adapters/models/bert/adapter_model.py b/src/transformers/adapters/models/bert/adapter_model.py
index 5672019105..0f6d921435 100644
--- a/src/transformers/adapters/models/bert/adapter_model.py
+++ b/src/transformers/adapters/models/bert/adapter_model.py
@@ -44,6 +44,8 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
+ output_adapter_gating_scores=False,
+ output_adapter_fusion_attentions=False,
**kwargs
):
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@@ -68,6 +70,8 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ output_adapter_gating_scores=output_adapter_gating_scores,
+ output_adapter_fusion_attentions=output_adapter_fusion_attentions,
)
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
diff --git a/src/transformers/adapters/models/deberta/adapter_model.py b/src/transformers/adapters/models/deberta/adapter_model.py
index 310bad4dec..49e791be39 100644
--- a/src/transformers/adapters/models/deberta/adapter_model.py
+++ b/src/transformers/adapters/models/deberta/adapter_model.py
@@ -38,6 +38,8 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
+ output_adapter_gating_scores=False,
+ output_adapter_fusion_attentions=False,
**kwargs
):
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@@ -61,6 +63,8 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ output_adapter_gating_scores=output_adapter_gating_scores,
+ output_adapter_fusion_attentions=output_adapter_fusion_attentions,
)
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
diff --git a/src/transformers/adapters/models/debertaV2/adapter_model.py b/src/transformers/adapters/models/debertaV2/adapter_model.py
index bca062036b..6c19c44a69 100644
--- a/src/transformers/adapters/models/debertaV2/adapter_model.py
+++ b/src/transformers/adapters/models/debertaV2/adapter_model.py
@@ -41,6 +41,8 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
+ output_adapter_gating_scores=False,
+ output_adapter_fusion_attentions=False,
**kwargs
):
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@@ -64,6 +66,8 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ output_adapter_gating_scores=output_adapter_gating_scores,
+ output_adapter_fusion_attentions=output_adapter_fusion_attentions,
)
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
diff --git a/src/transformers/adapters/models/distilbert/adapter_model.py b/src/transformers/adapters/models/distilbert/adapter_model.py
index 5c47f8ba26..4f8c9fa7be 100644
--- a/src/transformers/adapters/models/distilbert/adapter_model.py
+++ b/src/transformers/adapters/models/distilbert/adapter_model.py
@@ -70,6 +70,8 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
+ output_adapter_gating_scores=False,
+ output_adapter_fusion_attentions=False,
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -90,6 +92,8 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ output_adapter_gating_scores=output_adapter_gating_scores,
+ output_adapter_fusion_attentions=output_adapter_fusion_attentions,
)
outputs = self.forward_head(
diff --git a/src/transformers/adapters/models/gpt2/adapter_model.py b/src/transformers/adapters/models/gpt2/adapter_model.py
index 55a275d7eb..7cd9680a37 100644
--- a/src/transformers/adapters/models/gpt2/adapter_model.py
+++ b/src/transformers/adapters/models/gpt2/adapter_model.py
@@ -59,6 +59,8 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
+ output_adapter_gating_scores=False,
+ output_adapter_fusion_attentions=False,
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -77,6 +79,8 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ output_adapter_gating_scores=output_adapter_gating_scores,
+ output_adapter_fusion_attentions=output_adapter_fusion_attentions,
)
batch_size = outputs[0].shape[0]
diff --git a/src/transformers/adapters/models/mbart/adapter_model.py b/src/transformers/adapters/models/mbart/adapter_model.py
index d513a5a209..ca1a2f3fb0 100644
--- a/src/transformers/adapters/models/mbart/adapter_model.py
+++ b/src/transformers/adapters/models/mbart/adapter_model.py
@@ -57,6 +57,8 @@ def forward(
return_dict=None,
past_key_values=None,
head=None,
+ output_adapter_gating_scores=False,
+ output_adapter_fusion_attentions=False,
**kwargs
):
r"""
@@ -85,6 +87,8 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
past_key_values=past_key_values,
+ output_adapter_gating_scores=output_adapter_gating_scores,
+ output_adapter_fusion_attentions=output_adapter_fusion_attentions,
)
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
diff --git a/src/transformers/adapters/models/roberta/adapter_model.py b/src/transformers/adapters/models/roberta/adapter_model.py
index b8bfff31df..70ef643786 100644
--- a/src/transformers/adapters/models/roberta/adapter_model.py
+++ b/src/transformers/adapters/models/roberta/adapter_model.py
@@ -49,6 +49,8 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
+ output_adapter_gating_scores=False,
+ output_adapter_fusion_attentions=False,
**kwargs
):
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@@ -73,6 +75,8 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ output_adapter_gating_scores=output_adapter_gating_scores,
+ output_adapter_fusion_attentions=output_adapter_fusion_attentions,
)
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
diff --git a/src/transformers/adapters/models/t5/adapter_model.py b/src/transformers/adapters/models/t5/adapter_model.py
index 80ef3741b6..c47eb36ad9 100644
--- a/src/transformers/adapters/models/t5/adapter_model.py
+++ b/src/transformers/adapters/models/t5/adapter_model.py
@@ -53,6 +53,8 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
+ output_adapter_gating_scores=False,
+ output_adapter_fusion_attentions=False,
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -76,6 +78,8 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ output_adapter_gating_scores=output_adapter_gating_scores,
+ output_adapter_fusion_attentions=output_adapter_fusion_attentions,
)
sequence_output = model_output[0]
# ToDo move head to device for parallel forward pass
diff --git a/src/transformers/adapters/models/vit/adapter_model.py b/src/transformers/adapters/models/vit/adapter_model.py
index d2510909d6..6a5692dc16 100644
--- a/src/transformers/adapters/models/vit/adapter_model.py
+++ b/src/transformers/adapters/models/vit/adapter_model.py
@@ -33,6 +33,8 @@ def forward(
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
head=None,
+ output_adapter_gating_scores=False,
+ output_adapter_fusion_attentions=False,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -44,6 +46,8 @@ def forward(
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
+ output_adapter_gating_scores=output_adapter_gating_scores,
+ output_adapter_fusion_attentions=output_adapter_fusion_attentions,
)
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
diff --git a/src/transformers/adapters/prefix_tuning.py b/src/transformers/adapters/prefix_tuning.py
index b98f938741..0dceffd697 100644
--- a/src/transformers/adapters/prefix_tuning.py
+++ b/src/transformers/adapters/prefix_tuning.py
@@ -237,6 +237,7 @@ def __init__(self, location_key: str, config):
self.config = config
self.location_key = location_key
self.prefixes = {}
+ self.prefix_gates = nn.ModuleDict()
def set_pool(self, pool: PrefixTuningPool):
self.__setattr__("pool", pool)
@@ -258,10 +259,18 @@ def add_adapter(self, adapter_name: str, layer_idx: int):
prefix_id = self.pool.indicate_prefix(adapter_name, self.location_key)
self.prefixes[adapter_name] = prefix_id
+ if prefix_tuning_config.use_gating:
+ gate_outputs = 1 if prefix_tuning_config.shared_gating else 2
+ gate = nn.Linear(self.config.hidden_size, gate_outputs)
+ gate.weight.data.normal_(mean=0.0, std=0.02)
+ self.prefix_gates[adapter_name] = gate
+
def delete_adapter(self, adapter_name: str):
self.pool.delete_prefix(adapter_name)
if adapter_name in self.prefixes:
del self.prefixes[adapter_name]
+ if adapter_name in self.prefix_gates:
+ del self.prefix_gates[adapter_name]
def add_fusion_layer(self, adapter_names: Union[List, str]):
pass # not applicable to prefix tuning
@@ -273,17 +282,25 @@ def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapt
if unfreeze_adapters:
for prefix_tuning_name in adapter_setup.flatten():
self.pool.enable_prefix(prefix_tuning_name)
+ if prefix_tuning_name in self.prefix_gates:
+ for param in self.prefix_gates[prefix_tuning_name].parameters():
+ param.requires_grad = unfreeze_adapters
def get_adapter(self, adapter_name):
+ return_dict = nn.ModuleDict()
# Make sure to only return params once
if adapter_name in self.prefixes and self.prefixes[adapter_name] == 0:
prefix_module = self.pool.get_prefix(adapter_name)
if prefix_module is not None:
- return prefix_module[self.location_key]
+ return_dict["prefix"] = prefix_module[self.location_key]
+ if adapter_name in self.prefix_gates:
+ return_dict["gate"] = self.prefix_gates[adapter_name]
+ if len(return_dict) > 0:
+ return return_dict
return None
- def forward(self, key_states, value_states, attention_mask=None, invert_mask=True):
+ def forward(self, key_states, value_states, residual_input, attention_mask=None, invert_mask=True):
adapter_setup = self.get_active_setup(self.prefixes)
if adapter_setup is not None:
if len(adapter_setup) == 1:
@@ -295,10 +312,20 @@ def forward(self, key_states, value_states, attention_mask=None, invert_mask=Tru
# Retrieve pre-computed prefix states from context
context = ForwardContext.get_context()
+ # batch_size x n_heads x prefix_length x n_embd_per_head
prefix_keys, prefix_values = context.prefix_states[prefix_tuning_name][self.location_key][
prefix_id
]
+ if prefix_tuning_name in self.prefix_gates:
+ gate = self.prefix_gates[prefix_tuning_name]
+ gate_output = torch.mean(torch.sigmoid(gate(residual_input)), dim=1)
+ self._store_gating_score(prefix_tuning_name, gate_output)
+ gate_output_key = gate_output[:, 0].view(-1, 1, 1, 1)
+ gate_output_value = gate_output[:, -1].view(-1, 1, 1, 1)
+ key_states = key_states * gate_output_key
+ value_states = value_states * gate_output_value
+
key_states = torch.cat([prefix_keys, key_states], dim=2)
value_states = torch.cat([prefix_values, value_states], dim=2)
if attention_mask is not None:
diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py
index 3a9757ec6c..518978b69d 100755
--- a/src/transformers/models/bart/modeling_bart.py
+++ b/src/transformers/models/bart/modeling_bart.py
@@ -237,7 +237,9 @@ def forward(
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states, value_states, attention_mask = self.prefix_tuning(key_states, value_states, attention_mask)
+ key_states, value_states, attention_mask = self.prefix_tuning(
+ key_states, value_states, hidden_states, attention_mask
+ )
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py
index 1c367aad01..dd17f489b4 100644
--- a/src/transformers/models/bert/modeling_bert.py
+++ b/src/transformers/models/bert/modeling_bert.py
@@ -335,7 +335,9 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
- key_layer, value_layer, attention_mask = self.prefix_tuning(key_layer, value_layer, attention_mask)
+ key_layer, value_layer, attention_mask = self.prefix_tuning(
+ key_layer, value_layer, hidden_states, attention_mask
+ )
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py
index e753dd85b7..cea1b033c8 100644
--- a/src/transformers/models/deberta/modeling_deberta.py
+++ b/src/transformers/models/deberta/modeling_deberta.py
@@ -655,7 +655,9 @@ def linear(w, b, x):
query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :])
- key_layer, value_layer, attention_mask = self.prefix_tuning(key_layer, value_layer, attention_mask, False)
+ key_layer, value_layer, attention_mask = self.prefix_tuning(
+ key_layer, value_layer, hidden_states, attention_mask, False
+ )
rel_att = None
# Take the dot product between "query" and "key" to get the raw attention scores.
diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py
index 071ef71f06..87ea225318 100644
--- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py
+++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py
@@ -714,7 +714,7 @@ def forward(
value_layer = self.transpose_for_scores_extended(self.value_proj(hidden_states), self.num_attention_heads)
key_layer, value_layer, attention_mask = self.prefix_tuning(
- key_layer, value_layer, attention_mask, False
+ key_layer, value_layer, hidden_states, attention_mask, False
) # [:, 0, :, 0])
key_layer = key_layer.contiguous().view(-1, key_layer.size(2), key_layer.size(-1))
diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py
index 231778fa13..ae2d9ec7b1 100755
--- a/src/transformers/models/distilbert/modeling_distilbert.py
+++ b/src/transformers/models/distilbert/modeling_distilbert.py
@@ -216,7 +216,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
- k, v, mask = self.prefix_tuning(k, v, mask, invert_mask=False)
+ k, v, mask = self.prefix_tuning(k, v, value, mask, invert_mask=False)
mask_reshp = (bs, 1, 1, k.size(2))
diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py
index 6dfdf0f088..ae3669fe4f 100644
--- a/src/transformers/models/gpt2/modeling_gpt2.py
+++ b/src/transformers/models/gpt2/modeling_gpt2.py
@@ -355,7 +355,7 @@ def forward(
else:
present = None
- key, value, attention_mask = self.prefix_tuning(key, value, attention_mask)
+ key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask)
if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py
index 44dc64a35c..e817ad073f 100755
--- a/src/transformers/models/mbart/modeling_mbart.py
+++ b/src/transformers/models/mbart/modeling_mbart.py
@@ -244,7 +244,9 @@ def forward(
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states, value_states, attention_mask = self.prefix_tuning(key_states, value_states, attention_mask)
+ key_states, value_states, attention_mask = self.prefix_tuning(
+ key_states, value_states, hidden_states, attention_mask
+ )
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py
index 8d35c27bb3..4cb40cbe2b 100644
--- a/src/transformers/models/roberta/modeling_roberta.py
+++ b/src/transformers/models/roberta/modeling_roberta.py
@@ -246,7 +246,9 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
- key_layer, value_layer, attention_mask = self.prefix_tuning(key_layer, value_layer, attention_mask)
+ key_layer, value_layer, attention_mask = self.prefix_tuning(
+ key_layer, value_layer, hidden_states, attention_mask
+ )
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py
index 2039a03708..bf5e995557 100644
--- a/src/transformers/models/t5/modeling_t5.py
+++ b/src/transformers/models/t5/modeling_t5.py
@@ -309,7 +309,7 @@ def forward(self, hidden_states):
class T5DenseGatedActDense(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
- self.wi_0 = LoRALinear(config.d_model, config.d_ff, "intermediate", config, bias=False)
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = LoRALinear(config.d_model, config.d_ff, "intermediate", config, bias=False)
self.wo = LoRALinear(config.d_ff, config.d_model, "output", config, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
@@ -523,7 +523,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
- key_states, value_states, mask = self.prefix_tuning(key_states, value_states, mask)
+ key_states, value_states, mask = self.prefix_tuning(key_states, value_states, hidden_states, mask)
key_length = key_states.size(2)
# compute scores
diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py
index e08378de3c..d291fe9c37 100644
--- a/src/transformers/models/vit/modeling_vit.py
+++ b/src/transformers/models/vit/modeling_vit.py
@@ -223,7 +223,7 @@ def forward(
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
- key_layer, value_layer, _ = self.prefix_tuning(key_layer, value_layer)
+ key_layer, value_layer, _ = self.prefix_tuning(key_layer, value_layer, hidden_states)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
diff --git a/tests_adapters/methods/__init__.py b/tests_adapters/methods/__init__.py
index 3559c930d3..f40a688e58 100644
--- a/tests_adapters/methods/__init__.py
+++ b/tests_adapters/methods/__init__.py
@@ -22,3 +22,4 @@
from .test_ia3 import IA3TestMixin
from .test_lora import LoRATestMixin
from .test_prefix_tuning import PrefixTuningTestMixin
+from .test_unipelt import UniPELTTestMixin
diff --git a/tests_adapters/methods/base.py b/tests_adapters/methods/base.py
index 7eefc8686a..d3944114be 100644
--- a/tests_adapters/methods/base.py
+++ b/tests_adapters/methods/base.py
@@ -127,7 +127,7 @@ def run_forward_test(self, model, adapter_config):
self.assertEqual(len(output_1), len(output_2))
self.assertTrue(torch.equal(output_1[0], output_2[0]))
- self.assertEqual(len(output_1), len(base_output))
+ self.assertGreaterEqual(len(output_1), len(base_output))
self.assertFalse(torch.equal(output_1[0], base_output[0]))
def run_load_test(self, adapter_config):
@@ -191,14 +191,14 @@ def run_full_model_load_test(self, adapter_config):
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.equal(output1[0], output2[0]))
- def trainings_run(self, model):
+ def trainings_run(self, model, lr=1.0, steps=20):
# setup dataset
train_dataset = self.dataset()
training_args = TrainingArguments(
output_dir="./examples",
do_train=True,
- learning_rate=1.0,
- max_steps=20,
+ learning_rate=lr,
+ max_steps=steps,
no_cuda=True,
per_device_train_batch_size=2,
remove_unused_columns=False,
diff --git a/tests_adapters/methods/test_lora.py b/tests_adapters/methods/test_lora.py
index ba4f047ecf..8d35579528 100644
--- a/tests_adapters/methods/test_lora.py
+++ b/tests_adapters/methods/test_lora.py
@@ -36,6 +36,6 @@ def test_train_lora(self):
def test_merge_lora(self):
self.run_merge_test(LoRAConfig(init_weights="bert"))
-
+
def test_reset_lora(self):
self.run_reset_test(LoRAConfig(init_weights="bert"))
diff --git a/tests_adapters/methods/test_unipelt.py b/tests_adapters/methods/test_unipelt.py
new file mode 100644
index 0000000000..507aa7dd1e
--- /dev/null
+++ b/tests_adapters/methods/test_unipelt.py
@@ -0,0 +1,59 @@
+import torch
+
+from transformers.adapters import UniPELTConfig
+from transformers.testing_utils import require_torch, torch_device
+
+from .base import AdapterMethodBaseTestMixin
+
+
+@require_torch
+class UniPELTTestMixin(AdapterMethodBaseTestMixin):
+ def test_add_unipelt(self):
+ model = self.get_model()
+ self.run_add_test(model, UniPELTConfig(), ["loras.{name}.", "adapters.{name}.", "prefix_tunings.{name}."])
+
+ def test_delete_unipelt(self):
+ model = self.get_model()
+ self.run_delete_test(model, UniPELTConfig(), ["loras.{name}.", "adapters.{name}.", "prefix_tunings.{name}."])
+
+ def test_get_unipelt(self):
+ model = self.get_model()
+ self.run_get_test(model, UniPELTConfig())
+
+ def test_forward_unipelt(self):
+ model = self.get_model()
+ self.run_forward_test(model, UniPELTConfig())
+
+ def test_load_unipelt(self):
+ self.run_load_test(UniPELTConfig())
+
+ def test_load_full_model_unipelt(self):
+ self.run_full_model_load_test(UniPELTConfig())
+
+ def test_train_unipelt(self):
+ self.run_train_test(
+ UniPELTConfig(), ["loras.{name}.", "adapters.{name}.", "prefix_tunings.{name}.", "prefix_gates.{name}."]
+ )
+
+ def test_output_adapter_gating_scores_unipelt(self):
+ model = self.get_model()
+ model.eval()
+
+ adapter_config = UniPELTConfig()
+ name = adapter_config.__class__.__name__
+ model.add_adapter(name, config=adapter_config)
+ model.to(torch_device)
+
+ input_data = self.get_input_samples(config=model.config)
+
+ model.set_active_adapters([name])
+ output_1 = model(**input_data, output_adapter_gating_scores=True)
+
+ self.assertEqual(len(output_1[0]), self.default_input_samples_shape[0])
+ self.assertTrue(hasattr(output_1, "adapter_gating_scores"))
+ gating_scores = output_1.adapter_gating_scores[name]
+ self.assertEqual(len(list(model.iter_layers())), len(gating_scores))
+ for k, per_layer_scores in gating_scores.items():
+ self.assertGreaterEqual(len(per_layer_scores), 3)
+ for k, v in per_layer_scores.items():
+ self.assertEqual(self.default_input_samples_shape[0], v.shape[0], k)
diff --git a/tests_adapters/test_adapter_fusion_common.py b/tests_adapters/test_adapter_fusion_common.py
index b1a8c79ef9..703d971cdd 100644
--- a/tests_adapters/test_adapter_fusion_common.py
+++ b/tests_adapters/test_adapter_fusion_common.py
@@ -195,3 +195,26 @@ def test_adapter_fusion_save_with_head(self):
output2 = model2(**in_data)
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.equal(output1[0], output2[0]))
+
+ def test_output_adapter_fusion_attentions(self):
+ model = self.get_model()
+ model.eval()
+
+ model.add_adapter("a")
+ model.add_adapter("b")
+ model.add_adapter_fusion(["a", "b"])
+ model.to(torch_device)
+
+ input_data = self.get_input_samples(config=model.config)
+
+ model.set_active_adapters(Fuse("a", "b"))
+ output_1 = model(**input_data, output_adapter_fusion_attentions=True)
+
+ self.assertEqual(len(output_1[0]), self.default_input_samples_shape[0])
+ self.assertTrue(hasattr(output_1, "adapter_fusion_attentions"))
+ attention_scores = output_1.adapter_fusion_attentions["a,b"]
+ self.assertEqual(len(list(model.iter_layers())), len(attention_scores))
+ for k, per_layer_scores in attention_scores.items():
+ self.assertEqual(len(per_layer_scores), 1)
+ for k, v in per_layer_scores.items():
+ self.assertEqual(self.default_input_samples_shape[0], v.shape[0], k)
diff --git a/tests_adapters/test_bart.py b/tests_adapters/test_bart.py
index d5c0abe848..41dd3beab5 100644
--- a/tests_adapters/test_bart.py
+++ b/tests_adapters/test_bart.py
@@ -4,7 +4,14 @@
from transformers import BartAdapterModel
from transformers.testing_utils import require_torch
-from .methods import BottleneckAdapterTestMixin, LoRATestMixin, CompacterTestMixin, PrefixTuningTestMixin, IA3TestMixin
+from .methods import (
+ BottleneckAdapterTestMixin,
+ CompacterTestMixin,
+ IA3TestMixin,
+ LoRATestMixin,
+ PrefixTuningTestMixin,
+ UniPELTTestMixin,
+)
from .test_adapter import AdapterTestBase, make_config
from .test_adapter_backward_compability import CompabilityTestMixin
from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin
@@ -45,6 +52,7 @@ class BartAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ UniPELTTestMixin,
AdapterFusionModelTestMixin,
CompabilityTestMixin,
EmbeddingTestMixin,
diff --git a/tests_adapters/test_bert.py b/tests_adapters/test_bert.py
index e25d92be84..e79a53d0df 100644
--- a/tests_adapters/test_bert.py
+++ b/tests_adapters/test_bert.py
@@ -4,7 +4,14 @@
from transformers import BertAdapterModel
from transformers.testing_utils import require_torch
-from .methods import BottleneckAdapterTestMixin, CompacterTestMixin, LoRATestMixin, PrefixTuningTestMixin, IA3TestMixin
+from .methods import (
+ BottleneckAdapterTestMixin,
+ CompacterTestMixin,
+ IA3TestMixin,
+ LoRATestMixin,
+ PrefixTuningTestMixin,
+ UniPELTTestMixin,
+)
from .test_adapter import AdapterTestBase, make_config
from .test_adapter_backward_compability import CompabilityTestMixin
from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin
@@ -42,6 +49,7 @@ class BertAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ UniPELTTestMixin,
EmbeddingTestMixin,
AdapterFusionModelTestMixin,
CompabilityTestMixin,
diff --git a/tests_adapters/test_deberta.py b/tests_adapters/test_deberta.py
index 990ac0d77f..bdd64e31d6 100644
--- a/tests_adapters/test_deberta.py
+++ b/tests_adapters/test_deberta.py
@@ -4,7 +4,7 @@
from transformers import DebertaAdapterModel
from transformers.testing_utils import require_torch
-from .methods import BottleneckAdapterTestMixin, CompacterTestMixin, LoRATestMixin, PrefixTuningTestMixin, IA3TestMixin
+from .methods import BottleneckAdapterTestMixin, UniPELTTestMixin, CompacterTestMixin, IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin
from .test_adapter import AdapterTestBase, make_config
from .test_adapter_backward_compability import CompabilityTestMixin
from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin
@@ -49,6 +49,7 @@ class DebertaAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ UniPELTTestMixin,
EmbeddingTestMixin,
ParallelTrainingMixin,
diff --git a/tests_adapters/test_debertaV2.py b/tests_adapters/test_debertaV2.py
index c146b1bcdd..26ab7cde68 100644
--- a/tests_adapters/test_debertaV2.py
+++ b/tests_adapters/test_debertaV2.py
@@ -4,7 +4,7 @@
from transformers import DebertaV2AdapterModel
from transformers.testing_utils import require_torch
-from .methods import BottleneckAdapterTestMixin, CompacterTestMixin, LoRATestMixin, PrefixTuningTestMixin, IA3TestMixin
+from .methods import BottleneckAdapterTestMixin, UniPELTTestMixin, CompacterTestMixin, IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin
from .test_adapter import AdapterTestBase, make_config
from .test_adapter_backward_compability import CompabilityTestMixin
from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin
@@ -50,6 +50,7 @@ class DebertaV2AdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ UniPELTTestMixin,
EmbeddingTestMixin,
ParallelTrainingMixin,
diff --git a/tests_adapters/test_distilbert.py b/tests_adapters/test_distilbert.py
index 6748971587..759de3f80b 100644
--- a/tests_adapters/test_distilbert.py
+++ b/tests_adapters/test_distilbert.py
@@ -4,7 +4,7 @@
from transformers import DistilBertAdapterModel
from transformers.testing_utils import require_torch
-from .methods import BottleneckAdapterTestMixin, LoRATestMixin, CompacterTestMixin, PrefixTuningTestMixin, IA3TestMixin
+from .methods import BottleneckAdapterTestMixin, UniPELTTestMixin, CompacterTestMixin, IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin
from .test_adapter import AdapterTestBase, make_config
from .test_adapter_backward_compability import CompabilityTestMixin
from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin
@@ -42,6 +42,7 @@ class DistilBertAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ UniPELTTestMixin,
EmbeddingTestMixin,
CompabilityTestMixin,
AdapterFusionModelTestMixin,
diff --git a/tests_adapters/test_encoder_decoder.py b/tests_adapters/test_encoder_decoder.py
index 170127e690..f1a48cc631 100644
--- a/tests_adapters/test_encoder_decoder.py
+++ b/tests_adapters/test_encoder_decoder.py
@@ -1,7 +1,14 @@
from tests.models.encoder_decoder.test_modeling_encoder_decoder import * # Imported to execute model tests
from transformers import AutoModelForSeq2SeqLM, BertConfig
-from .methods import BottleneckAdapterTestMixin, CompacterTestMixin, LoRATestMixin, PrefixTuningTestMixin, IA3TestMixin
+from .methods import (
+ BottleneckAdapterTestMixin,
+ CompacterTestMixin,
+ IA3TestMixin,
+ LoRATestMixin,
+ PrefixTuningTestMixin,
+ UniPELTTestMixin,
+)
from .test_adapter import AdapterTestBase
from .test_adapter_fusion_common import AdapterFusionModelTestMixin
@@ -37,6 +44,7 @@ class EncoderDecoderAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ UniPELTTestMixin,
AdapterFusionModelTestMixin,
EncoderDecoderAdapterTestBase,
unittest.TestCase,
@@ -65,3 +73,11 @@ def forward_pre_hook(module, input):
self.assertEqual((1, 128, model.config.decoder.vocab_size), out[0].shape)
self.assertEqual(2, calls)
+
+ def test_output_adapter_gating_scores_unipelt(self):
+ # TODO currently not supported
+ self.skipTest("Not implemented.")
+
+ def test_output_adapter_fusion_attentions(self):
+ # TODO currently not supported
+ self.skipTest("Not implemented.")
diff --git a/tests_adapters/test_gpt2.py b/tests_adapters/test_gpt2.py
index 8f8fdd146f..2de9fb7a43 100644
--- a/tests_adapters/test_gpt2.py
+++ b/tests_adapters/test_gpt2.py
@@ -4,7 +4,14 @@
from transformers import GPT2AdapterModel
from transformers.testing_utils import require_torch
-from .methods import BottleneckAdapterTestMixin, CompacterTestMixin, LoRATestMixin, PrefixTuningTestMixin, IA3TestMixin
+from .methods import (
+ BottleneckAdapterTestMixin,
+ CompacterTestMixin,
+ IA3TestMixin,
+ LoRATestMixin,
+ PrefixTuningTestMixin,
+ UniPELTTestMixin,
+)
from .test_adapter import AdapterTestBase, make_config
from .test_adapter_backward_compability import CompabilityTestMixin
from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin
@@ -43,6 +50,7 @@ class GPT2AdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ UniPELTTestMixin,
EmbeddingTestMixin,
CompabilityTestMixin,
AdapterFusionModelTestMixin,
diff --git a/tests_adapters/test_mbart.py b/tests_adapters/test_mbart.py
index 0fbe71f629..0e4912d22b 100644
--- a/tests_adapters/test_mbart.py
+++ b/tests_adapters/test_mbart.py
@@ -4,7 +4,7 @@
from transformers import MBartAdapterModel
from transformers.testing_utils import require_torch
-from .methods import BottleneckAdapterTestMixin, LoRATestMixin, CompacterTestMixin, PrefixTuningTestMixin, IA3TestMixin
+from .methods import BottleneckAdapterTestMixin, UniPELTTestMixin, CompacterTestMixin, IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin
from .test_adapter import AdapterTestBase, make_config
from .test_adapter_composition import ParallelAdapterInferenceTestMixin
from .test_adapter_conversion import ModelClassConversionTestMixin
@@ -44,6 +44,7 @@ class MBartAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ UniPELTTestMixin,
AdapterFusionModelTestMixin,
PredictionHeadModelTestMixin,
ParallelAdapterInferenceTestMixin,
diff --git a/tests_adapters/test_roberta.py b/tests_adapters/test_roberta.py
index 8d7c2cb355..432e6eea31 100644
--- a/tests_adapters/test_roberta.py
+++ b/tests_adapters/test_roberta.py
@@ -4,7 +4,7 @@
from transformers import RobertaAdapterModel
from transformers.testing_utils import require_torch
-from .methods import BottleneckAdapterTestMixin, CompacterTestMixin, LoRATestMixin, PrefixTuningTestMixin, IA3TestMixin
+from .methods import BottleneckAdapterTestMixin, UniPELTTestMixin, CompacterTestMixin, IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin
from .test_adapter import AdapterTestBase, make_config
from .test_adapter_backward_compability import CompabilityTestMixin
from .test_adapter_composition import ParallelAdapterInferenceTestMixin
@@ -42,6 +42,7 @@ class RobertaAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ UniPELTTestMixin,
AdapterFusionModelTestMixin,
CompabilityTestMixin,
PredictionHeadModelTestMixin,
diff --git a/tests_adapters/test_t5.py b/tests_adapters/test_t5.py
index 5ac59ddaee..0f5f134bc3 100644
--- a/tests_adapters/test_t5.py
+++ b/tests_adapters/test_t5.py
@@ -3,10 +3,17 @@
from datasets import load_dataset
from tests.models.t5.test_modeling_t5 import *
-from transformers import T5AdapterModel, AutoTokenizer
+from transformers import AutoTokenizer, T5AdapterModel
from transformers.testing_utils import require_torch
-from .methods import BottleneckAdapterTestMixin, LoRATestMixin, CompacterTestMixin, PrefixTuningTestMixin, IA3TestMixin
+from .methods import (
+ BottleneckAdapterTestMixin,
+ CompacterTestMixin,
+ IA3TestMixin,
+ LoRATestMixin,
+ PrefixTuningTestMixin,
+ UniPELTTestMixin,
+)
from .test_adapter import AdapterTestBase, make_config
from .test_adapter_backward_compability import CompabilityTestMixin
from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin
@@ -92,6 +99,7 @@ class T5AdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ UniPELTTestMixin,
EmbeddingTestMixin,
CompabilityTestMixin,
ParallelAdapterInferenceTestMixin,
diff --git a/tests_adapters/test_vit.py b/tests_adapters/test_vit.py
index fc4c1c79a4..f5b0d9c04b 100644
--- a/tests_adapters/test_vit.py
+++ b/tests_adapters/test_vit.py
@@ -5,7 +5,14 @@
from transformers import ViTAdapterModel
from transformers.testing_utils import require_torch
-from .methods import BottleneckAdapterTestMixin, CompacterTestMixin, LoRATestMixin, PrefixTuningTestMixin, IA3TestMixin
+from .methods import (
+ BottleneckAdapterTestMixin,
+ CompacterTestMixin,
+ IA3TestMixin,
+ LoRATestMixin,
+ PrefixTuningTestMixin,
+ UniPELTTestMixin,
+)
from .test_adapter import VisionAdapterTestBase, make_config
from .test_adapter_backward_compability import CompabilityTestMixin
from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin
@@ -43,6 +50,7 @@ class ViTAdapterTest(
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
+ UniPELTTestMixin,
AdapterFusionModelTestMixin,
CompabilityTestMixin,
PredictionHeadModelTestMixin,