Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba & RecurrentGemma: enable strict signature #31549

Merged
merged 3 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 27 additions & 36 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2642,13 +2642,12 @@ def _sample(
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
Comment on lines +2646 to +2647
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yesssss I think I have a PR open where I dod this! Finally!


# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
outputs = self(**model_inputs, return_dict=True)

if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
Expand Down Expand Up @@ -2869,6 +2868,10 @@ def _beam_search(
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

# if sequential is True, split the input to batches of batch_size and run sequentially
if sequential:
if any(
Expand All @@ -2894,24 +2897,13 @@ def _beam_search(
model_inputs, split_size=batch_size, full_batch_size=batch_beam_size
)
outputs_per_sub_batch = [
self(
**inputs_per_sub_batch,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
for inputs_per_sub_batch in inputs_per_sub_batches
self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches
]

outputs = stack_model_outputs(outputs_per_sub_batch)

else: # Unchanged original behavior
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
outputs = self(**model_inputs, return_dict=True)

if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
Expand Down Expand Up @@ -3191,12 +3183,12 @@ def _group_beam_search(

# do one decoder step on all beams of all sentences in batch
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

outputs = self(**model_inputs, return_dict=True)

if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
Expand Down Expand Up @@ -3472,12 +3464,11 @@ def _constrained_beam_search(
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

outputs = self(**model_inputs, return_dict=True)

if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
Expand Down Expand Up @@ -3740,11 +3731,11 @@ def _assisted_decoding(
model_inputs["num_logits_to_keep"] = candidate_length + 1

# 2.2. Run a forward pass on the candidate sequence
outputs = self(
**model_inputs,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

outputs = self(**model_inputs)

# 2.3. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,6 @@ def forward(
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
Copy link
Member Author

@gante gante Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternatively, we can accept attention_mask and raise an exception when it is not None or not all ones

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this will break FDSP :( See #31161

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts I had a look and it should be fine: this PR removes **kwargs from the model class (e.g. MambaModel), while the FSDP PR ensures there are **kwargs in the decoder layers (e.g. FalconDecoderLayer).

We can see on main that the model themselves don't have **kwargs, even after the FSDP fix (e.g. llama) 🤗

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK!

) -> Union[Tuple, MambaOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -673,7 +672,6 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
**kwargs, # for now we need this for generation
) -> Union[Tuple, MambaCausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,6 @@ def forward(
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -823,7 +822,6 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
**kwargs, # for now we need this for generation
) -> Union[Tuple, CausalLMOutput]:
r"""
Args:
Expand Down
Loading