Skip to content

Commit

Permalink
Upgrade Transformers to v4.41.x (adapter-hub#712)
Browse files Browse the repository at this point in the history
Changes needed for sync:
- BERT/ ViT: Copy & adapt new sdpa attention classes
- Update copied `_prepare_encoder_decoder_kwargs_for_generation` in
model mixin
- Adjust 2dim attention masks for prompt tuning
  • Loading branch information
dainis-boumber committed Aug 25, 2024
1 parent 798a29c commit c6b7972
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 20 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"sphinx-multiversion==0.2.4",
"timeout-decorator",
"torch>=1.10,!=1.12.0",
"transformers~=4.40.2",
"transformers~=4.41.2",
]


Expand Down
4 changes: 1 addition & 3 deletions src/adapters/methods/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,8 @@ def pad_and_concat(self, states: List[PrefixTuningState]) -> PrefixTuningState:
value_states = F.pad(value_states, pad_size, "constant", self.model_config.pad_token_id)

# pad attention mask
if pad_length > 0:
if pad_length > 0 and attention_mask is not None:
# Masking the padded tokens only works correctly if attention_mask is set
# We assume this to be the case at this point
assert attention_mask is not None, "Attention mask must be set for prefix tuning"
attention_mask = F.pad(
attention_mask,
(max_prefix_length - attention_mask.shape[-1], 0),
Expand Down
20 changes: 19 additions & 1 deletion src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import torch
from torch import nn

from transformers import GenerationConfig
from transformers.modeling_outputs import ModelOutput
from transformers.utils import is_accelerate_available

from .composition import AdapterCompositionBlock, Fuse, Stack, parse_composition
from .configuration import ADAPTER_CONFIG_MAP, AdapterConfig, AdapterFusionConfig, BnConfig
Expand All @@ -29,6 +31,8 @@

logger = logging.getLogger(__name__)

if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

class InvertibleAdaptersMixin:
"""Mixin for Transformer models adding invertible adapters."""
Expand Down Expand Up @@ -1263,10 +1267,21 @@ def reset_adapter(self):

# HACK Copied from transformers/generation/utils.py
def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
self,
inputs_tensor: torch.Tensor,
model_kwargs,
model_input_name: Optional[str],
generation_config: GenerationConfig,
) -> Dict[str, Any]:
# 1. get encoder
encoder = self.get_encoder()
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
# as the inputs.
if hasattr(self, "hf_device_map"):
if hasattr(encoder, "_hf_hook"):
encoder._hf_hook.io_same_device = True
else:
add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))

# 2. prepare encoder args and encoder kwargs from model kwargs
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
Expand All @@ -1285,6 +1300,9 @@ def _prepare_encoder_decoder_kwargs_for_generation(
if argument in encoder_signature or argument == "adapter_input_parallelized"
}

encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states

# 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
encoder_kwargs["return_dict"] = True
Expand Down
106 changes: 105 additions & 1 deletion src/adapters/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@
import torch.utils.checkpoint
from torch import nn

from transformers.models.bert.modeling_bert import BertOutput, BertSelfAttention, BertSelfOutput
from transformers.models.bert.modeling_bert import BertOutput, BertSdpaSelfAttention, BertSelfAttention, BertSelfOutput
from transformers.utils import logging

from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from ...utils import prefix_attention_mask
from .mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin


logger = logging.get_logger(__name__)

class BertSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, BertSelfAttention):
def forward(
self,
Expand Down Expand Up @@ -142,6 +145,107 @@ def forward(
return outputs


class BertSdpaSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, BertSdpaSelfAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
attention_mask = prefix_attention_mask(attention_mask, [2, 3]) # type: ignore

if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
"BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
" non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to"
" the manual attention implementation, but specifying the manual implementation will be required from"
" Transformers version v5.0.0 onwards. This warning can be removed using the argument"
' `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)

bsz, tgt_len, _ = hidden_states.size()

# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
# mask needs to be such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask

# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
key_layer, value_layer = past_key_value
else:
key_layer = self.transpose_for_scores(self.key(current_states))
value_layer = self.transpose_for_scores(self.value(current_states))
if past_key_value is not None and not is_cross_attention:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)

query_layer = self.transpose_for_scores(self.query(hidden_states))
query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
(attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

key_layer, value_layer, attention_mask = self.prefix_tuning(
key_layer, value_layer, hidden_states, attention_mask
)
(query_layer,) = adjust_tensors_for_parallel(key_layer, query_layer)
bsz = query_layer.size(0)

# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()

# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal
# mask in case tgt_len == 1.
is_causal = self.is_decoder and attention_mask is None and tgt_len > 1

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=is_causal,
)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)

outputs = (attn_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs


class BertSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, BertSelfOutput):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
Expand Down
34 changes: 33 additions & 1 deletion src/adapters/models/vit/modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch import nn

from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from transformers.models.vit.modeling_vit import ViTLayer, ViTOutput, ViTSelfAttention
from transformers.models.vit.modeling_vit import ViTLayer, ViTOutput, ViTSdpaSelfAttention, ViTSelfAttention

from .mixin_vit import ViTLayerAdaptersMixin, ViTOutputAdaptersMixin, ViTSelfAttentionAdaptersMixin

Expand Down Expand Up @@ -70,6 +70,38 @@ def forward(
return outputs


class ViTSdpaSelfAttentionWithAdapters(ViTSelfAttentionAdaptersMixin, ViTSdpaSelfAttention):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)

key_layer, value_layer, _ = self.prefix_tuning(key_layer, value_layer, hidden_states)
(query_layer,) = adjust_tensors_for_parallel(key_layer, query_layer)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

return context_layer, None


class ViTOutputWithAdapters(ViTOutputAdaptersMixin, ViTOutput):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
Expand Down
29 changes: 16 additions & 13 deletions src/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def get_adapter_info(adapter_id: str, source: str = "ah") -> Optional[AdapterInf
raise ValueError("Please specify either 'ah' or 'hf' as source.")


def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0):
def prefix_attention_mask(attention_mask, dim: Union[int, List[int]] = 3, prefix_value: int = 0):
"""
Adds a prefix to an attention mask. The length of the prefix is determined by the `prefix_attention_mask_length`
attribute in the ForwardContext.
Expand All @@ -890,18 +890,21 @@ def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0):
and forward_context is not None
and getattr(forward_context, "prompt_tokens_length", None) is not None
):
# Create a tensor of ones with the desired shape
ones_shape = list(attention_mask.shape)
ones_shape[dim] = forward_context.prompt_tokens_length

prefix_attention_mask = torch.full(
ones_shape,
prefix_value,
dtype=attention_mask.dtype,
).to(attention_mask.device)

# Concatenate the prefix_attention_mask along the specified dimension
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=dim)
if isinstance(dim, int):
dim = [dim]
for d in dim:
# Create a tensor of ones with the desired shape
ones_shape = list(attention_mask.shape)
ones_shape[d] = forward_context.prompt_tokens_length

prefix_attention_mask = torch.full(
ones_shape,
prefix_value,
dtype=attention_mask.dtype,
).to(attention_mask.device)

# Concatenate the prefix_attention_mask along the specified dimension
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=d)

return attention_mask

Expand Down

0 comments on commit c6b7972

Please sign in to comment.