Skip to content

Commit

Permalink
Mixtral: Mixture of Experts quantization (#251)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Dec 22, 2023
1 parent 2350a4d commit 5b9f3c4
Show file tree
Hide file tree
Showing 12 changed files with 323 additions and 18 deletions.
1 change: 1 addition & 0 deletions awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .aquila import AquilaAWQForCausalLM
from .yi import YiAWQForCausalLM
from .qwen import QwenAWQForCausalLM
from .mixtral import MixtralAWQForCausalLM
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"gptj": GPTJAWQForCausalLM,
"gpt_bigcode": GptBigCodeAWQForCausalLM,
"mistral": MistralAWQForCausalLM,
"mixtral": MixtralAWQForCausalLM,
"gpt_neox": GPTNeoXAWQForCausalLM,
"aquila": AquilaAWQForCausalLM,
"Yi": YiAWQForCausalLM,
Expand Down
12 changes: 9 additions & 3 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from awq.quantize.quantizer import AwqQuantizer
from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import get_named_linears, set_op_by_name
from awq.utils.module import (
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize,
)
from transformers import (
AutoModelForCausalLM,
AutoConfig,
Expand All @@ -24,7 +28,6 @@
infer_auto_device_map,
load_checkpoint_and_dispatch,
)
from accelerate.utils import get_balanced_memory

class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, config, quant_config):
Expand Down Expand Up @@ -176,7 +179,7 @@ def _load_config(self, model_path, model_filename, safetensors=True,
if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
if safetensors:
ignore_patterns.extend(["*.pt*", "*.bin*"])
ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
else:
ignore_patterns.append("*.safetensors*")

Expand Down Expand Up @@ -215,6 +218,9 @@ def _load_quantized_modules(self, model, quant_config, version):
# Get every linear layer in a block
named_linears = get_named_linears(layer)

# Filter out the linear layers we don't want to exclude
named_linears = exclude_layers_to_not_quantize(named_linears, quant_config.modules_to_not_convert)

# Replace activation functions
self._scale_activations(self, layer)

Expand Down
137 changes: 137 additions & 0 deletions awq/models/mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import MixtralBlock
from awq.modules.fused.model import MixtralModel
from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer as OldMixtralDecoderLayer,
MixtralForCausalLM as OldMixtralForCausalLM
)
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm

class MixtralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MixtralDecoderLayer"
max_new_tokens_key = "max_position_embeddings"

@staticmethod
def fuse_layers(model: OldMixtralForCausalLM):
fuser = MixtralFuser(model)
# TODO: Fix perplexity on fusing Mixtral
#fuser.fuse_transformer()

@staticmethod
def get_model_layers(model: OldMixtralForCausalLM):
return model.model.layers

@staticmethod
def get_act_for_scaling(module):
return dict(
is_scalable=False
)

@staticmethod
def move_embed(model: OldMixtralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(module: OldMixtralDecoderLayer, input_feat, module_kwargs):
layers = []

# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))

# attention out
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))

# linear in
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[
w for expert in module.block_sparse_moe.experts
for w in [expert.w1, expert.w3]
],
inp=input_feat['block_sparse_moe'],
module2inspect=module.block_sparse_moe,
))

# linear out
for i, expert in enumerate(module.block_sparse_moe.experts):
layers.append(dict(
prev_op=expert.w3,
layers=[expert.w2],
inp=input_feat[f'block_sparse_moe.experts.{i}.w2'],
))

return layers


class MixtralFuser:
def __init__(self, model: OldMixtralForCausalLM):
self.model = model

self.mixtral_blocks: List[Tuple[str, OldMixtralDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if 'MixtralDecoderLayer'.lower() in module.__class__.__name__.lower()
]

def fuse_transformer(self):
blocks = []

module: OldMixtralDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
)
# Adapt to mixture of experts
for i in range(len(module.block_sparse_moe.experts)):
mlp = QuantFusedMLP(
gate_proj=module.block_sparse_moe.experts[i].w1,
down_proj=module.block_sparse_moe.experts[i].w2,
up_proj=module.block_sparse_moe.experts[i].w3
)
module.block_sparse_moe.experts[i] = mlp
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(MixtralBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
moe=module.block_sparse_moe,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))

