Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AwqConfig class #132

Merged
merged 9 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ The detailed support list:
| ---------| ----------------------------|
| LLaMA-2 | 7B/13B/70B |
| LLaMA | 7B/13B/30B/65B |
| Mistral | 7B |
| Vicuna | 7B/13B |
| MPT | 7B/30B |
| Falcon | 7B/40B |
Expand All @@ -97,6 +98,8 @@ There are two versions of AWQ: GEMM and GEMV. Both names relate to how matrix mu

### Examples

More examples can be found in the [examples directory](examples).

<details>

<summary>Quantization</summary>
Expand All @@ -109,7 +112,7 @@ from transformers import AutoTokenizer

model_path = 'lmsys/vicuna-7b-v1.5'
quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4 }
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
Expand All @@ -134,10 +137,9 @@ from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer

quant_path = "casperhansen/vicuna-7b-v1.5-awq"
quant_file = "awq_model_w4_g128.pt"

# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file, fuse_layers=True)
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_special_tokens=True)

Expand Down
89 changes: 89 additions & 0 deletions awq/models/_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
import json
import logging
from typing import Dict
from dataclasses import dataclass, field, fields
from transformers.utils.hub import PushToHubMixin, cached_file

@dataclass
class AwqConfig(PushToHubMixin):
quant_method: str = field(default="awq")
zero_point: bool = field(default=True)
q_group_size: int = field(default=128)
w_bit: int = field(default=4)
version: str = field(default="GEMM")
config_file_name = "quant_config.json"

def save_pretrained(self, save_dir: str, **kwargs):
logging.warning(
"`quant_config.json` is being deprecated in the future"
" in favor of quantization_config in config.json."
)
with open(os.path.join(save_dir, self.config_file_name), "w+", encoding="utf-8") as file:
file.write(json.dumps(self.to_dict(), indent=4))

@classmethod
def from_dict(cls, quant_config: Dict={}):
if not quant_config:
quant_config = cls()
else:
quant_config = cls(**quant_config)

return quant_config

