Skip to content

Commit

Permalink
Fix DeBERTa prefix tuning w. enabled relative attention (#451)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt authored Nov 24, 2022
1 parent 068286d commit d713de3
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 6 deletions.
9 changes: 7 additions & 2 deletions src/transformers/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,7 @@ def linear(w, b, x):
query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :])

orig_key_layer = key_layer # save this for relative attention
key_layer, value_layer, attention_mask = self.prefix_tuning(
key_layer, value_layer, hidden_states, attention_mask, False
)
Expand All @@ -667,10 +668,14 @@ def linear(w, b, x):
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
rel_att = self.disentangled_att_bias(
query_layer, orig_key_layer, relative_pos, rel_embeddings, scale_factor
)

if rel_att is not None:
attention_scores = attention_scores + rel_att
rel_att_padded = torch.zeros_like(attention_scores)
rel_att_padded[:, :, :, -rel_att.size(-1) :] = rel_att
attention_scores = attention_scores + rel_att_padded

# bxhxlxd
if self.talking_head:
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,9 @@ def forward(
key_layer = self.transpose_for_scores_extended(self.key_proj(hidden_states), self.num_attention_heads)
value_layer = self.transpose_for_scores_extended(self.value_proj(hidden_states), self.num_attention_heads)

orig_key_layer = key_layer.contiguous().view(
-1, key_layer.size(2), key_layer.size(-1)
) # save this for relative attention
key_layer, value_layer, attention_mask = self.prefix_tuning(
key_layer, value_layer, hidden_states, attention_mask, False
) # [:, 0, :, 0])
Expand All @@ -732,11 +735,13 @@ def forward(
if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_attention_bias(
query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
query_layer, orig_key_layer, relative_pos, rel_embeddings, scale_factor
)

if rel_att is not None:
attention_scores = attention_scores + rel_att
rel_att_padded = torch.zeros_like(attention_scores)
rel_att_padded[:, :, -rel_att.size(2) :] = rel_att
attention_scores = attention_scores + rel_att_padded
attention_scores = attention_scores
attention_scores = attention_scores.view(
-1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
Expand Down
11 changes: 10 additions & 1 deletion tests_adapters/test_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
from transformers import DebertaAdapterModel
from transformers.testing_utils import require_torch

from .methods import BottleneckAdapterTestMixin, UniPELTTestMixin, CompacterTestMixin, IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin
from .methods import (
BottleneckAdapterTestMixin,
CompacterTestMixin,
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
UniPELTTestMixin,
)
from .test_adapter import AdapterTestBase, make_config
from .test_adapter_backward_compability import CompabilityTestMixin
from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin
Expand Down Expand Up @@ -32,6 +39,8 @@ class DebertaAdapterTestBase(AdapterTestBase):
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
relative_attention=True,
pos_att_type="p2c|c2p",
)
tokenizer_name = "microsoft/deberta-base"

Expand Down
11 changes: 10 additions & 1 deletion tests_adapters/test_debertaV2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
from transformers import DebertaV2AdapterModel
from transformers.testing_utils import require_torch

from .methods import BottleneckAdapterTestMixin, UniPELTTestMixin, CompacterTestMixin, IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin
from .methods import (
BottleneckAdapterTestMixin,
CompacterTestMixin,
IA3TestMixin,
LoRATestMixin,
PrefixTuningTestMixin,
UniPELTTestMixin,
)
from .test_adapter import AdapterTestBase, make_config
from .test_adapter_backward_compability import CompabilityTestMixin
from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin
Expand Down Expand Up @@ -32,6 +39,8 @@ class DebertaV2AdapterTestBase(AdapterTestBase):
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
relative_attention=True,
pos_att_type="p2c|c2p",
)
tokenizer_name = "microsoft/deberta-v3-base"

Expand Down

0 comments on commit d713de3

Please sign in to comment.