Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inference] refactor code for smoothquant #4902

Merged
merged 1 commit into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions colossalai/inference/quant/smoothquant/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import torch.nn as nn
import transformers
from safetensors.torch import save_file as safe_save
from torch import device
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_utils import no_init_weights
Expand All @@ -24,8 +23,6 @@
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager

CPU = device("cpu")

SUPPORTED_MODELS = ["llama"]


Expand Down Expand Up @@ -204,7 +201,7 @@ def save_quantized(
if not self.quantized:
raise EnvironmentError("can only save quantized model, please execute .quantize first.")

self.model.to(CPU)
self.model.to("cpu")

model_base_name = model_basename # or f"smooth-"
if use_safetensors:
Expand Down Expand Up @@ -431,7 +428,7 @@ def from_quantized(

model_save_name = resolved_archive_file

# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
# == step2: convert model to quantized-model (replace Linear) == #
def skip(*args, **kwargs):
pass

Expand Down Expand Up @@ -463,10 +460,10 @@ def skip(*args, **kwargs):
model.model.register_buffer("_sin_cached", sin)
model.tie_weights()

# == step3: load checkpoint of to quantized-model == #
accelerate.utils.modeling.load_checkpoint_in_model(
model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True
)
model = model.to("cuda")

# == step4: set seqlen == #
model_config = model.config.to_dict()
Expand Down
52 changes: 14 additions & 38 deletions colossalai/inference/quant/smoothquant/models/llama.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# Code modified from smoothquant: https://github.com/mit-han-lab/smoothquant

import math
import os
import types
Expand Down Expand Up @@ -92,7 +90,7 @@ def pack(
out_input_scale: float,
):
int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads)
# self.register_buffer("attn_input_scale", torch.tensor([1.0]))

int8_module.attn_input_scale = torch.tensor([attn_input_scale])

int8_module.q_output_scale = torch.tensor([q_output_scale])
Expand All @@ -107,10 +105,6 @@ def pack(
int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale)
int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale)

# int8_module.q_proj = module.q_proj
# int8_module.k_proj = module.k_proj
# int8_module.v_proj = module.v_proj
# int8_module.o_proj = module.o_proj
int8_module.out_input_scale = torch.tensor([out_input_scale])

return int8_module
Expand Down Expand Up @@ -259,10 +253,8 @@ def forward(self, x):
@staticmethod
def from_float(module: torch.nn.LayerNorm, output_scale: float):
assert module.weight.shape[0] == module.weight.numel()
# assert module.bias.shape[0] == module.bias.numel()
q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon)
q_module.weight = module.weight / output_scale
# q_module.bias = module.bias / output_scale
return q_module


Expand Down Expand Up @@ -346,9 +338,6 @@ def pack(
out_input_scale,
)

# int8_decoder_layer.input_layernorm = module.input_layernorm
# int8_decoder_layer.self_attn = module.self_attn

int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float(
module.post_attention_layernorm, gate_input_scale
)
Expand All @@ -360,9 +349,6 @@ def pack(
down_input_scale,
)

# int8_decoder_layer.post_attention_layernorm = module.post_attention_layernorm
# int8_decoder_layer.mlp = module.mlp

return int8_decoder_layer

def forward(
Expand Down Expand Up @@ -641,8 +627,6 @@ def llama_model_forward(
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index

if position_ids is None:
Expand Down Expand Up @@ -673,11 +657,7 @@ def llama_model_forward(
hidden_states = inputs_embeds

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
raise NotImplementedError("not implement gradient_checkpointing and training options ")

if past_key_values_length == 0:
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
Expand All @@ -701,20 +681,17 @@ def llama_model_forward(

past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:
raise NotImplementedError("not implement gradient_checkpointing and training options ")
else:
layer_outputs = decoder_layer(
hidden_states,
rotary_emb=(position_cos, position_sin),
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
infer_state=infer_state,
)
layer_outputs = decoder_layer(
hidden_states,
rotary_emb=(position_cos, position_sin),
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
infer_state=infer_state,
)

hidden_states = layer_outputs[0]
infer_state.decode_layer_id += 1
Expand Down Expand Up @@ -836,13 +813,12 @@ def quantized(
scale_dict["q_rotary_output_scale"] = (
act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127
)

scale_dict["k_rotary_output_scale"] = (
act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127
)

scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127
# mlp scales

scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127
scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127
scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127
Expand Down
5 changes: 4 additions & 1 deletion examples/inference/smoothquant_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

import torch
from datasets import load_dataset
from transformers import LlamaTokenizer

from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
Expand Down Expand Up @@ -47,13 +48,15 @@ def main():
if not os.path.exists(dataset_path):
print(f"Cannot find the dataset at {args.dataset_path}")
raise FileNotFoundError
dataset = dataset = load_dataset("json", data_files=dataset_path, split="train")

model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len)
model = model.cuda()
model.quantized(tokenizer, dataset_path, num_samples=num_samples, seq_len=seq_len)

model.save_quantized(output_path, model_basename="llama-7b")

model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b")
model = model.cuda()

generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True)
input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda")
Expand Down
Loading