Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba / FalconMamba: Fix mamba left padding #32677

Merged
merged 10 commits into from
Aug 19, 2024
53 changes: 44 additions & 9 deletions src/transformers/models/falcon_mamba/modeling_falcon_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def cuda_kernels_forward(
hidden_states: torch.Tensor,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2)
Expand All @@ -179,6 +180,9 @@ def cuda_kernels_forward(
else:
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if cache_params is not None and cache_position[0] > 0:
Expand All @@ -200,6 +204,9 @@ def cuda_kernels_forward(
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
Expand Down Expand Up @@ -259,13 +266,17 @@ def slow_forward(
input_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
Expand Down Expand Up @@ -294,6 +305,9 @@ def slow_forward(
)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
Expand Down Expand Up @@ -355,10 +369,11 @@ def forward(
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position)
return self.slow_forward(hidden_states, cache_params, cache_position)
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)


# Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba
Expand Down Expand Up @@ -396,13 +411,16 @@ def forward(
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)

hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position)
hidden_states = self.mixer(
hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
)
hidden_states = residual + hidden_states
return hidden_states

Expand Down Expand Up @@ -601,14 +619,13 @@ def set_input_embeddings(self, new_embeddings):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # Ignored arg
inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[MambaCache] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
attention_mask: Optional[torch.LongTensor] = None,
) -> Union[Tuple, FalconMambaOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -649,10 +666,15 @@ def forward(
for mixer_block in self.layers:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
mixer_block.__call__, hidden_states, cache_params, cache_position
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
)
else:
hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position)
hidden_states = mixer_block(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -712,6 +734,13 @@ def _update_model_kwargs_for_generation(
and model_kwargs["cache_position"] is not None
):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens

if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
Comment on lines +738 to +742
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch !


return model_kwargs

def prepare_inputs_for_generation(
Expand All @@ -721,6 +750,7 @@ def prepare_inputs_for_generation(
use_cache=None,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**kwargs,
):
if use_cache:
Expand All @@ -733,6 +763,10 @@ def prepare_inputs_for_generation(
)
if cache_position[0] > 0:
input_ids = input_ids[:, -1].unsqueeze(-1)

if attention_mask is not None:
attention_mask = None

else:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
Expand All @@ -750,6 +784,7 @@ def prepare_inputs_for_generation(
"cache_params": cache_params,
"use_cache": use_cache,
"cache_position": cache_position,
"attention_mask": attention_mask,
}
)
return model_inputs
Expand All @@ -760,11 +795,10 @@ def prepare_inputs_for_generation(
output_type=FalconMambaCausalLMOutput,
config_class=_CONFIG_FOR_DOC,
)
# Ignore copy
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # Ignored copy
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[MambaCache] = None,
labels: Optional[torch.LongTensor] = None,
Expand All @@ -790,6 +824,7 @@ def forward(
return_dict=return_dict,
use_cache=use_cache,
cache_position=cache_position,
attention_mask=attention_mask,
)
hidden_states = falcon_mamba_outputs[0]

Expand Down
50 changes: 43 additions & 7 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def cuda_kernels_forward(
hidden_states: torch.Tensor,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2)
Expand All @@ -160,6 +161,9 @@ def cuda_kernels_forward(
else:
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if cache_params is not None and cache_position[0] > 0:
Expand All @@ -181,6 +185,9 @@ def cuda_kernels_forward(
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
Expand Down Expand Up @@ -226,13 +233,16 @@ def cuda_kernels_forward(
return contextualized_states

# fmt: off
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None):
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
Expand Down Expand Up @@ -261,6 +271,9 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca
)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
molbap marked this conversation as resolved.
Show resolved Hide resolved

# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
Expand Down Expand Up @@ -306,10 +319,11 @@ def forward(
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position)
return self.slow_forward(hidden_states, cache_params, cache_position)
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)


class MambaRMSNorm(nn.Module):
Expand Down Expand Up @@ -346,13 +360,16 @@ def forward(
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)

hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position)
hidden_states = self.mixer(
hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
)
hidden_states = residual + hidden_states
return hidden_states

Expand Down Expand Up @@ -563,7 +580,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
attention_mask: Optional[torch.LongTensor] = None,
) -> Union[Tuple, MambaOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -605,10 +622,15 @@ def forward(
for mixer_block in self.layers:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
mixer_block.__call__, hidden_states, cache_params, cache_position
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
)
else:
hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position)
hidden_states = mixer_block(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -668,6 +690,12 @@ def _update_model_kwargs_for_generation(
):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens

if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)

return model_kwargs

def prepare_inputs_for_generation(
Expand All @@ -677,6 +705,7 @@ def prepare_inputs_for_generation(
use_cache=None,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**kwargs,
):
if use_cache:
Expand All @@ -689,6 +718,10 @@ def prepare_inputs_for_generation(
)
if cache_position[0] > 0:
input_ids = input_ids[:, -1].unsqueeze(-1)

if attention_mask is not None:
attention_mask = None

else:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
Expand All @@ -706,6 +739,7 @@ def prepare_inputs_for_generation(
"cache_params": cache_params,
"use_cache": use_cache,
"cache_position": cache_position,
"attention_mask": attention_mask,
}
)
return model_inputs
Expand All @@ -719,6 +753,7 @@ def prepare_inputs_for_generation(
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[MambaCache] = None,
labels: Optional[torch.LongTensor] = None,
Expand All @@ -744,6 +779,7 @@ def forward(
return_dict=return_dict,
use_cache=use_cache,
cache_position=cache_position,
attention_mask=attention_mask,
)
hidden_states = mamba_outputs[0]

Expand Down
Loading
Loading