self.model.model = MixtralModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)

34 changes: 34 additions & 0 deletions awq/modules/fused/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,40 @@
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused

class MixtralBlock(nn.Module):
def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj,
moe, norm_1, norm_2, dev, max_seq_len
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False
).to(dev)
self.norm_2 = norm_2.to(dev)
self.moe = moe
self.device = dev

def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask
)

h = hidden_states.to(attn_output.device) + attn_output
out, _ = self.moe.forward(self.norm_2(h))
out = h + out

return out, None, past_key_value

class LlamaLikeBlock(nn.Module):
"""
LlamaLikeBlock is intended to be reused across blocks that have
Expand Down
5 changes: 4 additions & 1 deletion awq/modules/fused/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(

self.activation = activation

def forward(self, x):
def forward(self, x, routing_weights=None):
out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1])
gate_output = self.linear(
Expand All @@ -57,6 +57,9 @@ def forward(self, x):
x = x.reshape(out_shape)
x = self.down_proj(x)

if routing_weights is not None:
x = routing_weights * x

return x


Expand Down
59 changes: 57 additions & 2 deletions awq/modules/fused/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,63 @@
import torch.nn as nn
from typing import List
from awq.utils import fused_utils
from transformers.modeling_outputs import BaseModelOutputWithPast
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock
from transformers.modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock, MixtralBlock


class MixtralModel(nn.Module):
def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[MixtralBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0

@torch.inference_mode()
def forward(
self,
input_ids: torch.Tensor,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape

fused_utils.prepare_cache(self.blocks, seqlen)

h = self.embedding(input_ids)

mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h,
)

for layer in self.blocks:
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)

h = self.norm(h)

return MoeModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
router_logits=(),
)


class LlamaLikeModel(nn.Module):
"""
Expand Down
22 changes: 13 additions & 9 deletions awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from awq.utils.calib_data import get_calib_dataset
from awq.quantize.scale import apply_scale, apply_clip
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
from awq.utils.module import (
append_str_prefix,
get_op_name,
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize
)


class AwqQuantizer:
Expand Down Expand Up @@ -70,13 +76,6 @@ def pseudo_dequantize_tensor(self, w: nn.Linear, scales: torch.Tensor, zeros: to

return w

def _exclude_layers_to_not_quantize(self, linear_layers):
filtered_layers = {}
for name, linear_layer in linear_layers.items():
if not any(key in name for key in self.modules_to_not_convert):
filtered_layers[name] = linear_layer
return filtered_layers

def quantize(self):
for i in tqdm(range(len(self.modules)), desc="AWQ"):
# Move module and inputs to correct device
Expand All @@ -91,7 +90,7 @@ def quantize(self):
named_linears = get_named_linears(self.modules[i])

# Filter out the linear layers we don't want to exclude
named_linears = self._exclude_layers_to_not_quantize(named_linears)
named_linears = exclude_layers_to_not_quantize(named_linears, self.modules_to_not_convert)

input_feat = self._get_input_feat(self.modules[i], named_linears)
clear_memory()
Expand Down Expand Up @@ -387,6 +386,11 @@ def cache_input_hook(m, x, y, name, feat_dict):

input_feat = defaultdict(list)
handles = []

# FIXME: Workaround for Mixtral to use block_sparse_moe input features
if self.awq_model.model_type == "mixtral":
named_linears = {**named_linears, "block_sparse_moe": layer.block_sparse_moe}

for name in named_linears:
handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name,
Expand Down
Loading

0 comments on commit 5b9f3c4

Please sign in to comment.