-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix assisted decoding assistant model inputs #27503
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jiqing-feng thank you for promptly working on a fix! 🤗
I suspected this could be an issue, but I was a bit lazy and decided to rely on the tests. The catch is that the tests are stochastic (they are sample-based), so we has a lucky run in the past CI.
Regarding the PR itself: a few minor nits, and then it should be ready to go!
@jiqing-feng If possible, I would also like to revert these temporary changes in this PR :) |
🤗 thanks for the fix we had to skip it in #27508 as well! (Only the relevant test) |
Hi @gante @ArthurZucker . I think I have fixed all the comments and also added the tests you mentioned. Would you please help me review it? Thx! BTW, the failed test if not related to my changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for iterating 👍
@@ -759,10 +759,6 @@ def test_pt_tf_model_equivalence(self, allow_missing_keys=True): | |||
# Allow missing keys since TF doesn't cache the sinusoidal embeddings in an attribute | |||
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys) | |||
|
|||
@unittest.skip("Test failing, @RocketNight is looking into it") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, we need to keep this skip, it is what is causing the failure in CI!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
Before we merge can we update the non-slow test to be more robust? If the past CI was green because of a lucky sample, how do we know that this PR fixes it and this wasn't just another lucky run? e.g. can we set a seed which we know causes a failure on main and passes here?
@amyeroberts I'll have to think harder about assisted generation test robustness, as there are two conflicting effects in place:
On top of that, pinning a seed to a previous failure does not prevent bad failure checks in future models or flags. My suggestion would be: I'll work on test robustness today, and we merge this fix as is. WDYT? |
For my own understanding, why wouldn't a seed resolve the issues in randomness here? I'm guessing the tests are using
Agreed - but it should make sure that this one passes! For any future models or flags we should add new tests. In terms of tests to add - this relates back to my previous request here. It seems that the PR broke for a specific model type (encoder-decoder). Are there tests, which do not rely on randomness, which we can add that make sure just the API works? |
Hey @jiqing-feng, There is sadly still a bug with speculative decoding. The following doesn't work: from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import AutoModelForCausalLM
from datasets import load_dataset
import time
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v2"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
assistant_model_id = "distil-whisper/distil-large-v2"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
input_features = processor(sample["array"], return_tensors="pt").input_features.to("cuda").to(torch.float16)
# warm-up
_ = model.generate(input_features, assistant_model=assistant_model)
start_time = time.time()
out = model.generate(input_features, assistant_model=assistant_model)
# out = model.generate(input_features)
print(time.time() - start_time) |
@amyeroberts There is something odd here. We have a mixin test that should be catching API issues. I'm looking into it to attempt to figure out what's wrong. |
The following code snippet also needs to work: - assistant_model_id = "distil-whisper/distil-large-v2"
- assistant_model = AutoModelForCausalLM.from_pretrained(
- assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
-)
+ assistant_model_id = "openai/whisper-tiny"
+ assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
+ assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
+) But I think it does already |
I run the following script on my CPU device, and it works well. from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import AutoModelForCausalLM
from datasets import load_dataset
import time
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v2"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
assistant_model_id = "openai/whisper-tiny"
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
input_features = processor(sample["array"], return_tensors="pt").input_features.to(device).to(torch_dtype)
# warm-up
_ = model.generate(input_features, assistant_model=assistant_model)
start_time = time.time()
out = model.generate(input_features, assistant_model=assistant_model)
# out = model.generate(input_features)
print(time.time() - start_time) |
Hey @jiqing-feng, Thanks so much for quickly jumping on fixing the problem here 🙏 It sadly still doesn't fix Whisper distillation as per code snippet above. To make sure distil whisper works again on "main", we have now reverted the PR here: #27523 and also added two slow tests that should be run now everytime we do changes to assisted decoding: RUN_SLOW=1 pytest tests/models/whisper/test_modeling_whisper.py -k "distil" -sv It would be amazing if you could maybe try to open a new PR that is rebased to current "main" with all your nice changes and in which all fast tests as well as the slow tests pass: RUN_SLOW=1 pytest tests/models/whisper/test_modeling_whisper.py -k "distil" -sv Very sorry about the duplicated work here |
Hi @patrickvonplaten . There is no need to open a new PR. I have fixed the conflicts. There might be a mistake (here) that I see that you use |
Hi @jiqing-feng 👋 I have strengthened the test suite for assisted generation and did a small post mortem on why we didn't caught the issue in our tests in this PR. Let's merge that PR first and then rebase here, to ensure we don't break CI again 🤗 Again, apologies on our end for not having a robust enough test coverage! |
@jiqing-feng the improved assisted generation tests were merged 🤗 |
Hi @gante . I also updated my code base. Would you please help to merge this PR? Thx. |
Hi @jiqing-feng 👋 I got it working on my end, without the change you added on the Whisper test (which we must revert). It is a non-trivial set of changes, so I'm going to detail the entire diff :)
def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]:
"""Expands or crops the model's mask for decoding purposes, to the defined length"""
mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
if mask_key not in model_kwargs:
return model_kwargs
mask = model_kwargs[mask_key]
mask_length_diff = new_length - mask.shape[1]
if mask_length_diff < 0:
model_kwargs[mask_key] = mask[:, :mask_length_diff]
elif mask_length_diff > 0:
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
return model_kwargs
def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
"""Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
return model_kwargs
token_type_ids = model_kwargs["token_type_ids"]
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
type_length_diff = new_length - token_type_ids.shape[1]
if type_length_diff < 0:
token_type_ids = token_type_ids[:, :type_length_diff]
elif type_length_diff > 0:
token_type_copies = final_token_type.repeat(1, type_length_diff)
model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
return model_kwargs
assistant_kwargs = _prepare_attention_mask(
assistant_kwargs, new_cur_len, assistant_model.config.is_encoder_decoder
)
assistant_kwargs = _prepare_token_type_ids(assistant_kwargs, new_cur_len)
candidate_kwargs = copy.copy(model_kwargs)
candidate_kwargs = _prepare_attention_mask(
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
)
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
assistant_kwargs = copy.copy(model_kwargs)
if assistant_model.config.is_encoder_decoder:
# both are encoder-decoder
input_ids_key = "decoder_input_ids"
attention_key = "decoder_attention_mask"
assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs")
elif "assistant_encoder_outputs" in assistant_kwargs:
# special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
input_ids_key = "input_ids"
attention_key = "attention_mask"
assistant_kwargs["attention_mask"] = assistant_kwargs.get(
"decoder_attention_mask",
torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
)
assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs")
else:
# both are decoder-only
input_ids_key = "input_ids"
attention_key = "attention_mask" All these changes will make |
Hi @gante . Thanks for your review, I have updated all the changes you proposed. Would you please help me to check and merge it? Thx! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, thank you for working on the changes 💪
@jiqing-feng if possible, it would be nice to delete the now unused _extend_attention_mask
and _extend_token_type_ids
functions :)
@amyeroberts I've confirmed on my end that all relevant tests are passing:
RUN_SLOW=1 py.test tests/models/whisper/ -k speculative
py.test tests/ -k test_assisted_decoding_matches_greedy_search
py.test tests/ -k test_assisted_decoding_sample
Do you mean delete these 2 functions and replace all |
@jiqing-feng yes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating!
@gante Thanks for running the tests! Do these also cover the tests that were breaking previously for whisper? Happy to merge once we know it's whisper compatible 🤗
Yes, it is |
@jiqing-feng thank you for bearing with us 🤗 |
@gante D'oh sorry - PR blindness 🤦 Thanks for merging and thanks again @jiqing-feng for all the work iterating on this PR! |
In the last PR, we didn't consider the
decoder_attention_mask
while updatingmodel_kwargs
, see here. This PR has fixed it.Furthermore, I also use a cleaner way to process assistant models's inputs.
Hi @gante , would you please help me to review this PR? Thx!