Skip to content

Commit

Permalink
reset shardformer llama
Browse files Browse the repository at this point in the history
  • Loading branch information
Xu-Kai committed Aug 31, 2023
1 parent 9ced713 commit e79cd58
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 26 deletions.
20 changes: 12 additions & 8 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaModel,
apply_rotary_pos_emb,
LlamaRMSNorm
LlamaAttention,
LlamaDecoderLayer,
LlamaModel,
LlamaRMSNorm,
apply_rotary_pos_emb,
)

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
Expand All @@ -17,7 +17,7 @@
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd

try:
from vllm import pos_encoding_ops, layernorm_ops
from vllm import layernorm_ops, pos_encoding_ops
rms_norm = layernorm_ops.rms_norm
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
Expand Down Expand Up @@ -255,7 +255,9 @@ def llama_flash_attn_kvcache_forward(
if HAS_VLLM_KERNERL:
cos_sin_cache = torch.cat((cos, sin), dim=-1)
rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache)
key_states = key_states_transposed.transpose(1, 2)
else:
# TODO: there are some issurs for original rotary_embedding_neox of huggingface
query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids)

def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
Expand Down Expand Up @@ -313,9 +315,11 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,
# return past_key_value as None
return attn_output, None, None


def get_llama_vllm_rmsnorm_forward():

if HAS_VLLM_KERNERL:

def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
x = hidden_states
out = torch.empty_like(x)
Expand All @@ -330,4 +334,4 @@ def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):

return _vllm_rmsnorm_forward
else:
return None
return None
19 changes: 1 addition & 18 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,6 @@
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.pipeline.stage_manager import PipelineStageManager

try:
from vllm import pos_encoding_ops
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
except:
print("fall back to original rotary_embedding_neox of huggingface")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
print(
"if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
)
HAS_VLLM_KERNERL = False


class LlamaPipelineForwards:
'''
Expand Down Expand Up @@ -434,11 +422,7 @@ def forward(

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

if HAS_VLLM_KERNERL:
cos_sin_cache = torch.cat((cos, sin), dim=-1)
rotary_embedding_neox(position_ids, query_states, key_states, self.head_dim, cos_sin_cache)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
Expand Down Expand Up @@ -473,4 +457,3 @@ def forward(
return attn_output, None, past_key_value

return forward

0 comments on commit e79cd58

Please sign in to comment.