From be6757c4c8112575eb01bf4fc6dfe4b57be2048b Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Fri, 29 Sep 2023 17:23:55 +0800 Subject: [PATCH] [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [shardformer] fix GPT2DoubleHeadsModel (#4703) * [hotfix] Fix import error: colossal.kernel without triton installed (#4722) * [hotfix] remove triton kernels from kernel init * revise bloom/llama kernel imports for infer * [shardformer] to fix whisper test failed due to significant accuracy differences. (#4710) * [shardformer] fix whisper test failed * [shardformer] fix whisper test failed * [shardformer] fix whisper test failed * [shardformer] fix whisper test failed * [doc] fix llama2 code link (#4726) * [doc] fix llama2 code link * [doc] fix llama2 code link * [doc] fix llama2 code link * [doc] Add user document for Shardformer (#4702) * create shardformer doc files * add docstring for seq-parallel * update ShardConfig docstring * add links to llama example * add outdated massage * finish introduction & supporting information * finish 'how shardformer works' * finish shardformer.md English doc * fix doctest fail * add Chinese document * [format] applied code formatting on changed files in pull request 4726 (#4727) Co-authored-by: github-actions * [doc] add shardformer support matrix/update tensor parallel documents (#4728) * add compatibility matrix for shardformer doc * update tp doc * Optimized some syntax errors in the documentation and code under applications/ (#4127) Co-authored-by: flybird11111 <1829166702@qq.com> * [shardformer] update pipeline parallel document (#4725) * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [legacy] remove deterministic data loader test * [shardformer] update seq parallel document (#4730) * update doc of seq parallel * fix typo * [example] add gpt2 HybridParallelPlugin example (#4653) * add gpt2 HybridParallelPlugin example * update readme and testci * update test ci * fix test_ci bug * update requirements * add requirements * update requirements * add requirement * rename file * [doc] polish shardformer doc (#4735) * arrange position of chapters * fix typos in seq parallel doc * [shardformer] add custom policy in hybrid parallel plugin (#4718) * add custom policy * update assert * [example] llama2 add fine-tune example (#4673) * [shardformer] update shardformer readme [shardformer] update shardformer readme [shardformer] update shardformer readme * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] change dataset * [shardformer] change dataset * [shardformer] fix CI * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] fix [example] update opt example [example] resolve comments fix fix * [example] llama2 add finetune example * [example] llama2 add finetune example * [example] llama2 add finetune example * [example] llama2 add finetune example * fix * update llama2 example * update llama2 example * fix * update llama2 example * update llama2 example * update llama2 example * update llama2 example * update llama2 example * update llama2 example * Update requirements.txt * update llama2 example * update llama2 example * update llama2 example * [doc] explaination of loading large pretrained models (#4741) * [kernel] update triton init #4740 (#4740) * [legacy] clean up legacy code (#4743) * [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci * [format] applied code formatting on changed files in pull request 4743 (#4750) Co-authored-by: github-actions * [misc] update pre-commit and run all files (#4752) * [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format * [doc] explain suitable use case for each plugin * [doc] put individual plugin explanation in front * [doc] add model examples for each plugin * [doc] put native colossalai plugins first in description section * [chat]: update rm, add wandb and fix bugs (#4471) * feat: modify forward fn of critic and reward model * feat: modify calc_action_log_probs * to: add wandb in sft and rm trainer * feat: update train_sft * feat: update train_rm * style: modify type annotation and add warning * feat: pass tokenizer to ppo trainer * to: modify trainer base and maker base * feat: add wandb in ppo trainer * feat: pass tokenizer to generate * test: update generate fn tests * test: update train tests * fix: remove action_mask * feat: remove unused code * fix: fix wrong ignore_index * fix: fix mock tokenizer * chore: update requirements * revert: modify make_experience * fix: fix inference * fix: add padding side * style: modify _on_learn_batch_end * test: use mock tokenizer * fix: use bf16 to avoid overflow * fix: fix workflow * [chat] fix gemini strategy * [chat] fix * sync: update colossalai strategy * fix: fix args and model dtype * fix: fix checkpoint test * fix: fix requirements * fix: fix missing import and wrong arg * fix: temporarily skip gemini test in stage 3 * style: apply pre-commit * fix: temporarily skip gemini test in stage 1&2 --------- Co-authored-by: Mingyan Jiang <1829166702@qq.com> * [shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758) * fix master param sync for hybrid plugin * rewrite unwrap for ddp/fsdp * rewrite unwrap for zero/gemini * rewrite unwrap for hybrid plugin * fix geemini unwrap * fix bugs * [bug] fix get_default_parser in examples (#4764) * [doc] clean up outdated docs (#4765) * [doc] clean up outdated docs * [doc] fix linking * [doc] fix linking * [doc] add shardformer doc to sidebar (#4768) * [chat]: add lora merge weights config (#4766) * feat: modify lora merge weights fn * feat: add lora merge weights config * [lazy] support torch 2.0 (#4763) * [lazy] support _like methods and clamp * [lazy] pass transformers models * [lazy] fix device move and requires grad * [lazy] fix requires grad and refactor api * [lazy] fix requires grad * [bug] Fix the version check bug in colossalai run when generating the cmd. (#4713) * Fix the version check bug in colossalai run when generating the cmd. * polish code * [feature] add gptq for inference (#4754) * [gptq] add gptq kernel (#4416) * add gptq * refactor code * fix tests * replace auto-gptq * rname inferance/quant * refactor test * add auto-gptq as an option * reset requirements * change assert and check auto-gptq * add import warnings * change test flash attn version * remove example * change requirements of flash_attn * modify tests * [skip ci] change requirements-test * [gptq] faster gptq cuda kernel (#4494) * [skip ci] add cuda kernels * add license * [skip ci] fix max_input_len * format files & change test size * [skip ci] * [gptq] add gptq tensor parallel (#4538) * add gptq tensor parallel * add gptq tp * delete print * add test gptq check * add test auto gptq check * [gptq] combine gptq and kv cache manager (#4706) * combine gptq and kv cache manager * add init bits * delete useless code * add model path * delete usless print and update test * delete usless import * move option gptq to shard config * change replace linear to shardformer * update bloom policy * delete useless code * fix import bug and delete uselss code * change colossalai/gptq to colossalai/quant/gptq * update import linear for tests * delete useless code and mv gptq_kernel to kernel directory * fix triton kernel * add triton import * [inference] chatglm2 infer demo (#4724) * add chatglm2 * add * gather needed kernels * fix some bugs * finish context forward * finish context stage * fix * add * pause * add * fix bugs * finish chatglm * fix bug * change some logic * fix bugs * change some logics * add * add * add * fix * fix tests * fix * [release] update version (#4775) * [release] update version * [doc] revert versions * initial commit: add colossal llama 2 (#4784) * [feature] ColossalEval: Evaluation Pipeline for LLMs (#4786) * Add ColossalEval * Delete evaluate in Chat --------- Co-authored-by: Xu Yuanchen Co-authored-by: Tong Li * [doc] add llama2 domain-specific solution news (#4789) * [doc] add llama2 domain-specific solution news * [fix] fix weekly runing example (#4787) * [fix] fix weekly runing example * [fix] fix weekly runing example * [doc] polish shardformer doc (#4779) * fix example format in docstring * polish shardformer doc * [checkpointio] support unsharded checkpointIO for hybrid parallel (#4774) * support unsharded saving/loading for model * support optimizer unsharded saving * update doc * support unsharded loading for optimizer * small fix * update readme * [lazy] support from_pretrained (#4801) * [lazy] patch from pretrained * [lazy] fix from pretrained and add tests * [devops] update ci * update * [hotfix] change llama2 Colossal-LLaMA-2 script filename (#4800) change filename: pretraining.py -> trainin.py there is no file named pretraing.py. wrong writing * [misc] add last_epoch in CosineAnnealingWarmupLR (#4778) * [doc] add lazy init docs (#4808) * [hotfix] fix norm type error in zero optimizer (#4795) * [hotfix] Correct several erroneous code comments (#4794) * [format] applied code formatting on changed files in pull request 4595 (#4602) Co-authored-by: github-actions * fix format (#4815) * [chat] fix gemini strategy (#4698) * [chat] fix gemini strategy * [chat] fix gemini strategy * [chat] fix gemini strategy * [chat] fix gemini strategy * g# This is a combination of 2 commits. [chat] fix gemini strategy fox * [chat] fix gemini strategy update llama2 example [chat] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * fix * fix * fix * fix * fix * Update train_prompts.py * Update Qwen-7B results (#4821) Co-authored-by: Xu Yuanchen * [doc] update slack link (#4823) * add autotune (#4822) * update Colossal (#4832) * add int8 rotary embedding kernel * remove useless code --------- Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: binmakeswell Co-authored-by: Baizhou Zhang Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions Co-authored-by: digger yu Co-authored-by: Pengtai Xu Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com> Co-authored-by: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Co-authored-by: Hongxin Liu Co-authored-by: Wenhao Chen Co-authored-by: littsk <1214689160@qq.com> Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Tong Li Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen Co-authored-by: Desperado-Jia <502205863@qq.com> Co-authored-by: Chandler-Bing Co-authored-by: Yan haixu <40758050+hova88@users.noreply.github.com> --- colossalai/kernel/triton/__init__.py | 2 + .../triton/int8_rotary_embedding_kernel.py | 119 ++++++++++++++++++ .../test_smoothquant/test_rotary_embedding.py | 59 +++++++++ 3 files changed, 180 insertions(+) create mode 100644 colossalai/kernel/triton/int8_rotary_embedding_kernel.py create mode 100644 tests/test_smoothquant/test_rotary_embedding.py diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 87ea9cf6536e..09f7a5592253 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -7,6 +7,7 @@ from .copy_kv_cache_dest import copy_kv_cache_to_dest from .fused_layernorm import layer_norm from .gptq_triton import gptq_fused_linear_triton + from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd from .rms_norm import rmsnorm_forward from .rotary_embedding_kernel import rotary_embedding_fwd from .softmax import softmax @@ -22,6 +23,7 @@ "rotary_embedding_fwd", "token_attention_fwd", "gptq_fused_linear_triton", + "int8_rotary_embedding_fwd", ] except ImportError: diff --git a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py new file mode 100644 index 000000000000..1e2c5c427954 --- /dev/null +++ b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py @@ -0,0 +1,119 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm +import torch +import triton +import triton.language as tl + + +@triton.jit +def _rotary_kernel( + q, + input_scale, + output_scale, + Cos, + Sin, + q_bs_stride, + q_h_stride, + q_d_stride, + cos_bs_stride, + cos_d_stride, + total_len, + HEAD_NUM: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + current_head_index = tl.program_id(0) + current_seq_index = tl.program_id(1) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + off_q0 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range0[None, None, :] * q_d_stride + ) + off_q1 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range1[None, None, :] * q_d_stride + ) + + off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride + + q0 = tl.load( + q + off_q0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) + q1 = tl.load( + q + off_q1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) + + cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + in_scale = tl.load(input_scale) + o_scale = tl.load(output_scale) + + q0 = q0.to(tl.float32) * in_scale + q1 = q1.to(tl.float32) * in_scale + + out0 = (q0 * cos - q1 * sin) / o_scale + out1 = (q0 * sin + q1 * cos) / o_scale + + # out0 = out0.to(tl.int8) + # out1 = out1.to(tl.int8) + + tl.store( + q + off_q0, + out0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) + tl.store( + q + off_q1, + out1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) + + return + + +@torch.no_grad() +def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale): + total_len = q.shape[0] + head_num = q.shape[1] + head_dim = q.shape[2] + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + BLOCK_HEAD = 4 + BLOCK_SEQ = 32 + grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + _rotary_kernel[grid]( + q, + input_scale, + output_scale, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), + total_len, + HEAD_NUM=head_num, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + HEAD_DIM=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/tests/test_smoothquant/test_rotary_embedding.py b/tests/test_smoothquant/test_rotary_embedding.py new file mode 100644 index 000000000000..ee030065d66e --- /dev/null +++ b/tests/test_smoothquant/test_rotary_embedding.py @@ -0,0 +1,59 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm + + +import pytest +import torch +from packaging import version + +try: + from colossalai.kernel.triton import int8_rotary_embedding_fwd + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_rotary_emb(x, cos, sin): + seq_len, h, dim = x.shape + x0 = x[:, :, 0 : dim // 2] + x1 = x[:, :, dim // 2 : dim] + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + return torch.cat((o0, o1), dim=-1) + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_rotary_emb(): + SEQ_LEN = 1 + HEAD_NUM = 32 + HEAD_DIM = 128 + dtype = torch.float + # create data + x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + cos_shape = (SEQ_LEN, HEAD_DIM // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + # forward pass + y_torch = torch_rotary_emb(x, cos, sin) + + input_scale = torch.max(torch.abs(x)) / 127 + output_scale = torch.max(torch.abs(y_torch)) / 127 + + x = x / input_scale + x = x.to(torch.int8) + + int8_rotary_embedding_fwd(x, cos, sin, input_scale, output_scale) + y_triton = x.to(torch.float) * output_scale + assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True) + + +if __name__ == "__main__": + test_rotary_emb()