Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) #25242

91 changes: 51 additions & 40 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,43 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de
UserWarning,
)

def _extend_attention_mask(self, model_kwargs: Dict[str, Any], new_mask_length: int) -> Dict[str, Any]:
if self.config.is_encoder_decoder:
key = "decoder_attention_mask"
else:
key = "attention_mask"

if key not in model_kwargs:
return model_kwargs

mask = model_kwargs[key]
mask_extension_length = new_mask_length - mask.shape[1]

if mask_extension_length < 0:
raise ValueError("Cannot extend attention mask to a length less than it already is")

model_kwargs[key] = torch.cat(
[mask, mask.new_ones((mask.shape[0], mask_extension_length))],
dim=-1,
)

return model_kwargs

def _extend_token_type_ids(self, model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
return model_kwargs

token_type_ids = model_kwargs["token_type_ids"]
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
extension_length = new_length - token_type_ids.shape[1]
token_type_copies = final_token_type.repeat(1, extension_length)
model_kwargs["token_type_ids"] = torch.cat(
[model_kwargs["token_type_ids"], token_type_copies],
dim=-1,
)

return model_kwargs

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -4441,47 +4478,21 @@ def assisted_decoding(
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
# we use this forward pass to also pick the subsequent logits in the original model.

# 2.1. Run a forward pass on the candidate sequence
if "past_key_values" in model_kwargs:
model_attn = torch.ones_like(candidate_input_ids)
model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
if self.config.is_encoder_decoder:
outputs = self(
decoder_input_ids=model_input_ids,
decoder_attention_mask=model_attn,
past_key_values=model_kwargs["past_key_values"],
encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
else:
outputs = self(
model_input_ids,
attention_mask=model_attn,
past_key_values=model_kwargs["past_key_values"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
else:
if self.config.is_encoder_decoder:
outputs = self(
decoder_input_ids=candidate_input_ids,
encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
else:
outputs = self(
candidate_input_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
# 2.1. Prepare the model inputs
candidate_kwargs = copy.copy(model_kwargs)
candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1])
candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])

model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
sinking-point marked this conversation as resolved.
Show resolved Hide resolved

# 2.2. Run a forward pass on the candidate sequence
outputs = self(
**model_inputs,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

# 2.2. Process the new logits
# 2.3. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0:
for i in range(candidate_length):
Expand Down
15 changes: 12 additions & 3 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
position_ids = kwargs.get("position_ids", None)

if past_key_values is not None:
# only last token for inputs_ids if past is defined in kwargs
# Omit tokens covered by past_key_values
seq_len = input_ids.shape[1]
input_ids = input_ids[:, [-1]]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]
sinking-point marked this conversation as resolved.
Show resolved Hide resolved

# input_embeds have already been used and is not required anymore
input_embeds = None
Expand All @@ -507,7 +516,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None

Expand Down
22 changes: 20 additions & 2 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1

decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
Expand Down Expand Up @@ -1934,7 +1943,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)

if past_key_values:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,16 @@ def prepare_inputs_for_generation(

# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

return {
"input_ids": input_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -993,9 +993,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}

Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -2628,9 +2628,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2627,7 +2627,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1

decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
Expand Down
15 changes: 12 additions & 3 deletions src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,9 +729,18 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, attention_mask, inputs_embeds=None, past_key_values=None, **kwargs
):
# only last token for inputs_ids if past is defined in kwargs
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
# only last tokens for inputs_ids if past is defined in kwargs
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
Expand Down
22 changes: 20 additions & 2 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1

decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
Expand Down Expand Up @@ -1622,7 +1631,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)

if past_key_values:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1

decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
Expand Down Expand Up @@ -1589,7 +1598,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)

if past_key_values:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/models/blip/modeling_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,16 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti

# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

return {
"input_ids": input_ids,
Expand Down
15 changes: 12 additions & 3 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,9 +844,18 @@ def prepare_inputs_for_generation(
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
# only last token for input_ids if past is not None
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
# only last tokens for input_ids if past is not None
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
Expand Down
Loading
Loading