Skip to content

Commit

Permalink
[Refactor] Integrated some lightllm kernels into token-attention (hpc…
Browse files Browse the repository at this point in the history
…aitech#4946)

* add some req for inference

* clean codes

* add codes

* add some lightllm deps

* clean codes

* hello

* delete rms files

* add some comments

* add comments

* add doc

* add lightllm deps

* add lightllm cahtglm2 kernels

* add lightllm cahtglm2 kernels

* replace rotary embedding with lightllm kernel

* add some commnets

* add some comments

* add some comments

* add

* replace fwd kernel att1

* fix a arg

* add

* add

* fix token attention

* add some comments

* clean codes

* modify comments

* fix readme

* fix bug

* fix bug

---------

Co-authored-by: cuiqing.li <[email protected]>
Co-authored-by: CjhHa1 <[email protected]>
  • Loading branch information
3 people authored and flybird11111 committed Nov 10, 2023
1 parent b22923b commit 3e81321
Show file tree
Hide file tree
Showing 20 changed files with 158 additions and 1,553 deletions.
19 changes: 18 additions & 1 deletion colossalai/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

## Introduction

`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.
`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including LightLLM, TGI, vLLM, FasterTransformer and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.

## Design

Expand Down Expand Up @@ -62,6 +62,12 @@ triton==2.0.0.dev20221202
vllm
# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c
flash-attention

# install lightllm since we depend on lightllm triton kernels
git clone https://github.com/ModelTC/lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
cd lightllm
pip3 install -e .
```

### Docker
Expand All @@ -73,6 +79,17 @@ You can use docker run to use docker container to set-up environment
docker pull hpcaitech/colossalai-inference:v2
docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash
# enter into docker container
cd /path/to/CollossalAI
pip install -e .
# install lightllm
git clone https://github.com/ModelTC/lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
cd lightllm
pip3 install -e .
```

### Dive into fast-inference!
Expand Down
3 changes: 2 additions & 1 deletion colossalai/inference/tensor_parallel/batch_infer_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .kvcache_manager import MemoryManager


# adapted from: lightllm/server/router/model_infer/infer_batch.py
@dataclass
class BatchInferState:
r"""
Expand Down Expand Up @@ -41,6 +41,7 @@ def total_token_num(self):
def set_cache_manager(self, manager: MemoryManager):
self.cache_manager = manager

# adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1
@staticmethod
def init_block_loc(
b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
Expand Down
10 changes: 6 additions & 4 deletions colossalai/inference/tensor_parallel/kvcache_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Adapted from lightllm/common/mem_manager.py
# of the ModelTC/lightllm GitHub repository
# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py

"""
Refered/Modified from lightllm/common/mem_manager.py
of the ModelTC/lightllm GitHub repository
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
"""
import torch
from transformers.utils import logging

Expand Down
18 changes: 12 additions & 6 deletions colossalai/inference/tensor_parallel/modeling/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd
from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
Expand All @@ -20,6 +18,14 @@

from ._utils import copy_kv_to_mem_cache

try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd
HAS_LIGHTLLM_KERNEL = True
except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False


# This func is same as Llama model init_to_get_rotary, we should move them into _utils.py
def _init_to_get_rotary(self, base=10000):
Expand Down Expand Up @@ -433,17 +439,17 @@ def chatglm_flash_attn_kvcache_forward(

cos, sin = infer_state.position_cos, infer_state.position_sin

Llama2Forwards.rotary_emb_fwd(
chatglm2_rotary_emb_fwd(
query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin
)
if self.multi_query_attention:
Llama2Forwards.rotary_emb_fwd(
chatglm2_rotary_emb_fwd(
key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head),
cos,
sin,
)
else:
Llama2Forwards.rotary_emb_fwd(
chatglm2_rotary_emb_fwd(
key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
cos,
sin,
Expand Down Expand Up @@ -474,7 +480,7 @@ def chatglm_flash_attn_kvcache_forward(
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))

# NOTE: no bug in context attn fwd (del it )
llama2_context_attn_fwd(
lightllm_llama2_context_attention_fwd(
query_layer,
key_layer,
value_layer,
Expand Down
25 changes: 16 additions & 9 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import (
llama2_context_attn_fwd,
llama_context_attn_fwd,
rotary_embedding_fwd,
token_attention_fwd,
)
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards

from ._utils import copy_kv_to_mem_cache
Expand All @@ -29,6 +24,17 @@
)
HAS_VLLM_KERNERL = False

try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd

HAS_LIGHTLLM_KERNEL = True
except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
Expand Down Expand Up @@ -280,8 +286,8 @@ def llama_flash_attn_kvcache_forward(
cos, sin = infer_state.position_cos, infer_state.position_sin
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )

rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)

query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
Expand Down Expand Up @@ -312,7 +318,7 @@ def llama_flash_attn_kvcache_forward(
infer_state.cache_manager.past_key_values_length,
)
else:
llama2_context_attn_fwd(
lightllm_llama2_context_attention_fwd(
query_states,
key_states,
value_states,
Expand Down Expand Up @@ -371,6 +377,7 @@ def llama_flash_attn_kvcache_forward(
infer_state.cache_manager.past_key_values_length,
infer_state.other_kv_index,
)

attn_output = attn_output.view(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)
Expand Down
6 changes: 2 additions & 4 deletions colossalai/inference/tensor_parallel/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward

try:
from colossalai.kernel.triton import rmsnorm_forward

from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
Expand All @@ -22,9 +21,8 @@

def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:

def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)

return _triton_rmsnorm_forward
else:
Expand Down
7 changes: 1 addition & 6 deletions colossalai/kernel/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,21 @@

# There may exist import error even if we have triton installed.
if HAS_TRITON:
from .context_attention import bloom_context_attn_fwd, llama2_context_attn_fwd, llama_context_attn_fwd
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm
from .gptq_triton import gptq_fused_linear_triton
from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd
from .rms_norm import rmsnorm_forward
from .rotary_embedding_kernel import rotary_embedding_fwd
from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd
from .softmax import softmax
from .token_attention_kernel import token_attention_fwd

__all__ = [
"llama_context_attn_fwd",
"llama2_context_attn_fwd",
"bloom_context_attn_fwd",
"softmax",
"layer_norm",
"rmsnorm_forward",
"copy_kv_cache_to_dest",
"rotary_embedding_fwd",
"token_attention_fwd",
"gptq_fused_linear_triton",
"int8_rotary_embedding_fwd",
Expand Down
Loading

0 comments on commit 3e81321

Please sign in to comment.