diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 6557f4d7a7701a..07374fe1dfd7b5 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -155,6 +155,7 @@ def cuda_kernels_forward( hidden_states: torch.Tensor, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) @@ -179,6 +180,9 @@ def cuda_kernels_forward( else: hidden_states, gate = projected_states.chunk(2, dim=1) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if cache_params is not None and cache_position[0] > 0: @@ -200,6 +204,9 @@ def cuda_kernels_forward( hidden_states, conv_weights, self.conv1d.bias, activation=self.activation ) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) @@ -259,6 +266,7 @@ def slow_forward( input_states, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype @@ -266,6 +274,9 @@ def slow_forward( projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] hidden_states, gate = projected_states.chunk(2, dim=1) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 2. Convolution sequence transformation if cache_params is not None: ssm_state = cache_params.ssm_states[self.layer_idx].clone() @@ -294,6 +305,9 @@ def slow_forward( ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 3. State Space Model sequence transformation # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) @@ -355,10 +369,11 @@ def forward( hidden_states, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling(): - return self.cuda_kernels_forward(hidden_states, cache_params, cache_position) - return self.slow_forward(hidden_states, cache_params, cache_position) + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) # Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba @@ -396,13 +411,16 @@ def forward( hidden_states, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): residual = hidden_states hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) - hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position) + hidden_states = self.mixer( + hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask + ) hidden_states = residual + hidden_states return hidden_states @@ -601,14 +619,13 @@ def set_input_embeddings(self, new_embeddings): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, # Ignored arg inputs_embeds: Optional[torch.LongTensor] = None, cache_params: Optional[MambaCache] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it + attention_mask: Optional[torch.LongTensor] = None, ) -> Union[Tuple, FalconMambaOutput]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -649,10 +666,15 @@ def forward( for mixer_block in self.layers: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, cache_params, cache_position + mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask ) else: - hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position) + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -712,6 +734,13 @@ def _update_model_kwargs_for_generation( and model_kwargs["cache_position"] is not None ): model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens + + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + return model_kwargs def prepare_inputs_for_generation( @@ -721,6 +750,7 @@ def prepare_inputs_for_generation( use_cache=None, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, **kwargs, ): if use_cache: @@ -733,6 +763,10 @@ def prepare_inputs_for_generation( ) if cache_position[0] > 0: input_ids = input_ids[:, -1].unsqueeze(-1) + + if attention_mask is not None: + attention_mask = None + else: # we initialize the `cache_position` to full size of `conv_states` at prefill stage # considering padding will be applied when input length is shorter, and truncation @@ -750,6 +784,7 @@ def prepare_inputs_for_generation( "cache_params": cache_params, "use_cache": use_cache, "cache_position": cache_position, + "attention_mask": attention_mask, } ) return model_inputs @@ -760,11 +795,10 @@ def prepare_inputs_for_generation( output_type=FalconMambaCausalLMOutput, config_class=_CONFIG_FOR_DOC, ) - # Ignore copy def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, # Ignored copy + attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_params: Optional[MambaCache] = None, labels: Optional[torch.LongTensor] = None, @@ -790,6 +824,7 @@ def forward( return_dict=return_dict, use_cache=use_cache, cache_position=cache_position, + attention_mask=attention_mask, ) hidden_states = falcon_mamba_outputs[0] diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 23ab3ab142d075..14a3dea1d1ccf8 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -136,6 +136,7 @@ def cuda_kernels_forward( hidden_states: torch.Tensor, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) @@ -160,6 +161,9 @@ def cuda_kernels_forward( else: hidden_states, gate = projected_states.chunk(2, dim=1) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if cache_params is not None and cache_position[0] > 0: @@ -181,6 +185,9 @@ def cuda_kernels_forward( hidden_states, conv_weights, self.conv1d.bias, activation=self.activation ) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) @@ -226,13 +233,16 @@ def cuda_kernels_forward( return contextualized_states # fmt: off - def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None): + def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] hidden_states, gate = projected_states.chunk(2, dim=1) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 2. Convolution sequence transformation if cache_params is not None: ssm_state = cache_params.ssm_states[self.layer_idx].clone() @@ -261,6 +271,9 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 3. State Space Model sequence transformation # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) @@ -306,10 +319,11 @@ def forward( hidden_states, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling(): - return self.cuda_kernels_forward(hidden_states, cache_params, cache_position) - return self.slow_forward(hidden_states, cache_params, cache_position) + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) class MambaRMSNorm(nn.Module): @@ -346,13 +360,16 @@ def forward( hidden_states, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): residual = hidden_states hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) - hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position) + hidden_states = self.mixer( + hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask + ) hidden_states = residual + hidden_states return hidden_states @@ -563,7 +580,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it + attention_mask: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MambaOutput]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -605,10 +622,15 @@ def forward( for mixer_block in self.layers: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, cache_params, cache_position + mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask ) else: - hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position) + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -668,6 +690,12 @@ def _update_model_kwargs_for_generation( ): model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + return model_kwargs def prepare_inputs_for_generation( @@ -677,6 +705,7 @@ def prepare_inputs_for_generation( use_cache=None, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, **kwargs, ): if use_cache: @@ -689,6 +718,10 @@ def prepare_inputs_for_generation( ) if cache_position[0] > 0: input_ids = input_ids[:, -1].unsqueeze(-1) + + if attention_mask is not None: + attention_mask = None + else: # we initialize the `cache_position` to full size of `conv_states` at prefill stage # considering padding will be applied when input length is shorter, and truncation @@ -706,6 +739,7 @@ def prepare_inputs_for_generation( "cache_params": cache_params, "use_cache": use_cache, "cache_position": cache_position, + "attention_mask": attention_mask, } ) return model_inputs @@ -719,6 +753,7 @@ def prepare_inputs_for_generation( def forward( self, input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_params: Optional[MambaCache] = None, labels: Optional[torch.LongTensor] = None, @@ -744,6 +779,7 @@ def forward( return_dict=return_dict, use_cache=use_cache, cache_position=cache_position, + attention_mask=attention_mask, ) hidden_states = mamba_outputs[0] diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index 8e7c456e4a383b..d75014f370d29f 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -101,6 +101,7 @@ def prepare_config_and_inputs( self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False ): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = ids_tensor([self.batch_size, self.seq_length], 1) sequence_labels = None token_labels = None @@ -119,7 +120,7 @@ def prepare_config_and_inputs( return ( config, input_ids, - None, + attention_mask, sequence_labels, token_labels, choice_labels, @@ -153,6 +154,7 @@ def prepare_config_and_inputs_for_decoder(self): ( config, input_ids, + attention_mask, sequence_labels, token_labels, choice_labels, @@ -161,6 +163,7 @@ def prepare_config_and_inputs_for_decoder(self): return ( config, input_ids, + attention_mask, sequence_labels, token_labels, choice_labels, @@ -253,12 +256,12 @@ def prepare_config_and_inputs_for_common(self): ( config, input_ids, - _, + attention_mask, sequence_labels, token_labels, choice_labels, ) = self.prepare_config_and_inputs() - inputs_dict = {"input_ids": input_ids} + inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask} return config, inputs_dict @@ -491,3 +494,33 @@ def test_generation_torch_compile(self): self.tokenizer.batch_decode(out, skip_special_tokens=False)[0], "Hello today I am going to show you how to make a simple and easy to make paper plane.\nStep", ) + + def test_batched_generation(self): + model_id = "tiiuae/falcon-mamba-7b" + tok = AutoTokenizer.from_pretrained(model_id) + tok.pad_token_id = tok.eos_token_id + + texts = ["Hello today", "Hello my name is Younes and today"] + + EXPECTED_OUTPUT = [ + "Hello today I'm going to show you how to make a 3D model of a house.\n", + "Hello my name is Younes and today I will be talking about the topic of “The importance of the internet in our life”.\n", + ] + + inputs = tok(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(torch_device) + model = AutoModelForCausalLM.from_pretrained(model_id, device_map=0, torch_dtype=torch.bfloat16) + + out = model.generate(**inputs, max_new_tokens=20) + out = tok.batch_decode(out, skip_special_tokens=True) + + self.assertListEqual(out, EXPECTED_OUTPUT) + + # We test the same generations with inputs_embeds + with torch.no_grad(): + inputs_embeds = model.get_input_embeddings()(inputs.pop("input_ids")) + + inputs["inputs_embeds"] = inputs_embeds + out = model.generate(**inputs, max_new_tokens=20) + out = tok.batch_decode(out, skip_special_tokens=True) + + self.assertListEqual(out, EXPECTED_OUTPUT) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index e7e3a7242cddc4..54d35917556f6d 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -94,6 +94,7 @@ def prepare_config_and_inputs( self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False ): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = ids_tensor([self.batch_size, self.seq_length], 1) sequence_labels = None token_labels = None @@ -112,7 +113,7 @@ def prepare_config_and_inputs( return ( config, input_ids, - None, + attention_mask, sequence_labels, token_labels, choice_labels, @@ -146,6 +147,7 @@ def prepare_config_and_inputs_for_decoder(self): ( config, input_ids, + attention_mask, sequence_labels, token_labels, choice_labels, @@ -154,6 +156,7 @@ def prepare_config_and_inputs_for_decoder(self): return ( config, input_ids, + attention_mask, sequence_labels, token_labels, choice_labels, @@ -246,12 +249,12 @@ def prepare_config_and_inputs_for_common(self): ( config, input_ids, - _, + attention_mask, sequence_labels, token_labels, choice_labels, ) = self.prepare_config_and_inputs() - inputs_dict = {"input_ids": input_ids} + inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask} return config, inputs_dict