Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster generation using AWQ + Fused modules #27411

Merged
merged 51 commits into from
Dec 5, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Nov 9, 2023

What does this PR do?

Introduces a new feature - fused module generation using autoawq library. Users need to specify modules that they want to fuse inside fusing_mapping.

The API is as follows:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AwqConfig, TextStreamer

model_id = "TheBloke/Mistral-7B-OpenOrca-AWQ"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

quantization_config = AwqConfig(
    bits=4,
    do_fuse=True,
    fuse_max_seq_len=512,
)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config).to(0)
tokenizer = AutoTokenizer.from_pretrained(model_id)

streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

prompt_template = """\
<|im_start|>system
You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!<|im_end|>
<|im_start|>user
{prompt}<|im_end|>
<|im_start|>assistant"""

prompt = "You're standing on the surface of the Earth. "\
        "You walk one mile south, one mile west and one mile north. "\
        "You end up exactly where you started. Where are you?"

tokenizer.pad_token = tokenizer.eos_token

inputs = tokenizer([prompt_template.format(prompt=prompt), prompt_template.format(prompt=prompt), prompt_template.format(prompt=prompt)], return_tensors="pt", padding=True).to(0)

outputs = model.generate(**inputs, max_new_tokens=512)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

Before this PR:

before-fusing

After this PR:

fast-awq-generate

TODOs:

cc @amyeroberts @casper-hansen @SunMarc

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@younesbelkada
Copy link
Contributor Author

Before moving forward with tests and advanced docs, I would love to have an early feedback of the API that is described in the PR description. cc @amyeroberts , whenever you have time, i would appreciate your feedback on this PR 🙏 Thanks!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice piece of work! 🔥

Main comment is about structure of the input arguments. It might just be my current understanding from the PR so please do correct me if I'm wrong.

At the moment, it seems there's two types of arguments which configure the behaviour: the fusing_mapping and max_seq_len. Does max_seq_len have its own config value because we expect this to be the value users modify a lot?

For AwqConfig do we foresee other fuse arguments to be added to configure this behaviour? If so, we might want to bundle them all together into a single fuse_config which the AwqConfig owns.

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/integrations/awq.py Show resolved Hide resolved
Comment on lines +148 to +151
current_fused_mapping["hidden_size"] = hidden_size
current_fused_mapping["num_attention_heads"] = num_attention_heads
current_fused_mapping["num_key_value_heads"] = num_key_value_heads
current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If hidden_size, num_attentions_heads and num_key_value_heads are required arguments it would be good to verify these keys exist in the mapping from quantization_config.fusion_mapping

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose to do it in 0a08551 lmk what do you think!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

# Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
hidden_size = model.config.hidden_size
num_attention_heads = model.config.num_attention_heads
num_key_value_heads = getattr(model.config, "num_key_value_heads", num_attention_heads)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line here indicates a possible area of brittleness in the current design: there's going to be many different variations of model attribute names which - if this becomes popular - we'll have to account for. Passing in quantization_config.fusing_mapping is great to provide flexibility for the mapping.

My understanding is that quantization_config.fuse_max_seq_len is passed in separately because the value will be specific to each model's configuration rather than e.g. architecture. The question is - why are params like "hidden_size" and "num_key_value_heads" not passed in with quantization_config.fusing_mapping? I think it would make more sense for them all to be passed in together or have two separate dictionaries - one to map layer names and the other to map model-specific configs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I proposed to do that to handle different architecture variants , 7b, 13b, .. etc. I did not find a way to properly map that in AWQ_FUSED_MAPPINGS and decided to make it as architecture agnostic as possible.
For now we only support mistral and llama if one passes fuse_modules, for other architectures users need to manually create a mapping and pass it through AwqConfig. If we want to support other architectures in the future, and face attribute errors I propose to fix it directly in the corresponding config object by adding an attribute_map: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py#L83 to make sure num_attention_heads, and hidden_size exists. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me!

