Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba / FalconMamba: Fix mamba left padding #32677

Merged
merged 10 commits into from
Aug 19, 2024

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Aug 14, 2024

What does this PR do?

As pointed out in #32080 (comment) - it is important to zero-out hidden states that corresponds to the padd tokens before and after the causal convolution so that the padd token will not have an impact on the calculated hidden states.

This can be empirically proven by generation quality before / after this fix (note by default FalconMamba uses left padding):

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "tiiuae/falcon-mamba-7b"
tok = AutoTokenizer.from_pretrained(model_id)
tok.pad_token_id = tok.eos_token_id

texts = [
    "Hello today",
    "Hello my name is Younes and today"
]

inputs = tok(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(0)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=0, torch_dtype=torch.bfloat16)

out = model.generate(**inputs, max_new_tokens=20)
print(tok.batch_decode(out, skip_special_tokens=True))

Before the fix:

Hello today I'm.\nI'm.\n Hello today.\n Hello today.\n Hello today

After the fix:

Hello today I'm going to show you how to make a 3D model of a house.\n

Propagated the changes in Mamba1 as well

cc @ArthurZucker @molbap

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @younesbelkada for adding the states tuning-out! 😁 left a couple comments, mostly curious of some situations that were edge cases for mamba 2

@vasqu
Copy link
Contributor

vasqu commented Aug 14, 2024

Can we propagate this to Jamba as well :D thx for this fix ❤️

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! pinging @ArthurZucker for merging 🙂

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding a test 🤗

# In case cache is not used, manually add a new column in the attention mask
if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape:
pad_length = input_ids.shape[-1] - attention_mask.shape[-1]
attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :pad_length])], dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand why we are adding a [1] x batch_size? ( past_length is usually gonna be 1 - current_generation_token , so imagine 20 input ids, then -19 to slice the input_ids?
Unless the inpud_ids is 20, but then it always has the same shape as the mask

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for users that run generation with use_cache=False and makes sure to manually update the attention mask because this is done no where else except here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then this is more a problem with generate as it should pass the correct attention mask 😓

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will include this in the patch 🤗

Comment on lines 732 to 736
# In case cache is not used, manually update the attention mask
if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape:
past_length = input_ids.shape[-1] - attention_mask.shape[-1]
attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :past_length])], dim=-1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's the only thing bothering me as generate with use_cache = False should not alter the attention mask being passed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes fixed it !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -557,6 +574,7 @@ def set_input_embeddings(self, new_embeddings):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is breaking (having it as the second place)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes fixed it

Comment on lines +738 to +742
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch !

@ArthurZucker ArthurZucker merged commit 93e538a into huggingface:main Aug 19, 2024
21 checks passed
@younesbelkada younesbelkada deleted the fix-mamba-padding branch August 19, 2024 14:01
ArthurZucker added a commit that referenced this pull request Aug 20, 2024
* fix mamba left padding

* Apply suggestions from code review

Co-authored-by: Pablo Montalvo <[email protected]>

* fix copies

* test with `inputs_embeds`

* Update src/transformers/models/falcon_mamba/modeling_falcon_mamba.py

Co-authored-by: Arthur <[email protected]>

* copies

* clairfy

* fix last comments

* remove

---------

Co-authored-by: Pablo Montalvo <[email protected]>
Co-authored-by: Arthur <[email protected]>
ArthurZucker added a commit that referenced this pull request Aug 20, 2024
@vasqu vasqu mentioned this pull request Aug 21, 2024
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants