Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba / FalconMamba: Fix mamba left padding #32677

Merged
merged 10 commits into from
Aug 19, 2024
15 changes: 8 additions & 7 deletions src/transformers/models/falcon_mamba/modeling_falcon_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,14 +619,13 @@ def set_input_embeddings(self, new_embeddings):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # Ignored arg
inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[MambaCache] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
attention_mask: Optional[torch.LongTensor] = None, # Ignored arg
) -> Union[Tuple, FalconMambaOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -735,6 +734,13 @@ def _update_model_kwargs_for_generation(
and model_kwargs["cache_position"] is not None
):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens

if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
Comment on lines +738 to +742
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch !


return model_kwargs

def prepare_inputs_for_generation(
Expand Down Expand Up @@ -773,11 +779,6 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids.contiguous()}

# In case cache is not used, manually update the attention mask
if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape:
past_length = input_ids.shape[-1] - attention_mask.shape[-1]
attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :past_length])], dim=-1)

model_inputs.update(
{
"cache_params": cache_params,
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,13 @@ def set_input_embeddings(self, new_embeddings):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[MambaCache] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
attention_mask: Optional[torch.LongTensor] = None, # Ignored arg
) -> Union[Tuple, MambaOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -691,6 +690,12 @@ def _update_model_kwargs_for_generation(
):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens

if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)

return model_kwargs

def prepare_inputs_for_generation(
Expand Down Expand Up @@ -729,11 +734,6 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids.contiguous()}

# In case cache is not used, manually update the attention mask
if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape:
past_length = input_ids.shape[-1] - attention_mask.shape[-1]
attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :past_length])], dim=-1)

model_inputs.update(
{
"cache_params": cache_params,
Expand Down