diff --git a/Makefile b/Makefile index a4333c6634..04b9374133 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,6 @@ quality: python utils/custom_init_isort.py --check_only python utils/sort_auto_mappings.py --check_only flake8 $(check_dirs) - doc-builder style src/adapters docs --max_len 119 --check_only --path_to_docs docs python utils/check_inits.py # Format source code automatically and check is there are any problems left that need manual fixing @@ -21,7 +20,6 @@ quality: extra_style_checks: python utils/custom_init_isort.py python utils/sort_auto_mappings.py - doc-builder style src/adapters docs --max_len 119 --path_to_docs docs # this target runs checks on all files and potentially modifies some of them diff --git a/hf_transformers b/hf_transformers index 514de24abf..a7cab3c283 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit 514de24abfd4416aeba6a6455ad5920f57f3567d +Subproject commit a7cab3c283312b8d4de5df3bbe719971e24f4281 diff --git a/setup.py b/setup.py index 424eab7d1c..6d1d5e72c1 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,8 @@ # We try to follow their general layout wherever sensible. _deps = [ - "accelerate>=0.20.3", + "accelerate>=0.21.0", + "beautifulsoup4", "black==22.3", # after updating to black 2023, also update Python version in pyproject.toml to 3.7 "datasets!=2.5.0", "dill<0.3.5", @@ -29,7 +30,6 @@ "evaluate>=0.2.0", "flake8>=3.8.3", "GitPython<3.1.19", - "hf-doc-builder>=0.3.0", "isort>=5.5.4", "Jinja2==2.11.3", "nltk", @@ -49,7 +49,6 @@ "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "sacrebleu>=1.4.12,<2.0.0", "sacremoses", - "safetensors>=0.2.1", "scikit-learn", "sentencepiece>=0.1.91,!=0.1.92", "sphinx-copybutton", @@ -61,8 +60,7 @@ "sphinx-multiversion", "timeout-decorator", "torch>=1.10,!=1.12.0", - "transformers~=4.35.2", - "beautifulsoup4", + "transformers~=4.36.0", ] @@ -103,17 +101,15 @@ def deps_list(*pkgs): "rouge-score", "nltk", "GitPython", - "hf-doc-builder", "protobuf", # Can be removed once we can unpin protobuf "sacremoses", "rjieba", - "safetensors", "beautifulsoup4", "pillow", "accelerate", ) -extras["quality"] = deps_list("black", "datasets", "isort", "flake8", "GitPython", "hf-doc-builder") +extras["quality"] = deps_list("black", "datasets", "isort", "flake8", "GitPython") extras["docs"] = deps_list( "docutils", diff --git a/src/adapters/heads/base.py b/src/adapters/heads/base.py index 343b2774af..0f35752b00 100644 --- a/src/adapters/heads/base.py +++ b/src/adapters/heads/base.py @@ -102,7 +102,11 @@ def build(self, model): for i, module in enumerate(pred_head): self.add_module(str(i), module) - self.apply(model._init_weights) + # We need to import the current value of _init_weights at each execution to determine if weights init is disabled. + from transformers.modeling_utils import _init_weights + + if _init_weights: + self.apply(model._init_weights) self.train(model.training) # make sure training mode is consistent def get_output_embeddings(self): diff --git a/src/adapters/models/bart/modeling_bart.py b/src/adapters/models/bart/modeling_bart.py index 28bf37bd7c..b347fddf07 100644 --- a/src/adapters/models/bart/modeling_bart.py +++ b/src/adapters/models/bart/modeling_bart.py @@ -20,11 +20,15 @@ from torch import nn from transformers.models.bart.modeling_bart import BartAttention, BartDecoderLayer, BartEncoderLayer +from transformers.utils import logging from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel from .mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin +logger = logging.get_logger(__name__) + + class BartAttentionWithAdapters(BartAttentionAdaptersMixin, BartAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -96,10 +100,9 @@ def forward( bsz = query_states.size(0) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) @@ -145,7 +148,7 @@ def forward( if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) @@ -153,7 +156,7 @@ def forward( attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. + # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) @@ -161,6 +164,233 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value +class BartFlashAttention2WithAdapters(BartAttentionAdaptersMixin, BartAttention): + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # BartFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("BartFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + bsz = query_states.size(0) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to the fact" + " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class BartSdpaAttentionWithAdapters(BartAttentionAdaptersMixin, BartAttention): + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not" + " support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + " implementation, but specifying the manual implementation will be required from Transformers version" + ' v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when' + " loading the model." + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + bsz = query_states.size(0) + + query_states = self._shape(query_states, tgt_len, bsz) + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + class BartEncoderLayerWithAdapters(BartEncoderLayerAdaptersMixin, BartEncoderLayer): def forward( self, diff --git a/src/adapters/models/beit/mixin_beit.py b/src/adapters/models/beit/mixin_beit.py index 608d9c5cc4..d4490a33d3 100644 --- a/src/adapters/models/beit/mixin_beit.py +++ b/src/adapters/models/beit/mixin_beit.py @@ -55,3 +55,7 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: def set_input_embeddings(self, value): self.embeddings.patch_embeddings = value + + def post_embedding_forward(self, module, args, outputs): + embedding_output, tup = outputs + return super().post_embedding_forward(module, args, embedding_output), tup diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index 3b22e5ae13..baefac75a9 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -19,9 +19,11 @@ # limitations under the License. """ PyTorch LLaMA model.""" import math +import warnings from typing import Optional, Tuple import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -30,7 +32,8 @@ adjust_tensors_for_parallel_, match_attn_matrices_for_parallel, ) -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from transformers.cache_utils import Cache +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, apply_rotary_pos_emb, repeat_kv from transformers.utils import logging from .mixin_llama import LlamaAttentionMixin, LlamaDecoderLayerMixin @@ -39,7 +42,7 @@ logger = logging.get_logger(__name__) -class LlamaAttentionWithAdapters(nn.Module, LlamaAttentionMixin): +class LlamaAttentionWithAdapters(LlamaAttentionMixin, LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" def forward( @@ -47,15 +50,44 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" + " `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) query_states, key_states, value_states = match_attn_matrices_for_parallel( query_states, key_states, value_states @@ -64,17 +96,22 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if self.layer_idx is None: + raise ValueError( + "The cache structure has changed since version v4.36. If you are using" + f" {self.__class__.__name__} for auto-regressive decoding with k/v caching, please make sure to" + " initialize the attention class with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - # [bsz, nh, t, hd] if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - past_key_value = (key_states, value_states) if use_cache else None + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) key_states, value_states, attention_mask = self.prefix_tuning( key_states, value_states, hidden_states, attention_mask @@ -85,7 +122,7 @@ def forward( # Make adjustments since (parallel) prefix tuning changes the attention mask kv_seq_len = key_states.shape[-2] - bsz = attention_mask.shape[0] + bsz = key_states.shape[0] if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -99,10 +136,10 @@ def forward( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -111,9 +148,120 @@ def forward( f" {attn_output.size()}" ) - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2WithAdapters(LlamaAttentionMixin, LlamaAttention): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" + " `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + + # Make adjustments since (parallel) prefix tuning changes the attention mask + kv_seq_len = key_states.shape[-2] + bsz = key_states.shape[0] + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to the fact" + " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: @@ -122,7 +270,105 @@ def forward( return attn_output, attn_weights, past_key_value -class LlamaDecoderLayerWithAdapters(nn.Module, LlamaDecoderLayerMixin): +class LlamaSdpaAttentionWithAdapters(LlamaAttentionMixin, LlamaAttention): + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does" + " not support `output_attentions=True`. Falling back to the manual attention implementation, but" + " specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This" + ' warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + + # Make adjustments since (parallel) prefix tuning changes the attention mask + kv_seq_len = key_states.shape[-2] + bsz = key_states.shape[0] + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class LlamaDecoderLayerWithAdapters(LlamaDecoderLayerMixin, LlamaDecoderLayer): def forward( self, hidden_states: torch.Tensor, diff --git a/tests/test_llama.py b/tests/test_llama.py index 2fd455c174..9cb6fcfda2 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -56,7 +56,7 @@ class LlamaAdapterTest( @require_torch -class BertClassConversionTest( +class LlamaClassConversionTest( ModelClassConversionTestMixin, LlamaAdapterTestBase, unittest.TestCase,