Skip to content

Commit

Permalink
Mamba / FalconMamba: Fix mamba left padding (#32677)
Browse files Browse the repository at this point in the history
* fix mamba left padding

* Apply suggestions from code review

Co-authored-by: Pablo Montalvo <[email protected]>

* fix copies

* test with `inputs_embeds`

* Update src/transformers/models/falcon_mamba/modeling_falcon_mamba.py

Co-authored-by: Arthur <[email protected]>

* copies

* clairfy

* fix last comments

* remove

---------

Co-authored-by: Pablo Montalvo <[email protected]>
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
3 people authored Aug 19, 2024
1 parent 59e8f19 commit 93e538a
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 22 deletions.
53 changes: 44 additions & 9 deletions src/transformers/models/falcon_mamba/modeling_falcon_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def cuda_kernels_forward(
hidden_states: torch.Tensor,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2)
Expand All @@ -179,6 +180,9 @@ def cuda_kernels_forward(
else:
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if cache_params is not None and cache_position[0] > 0:
Expand All @@ -200,6 +204,9 @@ def cuda_kernels_forward(
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
Expand Down Expand Up @@ -259,13 +266,17 @@ def slow_forward(
input_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
Expand Down Expand Up @@ -294,6 +305,9 @@ def slow_forward(
)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
Expand Down Expand Up @@ -355,10 +369,11 @@ def forward(
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position)
return self.slow_forward(hidden_states, cache_params, cache_position)
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)


# Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba
Expand Down Expand Up @@ -396,13 +411,16 @@ def forward(
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)

hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position)
hidden_states = self.mixer(
hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
)
hidden_states = residual + hidden_states
return hidden_states

Expand Down Expand Up @@ -601,14 +619,13 @@ def set_input_embeddings(self, new_embeddings):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # Ignored arg
inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[MambaCache] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
attention_mask: Optional[torch.LongTensor] = None,
) -> Union[Tuple, FalconMambaOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -649,10 +666,15 @@ def forward(
for mixer_block in self.layers:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
mixer_block.__call__, hidden_states, cache_params, cache_position
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
)
else:
hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position)
hidden_states = mixer_block(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -712,6 +734,13 @@ def _update_model_kwargs_for_generation(
and model_kwargs["cache_position"] is not None
):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens

if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)

return model_kwargs

def prepare_inputs_for_generation(
Expand All @@ -721,6 +750,7 @@ def prepare_inputs_for_generation(
use_cache=None,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**kwargs,
):
if use_cache:
Expand All @@ -733,6 +763,10 @@ def prepare_inputs_for_generation(
)
if cache_position[0] > 0:
input_ids = input_ids[:, -1].unsqueeze(-1)

if attention_mask is not None:
attention_mask = None

else:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
Expand All @@ -750,6 +784,7 @@ def prepare_inputs_for_generation(
"cache_params": cache_params,
"use_cache": use_cache,
"cache_position": cache_position,
"attention_mask": attention_mask,
}
)
return model_inputs
Expand All @@ -760,11 +795,10 @@ def prepare_inputs_for_generation(
output_type=FalconMambaCausalLMOutput,
config_class=_CONFIG_FOR_DOC,
)
# Ignore copy
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # Ignored copy
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[MambaCache] = None,
labels: Optional[torch.LongTensor] = None,
Expand All @@ -790,6 +824,7 @@ def forward(
return_dict=return_dict,
use_cache=use_cache,
cache_position=cache_position,
attention_mask=attention_mask,
)
hidden_states = falcon_mamba_outputs[0]

Expand Down
50 changes: 43 additions & 7 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def cuda_kernels_forward(
hidden_states: torch.Tensor,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2)
Expand All @@ -160,6 +161,9 @@ def cuda_kernels_forward(
else:
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if cache_params is not None and cache_position[0] > 0:
Expand All @@ -181,6 +185,9 @@ def cuda_kernels_forward(
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
Expand Down Expand Up @@ -226,13 +233,16 @@ def cuda_kernels_forward(
return contextualized_states

# fmt: off
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None):
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
Expand Down Expand Up @@ -261,6 +271,9 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca
)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
Expand Down Expand Up @@ -306,10 +319,11 @@ def forward(
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position)
return self.slow_forward(hidden_states, cache_params, cache_position)
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)


class MambaRMSNorm(nn.Module):
Expand Down Expand Up @@ -346,13 +360,16 @@ def forward(
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)

hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position)
hidden_states = self.mixer(
hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
)
hidden_states = residual + hidden_states
return hidden_states

Expand Down Expand Up @@ -563,7 +580,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
attention_mask: Optional[torch.LongTensor] = None,
) -> Union[Tuple, MambaOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -605,10 +622,15 @@ def forward(
for mixer_block in self.layers:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
mixer_block.__call__, hidden_states, cache_params, cache_position
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
)
else:
hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position)
hidden_states = mixer_block(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -668,6 +690,12 @@ def _update_model_kwargs_for_generation(
):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens

if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)

return model_kwargs

def prepare_inputs_for_generation(
Expand All @@ -677,6 +705,7 @@ def prepare_inputs_for_generation(
use_cache=None,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**kwargs,
):
if use_cache:
Expand All @@ -689,6 +718,10 @@ def prepare_inputs_for_generation(
)
if cache_position[0] > 0:
input_ids = input_ids[:, -1].unsqueeze(-1)

if attention_mask is not None:
attention_mask = None

else:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
Expand All @@ -706,6 +739,7 @@ def prepare_inputs_for_generation(
"cache_params": cache_params,
"use_cache": use_cache,
"cache_position": cache_position,
"attention_mask": attention_mask,
}
)
return model_inputs
Expand All @@ -719,6 +753,7 @@ def prepare_inputs_for_generation(
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[MambaCache] = None,
labels: Optional[torch.LongTensor] = None,
Expand All @@ -744,6 +779,7 @@ def forward(
return_dict=return_dict,
use_cache=use_cache,
cache_position=cache_position,
attention_mask=attention_mask,
)
hidden_states = mamba_outputs[0]

Expand Down
Loading

0 comments on commit 93e538a

Please sign in to comment.