diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index ba6c95ce8832..d0c281e057b3 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -4,7 +4,7 @@ ## Introduction -`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users. +`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including LightLLM, TGI, vLLM, FasterTransformer and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users. ## Design @@ -62,6 +62,12 @@ triton==2.0.0.dev20221202 vllm # for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c flash-attention + +# install lightllm since we depend on lightllm triton kernels +git clone https://github.com/ModelTC/lightllm +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 +cd lightllm +pip3 install -e . ``` ### Docker @@ -73,6 +79,17 @@ You can use docker run to use docker container to set-up environment docker pull hpcaitech/colossalai-inference:v2 docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash +# enter into docker container +cd /path/to/CollossalAI +pip install -e . + +# install lightllm +git clone https://github.com/ModelTC/lightllm +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 +cd lightllm +pip3 install -e . + + ``` ### Dive into fast-inference! diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py index ac185f1b6529..de150311cc08 100644 --- a/colossalai/inference/tensor_parallel/batch_infer_state.py +++ b/colossalai/inference/tensor_parallel/batch_infer_state.py @@ -5,7 +5,7 @@ from .kvcache_manager import MemoryManager - +# adapted from: lightllm/server/router/model_infer/infer_batch.py @dataclass class BatchInferState: r""" @@ -41,6 +41,7 @@ def total_token_num(self): def set_cache_manager(self, manager: MemoryManager): self.cache_manager = manager + # adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1 @staticmethod def init_block_loc( b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py index e74a3a491a7b..c9e7aaae0844 100644 --- a/colossalai/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/inference/tensor_parallel/kvcache_manager.py @@ -1,7 +1,9 @@ -# Adapted from lightllm/common/mem_manager.py -# of the ModelTC/lightllm GitHub repository -# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py - +""" +Refered/Modified from lightllm/common/mem_manager.py +of the ModelTC/lightllm GitHub repository +https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py +we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design. +""" import torch from transformers.utils import logging diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index 4b1bc601f436..b8274d3c660f 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -6,8 +6,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd -from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, @@ -20,6 +18,14 @@ from ._utils import copy_kv_to_mem_cache +try: + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd + from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + # This func is same as Llama model init_to_get_rotary, we should move them into _utils.py def _init_to_get_rotary(self, base=10000): @@ -433,17 +439,17 @@ def chatglm_flash_attn_kvcache_forward( cos, sin = infer_state.position_cos, infer_state.position_sin - Llama2Forwards.rotary_emb_fwd( + chatglm2_rotary_emb_fwd( query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin ) if self.multi_query_attention: - Llama2Forwards.rotary_emb_fwd( + chatglm2_rotary_emb_fwd( key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head), cos, sin, ) else: - Llama2Forwards.rotary_emb_fwd( + chatglm2_rotary_emb_fwd( key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin, @@ -474,7 +480,7 @@ def chatglm_flash_attn_kvcache_forward( attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) # NOTE: no bug in context attn fwd (del it ) - llama2_context_attn_fwd( + lightllm_llama2_context_attention_fwd( query_layer, key_layer, value_layer, diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index ac4ae72f3d18..a3937f6f10ba 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -5,12 +5,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton import ( - llama2_context_attn_fwd, - llama_context_attn_fwd, - rotary_embedding_fwd, - token_attention_fwd, -) +from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from ._utils import copy_kv_to_mem_cache @@ -29,6 +24,17 @@ ) HAS_VLLM_KERNERL = False +try: + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_llama2_context_attention_fwd, + ) + from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd + + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -280,8 +286,8 @@ def llama_flash_attn_kvcache_forward( cos, sin = infer_state.position_cos, infer_state.position_sin # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) - rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) - rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) query_states = query_states.reshape(-1, self.num_heads, self.head_dim) key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) @@ -312,7 +318,7 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager.past_key_values_length, ) else: - llama2_context_attn_fwd( + lightllm_llama2_context_attention_fwd( query_states, key_states, value_states, @@ -371,6 +377,7 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager.past_key_values_length, infer_state.other_kv_index, ) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 507c1203dd6b..7e163efe0173 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -12,8 +12,7 @@ from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward try: - from colossalai.kernel.triton import rmsnorm_forward - + from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward HAS_TRITON_RMSNORM = True except: print("you should install triton from https://github.com/openai/triton") @@ -22,9 +21,8 @@ def get_triton_rmsnorm_forward(): if HAS_TRITON_RMSNORM: - def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) + return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) return _triton_rmsnorm_forward else: diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 27351a686d2f..1fe292289f3d 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -9,26 +9,21 @@ # There may exist import error even if we have triton installed. if HAS_TRITON: - from .context_attention import bloom_context_attn_fwd, llama2_context_attn_fwd, llama_context_attn_fwd + from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd from .copy_kv_cache_dest import copy_kv_cache_to_dest from .fused_layernorm import layer_norm from .gptq_triton import gptq_fused_linear_triton from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd - from .rms_norm import rmsnorm_forward - from .rotary_embedding_kernel import rotary_embedding_fwd from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd from .softmax import softmax from .token_attention_kernel import token_attention_fwd __all__ = [ "llama_context_attn_fwd", - "llama2_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", - "rmsnorm_forward", "copy_kv_cache_to_dest", - "rotary_embedding_fwd", "token_attention_fwd", "gptq_fused_linear_triton", "int8_rotary_embedding_fwd", diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 01d54566483a..1b4f6e44b0f2 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -238,329 +238,5 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): num_warps=num_warps, num_stages=1, ) - return - - @triton.jit - def _fwd_kernel_latest( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - @triton.jit - def _fwd_kernel_old( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, - TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_tmp_b, - stride_tmp_h, - stride_tmp_s, - kv_group_num, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s - # t_ptrs = TMP + offs_m - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - - return - - @torch.no_grad() - def llama2_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel_latest[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - kv_group_num=kv_group_num, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - elif triton.__version__ == "2.0.0": - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / (Lq**0.5) - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - num_warps = 4 if Lk <= 64 else 8 - # num_warps = 4 - _fwd_kernel_old[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - kv_group_num=kv_group_num, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return + + return \ No newline at end of file diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py index 02edcc9a903a..0ce6b09e54dc 100644 --- a/colossalai/kernel/triton/copy_kv_cache_dest.py +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -11,6 +11,7 @@ if HAS_TRITON: + # adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py @triton.jit def _fwd_copy_kv_cache_dest( kv_cache_ptr, @@ -42,6 +43,7 @@ def _fwd_copy_kv_cache_dest( tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num) return + # adepted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py @torch.no_grad() def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): seq_len = dest_index_ptr.shape[0] diff --git a/colossalai/kernel/triton/rms_norm.py b/colossalai/kernel/triton/rms_norm.py deleted file mode 100644 index d5d6f9d85df1..000000000000 --- a/colossalai/kernel/triton/rms_norm.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch - -try: - import triton - import triton.language as tl - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - - -if HAS_TRITON: - """ - this kernel function is modified from - https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py - """ - - @triton.jit - def _rms_norm_fwd_fused( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK_SIZE: tl.constexpr, - ): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - Y += row * stride - X += row * stride - # Compute variance - _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - _var += x * x - var = tl.sum(_var, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - # Normalize and apply linear transformation - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) - x_hat = x * rstd - y = x_hat * w - # Write output - tl.store(Y + cols, y.to(tl.float16), mask=mask) - - def rmsnorm_forward(x, weight, eps): - # allocate output - y = torch.empty_like(x) - # reshape input data into 2D tensor - x_arg = x.view(-1, x.shape[-1]) - M, N = x_arg.shape - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - # print("BLOCK_SIZE:", BLOCK_SIZE) - if N > BLOCK_SIZE: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - # heuristics for number of warps - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - # print(BLOCK_SIZE, num_warps, "block_size, numwarps") - BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2 - num_warps = 8 - # enqueue kernel - _rms_norm_fwd_fused[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) - return y diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py deleted file mode 100644 index fd74ba817551..000000000000 --- a/colossalai/kernel/triton/rotary_embedding_kernel.py +++ /dev/null @@ -1,212 +0,0 @@ -# Adapted from ModelTC https://github.com/ModelTC/lightllm -import torch -import triton -import triton.language as tl - - -@triton.jit -def _rotary_kernel( - q, - Cos, - Sin, - q_bs_stride, - q_h_stride, - q_d_stride, - cos_bs_stride, - cos_d_stride, - total_len, - HEAD_NUM: tl.constexpr, - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - HEAD_DIM: tl.constexpr, -): - current_head_index = tl.program_id(0) - current_seq_index = tl.program_id(1) - - current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - dim_range0 = tl.arange(0, HEAD_DIM // 2) - dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) - - off_q0 = ( - current_seq_range[:, None, None] * q_bs_stride - + current_head_range[None, :, None] * q_h_stride - + dim_range0[None, None, :] * q_d_stride - ) - off_q1 = ( - current_seq_range[:, None, None] * q_bs_stride - + current_head_range[None, :, None] * q_h_stride - + dim_range1[None, None, :] * q_d_stride - ) - - off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride - - q0 = tl.load( - q + off_q0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0, - ) - q1 = tl.load( - q + off_q1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) - - out0 = q0 * cos - q1 * sin - out1 = q0 * sin + q1 * cos - - tl.store( - q + off_q0, - out0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - ) - tl.store( - q + off_q1, - out1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - ) - - return - - -@torch.no_grad() -def rotary_embedding_fwd(q, cos, sin): - total_len = q.shape[0] - head_num = q.shape[1] - head_dim = q.shape[2] - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - BLOCK_HEAD = 4 - BLOCK_SEQ = 32 - grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - _rotary_kernel[grid]( - q, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - cos.stride(0), - cos.stride(1), - total_len, - HEAD_NUM=head_num, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - HEAD_DIM=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return - - -class Llama2Forwards: - @staticmethod - @triton.jit - def _rotary_kernel( - Q, - Cos, - Sin, - stride_qbs, - stride_qh, - stride_qd, - stride_cosbs, - stride_cosd, - stride_sinbs, - stride_sind, - max_total_len, - H, # N_CTX - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - ): - cur_head_index = tl.program_id(0) - cur_seq_index = tl.program_id(1) - - cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 - dim_range1 = dim_range0 + 1 - off_q0 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range0[None, None, :] * stride_qd - ) - off_q1 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range1[None, None, :] * stride_qd - ) - - cos_range = tl.arange(0, BLOCK_DMODEL // 2) - off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd - - q0 = tl.load( - Q + off_q0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), - other=0.0, - ) - q1 = tl.load( - Q + off_q1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out0 = q0 * cos - q1 * sin - out1 = q0 * sin + q1 * cos - - tl.store( - Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H) - ) - tl.store( - Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H) - ) - - return - - @staticmethod - @torch.no_grad() - def rotary_emb_fwd(q, cos, sin): - total_len = q.shape[0] - head_num = q.shape[1] - head_dim = q.shape[2] // 2 - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - BLOCK_HEAD = 4 - BLOCK_SEQ = 32 - grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - Llama2Forwards._rotary_kernel[grid]( - q, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - cos.stride(0), - cos.stride(1), - sin.stride(0), - sin.stride(1), - total_len, - head_num, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/colossalai/kernel/triton/self_attention_nofusion.py b/colossalai/kernel/triton/self_attention_nofusion.py index 4b56c8afd67f..50d6786bd940 100644 --- a/colossalai/kernel/triton/self_attention_nofusion.py +++ b/colossalai/kernel/triton/self_attention_nofusion.py @@ -12,6 +12,7 @@ from .qkv_matmul_kernel import qkv_gemm_4d_kernel from .softmax import softmax_kernel + # adpeted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py#L312 def self_attention_forward_without_fusion( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float ): @@ -141,6 +142,7 @@ def self_attention_forward_without_fusion( ) return output.view(batches, -1, d_model) + # modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/attention.py#L212 def self_attention_compute_using_triton( qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False ): diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index c27394f0f9cf..8dc919bad125 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -12,401 +12,78 @@ HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -if HAS_TRITON: - - @triton.jit - def _token_attn_1_kernel( - Q, - K, - sm_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - q_batch_stride, - q_head_stride, - q_head_dim_stride, - k_batch_stride, - k_head_stride, - k_head_dim_stride, - attn_head_stride, - attn_batch_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - start_n = tl.program_id(2) - - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = max_kv_cache_len - - off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) - offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load( - kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride - k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride - tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) - return - - @triton.jit - def _token_attn_1_alibi_kernel( - Q, - K, - sm_scale, - alibi, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - q_batch_stride, - q_head_stride, - q_head_dim_stride, - k_batch_stride, - k_head_stride, - k_head_dim_stride, - attn_head_stride, - attn_batch_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - start_n = tl.program_id(2) - - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = max_kv_cache_len - - off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) +try: + from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import ( + token_att_fwd as lightllm_llama2_token_att_fwd, + ) + from lightllm.models.llama2.triton_kernel.token_attention_nopad_reduceV import ( + token_att_fwd2 as lightllm_llama2_token_att_fwd2, + ) + from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import ( + token_softmax_fwd as lightllm_llama2_token_softmax_fwd, + ) + + from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2 + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd + from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd + from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd + + HAS_TRITON_TOKEN_ATTENTION = True +except ImportError: + print("unable to import lightllm kernels") + HAS_TRITON_TOKEN_ATTENTION = False - for start_mark in range(0, block_mask, 1): - alibi_m = tl.load(alibi + current_head) - q = tl.load(Q + off_q + start_mark) - offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load( - kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride - k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) - off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride - tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) - return +if HAS_TRITON: @torch.no_grad() - def token_attn_fwd_1( - q, k, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, alibi=None + def token_attention_fwd( + q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None ): - BLOCK = 32 - # shape constraints - q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] - assert q_head_dim == k_head_dim - assert k_head_dim in {16, 32, 64, 128} - sm_scale = 1.0 / (k_head_dim**0.5) - - batch, head_num = kv_cache_loc.shape[0], q.shape[1] - - grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) + head_num = k.shape[1] + batch_size = kv_cache_seq_len.shape[0] + calcu_shape1 = (batch_size, head_num, k.shape[2]) + total_token_num = k.shape[0] - num_warps = 4 if k_head_dim <= 64 else 8 - num_warps = 2 + att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - if alibi is not None: - _token_attn_1_alibi_kernel[grid]( - q, + if alibi is None: + lightllm_llama_token_att_fwd( + q.view(calcu_shape1), k, - sm_scale, - alibi, + att_m_tensor, kv_cache_loc, kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - attn_out.stride(0), - attn_out.stride(1), - HEAD_DIM=k_head_dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, + kv_cache_seq_len, + max_len_in_batch, ) else: - _token_attn_1_kernel[grid]( - q, + lightllm_bloom_token_att_fwd( + q.view(calcu_shape1), k, - sm_scale, + att_m_tensor, + alibi, kv_cache_loc, kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - attn_out.stride(0), - attn_out.stride(1), - HEAD_DIM=k_head_dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - @triton.jit - def _token_attn_softmax_fwd( - softmax_logics, - kv_cache_start_loc, - kv_cache_seqlen, - softmax_prob_out, - logics_head_dim_stride, - logics_batch_stride, - prob_head_dim_stride, - prob_batch_stride, - BLOCK_SIZE: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - - col_offsets = tl.arange(0, BLOCK_SIZE) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - row = tl.load( - softmax_logics - + current_head * logics_head_dim_stride - + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, - mask=col_offsets < current_batch_seq_len, - other=-float("inf"), - ).to(tl.float32) - - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - tl.store( - softmax_prob_out - + current_head * prob_head_dim_stride - + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, - softmax_output, - mask=col_offsets < current_batch_seq_len, - ) - return - - @torch.no_grad() - def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): - BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) - batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] - - num_warps = 4 - if BLOCK_SIZE >= 2048: - num_warps = 8 - if BLOCK_SIZE >= 4096: - num_warps = 16 - - _token_attn_softmax_fwd[(batch, head_num)]( - softmax_logics, - kv_cache_start_loc, - kv_cache_seqlen, - softmax_prob_out, - softmax_logics.stride(0), - softmax_logics.stride(1), - softmax_prob_out.stride(0), - softmax_prob_out.stride(1), - num_warps=num_warps, - BLOCK_SIZE=BLOCK_SIZE, - ) - return - - @triton.jit - def _token_attn_2_kernel( - Prob, - V, - attn_out, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - prob_head_dim_stride, - prob_batch_stride, - v_batch_stride, - v_head_stride, - v_head_dim_stride, - attn_out_batch_stride, - attn_out_head_stride, - attn_out_head_dim_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride - p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride - v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride - - acc = tl.zeros([HEAD_DIM], dtype=tl.float32) - for start_n in range(0, current_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load( - Prob + p_offs + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0, - ) - v_loc = tl.load( - kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0, - ) - v_value = tl.load( - V + v_offs + v_loc[:, None] * v_batch_stride, - mask=(start_n + offs_n[:, None]) < current_batch_seq_len, - other=0.0, + kv_cache_seq_len, + max_len_in_batch, ) - acc += tl.sum(p_value[:, None] * v_value, 0) - - acc = acc.to(tl.float16) - off_o = ( - current_batch * attn_out_batch_stride - + current_head * attn_out_head_stride - + offs_d * attn_out_head_dim_stride - ) - out_ptrs = attn_out + off_o - tl.store(out_ptrs, acc) - return - - @torch.no_grad() - def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - else: - BLOCK = 64 - batch, head = kv_cache_loc.shape[0], v.shape[1] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - _token_attn_2_kernel[grid]( - prob, - v, - attn_out, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - attn_out.stride(0), - attn_out.stride(1), - attn_out.stride(2), - HEAD_DIM=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - @torch.no_grad() - def token_attention_fwd( - q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None - ): - head_num = k.shape[1] - batch_size = kv_cache_seq_len.shape[0] - calcu_shape1 = (batch_size, head_num, k.shape[2]) - total_token_num = k.shape[0] - - att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - - token_attn_fwd_1( - q.view(calcu_shape1), - k, - att_m_tensor, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - alibi=alibi, - ) prob = torch.empty_like(att_m_tensor) - token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) + lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) att_m_tensor = None - token_attn_fwd_2( + lightllm_llama_token_att_fw2( prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch ) - prob = None - return class Llama2TokenAttentionForwards: @staticmethod @triton.jit + + # this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L8 def _fwd_kernel( Logics, V, @@ -478,6 +155,7 @@ def _fwd_kernel( tl.store(out_ptrs, acc) return + # this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L36 @staticmethod @torch.no_grad() def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index): @@ -514,277 +192,6 @@ def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_i ) return - @staticmethod - @triton.jit - def _fwd_kernel_token_softmax( - Logics, - B_Start_Loc, - B_Seqlen, - Prob_Out, - stride_logic_h, - stride_logic_bs, - stride_prob_h, - stride_prob_bs, - BLOCK_SIZE: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - col_offsets = tl.arange(0, BLOCK_SIZE) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - row = tl.load( - Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, - mask=col_offsets < cur_batch_seq_len, - other=-float("inf"), - ).to(tl.float32) - - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - tl.store( - Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, - softmax_output, - mask=col_offsets < cur_batch_seq_len, - ) - return - - @staticmethod - @torch.no_grad() - def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): - BLOCK_SIZE = triton.next_power_of_2(max_input_len) - batch, head_num = B_Start_Loc.shape[0], Logics.shape[0] - - num_warps = 4 - if BLOCK_SIZE >= 2048: - num_warps = 8 - if BLOCK_SIZE >= 4096: - num_warps = 16 - - Llama2TokenAttentionForwards._fwd_kernel_token_softmax[(batch, head_num)]( - Logics, - B_Start_Loc, - B_Seqlen, - Prob_Out, - Logics.stride(0), - Logics.stride(1), - Prob_Out.stride(0), - Prob_Out.stride(1), - num_warps=num_warps, - BLOCK_SIZE=BLOCK_SIZE, - ) - return - - @staticmethod - @triton.jit - def _fwd_kernel_token_att1( - Q, - K, - sm_scale, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, - Att_Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - att_stride_h, - att_stride_bs, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_n = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - cur_batch_start_index = max_input_len - cur_batch_seq_len - cur_batch_end_index = max_input_len - - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) - offs_n_new = cur_batch_start_index + offs_n - k_loc = tl.load( - B_Loc + stride_b_loc_b * cur_batch + stride_b_loc_s * offs_n_new, - mask=offs_n_new < cur_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd - k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs - tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) - return - - @staticmethod - @torch.no_grad() - def token_att_fwd(q, k, att_out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): - BLOCK = 32 - # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lk**0.5) - - batch, head_num = B_Loc.shape[0], q.shape[1] - - grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK)) - kv_group_num = q.shape[1] // k.shape[1] - - num_warps = 4 if Lk <= 64 else 8 - num_warps = 2 - - Llama2TokenAttentionForwards._fwd_kernel_token_att1[grid]( - q, - k, - sm_scale, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, - att_out, - B_Loc.stride(0), - B_Loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - att_out.stride(0), - att_out.stride(1), - kv_group_num=kv_group_num, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - @staticmethod - @triton.jit - def _fwd_kernel_token_att2( - Prob, - V, - Out, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, # B_Start_Loc cumsum of input lens if continuous - stride_b_loc_b, - stride_b_loc_s, - stride_ph, - stride_pbs, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = max_input_len - cur_batch_seq_len - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - v_loc_off = cur_batch * stride_b_loc_b + (cur_batch_start_index + offs_n) * stride_b_loc_s - p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs - v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load( - Prob + p_offs + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0 - ) - v_loc = tl.load( - B_Loc + v_loc_off + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0 - ) - v_value = tl.load( - V + v_offs + v_loc[:, None] * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - acc += tl.sum(p_value[:, None] * v_value, 0) - - acc = acc.to(tl.float16) - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - @staticmethod - @torch.no_grad() - def token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - else: - BLOCK = 64 - batch, head = B_Loc.shape[0], prob.shape[0] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - kv_group_num = prob.shape[0] // v.shape[1] - - Llama2TokenAttentionForwards._fwd_kernel_token_att2[grid]( - prob, - v, - out, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, - B_Loc.stride(0), - B_Loc.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - kv_group_num=kv_group_num, - BLOCK_DMODEL=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - # this is the interface of llama2 attn forward @staticmethod @torch.no_grad() @@ -796,7 +203,7 @@ def token_attn( calcu_shape1 = (batch_size, head_num, head_dim) att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - Llama2TokenAttentionForwards.token_att_fwd( + lightllm_llama2_token_att_fwd( q, k, att_m_tensor, @@ -808,12 +215,12 @@ def token_attn( if triton.__version__ == "2.0.0": prob = torch.empty_like(att_m_tensor) - Llama2TokenAttentionForwards.token_softmax_fwd( + lightllm_llama2_token_softmax_fwd( att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch ) att_m_tensor = None - Llama2TokenAttentionForwards.token_att_fwd2( + lightllm_llama2_token_att_fwd2( prob, v, attn_out.view(calcu_shape1), diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 90d49f6a264a..0ca1953c6a41 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -3,7 +3,6 @@ import time import torch -from torch.profiler import ProfilerActivity, profile, record_function from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai @@ -16,6 +15,7 @@ def print_perf_stats(latency_set, config, bs, warmup=3): + torch.cuda.empty_cache() # trim warmup queries latency_set = list(latency_set) latency_set = latency_set[warmup:] @@ -38,24 +38,29 @@ def run_llama_test(args): max_batch_size = args.batch_size max_input_len = args.input_len max_output_len = args.output_len + args.test_mode + + print("max_batch_size : " + str(max_batch_size)) tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) tokenizer.pad_token_id = tokenizer.unk_token_id model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) model = model.half() - model_config = model.config + model.config shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) - generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + generate_kwargs = dict(max_new_tokens=1, do_sample=False) input_tokens = { "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), } iters = 10 - times = [] + prefill_times = [] + + warmup = 3 for i in range(iters): torch.cuda.synchronize() @@ -65,17 +70,39 @@ def run_llama_test(args): end = time.time() out_len = outputs.shape[1] print("generation time {} s".format(str(end - start))) + print(out_len - max_input_len) + prefill_times.append((end - start) / (out_len - max_input_len)) + + prefill_times = prefill_times[warmup:] + prefill_time_avg = sum(prefill_times) / len(prefill_times) + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + + times = [] + decoder_times = [] + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print("generation time {} s".format(str(end - start))) + print(out_len - max_input_len) times.append((end - start) / (out_len - max_input_len)) + if args.test_mode == "decoder_test": + decoder_times.append((end - start - prefill_time_avg) / (out_len - max_input_len - 1)) + + times = times[warmup:] + latency = sum(times) / len(times) + print("total process latency is : " + str(latency) + " s") + print("total throughput is : " + str(1 / latency * max_batch_size)) - print("outputs, ", len(outputs)) - print_perf_stats(times, model_config, max_batch_size) + if args.test_mode == "decoder_test": + decoder_times = decoder_times[warmup:] + latency = sum(decoder_times) / len(decoder_times) - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - with record_function("model_inference"): - torch.cuda.synchronize() - outputs = infer_engine.generate(input_tokens, **generate_kwargs) - torch.cuda.synchronize() - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + print("decoder process latency is : " + str(latency) + " s") + print("decoder throughput is : " + str(1 / latency * max_batch_size)) def check_llama(rank, world_size, port, args): @@ -95,8 +122,11 @@ def test_llama(args): parser.add_argument("-p", "--path", type=str, help="Model path", required=True) parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") - parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--input_len", type=int, default=256, help="Maximum input length") parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") + parser.add_argument( + "--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"] + ) args = parser.parse_args() diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 9aa5f2822e40..19cb7a154a01 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,3 +11,6 @@ ninja torch>=1.12 safetensors einops +sentencepiece +google +protobuf diff --git a/tests/test_infer_ops/triton/test_llama2_token_attn.py b/tests/test_infer_ops/triton/test_llama2_token_attn.py deleted file mode 100644 index 0537a3d76129..000000000000 --- a/tests/test_infer_ops/triton/test_llama2_token_attn.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): - xq = xq.view(bs, 1, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - xv = xv.view(bs, seqlen, num_head, head_dim) - - logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5) - prob = torch.softmax(logics, dim=1) - prob = prob.view(bs, seqlen, num_head, 1) - - return torch.sum(prob * xv, dim=1, keepdim=False) - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test(): - Z, head_num, seq_len, head_dim = 2, 32, 2048, 128 - dtype = torch.float16 - - # attn out: 2,4096 - q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda") - max_kv_cache_len = seq_len - kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - other_kv_index = 2048 - - kv_cache_seq_len[:] = seq_len - kv_cache_start_loc[0] = 0 - kv_cache_start_loc[1] = seq_len - - for i in range(Z): - kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") - - Llama2TokenAttentionForwards.token_attn( - q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, other_kv_index - ) - torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) - assert torch.allclose(torch_out, o, atol=1e-3, rtol=0) - - -if __name__ == "__main__": - test() diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py deleted file mode 100644 index 7e05ccafbfc4..000000000000 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ /dev/null @@ -1,55 +0,0 @@ -# Adapted from ModelTC https://github.com/ModelTC/lightllm - - -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -def torch_rotary_emb(x, cos, sin): - seq_len, h, dim = x.shape - x0 = x[:, :, 0 : dim // 2] - x1 = x[:, :, dim // 2 : dim] - cos = cos.view((seq_len, 1, dim // 2)) - sin = sin.view((seq_len, 1, dim // 2)) - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos - return torch.cat((o0, o1), dim=-1) - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_rotary_emb(): - SEQ_LEN = 1 - HEAD_NUM = 32 - HEAD_DIM = 128 - dtype = torch.half - # create data - x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - cos_shape = (SEQ_LEN, HEAD_DIM // 2) - cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - # forward pass - y_torch = torch_rotary_emb(x, cos, sin) - rotary_embedding_fwd(x, cos, sin) - y_triton = x - # compare - assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test_rotary_emb() diff --git a/tests/test_infer_ops/triton/test_token_attn_1.py b/tests/test_infer_ops/triton/test_token_attn_1.py deleted file mode 100644 index fc5f8cd6c9dc..000000000000 --- a/tests/test_infer_ops/triton/test_token_attn_1.py +++ /dev/null @@ -1,74 +0,0 @@ -import math - -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1 - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -def torch_attn(xq, xk, bs, seqlen, num_head, head_dim): - xq = xq.view(bs, 1, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - keys = xk - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - scores = ( - (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1) - ) - return scores - - -def torch_attn_1(xq, xk, seqlen, num_head, head_dim): - xq = xq.view(1, num_head, head_dim) - xk = xk.view(seqlen, num_head, head_dim) - logics = torch.sum(xq * xk, dim=-1, keepdim=False) - - logics = logics.transpose(0, 1) / math.sqrt(head_dim) - return logics - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_attn_1(): - pass - - batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 - - dtype = torch.float16 - - q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") - - b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") - kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - - for i in range(batch_size): - kv_cache_start_loc[i] = i * seq_len - kv_cache_seq_len[i] = seq_len - b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") - - token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - - torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze() - o = attn_out.squeeze() - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test_attn_1() diff --git a/tests/test_infer_ops/triton/test_token_attn_2.py b/tests/test_infer_ops/triton/test_token_attn_2.py deleted file mode 100644 index 2dd756f2ba91..000000000000 --- a/tests/test_infer_ops/triton/test_token_attn_2.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2 - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -def torch_attn(V, P, bs, seqlen, num_head, head_dim): - V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2) - P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1) - attn_out = torch.matmul(P, V) - - return attn_out - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_token_attn_2(): - pass - - batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 - dtype = torch.float16 - - V = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) - Prob = ( - torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") - .normal_(mean=0.4, std=0.2) - .reshape(head_num, batch_size, seq_len) - .softmax(-1) - .reshape(head_num, batch_size * seq_len) - ) - attn_out = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda") - - kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - kv_cache_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") - for i in range(batch_size): - kv_cache_start_loc[i] = i * seq_len - kv_cache_seq_len[i] = seq_len - kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") - - token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - - torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze() - o = attn_out - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test_token_attn_2() diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py index 9c7a53798317..a7fc3d29b77a 100644 --- a/tests/test_infer_ops/triton/test_token_attn_fwd.py +++ b/tests/test_infer_ops/triton/test_token_attn_fwd.py @@ -3,16 +3,13 @@ from packaging import version try: - pass - from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd - HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) >= version.parse("11.6") def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):