Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Transformers to v4.41.x #712

Merged
merged 5 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hf_transformers
Submodule hf_transformers updated 1424 files
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"sphinx-multiversion==0.2.4",
"timeout-decorator",
"torch>=1.10,!=1.12.0",
"transformers~=4.40.2",
"transformers~=4.41.2",
]


Expand Down
4 changes: 1 addition & 3 deletions src/adapters/methods/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,10 +436,8 @@ def pad_and_concat(self, states: List[PrefixTuningState]) -> PrefixTuningState:
value_states = F.pad(value_states, pad_size, "constant", self.model_config.pad_token_id)

# pad attention mask
if pad_length > 0:
if pad_length > 0 and attention_mask is not None:
# Masking the padded tokens only works correctly if attention_mask is set
# We assume this to be the case at this point
assert attention_mask is not None, "Attention mask must be set for prefix tuning"
attention_mask = F.pad(
attention_mask,
(max_prefix_length - attention_mask.shape[-1], 0),
Expand Down
21 changes: 19 additions & 2 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import torch
from torch import nn

from transformers import GenerationConfig
from transformers.modeling_outputs import ModelOutput
from transformers.utils import is_accelerate_available

from .composition import AdapterCompositionBlock, Fuse, Stack, parse_composition
from .configuration import ADAPTER_CONFIG_MAP, AdapterConfig, AdapterFusionConfig, BnConfig
Expand All @@ -29,6 +31,9 @@

logger = logging.getLogger(__name__)

if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module


class InvertibleAdaptersMixin:
"""Mixin for Transformer models adding invertible adapters."""
Expand Down Expand Up @@ -1264,10 +1269,21 @@ def reset_adapter(self):

# HACK Copied from transformers/generation/utils.py
def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
self,
inputs_tensor: torch.Tensor,
model_kwargs,
model_input_name: Optional[str],
generation_config: GenerationConfig,
) -> Dict[str, Any]:
# 1. get encoder
encoder = self.get_encoder()
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
# as the inputs.
if hasattr(self, "hf_device_map"):
if hasattr(encoder, "_hf_hook"):
encoder._hf_hook.io_same_device = True
else:
add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))

# 2. prepare encoder args and encoder kwargs from model kwargs
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
Expand All @@ -1276,7 +1292,6 @@ def _prepare_encoder_decoder_kwargs_for_generation(
for argument, value in model_kwargs.items()
if not any(argument.startswith(p) for p in irrelevant_prefix)
}

encoder_signature = set(inspect.signature(encoder.forward).parameters)
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
if not encoder_accepts_wildcard:
Expand All @@ -1285,6 +1300,8 @@ def _prepare_encoder_decoder_kwargs_for_generation(
for argument, value in encoder_kwargs.items()
if argument in encoder_signature or argument == "adapter_input_parallelized"
}
encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states

# 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
Expand Down
107 changes: 106 additions & 1 deletion src/adapters/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@
import torch.utils.checkpoint
from torch import nn

from transformers.models.bert.modeling_bert import BertOutput, BertSelfAttention, BertSelfOutput
from transformers.models.bert.modeling_bert import BertOutput, BertSdpaSelfAttention, BertSelfAttention, BertSelfOutput
from transformers.utils import logging

from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from ...utils import prefix_attention_mask
from .mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin


logger = logging.get_logger(__name__)


class BertSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, BertSelfAttention):
def forward(
self,
Expand Down Expand Up @@ -142,6 +146,107 @@ def forward(
return outputs


class BertSdpaSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, BertSdpaSelfAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
attention_mask = prefix_attention_mask(attention_mask, [2, 3]) # type: ignore

if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
"BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
" non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to"
" the manual attention implementation, but specifying the manual implementation will be required from"
" Transformers version v5.0.0 onwards. This warning can be removed using the argument"
' `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)

bsz, tgt_len, _ = hidden_states.size()

# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
# mask needs to be such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask

# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
key_layer, value_layer = past_key_value
else:
key_layer = self.transpose_for_scores(self.key(current_states))
value_layer = self.transpose_for_scores(self.value(current_states))
if past_key_value is not None and not is_cross_attention:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)

query_layer = self.transpose_for_scores(self.query(hidden_states))
query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
(attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

key_layer, value_layer, attention_mask = self.prefix_tuning(
key_layer, value_layer, hidden_states, attention_mask
)
(query_layer,) = adjust_tensors_for_parallel(key_layer, query_layer)
bsz = query_layer.size(0)

# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()

# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal
# mask in case tgt_len == 1.
is_causal = self.is_decoder and attention_mask is None and tgt_len > 1

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=is_causal,
)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)

outputs = (attn_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs


class BertSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, BertSelfOutput):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
Expand Down
34 changes: 33 additions & 1 deletion src/adapters/models/vit/modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch import nn

from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from transformers.models.vit.modeling_vit import ViTLayer, ViTOutput, ViTSelfAttention
from transformers.models.vit.modeling_vit import ViTLayer, ViTOutput, ViTSdpaSelfAttention, ViTSelfAttention

from .mixin_vit import ViTLayerAdaptersMixin, ViTOutputAdaptersMixin, ViTSelfAttentionAdaptersMixin

Expand Down Expand Up @@ -70,6 +70,38 @@ def forward(
return outputs


class ViTSdpaSelfAttentionWithAdapters(ViTSelfAttentionAdaptersMixin, ViTSdpaSelfAttention):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)

key_layer, value_layer, _ = self.prefix_tuning(key_layer, value_layer, hidden_states)
(query_layer,) = adjust_tensors_for_parallel(key_layer, query_layer)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

return context_layer, None


class ViTOutputWithAdapters(ViTOutputAdaptersMixin, ViTOutput):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
Expand Down
29 changes: 16 additions & 13 deletions src/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def get_adapter_info(adapter_id: str, source: str = "ah") -> Optional[AdapterInf
raise ValueError("Please specify either 'ah' or 'hf' as source.")


def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0):
def prefix_attention_mask(attention_mask, dim: Union[int, List[int]] = 3, prefix_value: int = 0):
"""
Adds a prefix to an attention mask. The length of the prefix is determined by the `prefix_attention_mask_length`
attribute in the ForwardContext.
Expand All @@ -890,18 +890,21 @@ def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0):
and forward_context is not None
and getattr(forward_context, "prompt_tokens_length", None) is not None
):
# Create a tensor of ones with the desired shape
ones_shape = list(attention_mask.shape)
ones_shape[dim] = forward_context.prompt_tokens_length

prefix_attention_mask = torch.full(
ones_shape,
prefix_value,
dtype=attention_mask.dtype,
).to(attention_mask.device)

# Concatenate the prefix_attention_mask along the specified dimension
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=dim)
if isinstance(dim, int):
dim = [dim]
for d in dim:
# Create a tensor of ones with the desired shape
ones_shape = list(attention_mask.shape)
ones_shape[d] = forward_context.prompt_tokens_length

prefix_attention_mask = torch.full(
ones_shape,
prefix_value,
dtype=attention_mask.dtype,
).to(attention_mask.device)

# Concatenate the prefix_attention_mask along the specified dimension
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=d)

return attention_mask

Expand Down
Loading