Skip to content

Commit

Permalink
fix generation input preparation (#1512)
Browse files Browse the repository at this point in the history
* fix generation input preparation

* fix
  • Loading branch information
echarlaix authored Nov 3, 2023
1 parent 308d282 commit 554e312
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 100 deletions.
7 changes: 5 additions & 2 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,11 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
# TODO: Remove this if once transformers if much above 4.35
if AttentionMaskConverter is not None:
AttentionMaskConverter._make_causal_mask = self.original_make_causal
# TODO: We should unpatch it - however `self._make_causal_mask` may still be called later which raises issues with this simple patch strategy.
# We need to find a proper solution.
# if AttentionMaskConverter is not None:
# AttentionMaskConverter._make_causal_mask = self.original_make_causal
pass

def __init__(
self,
Expand Down
51 changes: 48 additions & 3 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def forward(
loss = None
if self.use_cache:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
past_key_values = tuple(
Expand Down Expand Up @@ -630,8 +629,17 @@ def _from_transformers(

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

Expand Down Expand Up @@ -667,6 +675,16 @@ def can_generate(self):
class ORTBloomForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

Expand Down Expand Up @@ -706,6 +724,16 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
class ORTOPTForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

Expand All @@ -721,6 +749,16 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
class ORTMPTForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

Expand Down Expand Up @@ -794,7 +832,14 @@ def prepare_inputs_for_generation(
**kwargs,
) -> dict:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

# the cache may be in the stardard format (e.g. in contrastive search), convert to falcon's format if needed
if len(past_key_values[0][0].shape) == 4:
Expand Down
183 changes: 88 additions & 95 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,30 +1192,18 @@ def forward(
if encoder_outputs is None:
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)

# Decode
if past_key_values is None or self.use_cache is False:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
elif self.use_merged is True:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids[:, -1:],
encoder_hidden_states=encoder_outputs.last_hidden_state,
past_key_values=past_key_values,
encoder_attention_mask=attention_mask,
labels=labels,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
model = (
self.decoder
if past_key_values is None or not self.use_cache or self.use_merged
else self.decoder_with_past
)
decoder_outputs = model(
input_ids=decoder_input_ids,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)

return Seq2SeqLMOutput(
loss=decoder_outputs.get("loss", None),
Expand All @@ -1236,6 +1224,16 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
) -> Dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
Expand Down Expand Up @@ -1331,28 +1329,18 @@ def forward(
if encoder_outputs is None:
encoder_outputs = self.encoder(input_features=input_features, attention_mask=attention_mask)

# Decode
if past_key_values is None or self.use_cache is False:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)
elif self.use_merged is True:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids[:, -1:],
encoder_hidden_states=encoder_outputs.last_hidden_state,
past_key_values=past_key_values,
encoder_attention_mask=attention_mask,
labels=labels,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)
model = (
self.decoder
if past_key_values is None or not self.use_cache or self.use_merged
else self.decoder_with_past
)
decoder_outputs = model(
input_ids=decoder_input_ids,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)

return Seq2SeqLMOutput(
loss=decoder_outputs.get("loss", None),
Expand All @@ -1372,6 +1360,16 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
) -> Dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
Expand Down Expand Up @@ -1526,27 +1524,17 @@ def forward(
if encoder_outputs is None:
encoder_outputs = self.encoder(pixel_values=pixel_values)

# Decode
if past_key_values is None or self.use_cache is False:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)
elif self.use_merged is True:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids[:, -1:],
encoder_hidden_states=encoder_outputs.last_hidden_state,
past_key_values=past_key_values,
labels=labels,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)
model = (
self.decoder
if past_key_values is None or not self.use_cache or self.use_merged
else self.decoder_with_past
)
decoder_outputs = model(
input_ids=decoder_input_ids,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)

return Seq2SeqLMOutput(
loss=decoder_outputs.get("loss", None),
Expand All @@ -1565,6 +1553,16 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
) -> Dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
Expand Down Expand Up @@ -1641,34 +1639,19 @@ def forward(
else:
attention_mask = attention_mask.astype(np.int64)

# Decode
if past_key_values is None or self.use_cache is False:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
elif self.use_merged is True:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids[:, -1:],
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
model = (
self.decoder
if past_key_values is None or not self.use_cache or self.use_merged
else self.decoder_with_past
)
decoder_outputs = model(
input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)

return Seq2SeqLMOutput(
loss=decoder_outputs.get("loss", None),
Expand All @@ -1690,6 +1673,16 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
) -> Dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

if decoder_attention_mask is None:
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)

Expand Down

0 comments on commit 554e312

Please sign in to comment.