From ff9b517aefd052c04c91790ce1f2cef27de68548 Mon Sep 17 00:00:00 2001 From: Alex621Lin Date: Thu, 20 Jul 2023 15:02:54 +0200 Subject: [PATCH] fix(server): llama v2 GPTQ (#648) As per title & reported https://github.com/huggingface/text-generation-inference/issues/601#issuecomment-1641435956 https://huggingface.co/TheBloke/Llama-2-70B-chat-GPTQ/discussions/5 Test it: ``` GPTQ_BITS=4 GPTQ_GROUPSIZE=1 text-generation-launcher --model-id TheBloke/Llama-2-70B-chat-GPTQ --port 8080 --num-shard 4 --quantize gptq ``` & ``` curl 127.0.0.1:8080/generate \ -X POST \ -d '{"inputs":"hey llama","parameters":{"max_new_tokens":256}}' \ -H 'Content-Type: application/json' ``` --- .../custom_modeling/flash_llama_modeling.py | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index b2bde28..84e22f7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -148,24 +148,27 @@ def forward(self, hidden_states, residual=None): def _load_gqa(config, prefix: str, weights): - w = [ - weights.get_sharded(f"{prefix}.q_proj.weight", dim=0), - weights.get_sharded(f"{prefix}.k_proj.weight", dim=0), - weights.get_sharded(f"{prefix}.v_proj.weight", dim=0), - ] - weight = torch.cat(w, dim=0) - weight = weight.to(dtype=weights.dtype).to(device=weights.device) - bias = None assert config.hidden_size % config.num_attention_heads == 0 - head_size = config.hidden_size // config.num_attention_heads assert config.num_attention_heads % weights.process_group.size() == 0 - num_heads = config.num_attention_heads // weights.process_group.size() - num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ - (num_heads + 2 * num_key_value_heads) * head_size, - config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0 + ) + + if config.quantize != "gptq": + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize)) class FlashLlamaAttention(torch.nn.Module):