Skip to content

Commit

Permalink
[inference] add reference and fix some bugs (#4937)
Browse files Browse the repository at this point in the history
* add reference and fix some bugs

* update gptq init

---------

Co-authored-by: Xu Kai <[email protected]>
  • Loading branch information
Xu-Kai and Xu Kai authored Oct 20, 2023
1 parent b8e770c commit 785802e
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 10 deletions.
6 changes: 6 additions & 0 deletions colossalai/inference/quant/smoothquant/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samp
mean_scale = np.mean([v["input"] for v in act_dict.values()])
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")

# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512):
model.eval()
device = next(model.parameters()).device
Expand Down Expand Up @@ -163,6 +164,7 @@ def stat_input_hook(m, x, y, name):

return act_scales

# Adapted from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
@torch.no_grad()
def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
if not isinstance(fcs, list):
Expand All @@ -189,6 +191,7 @@ def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
def create_quantized_model(model):
raise NotImplementedError("Not implement create_quantized_model method")

# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
def save_quantized(
self,
save_dir: str,
Expand Down Expand Up @@ -249,6 +252,7 @@ def save_quantized(

self.model.config.save_pretrained(save_dir)

# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
def save_pretrained(
self,
save_dir: str,
Expand All @@ -260,6 +264,7 @@ def save_pretrained(
warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)

# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
@classmethod
def from_pretrained(
cls,
Expand Down Expand Up @@ -354,6 +359,7 @@ def skip(*args, **kwargs):

return cls(model, False)

# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
@classmethod
def from_quantized(
cls,
Expand Down
2 changes: 2 additions & 0 deletions colossalai/inference/quant/smoothquant/models/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def from_float(module: torch.nn.Linear, input_scale):
return int8_module


# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
class W8A8B8O8Linear(torch.nn.Module):
# For qkv_proj
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
Expand Down Expand Up @@ -117,6 +118,7 @@ def from_float(module: torch.nn.Linear, input_scale, output_scale):
return int8_module


# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
class W8A8BFP32OFP32Linear(torch.nn.Module):
# For fc2 and out_proj
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
Expand Down
3 changes: 3 additions & 0 deletions colossalai/inference/quant/smoothquant/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def forward(self, x, cos, sin, position_ids):
return x_embed


# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
def llama_decoder_layer_forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -559,6 +560,7 @@ def init_to_get_rotary(config, base=10000, use_elem=False):
return _cos_cached, _sin_cached


# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def llama_model_forward(
self,
Expand Down Expand Up @@ -729,6 +731,7 @@ class SmoothLlamaForCausalLM(BaseSmoothForCausalLM):
def __init__(self, model: PreTrainedModel, quantized: bool = False):
super().__init__(model, quantized)

# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
def get_act_dict(
self,
tokenizer,
Expand Down
7 changes: 6 additions & 1 deletion colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"BloomForCausalLM",
"ChatGLMModel",
"ChatGLMForConditionalGeneration",
"LlamaGPTQForCausalLM",
"BloomGPTQForCausalLM",
]


Expand Down Expand Up @@ -213,11 +215,14 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."

model = model.model if self.shard_config.inference_gptq else model

policy = get_autopolicy(model, inference_only=True)
self.model, _ = shardformer.optimize(model, policy)

if self.shard_config.inference_gptq:
self._post_init_gptq_buffer(model)
self._post_init_gptq_buffer(self.model)

self.model = self.model.cuda()

Expand Down
1 change: 1 addition & 0 deletions colossalai/kernel/triton/gptq_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def cai_gptq_matmul_248_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask)


# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
@autotune(
configs=[
triton.Config(
Expand Down
12 changes: 6 additions & 6 deletions colossalai/kernel/triton/smooth_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

if HAS_TRITON:
"""
this function is modified from
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
this functions are modified from https://github.com/ModelTC/lightllm
"""

# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
@triton.jit
def _context_flash_attention_kernel(
Q,
Expand Down Expand Up @@ -145,20 +145,16 @@ def _context_flash_attention_kernel(
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return



@torch.no_grad()
def smooth_llama_context_attn_fwd(
q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len
):

BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk, "context process only supports equal query, key, value length"
assert Lk == Lv, "context process only supports equal query, key, value length"
assert Lk in {16, 32, 64, 128}
BLOCK_N = 128
sm_scale = 1.0 / math.sqrt(Lk)
batch, head = b_seq_len.shape[0], q.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
Expand Down Expand Up @@ -203,6 +199,7 @@ def smooth_llama_context_attn_fwd(
)
return

# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_1_kernel(
Q,
Expand Down Expand Up @@ -264,6 +261,7 @@ def _token_attn_1_kernel(
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
return

# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_1_alibi_kernel(
Q,
Expand Down Expand Up @@ -413,6 +411,7 @@ def token_attn_fwd_1(
)
return

# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
@triton.jit
def _token_attn_softmax_fwd(
softmax_logics,
Expand Down Expand Up @@ -479,6 +478,7 @@ def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen,
)
return

# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_2_kernel(
Prob,
Expand Down
3 changes: 0 additions & 3 deletions examples/inference/gptq_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.inference.tensor_parallel.modeling._utils import init_to_get_rotary
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
Expand Down Expand Up @@ -50,8 +49,6 @@ def run_llama_test(args):
quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False
)

init_to_get_rotary(model.model.model, base=10000)

model_config = model.config
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
Expand Down

0 comments on commit 785802e

Please sign in to comment.