Skip to content

Commit

Permalink
[format] applied code formatting on changed files in pull request 4908 (
Browse files Browse the repository at this point in the history
hpcaitech#4918)

Co-authored-by: github-actions <[email protected]>
  • Loading branch information
2 people authored and flybird11111 committed Nov 10, 2023
1 parent 0f976c4 commit e8d9da2
Showing 1 changed file with 7 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,20 @@

import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
from flash_attn.ops.rms_norm import rms_norm
from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
LlamaAttention,
LlamaModel,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
apply_rotary_pos_emb,
repeat_kv,
)

from colossalai.logging import get_dist_logger
from einops import rearrange

from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_varlen_kvpacked_func,
)
from flash_attn.ops.rms_norm import rms_norm


logger = get_dist_logger()

Expand Down Expand Up @@ -65,7 +60,7 @@ def attention_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
Expand Down

0 comments on commit e8d9da2

Please sign in to comment.