Skip to content

Commit

Permalink
In assisted decoding, pass model_kwargs to model's forward call (fix …
Browse files Browse the repository at this point in the history
…prepare_input_for_generation in all models) (huggingface#25242)

* In assisted decoding, pass model_kwargs to model's forward call

Previously, assisted decoding would ignore any additional kwargs
that it doesn't explicitly handle. This was inconsistent with other
generation methods, which pass the model_kwargs through
prepare_inputs_for_generation and forward the returned dict to the
model's forward call.

The prepare_inputs_for_generation method needs to be amended in all
models, as previously it only kept the last input ID when a past_key_values
was passed.

* Improve variable names in _extend_attention_mask

* Refactor extending token_type_ids into a function

* Replace deepcopy with copy to optimize performance

* Update new persimmon model with llama changes for assisted generation

* Update new mistral model for assisted generation with prepare_inputs_for_generation

* Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
  • Loading branch information
sinking-point authored and EduardoPach committed Nov 18, 2023
1 parent bfa82b1 commit 04368d1
Show file tree
Hide file tree
Showing 63 changed files with 911 additions and 179 deletions.
91 changes: 51 additions & 40 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,43 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de
UserWarning,
)

def _extend_attention_mask(self, model_kwargs: Dict[str, Any], new_mask_length: int) -> Dict[str, Any]:
if self.config.is_encoder_decoder:
key = "decoder_attention_mask"
else:
key = "attention_mask"

if key not in model_kwargs:
return model_kwargs

mask = model_kwargs[key]
mask_extension_length = new_mask_length - mask.shape[1]

if mask_extension_length < 0:
raise ValueError("Cannot extend attention mask to a length less than it already is")

model_kwargs[key] = torch.cat(
[mask, mask.new_ones((mask.shape[0], mask_extension_length))],
dim=-1,
)

return model_kwargs

def _extend_token_type_ids(self, model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
return model_kwargs

token_type_ids = model_kwargs["token_type_ids"]
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
extension_length = new_length - token_type_ids.shape[1]
token_type_copies = final_token_type.repeat(1, extension_length)
model_kwargs["token_type_ids"] = torch.cat(
[model_kwargs["token_type_ids"], token_type_copies],
dim=-1,
)

return model_kwargs

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -4441,47 +4478,21 @@ def assisted_decoding(
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
# we use this forward pass to also pick the subsequent logits in the original model.

# 2.1. Run a forward pass on the candidate sequence
if "past_key_values" in model_kwargs:
model_attn = torch.ones_like(candidate_input_ids)
model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
if self.config.is_encoder_decoder:
outputs = self(
decoder_input_ids=model_input_ids,
decoder_attention_mask=model_attn,
past_key_values=model_kwargs["past_key_values"],
encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
else:
outputs = self(
model_input_ids,
attention_mask=model_attn,
past_key_values=model_kwargs["past_key_values"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
else:
if self.config.is_encoder_decoder:
outputs = self(
decoder_input_ids=candidate_input_ids,
encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
else:
outputs = self(
candidate_input_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
# 2.1. Prepare the model inputs
candidate_kwargs = copy.copy(model_kwargs)
candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1])
candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])

model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)

# 2.2. Run a forward pass on the candidate sequence
outputs = self(
**model_inputs,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

# 2.2. Process the new logits
# 2.3. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0:
for i in range(candidate_length):
Expand Down
15 changes: 12 additions & 3 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
position_ids = kwargs.get("position_ids", None)

if past_key_values is not None:
# only last token for inputs_ids if past is defined in kwargs
# Omit tokens covered by past_key_values
seq_len = input_ids.shape[1]
input_ids = input_ids[:, [-1]]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

# input_embeds have already been used and is not required anymore
input_embeds = None
Expand All @@ -507,7 +516,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None

Expand Down
22 changes: 20 additions & 2 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1

decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
Expand Down Expand Up @@ -1934,7 +1943,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)

if past_key_values:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,16 @@ def prepare_inputs_for_generation(

# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

return {
"input_ids": input_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -993,9 +993,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}

Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -2628,9 +2628,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2627,7 +2627,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1

decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
Expand Down
15 changes: 12 additions & 3 deletions src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,9 +729,18 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, attention_mask, inputs_embeds=None, past_key_values=None, **kwargs
):
# only last token for inputs_ids if past is defined in kwargs
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
# only last tokens for inputs_ids if past is defined in kwargs
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
Expand Down
22 changes: 20 additions & 2 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1

decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
Expand Down Expand Up @@ -1622,7 +1631,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)

if past_key_values:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1

decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
Expand Down Expand Up @@ -1589,7 +1598,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)

if past_key_values:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/models/blip/modeling_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,16 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti

# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

return {
"input_ids": input_ids,
Expand Down
15 changes: 12 additions & 3 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,9 +844,18 @@ def prepare_inputs_for_generation(
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
# only last token for input_ids if past is not None
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
# only last tokens for input_ids if past is not None
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
Expand Down
Loading

0 comments on commit 04368d1

Please sign in to comment.