It handles almost all cases and if there's anything more complex in the future we can handle it then when we know more about the problem.

src/transformers/integrations/awq.py Show resolved Hide resolved
return current_fused_mapping


def fuse_awq_modules(model, quantization_config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty big function - I'd split it up so that there's private functions for each layer-replacement which is then called within the big for-loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in cde53ef

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful 🤩

).to(old_module.weight.device)
del old_module
# Replace MLP layers
if hasattr(module, fusing_mapping["mlp"][0]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we guarantee fusing_mapping has an "mlp" key? and that it has at least one value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose to do a more extensive check in b187c07 and I changed the logic there a bit to make sure these methods will do nothing in case we put an empty array in these fields

Comment on lines 193 to 195
gate_proj = getattr(module, fusing_mapping["mlp"][0])
up_proj = getattr(module, fusing_mapping["mlp"][1])
down_proj = getattr(module, fusing_mapping["mlp"][2])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gate_, up_ down_ is highly specific to just a few models. Instead we can generalise this to take a list of linear layers so that any MLP can be passed

down_proj = getattr(module, fusing_mapping["mlp"][2])

previous_device = gate_proj.qweight.device
activation_fn = ACT2FN[model.config.hidden_act]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here re config param names. hidden_act is common but it's not always the name used. Could we pass this in in the fuse config too?

Copy link
Contributor Author

@younesbelkada younesbelkada Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, do you think my suggestion here: #27411 (comment) could be applied in this case?

src/transformers/integrations/awq.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

Great code, great tests, great docs 💯

# Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
hidden_size = model.config.hidden_size
num_attention_heads = model.config.num_attention_heads
num_key_value_heads = getattr(model.config, "num_key_value_heads", num_attention_heads)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me!

It handles almost all cases and if there's anything more complex in the future we can handle it then when we know more about the problem.

Comment on lines +148 to +151
current_fused_mapping["hidden_size"] = hidden_size
current_fused_mapping["num_attention_heads"] = num_attention_heads
current_fused_mapping["num_key_value_heads"] = num_key_value_heads
current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

src/transformers/integrations/awq.py Show resolved Hide resolved
return current_fused_mapping


def fuse_awq_modules(model, quantization_config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful 🤩

src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
Comment on lines +626 to +635
if self.do_fuse and self.modules_to_fuse is not None:
required_keys = [
"hidden_size",
"num_attention_heads",
"num_key_value_heads",
"mlp",
"attention",
"layernorm",
"use_alibi",
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on these keys it would be cool to have a tool which automatically generates a config for a model assuming you wanted to fuse all modules.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making AWQ faster through fused modules 🔥. The design looks great and could be easily extended to other quantization scheme in the future. I left a few comments.

docs/source/en/main_classes/quantization.md Outdated Show resolved Hide resolved
docs/source/en/main_classes/quantization.md Outdated Show resolved Hide resolved
docs/source/en/main_classes/quantization.md Outdated Show resolved Hide resolved
src/transformers/integrations/awq.py Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Show resolved Hide resolved
src/transformers/modeling_utils.py Show resolved Hide resolved
tests/quantization/autoawq/test_awq.py Show resolved Hide resolved
tests/quantization/autoawq/test_awq.py Show resolved Hide resolved
tests/quantization/autoawq/test_awq.py Show resolved Hide resolved
@younesbelkada
Copy link
Contributor Author

Thanks @amyeroberts @SunMarc for your great reviews!
@SunMarc I just want to get more clarification on this comment, otherwise good to merge IMO !

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this ! I've left a few comments about two points but they are not blocking for this PR !

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Show resolved Hide resolved
@younesbelkada younesbelkada merged commit fdb85be into huggingface:main Dec 5, 2023
22 checks passed
@younesbelkada younesbelkada deleted the awq-fused-modules branch December 5, 2023 11:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants