Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix assisted decoding assistant model inputs #27503

Merged
merged 23 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 35 additions & 49 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4541,6 +4541,16 @@ def assisted_decoding(
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)

# prepare assistant model's keys of inputs
assistant_kwargs = copy.copy(model_kwargs)
if assistant_model.config.is_encoder_decoder:
input_ids_key = "decoder_input_ids"
attention_key = "decoder_attention_mask"
assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs")
else:
input_ids_key = "input_ids"
attention_key = "attention_mask"

# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

Expand All @@ -4566,62 +4576,36 @@ def assisted_decoding(
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
# need access to the assistant cache to secure strong speedups.
candidate_input_ids = input_ids
assistant_attention_mask = model_kwargs.get("attention_mask", None)
assistant_decoder_attention_mask = model_kwargs.get("decoder_attention_mask", None)
assistant_encoder_outputs = (model_kwargs.get("assistant_encoder_outputs", None),)
for _ in range(int(num_assistant_tokens)):
# 1.1. use the assistant model to obtain the next candidate logits
# 1.1 prepare assistant model inputs
assistant_inputs = assistant_model.prepare_inputs_for_generation(
candidate_input_ids,
attention_mask=assistant_attention_mask,
decoder_attention_mask=assistant_decoder_attention_mask,
encoder_outputs=assistant_encoder_outputs,
past_key_values=model_kwargs.get("assistant_past_key_values", None),
**assistant_kwargs,
)
if assistant_inputs.get("past_key_values", None) is not None:
if assistant_model.config.is_encoder_decoder:
input_ids_len = assistant_inputs["decoder_input_ids"].shape[-1]
else:
input_ids_len = assistant_inputs["input_ids"].shape[-1]

if input_ids_len not in (1, 2):
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2")
# 1.2. check if the input ids length is correct
has_past_key_values = assistant_inputs.get("past_key_values", None) is not None
if has_past_key_values and assistant_inputs[input_ids_key].shape[-1] not in (1, 2):
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2")

# 1.3. use the assistant model to obtain the next candidate logits
assistant_model_outputs = assistant_model(**assistant_inputs)

# 1.2. greedily select the next candidate token
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
# 1.4. greedily select the next candidate token
if len(logits_processor) > 0:
assistant_model_outputs.logits[:, -1, :] = logits_processor(
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
)

new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
if assistant_model.config.is_encoder_decoder and assistant_decoder_attention_mask is not None:
assistant_decoder_attention_mask = torch.cat(
(
assistant_decoder_attention_mask,
torch.ones(
[1, 1],
dtype=assistant_decoder_attention_mask.dtype,
device=assistant_decoder_attention_mask.device,
),
),
dim=-1,
)
elif not assistant_model.config.is_encoder_decoder and assistant_attention_mask is not None:
assistant_attention_mask = torch.cat(
(
assistant_attention_mask,
torch.ones(
[1, 1], dtype=assistant_attention_mask.dtype, device=assistant_attention_mask.device
),
),
dim=-1,
)

# 1.3. stop assistant generation on EOS
# 1.5. update assistant model inputs
if assistant_kwargs.get(attention_key, None) is not None:
mask = assistant_kwargs[attention_key]
assistant_kwargs[attention_key] = torch.cat([mask, mask.new_ones((mask.shape[0], 1))], dim=-1)
assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values

# 1.6. stop assistant generation on EOS
if eos_token_id_tensor is not None:
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
last_assistant_token_is_eos = (
Expand Down Expand Up @@ -4693,8 +4677,8 @@ def assisted_decoding(
# 5.3. Discard past key values relative to unused assistant tokens
new_cache_size = new_cur_len - 1
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1
assistant_kwargs["past_key_values"] = _crop_past_key_values(
assistant_model, assistant_kwargs["past_key_values"], new_cache_size - 1
) # the assistant does not have the token after the last match, hence the -1

# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
Expand Down Expand Up @@ -4755,12 +4739,14 @@ def assisted_decoding(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)

# Update attention_mask for the assistant's next round of generations
if n_matches > 0 and model_kwargs.get("attention_mask", None) is not None:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], n_matches))], dim=-1
)
# Update assistant_kwargs for the assistant's next round of generations
if n_matches > 0:
model_kwargs = self._extend_attention_mask(model_kwargs, new_cur_len)
model_kwargs = self._extend_token_type_ids(model_kwargs, new_cur_len)
if attention_key in assistant_kwargs:
assistant_kwargs[attention_key] = model_kwargs.get(attention_key, None)
if "token_type_ids" in assistant_kwargs:
assistant_kwargs["token_type_ids"] = model_kwargs.get("token_type_ids", None)

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
Expand Down
1 change: 0 additions & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,7 +1576,6 @@ def test_assisted_decoding_matches_greedy_search(self):
for output in (output_greedy, output_assisted):
self._check_outputs(output, input_ids, model.config, use_cache=True)

@unittest.skip("Failing for a lot of models du to attention mask size missmatch. Works well when standalone.")
def test_assisted_decoding_sample(self):
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
Expand Down
4 changes: 0 additions & 4 deletions tests/models/nllb_moe/test_modeling_nllb_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,6 @@ def test_get_loss(self):
self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])

@unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it")
def test_assisted_decoding_sample(self):
pass


@require_torch
@require_sentencepiece
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -726,10 +726,6 @@ def test_generate_with_head_masking(self):
def test_disk_offload(self):
pass

@unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it")
def test_assisted_decoding_sample(self):
pass


class SwitchTransformersEncoderOnlyModelTester:
def __init__(
Expand Down
4 changes: 0 additions & 4 deletions tests/models/t5/test_modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,10 +1036,6 @@ def test_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)

@unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it")
def test_assisted_decoding_sample(self):
pass


def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task])
Expand Down
Loading