From a41cf88e9b3e4efc93d97b62bba2a1d8cd48141d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 17 Oct 2023 10:48:24 +0800 Subject: [PATCH] [format] applied code formatting on changed files in pull request 4908 (#4918) Co-authored-by: github-actions --- .../utils/flash_attention_patch.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py index 111659b2d928..1926ec78aba8 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py @@ -6,25 +6,20 @@ import torch import torch.nn.functional as F +from einops import rearrange +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func +from flash_attn.ops.rms_norm import rms_norm from transformers.models.llama.modeling_llama import ( - LlamaRMSNorm, LlamaAttention, - LlamaModel, LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, apply_rotary_pos_emb, repeat_kv, ) from colossalai.logging import get_dist_logger -from einops import rearrange - -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import ( - flash_attn_func, - flash_attn_varlen_kvpacked_func, -) -from flash_attn.ops.rms_norm import rms_norm - logger = get_dist_logger() @@ -65,7 +60,7 @@ def attention_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - **kwargs + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.