From 2e456e95f4029919d58ab12f00f4d1b1b9276397 Mon Sep 17 00:00:00 2001 From: Timo Imhof Date: Thu, 11 Apr 2024 10:05:15 +0200 Subject: [PATCH 1/3] Add check for shape mismatch of position bias and scores when using prefix tuning with AdapterDrop --- src/adapters/models/t5/modeling_t5.py | 94 ++++++++++++++------------- 1 file changed, 48 insertions(+), 46 deletions(-) diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index b366b9ceb..c403aa397 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -37,7 +37,6 @@ T5StackAdaptersMixin, ) - logger = logging.get_logger(__name__) @@ -53,16 +52,16 @@ def forward(self, hidden_states): class T5AttentionWithAdapters(T5AttentionAdaptersMixin, T5Attention): def forward( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - output_attentions=False, + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). @@ -76,8 +75,8 @@ def forward( if past_key_value is not None: assert ( - len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] @@ -144,7 +143,10 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): query_states, key_states.transpose(3, 2) ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - if position_bias is None: + # For Prefix Tuning, when training with AdapterDrop, we must additionally check if the dimensions of the + # position_bias given from previous layers match the dimensions of the scores of the current layer to make + # sure that the position_bias is adequately recomputed if previous layers have been skipped. + if position_bias is None or position_bias.shape != scores.shape: if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype @@ -157,7 +159,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # if key and values are already calculated # we want only the last query position bias if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = position_bias[:, :, -hidden_states.size(1):, :] if mask is not None: position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) @@ -193,14 +195,14 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): class T5LayerSelfAttentionWithAdapters(T5SelfAttentionLayerAdaptersMixin, T5LayerSelfAttention): def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -221,16 +223,16 @@ def forward( class T5LayerCrossAttentionWithAdapters(T5CrossAttentionLayerAdaptersMixin, T5LayerCrossAttention): def forward( - self, - hidden_states, - key_value_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - query_length=None, - output_attentions=False, + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -253,19 +255,19 @@ def forward( class T5StackWithAdapters(T5StackAdaptersMixin, T5Stack): def forward( - self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, ): # Model parallel if self.model_parallel: From 123b146d34012247764d7fac6d0522df91c8bfa6 Mon Sep 17 00:00:00 2001 From: Timo Imhof Date: Thu, 11 Apr 2024 10:27:21 +0200 Subject: [PATCH 2/3] make style --- src/adapters/models/t5/modeling_t5.py | 87 ++++++++++++++------------- 1 file changed, 44 insertions(+), 43 deletions(-) diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index c403aa397..31d708d1b 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -37,6 +37,7 @@ T5StackAdaptersMixin, ) + logger = logging.get_logger(__name__) @@ -52,16 +53,16 @@ def forward(self, hidden_states): class T5AttentionWithAdapters(T5AttentionAdaptersMixin, T5Attention): def forward( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - output_attentions=False, + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). @@ -75,7 +76,7 @@ def forward( if past_key_value is not None: assert ( - len(past_key_value) == 2 + len(past_key_value) == 2 ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length @@ -159,7 +160,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # if key and values are already calculated # we want only the last query position bias if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1):, :] + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] if mask is not None: position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) @@ -195,14 +196,14 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): class T5LayerSelfAttentionWithAdapters(T5SelfAttentionLayerAdaptersMixin, T5LayerSelfAttention): def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -223,16 +224,16 @@ def forward( class T5LayerCrossAttentionWithAdapters(T5CrossAttentionLayerAdaptersMixin, T5LayerCrossAttention): def forward( - self, - hidden_states, - key_value_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - query_length=None, - output_attentions=False, + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -255,19 +256,19 @@ def forward( class T5StackWithAdapters(T5StackAdaptersMixin, T5Stack): def forward( - self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, ): # Model parallel if self.model_parallel: From 363d60dd56a849763632a00e70acb10c41942970 Mon Sep 17 00:00:00 2001 From: Timo Imhof Date: Fri, 12 Apr 2024 10:50:06 +0200 Subject: [PATCH 3/3] Update: - specify the sequence length dimension - Improve explaining comment --- src/adapters/models/t5/modeling_t5.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index 31d708d1b..03d9f2797 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -144,10 +144,13 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): query_states, key_states.transpose(3, 2) ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - # For Prefix Tuning, when training with AdapterDrop, we must additionally check if the dimensions of the - # position_bias given from previous layers match the dimensions of the scores of the current layer to make - # sure that the position_bias is adequately recomputed if previous layers have been skipped. - if position_bias is None or position_bias.shape != scores.shape: + # For Prefix Tuning, when training with AdapterDrop, we must additionally check that the sequence lengths of + # both positional encoding and the scores account for the prefix tokens. + # This is because the positional encoding is calculated only once in the beginning and then used for all layers. + # However, if the encoding was calculated without the prefix tokens due to AdapterDrop having dropped an + # adapter layer in the beginning, the positional encoding will be shorter than the scores, resulting in a + # dimension mismatch when adding the positional encoding to the scores. + if position_bias is None or position_bias.shape[3] != scores.shape[3]: if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype