Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute flash data layout info once and for all when possible #4

Open
wants to merge 2 commits into
base: add-flash-attn-2
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# limitations under the License.
""" PyTorch LLaMA model."""
import math
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Dict

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -323,6 +323,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
flash_kwargs: None = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -478,6 +479,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
flash_kwargs: Optional[Dict] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention attention does not support output_attentions
output_attentions = False
Expand Down Expand Up @@ -519,9 +521,12 @@ def forward(
# when training.
dropout_rate = 0.0 # if not self.training else self.attn_dropout

# contains at least one padding token
if padding_mask is not None:
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
# contains at least one masked token
if flash_kwargs["masking"]:
indices_k = flash_kwargs["indices_k"]
cu_seqlens_k = flash_kwargs["cu_seqlens_k"]
max_seqlen_in_batch_k = flash_kwargs["max_seqlen_in_batch_k"]

key_states = index_first_axis(rearrange(key_states, "b s ... -> (b s) ..."), indices_k)
value_states = index_first_axis(rearrange(value_states, "b s ... -> (b s) ..."), indices_k)

Expand All @@ -533,11 +538,9 @@ def forward(
indices_q = indices_k
elif q_len == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
bsz + 1, dtype=torch.int32, device=query_states.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_states = query_states.squeeze(1)
cu_seqlens_q = flash_kwargs["cu_seqlens_q"]
indices_q = flash_kwargs["indices_q"]
query_states = query_states.squeeze(1) # [batch_size, 1, num_heads, head_dim] -> [batch_size, num_heads, head_dim]
else:
# The -q_len: slice assumes left padding.
padding_mask = padding_mask[:, -q_len:]
Expand Down Expand Up @@ -591,6 +594,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
flash_kwargs: Optional[Dict] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand Down Expand Up @@ -619,6 +623,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
flash_kwargs=flash_kwargs,
)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -770,6 +775,7 @@ def __init__(self, config: LlamaConfig):
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
self._flash = getattr(config, "_flash_attn_2_enabled", False)
# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -864,9 +870,26 @@ def forward(
else:
padding_mask = None

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)

flash_kwargs = None
if not self._flash:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
else:
flash_kwargs = {}
flash_kwargs["masking"] = padding_mask is not None

if padding_mask is not None:
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
flash_kwargs["indices_k"] = indices_k
flash_kwargs["cu_seqlens_k"] = cu_seqlens_k
flash_kwargs["max_seqlen_in_batch_k"] = max_seqlen_in_batch_k
if seq_length == 1:
flash_kwargs["cu_seqlens_q"] = torch.arange(
batch_size + 1, dtype=torch.int32, device=input_ids.device
) # There is a memcpy here, that is very bad. At least happening only once.
flash_kwargs["indices_q"] = flash_kwargs["cu_seqlens_q"][:-1]

hidden_states = inputs_embeds

Expand Down Expand Up @@ -909,6 +932,7 @@ def custom_forward(*inputs):
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
flash_kwargs=flash_kwargs,
)

hidden_states = layer_outputs[0]
Expand Down
Loading