@classmethod
def from_pretrained(cls, save_dir: str, **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
commit_hash = kwargs.pop("_commit_hash", None)

if os.path.isdir(save_dir): # Local
resolved_config_file = os.path.join(save_dir, cls.config_file_name)
else: # Remote
resolved_config_file = cached_file(
save_dir,
cls.config_file_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)

if os.path.exists(resolved_config_file):
with open(resolved_config_file, 'r', encoding="utf-8") as file:
loaded_config = json.loads(file.read())
quant_config = cls(**loaded_config)
else:
quant_config = cls()

return quant_config

def to_dict(self):
return {
"zero_point": self.zero_point,
"q_group_size": self.q_group_size,
"w_bit": self.w_bit,
"version": self.version
}

def to_transformers_dict(self):
return {
"quant_method": self.quant_method,
"zero_point": self.zero_point,
"group_size": self.q_group_size,
"bits": self.w_bit,
"version": self.version.lower(),
}
8 changes: 3 additions & 5 deletions awq/models/aquila.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
## Reference from llama.py
from .base import BaseAWQForCausalLM
from typing import Dict
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as AquilaDecoderLayer,
LlamaForCausalLM as AquilaForCausalLM,
Expand All @@ -14,8 +13,8 @@ class AquilaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_position_embeddings"

@staticmethod
def fuse_layers(model: AquilaForCausalLM, quant_config: Dict):
fuser = AquilaFuser(model, quant_config)
def fuse_layers(model: AquilaForCausalLM):
fuser = AquilaFuser(model)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
Expand Down Expand Up @@ -82,9 +81,8 @@ def get_layers_for_scaling(module: AquilaDecoderLayer, input_feat, module_kwargs
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV

class AquilaFuser:
def __init__(self, model, quant_config):
def __init__(self, model):
self.model = model
self.quant_config = quant_config

self.attention_modules: List[Tuple[str, AquilaAttention]] = [
(name, module) for name, module in self.model.named_modules()
Expand Down
45 changes: 17 additions & 28 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,27 @@
import torch
import torch.nn as nn
from tqdm import tqdm
from typing import List, Union, Dict
from typing import List, Union
from safetensors.torch import save_file
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.utils import simple_dispatch_model
from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import get_named_linears, set_op_by_name
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map


class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, quant_config):
super().__init__()
self.model:PreTrainedModel = model
self.model_type:str = model_type
self.is_quantized:bool = is_quantized
self.search_result = None
self.quant_config: Dict = quant_config
self.quant_config: AwqConfig = quant_config

def to(self, device: str):
return self.model.to(device)
Expand All @@ -39,18 +40,17 @@ def generate(self, *args, **kwargs):
def quantize(self, tokenizer=None, quant_config={},
calib_data: Union[str, List[str]]="pileval",
split="train", text_column="text"):
self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)

quantizer = AwqQuantizer(
self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"],
quant_config["version"], calib_data, split, text_column
self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size,
self.quant_config.version, calib_data, split, text_column
)
quantizer.quantize()
self.is_quantized = True

@staticmethod
def fuse_layers(model, quant_config):
def fuse_layers(model):
pass

def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
Expand All @@ -61,8 +61,10 @@ class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x

# Save model files with empty state dict
# Save model and config files with empty state dict
self.model.config.quantization_config = self.quant_config.to_transformers_dict()
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
self.quant_config.save_pretrained(save_dir)

# Remove empty state dict
os.remove(f'{save_dir}/pytorch_model.bin')
Expand All @@ -89,10 +91,6 @@ def forward(self, x): return x
if index is not None:
with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
file.write(json.dumps(index, indent=4))

# Save config
with open(f'{save_dir}/quant_config.json', 'w+') as file:
file.write(json.dumps(self.quant_config, indent=4))


@classmethod
Expand Down Expand Up @@ -146,7 +144,7 @@ def from_quantized(self, model_path, model_type, model_filename='',
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)

# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config, quant_config["version"])
self._load_quantized_modules(self, model, quant_config, quant_config.version)

model.tie_weights()

Expand All @@ -169,7 +167,7 @@ def from_quantized(self, model_path, model_type, model_filename='',

# Dispath to devices
if fuse_layers:
self.fuse_layers(model, quant_config)
self.fuse_layers(model)

# Offloading dispatch
from accelerate import dispatch_model
Expand Down Expand Up @@ -201,16 +199,7 @@ def _load_config(self, model_path, model_filename, safetensors=False,

# [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class
quant_config_path = f'{model_path}/quant_config.json'
if os.path.exists(quant_config_path):
with open(quant_config_path, 'r') as file:
quant_config = json.loads(file.read())

if "version" not in quant_config.keys():
quant_config["version"] = version
else:
# Default config that works for most models
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
quant_config = AwqConfig.from_pretrained(model_path)

# Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
Expand All @@ -225,7 +214,7 @@ def _load_config(self, model_path, model_filename, safetensors=False,

def _load_quantized_modules(self, model, quant_config, version):
# Real quantization of weights
assert quant_config["zero_point"], "We only support zero_point quantization now."
assert quant_config.zero_point, "We only support zero_point quantization now."

# Get blocks of model
layers = self.get_model_layers(model)
Expand All @@ -248,8 +237,8 @@ def _load_quantized_modules(self, model, quant_config, version):

q_linear = q_linear_module.from_linear(
module,
quant_config['w_bit'],
quant_config['q_group_size'],
quant_config.w_bit,
quant_config.q_group_size,
True
)
q_linear.to(next(layer.parameters()).device)
Expand Down
3 changes: 1 addition & 2 deletions awq/models/falcon.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from .base import BaseAWQForCausalLM
from typing import Dict
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer as OldFalconDecoderLayer, FalconForCausalLM, FalconAttention

class FalconAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "FalconDecoderLayer"

@staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config: Dict):
def fuse_layers(model: FalconForCausalLM):
fuser = FalconFuser(model)

# TODO: Implement correctly fused modules for Falcon 40B and Falcon 180B
Expand Down
8 changes: 3 additions & 5 deletions awq/models/llama.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from .base import BaseAWQForCausalLM
from typing import Dict
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM

class LlamaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer"
max_new_tokens_key = "max_position_embeddings"

@staticmethod
def fuse_layers(model: LlamaForCausalLM, quant_config: Dict):
fuser = LlamaFuser(model, quant_config)
def fuse_layers(model: LlamaForCausalLM):
fuser = LlamaFuser(model)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
Expand Down Expand Up @@ -76,9 +75,8 @@ def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs)
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP

class LlamaFuser:
def __init__(self, model, quant_config):
def __init__(self, model):
self.model = model
self.quant_config = quant_config

self.attention_modules: List[Tuple[str, LlamaAttention]] = [
(name, module) for name, module in self.model.named_modules()
Expand Down
8 changes: 3 additions & 5 deletions awq/models/mistral.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Dict
from .base import BaseAWQForCausalLM
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralForCausalLM

Expand All @@ -7,8 +6,8 @@ class MistralAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_position_embeddings"

@staticmethod
def fuse_layers(model: MistralForCausalLM, quant_config: Dict):
fuser = MistralFuser(model, quant_config)
def fuse_layers(model: MistralForCausalLM):
fuser = MistralFuser(model)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
Expand Down Expand Up @@ -76,9 +75,8 @@ def get_layers_for_scaling(module: MistralDecoderLayer, input_feat, module_kwarg
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRMSNorm, MistralMLP

class MistralFuser:
def __init__(self, model, quant_config):
def __init__(self, model):
self.model = model
self.quant_config = quant_config

self.attention_modules: List[Tuple[str, MistralAttention]] = [
(name, module) for name, module in self.model.named_modules()
Expand Down
3 changes: 1 addition & 2 deletions awq/models/mpt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from .base import BaseAWQForCausalLM
from typing import Dict
from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM

class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock"
max_new_tokens_key = "max_seq_len"

@staticmethod
def fuse_layers(model: MptForCausalLM, quant_config: Dict):
def fuse_layers(model: MptForCausalLM):
fuser = MptFuser(model)
fuser.fuse_transformer()

Expand Down
Loading