Skip to content

Commit

Permalink
feat(bench): add AWQ setup
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Mar 21, 2024
1 parent 5b2bb9b commit 96871c1
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 2 deletions.
5 changes: 4 additions & 1 deletion bench/generation/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from metrics.perplexity import perplexity
from metrics.prediction import prediction_accuracy

from setup.awq import setup as awq_setup
from setup.bnb import setup as bnb_setup
from setup.hqq import setup as hqq_setup
from setup.quanto import setup as quanto_setup
Expand All @@ -32,6 +33,8 @@ def evaluate(
):
if quantizer == "quanto":
model, tokenizer = quanto_setup(model_id, weights, activations, batch_size, device)
elif quantizer == "awq":
model, tokenizer = awq_setup(model_id, weights, activations)
elif quantizer == "bnb":
model, tokenizer = bnb_setup(model_id, weights, activations, device)
elif quantizer == "hqq":
Expand All @@ -57,7 +60,7 @@ def main():
)
parser.add_argument("--device", type=str, default=None, help="The device to use for generation.")
parser.add_argument("--metric", type=str, default="prediction", choices=["latency", "prediction", "perplexity"])
parser.add_argument("--quantizer", type=str, default="quanto", choices=["quanto", "bnb", "hqq"])
parser.add_argument("--quantizer", type=str, default="quanto", choices=["quanto", "awq", "bnb", "hqq"])
parser.add_argument(
"--weights",
type=str,
Expand Down
3 changes: 2 additions & 1 deletion bench/generation/metrics/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def elapsed_time(self, other):
do_sample=False,
eos_token_id=None, # This is required for min_new_tokens to actually have an effect.
)
model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.
if getattr(model, "generation_config", None) is not None:
model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.

synchronize(device)
if device.type == "cuda":
Expand Down
77 changes: 77 additions & 0 deletions bench/generation/setup/awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer


def prepare_inputs_for_generation(input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
if past_key_values is not None:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs


def setup(model_id: str, weights: str, activations: str, group_size: int = 64, version="GEMV"):
if activations != "none":
raise ValueError("Activation quantization is not supported by HQQ")
if weights != "int4":
raise ValueError("AWQ only supports int4 weights.")
quant_config = {"zero_point": True, "q_group_size": group_size, "w_bit": 4, "version": version}
# Load model
model = AutoAWQForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
# Quantize
model.quantize(tokenizer, quant_config=quant_config)
# We need to save otherwise it doesn't work
quant_path = model_id.replace("/", "-") + f"_{group_size}_{version}"
model.save_quantized(quant_path)
# Reload model
model = AutoAWQForCausalLM.from_quantized(quant_path)
# Hack: force transformers 4.36.2 behaviour
model.model.prepare_inputs_for_generation = prepare_inputs_for_generation
# Hack because AWQ models are not transformers models
model.device = next(model.parameters()).device
return model, tokenizer

0 comments on commit 96871c1

Please sign in to comment.