Skip to content

Commit

Permalink
update, using sep_group
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyuqin1998 committed May 31, 2024
1 parent 63b2be8 commit ab562b7
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 817 deletions.
20 changes: 13 additions & 7 deletions paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,18 @@ def swiglu(x, y=None):
flash_attention = None

from paddlenlp.transformers.ring_flash_attention import RingFlashAttention
from paddlenlp.transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance

def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb, cp_parallel_degree=-1):

def fusion_rope(
query_states,
key_states,
value_states,
hidden_states,
position_ids,
past_key_value,
rotary_emb,
cp_parallel_degree=-1,
):
if get_env_device() != "gcu":
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
Expand All @@ -64,9 +73,6 @@ def fusion_rope(query_states, key_states, value_states, hidden_states, position_
kv_seq_len *= cp_parallel_degree

Check warning on line 73 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L71-L73

Added lines #L71 - L73 were not covered by tests
if get_env_device() != "gcu":
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
if cp_parallel_degree > 1:
cos = split_inputs_sequence_dim_load_balance(cos)
sin = split_inputs_sequence_dim_load_balance(sin)
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
Expand Down Expand Up @@ -165,7 +171,7 @@ def fusion_flash_attention(
attention_mask = attention_mask.cast(alibi.dtype) + alibi
if get_env_device() == "npu":
if config.cp_parallel_degree > 1:
raise ValueError(f"Context parallel is not implemented for npu")
raise ValueError("Context parallel is not implemented for npu")

Check warning on line 174 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L173-L174

Added lines #L173 - L174 were not covered by tests
attn_output = core.eager._run_custom_op(
"flash_attention_npu",
query_states,
Expand All @@ -181,7 +187,7 @@ def fusion_flash_attention(
)[0]
elif get_env_device() == "gcu":
if config.cp_parallel_degree > 1:
raise ValueError(f"Context parallel is not implemented for gcu")
raise ValueError("Context parallel is not implemented for gcu")

Check warning on line 190 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L189-L190

Added lines #L189 - L190 were not covered by tests
attn_output = core.eager._run_custom_op(
"fused_sdp_flash_attention_gcu",
query_states,
Expand Down
4 changes: 1 addition & 3 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def swiglu(x, y=None):
]



def _get_interleave(n):
def _get_interleave_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
Expand Down Expand Up @@ -956,7 +955,7 @@ def forward(
position_ids,
past_key_value,
self.rotary_emb,
self.cp_parallel_degree
self.config.cp_parallel_degree,
)

else:
Expand All @@ -972,7 +971,6 @@ def forward(
)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

# [bs, seq_len, num_head, head_dim]
Expand Down
Loading

0 comments on commit ab562b7

Please sign in to comment.