Skip to content

Commit

Permalink
Fix Training Error with AdapterDrop and Prefix Tuning (#673)
Browse files Browse the repository at this point in the history
Fixes #669 

Changes in this PR:
- Avoid throwing `RuntimeError` due to dimension mismatch occuring when
passing the positional encoding from layers dropped by AdapterDrop to
layers modified by prefix tuning.
  • Loading branch information
TimoImhof committed Apr 12, 2024
1 parent 35fd1c2 commit 95cf6bd
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/adapters/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def forward(
if past_key_value is not None:
assert (
len(past_key_value) == 2
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
Expand Down Expand Up @@ -144,7 +144,13 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

if position_bias is None:
# For Prefix Tuning, when training with AdapterDrop, we must additionally check that the sequence lengths of
# both positional encoding and the scores account for the prefix tokens.
# This is because the positional encoding is calculated only once in the beginning and then used for all layers.
# However, if the encoding was calculated without the prefix tokens due to AdapterDrop having dropped an
# adapter layer in the beginning, the positional encoding will be shorter than the scores, resulting in a
# dimension mismatch when adding the positional encoding to the scores.
if position_bias is None or position_bias.shape[3] != scores.shape[3]:
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
Expand Down

0 comments on commit 95cf6bd

Please sign in to comment.