Skip to content

Commit

Permalink
[Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding …
Browse files Browse the repository at this point in the history
…for llama token attention (hpcaitech#4965)

* adding flash-decoding

* clean

* adding kernel

* adding flash-decoding

* add integration

* add

* adding kernel

* adding kernel

* adding triton 2.1.0 features for inference

* update bloom triton kernel

* remove useless vllm kernels

* clean codes

* fix

* adding files

* fix readme

* update llama flash-decoding

---------

Co-authored-by: cuiqing.li <[email protected]>
  • Loading branch information
2 people authored and flybird11111 committed Nov 9, 2023
1 parent 9fce43b commit 62eb99f
Show file tree
Hide file tree
Showing 13 changed files with 226 additions and 374 deletions.
15 changes: 13 additions & 2 deletions colossalai/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ In this section we discuss how the colossal inference works and integrates with
- [x] policy
- [x] context forward
- [x] token forward
- [x] support flash-decoding
- [ ] Replace the kernels with `faster-transformer` in token-forward stage
- [ ] Support all models
- [x] Llama
- [x] Llama-2
- [x] Bloom
- [ ] Chatglm2
- [x] Chatglm2
- [ ] Benchmarking for all models

## Get started
Expand Down Expand Up @@ -68,6 +70,12 @@ git clone https://github.com/ModelTC/lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
cd lightllm
pip3 install -e .

# also, install xformers from source:
pip install ninja
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers

```

### Docker
Expand All @@ -89,7 +97,10 @@ git checkout 28c1267cfca536b7b4f28e921e03de735b003039
cd lightllm
pip3 install -e .
# install xformers from source
pip install ninja
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
```

### Dive into fast-inference!
Expand Down
1 change: 1 addition & 0 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch

block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to("cuda")
Expand Down
11 changes: 10 additions & 1 deletion colossalai/inference/tensor_parallel/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd

try:
from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_bloom_context_attention_fwd
HAS_LIGHTLLM_KERNEL = True
except:
HAS_LIGHTLLM_KERNEL = False


def generate_alibi(n_head, dtype=torch.float16):
"""
Expand Down Expand Up @@ -460,7 +466,10 @@ def bloom_attention_forward(
# output = self.output[:batch_size*q_length, :, :]
output = torch.empty_like(q)

bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)
if HAS_LIGHTLLM_KERNEL:
lightllm_bloom_context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len)
else:
bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)

context_layer = output.view(batch_size, q_length, H * D_HEAD)
else:
Expand Down
182 changes: 99 additions & 83 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List, Optional, Tuple
import math
import copy

import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
Expand All @@ -10,31 +12,25 @@

from ._utils import copy_kv_to_mem_cache

try:
from vllm import layernorm_ops, pos_encoding_ops

rms_norm = layernorm_ops.rms_norm
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
except:
print("fall back to original rotary_embedding_neox of huggingface")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
print(
"if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
)
HAS_VLLM_KERNERL = False

try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_context_attention_fwd
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd

HAS_LIGHTLLM_KERNEL = True
except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False

try:
from flash_attn import flash_attn_with_kvcache
HAS_FLASH_KERNEL = True
except:
HAS_FLASH_KERNEL = False
print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention")


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
Expand All @@ -54,6 +50,71 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

def llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1):
if num_key_value_groups == 1:
if HAS_LIGHTLLM_KERNEL is False:
llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
lightllm_context_attention_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model"
lightllm_llama2_context_attention_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)

def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
if num_key_value_groups == 1:
token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
Llama2TokenAttentionForwards.token_attn(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
infer_state.other_kv_index,
)


class LlamaInferenceForwards:
"""
Expand Down Expand Up @@ -204,7 +265,8 @@ def llama_model_forward(
hidden_states=all_hidden_states,
attentions=all_self_attns,
)



@staticmethod
def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
Expand Down Expand Up @@ -247,6 +309,7 @@ def llama_decoder_layer_forward(
outputs += (present_key_value,)

return outputs


@staticmethod
def llama_flash_attn_kvcache_forward(
Expand Down Expand Up @@ -295,27 +358,8 @@ def llama_flash_attn_kvcache_forward(
infer_state.cache_manager,
)
attn_output = torch.empty_like(query_states)

if self.num_key_value_groups == 1:
llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
)
else:
lightllm_llama2_context_attention_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
)

llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups)
else:
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
Expand All @@ -337,35 +381,26 @@ def llama_flash_attn_kvcache_forward(
infer_state.decode_mem_index,
infer_state.cache_manager,
)

# second token and follows
# kv = torch.stack((key_states, value_states), dim=2)
# (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states)

if self.num_key_value_groups == 1:
token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
)

HAS_LIGHTLLM_KERNEL = False
if HAS_LIGHTLLM_KERNEL:
attn_output = torch.empty_like(query_states)
llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups)
else:
Llama2TokenAttentionForwards.token_attn(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
infer_state.other_kv_index,
)
heads_per_group = self.num_heads // self.num_key_value_heads
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]

query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)

attn_output = flash_attn_with_kvcache(q = query_states,
k_cache = copy_cache_k,
v_cache = copy_cache_v,
softmax_scale = 1/ math.sqrt(self.head_dim),
causal = True)


attn_output = attn_output.view(bsz, q_len, self.hidden_size)

Expand All @@ -374,22 +409,3 @@ def llama_flash_attn_kvcache_forward(
# return past_key_value as None
return attn_output, None, None


def get_llama_vllm_rmsnorm_forward():
if HAS_VLLM_KERNERL:

def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
x = hidden_states
out = torch.empty_like(x)
rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)

return out

return _vllm_rmsnorm_forward
else:
return None
5 changes: 1 addition & 4 deletions colossalai/inference/tensor_parallel/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy

from ..modeling._utils import init_to_get_rotary
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
from ..modeling.llama import LlamaInferenceForwards

try:
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
Expand Down Expand Up @@ -105,9 +105,6 @@ def module_policy(self):
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()
else:
# NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
infer_forward = get_llama_vllm_rmsnorm_forward()

if infer_forward is not None:
method_replacement = {"forward": partial(infer_forward)}
Expand Down
Loading

0 comments on commit 62eb99f

Please sign in to comment.