-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Awq
] Add llava fused modules support
#28239
[Awq
] Add llava fused modules support
#28239
Conversation
Thank you soooo much, this PR and #28032 helped me work well now! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! More discussion needed for the attended / non attended tokens is needed IMO! 🤗
# In case a user passes a `AwqConfig` with `do_fuse=True` for models that have | ||
# a `modules_to_not_convert` attribute we need to manually set that attribute into the | ||
# passed `quantization_config` | ||
elif ( | ||
quantization_config.modules_to_not_convert is None | ||
and "modules_to_not_convert" in config.quantization_config | ||
): | ||
quantization_config.modules_to_not_convert = config.quantization_config["modules_to_not_convert"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that seems a bit odd to me, should either be done in the integration (I know you don't have access to the config) or when you init the quantization_config
, you should use config.quantization_config
no? (at some point merging kwargs?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we always rely on quantization_config
, we do merge kwargs but on the other way around (from quantization_config
to config.quantization_config
) with get_loading_attributes()
. The scenario above happens only with the specific case where users pass do_fuse=True
& a non-None
value in config.quantization_config["modules_to_not_convert"]
. I think it is a good idea to think of a way to harmonize how to merge kwargs between config.quantization_config
and quantization_config
but might be slightly out of the scope of the PR as I need to do it for all quantization schemes we support. I propose to do that properly in a follow up PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright let's keep that in mind
valid_indices = non_attended_tokens < extended_attention_mask.size(-1) | ||
new_batch_index = batch_index[valid_indices] | ||
new_non_attended_tokens = non_attended_tokens[valid_indices] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a fan of adding custom code only to handle custom usages. There should be a more general way of handling these things (why use the extended attention mask and not just the attention mask, why not use the past key value length, etc)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not to handle custom usage, it happens when a past key value with padd tokens are on indices that are larger than the extended attention mask shape: #28032 (comment) & #28239 (comment) - this can mainly happen in batched generation with long seq len and it specifically happens for autoawq fused modules because the dummy past key values are initialized will all zeros: https://github.com/casper-hansen/AutoAWQ/blob/a3db8099a234a46a21bf5e46340da60da6992e0c/awq/modules/fused/attn.py#L238
In any case I don't think this will cause any harm since it just filers out indices of padd tokens (that are not attended anyway) that are out of the extended attention mask range, and I confirmed all slow tests pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, I though this was already solved. No worries them, it's just that tensor indexing might slow things down a bit but is required anyway. I think a refactor might help:
- Init the embeddings with a different value (like -1 which is might not happen as often as zeros) when we compute the image indexes
- correctly update the attention mask when merging to make sure we keep track of what we computed
I'd be in favor of moving this fix to another PR maybe? WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see thanks !
I am happy to explore:
Init the embeddings with a different value (like -1 which is might not happen as often as zeros) when we compute the image indexes
In another PR !
I'd be in favor of moving this fix to another PR maybe? WDYT?
That might be not ideal because if this fix is not introduced, users cannot run llava + fused modules :/ I'll address the points you shared in a follow up PR !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
valid_indices = non_attended_tokens < extended_attention_mask.size(-1) | ||
new_batch_index = batch_index[valid_indices] | ||
new_non_attended_tokens = non_attended_tokens[valid_indices] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, I though this was already solved. No worries them, it's just that tensor indexing might slow things down a bit but is required anyway. I think a refactor might help:
- Init the embeddings with a different value (like -1 which is might not happen as often as zeros) when we compute the image indexes
- correctly update the attention mask when merging to make sure we keep track of what we computed
I'd be in favor of moving this fix to another PR maybe? WDYT?
Thanks for your reviews @ArthurZucker ! Merging ! I'll address the points you shared in #28239 (comment) in another PR as stated in my reply |
* add llava + fused modules * Update src/transformers/models/llava/modeling_llava.py Co-authored-by: Arthur <[email protected]> --------- Co-authored-by: Arthur <[email protected]>
* add llava + fused modules * Update src/transformers/models/llava/modeling_llava.py Co-authored-by: Arthur <[email protected]> --------- Co-authored-by: Arthur <[email protected]>
* add llava + fused modules * Update src/transformers/models/llava/modeling_llava.py Co-authored-by: Arthur <[email protected]> --------- Co-authored-by: Arthur <[email protected]>
* add llava + fused modules * Update src/transformers/models/llava/modeling_llava.py Co-authored-by: Arthur <[email protected]> --------- Co-authored-by: Arthur <[email protected]>
What does this PR do?
This PR adds the Llava + fused modules support for blazing fast text generation using Llava + AWQ!
This PR also fixes the issue: #28032 (comment) pointed out by a user since a custom past key value is passed to the model, indeed filtering out indexes that are inside the range of
extended_attention_mask
fixes the issue.Added also a slow test
Can also confirm all Llava slow tests pass!
cc @casper-hansen