Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix assisted decoding assistant model inputs #27503

Merged
merged 23 commits into from
Nov 27, 2023
Merged

Conversation

jiqing-feng
Copy link
Contributor

@jiqing-feng jiqing-feng commented Nov 15, 2023

In the last PR, we didn't consider the decoder_attention_mask while updating model_kwargs, see here. This PR has fixed it.

Furthermore, I also use a cleaner way to process assistant models's inputs.

Hi @gante , would you please help me to review this PR? Thx!

@jiqing-feng jiqing-feng changed the title fix assisted decoding attention_cat fix assisted decoding assistant model inputs Nov 15, 2023
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiqing-feng thank you for promptly working on a fix! 🤗

I suspected this could be an issue, but I was a bit lazy and decided to rely on the tests. The catch is that the tests are stochastic (they are sample-based), so we has a lucky run in the past CI.

Regarding the PR itself: a few minor nits, and then it should be ready to go!

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
@gante
Copy link
Member

gante commented Nov 15, 2023

@jiqing-feng If possible, I would also like to revert these temporary changes in this PR :)

@ArthurZucker
Copy link
Collaborator

🤗 thanks for the fix we had to skip it in #27508 as well! (Only the relevant test)

@jiqing-feng
Copy link
Contributor Author

Hi @gante @ArthurZucker . I think I have fixed all the comments and also added the tests you mentioned. Would you please help me review it? Thx!

BTW, the failed test if not related to my changes.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for iterating 👍

@@ -759,10 +759,6 @@ def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
# Allow missing keys since TF doesn't cache the sinusoidal embeddings in an attribute
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)

@unittest.skip("Test failing, @RocketNight is looking into it")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we need to keep this skip, it is what is causing the failure in CI!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

Before we merge can we update the non-slow test to be more robust? If the past CI was green because of a lucky sample, how do we know that this PR fixes it and this wasn't just another lucky run? e.g. can we set a seed which we know causes a failure on main and passes here?

@gante
Copy link
Member

gante commented Nov 16, 2023

@amyeroberts I'll have to think harder about assisted generation test robustness, as there are two conflicting effects in place:

  1. In theory, assisted generation should yield the exact same outputs
  2. In practice, due to the matrix multiplication being shape-dependent (see here), there will be tiny fluctuations. With random models, this means that the odds of a simple assisted vs non-assisted output check failing are high.

On top of that, pinning a seed to a previous failure does not prevent bad failure checks in future models or flags.

My suggestion would be: I'll work on test robustness today, and we merge this fix as is. WDYT?

@amyeroberts
Copy link
Collaborator

In practice, due to the matrix multiplication being shape-dependent (see #25420 (comment)), there will be tiny fluctuations. With random models, this means that the odds of a simple assisted vs non-assisted output check failing are high.

For my own understanding, why wouldn't a seed resolve the issues in randomness here? I'm guessing the tests are using hf-internal-testing/tiny-random-model-name which can change?

On top of that, pinning a seed to a previous failure does not prevent bad failure checks in future models or flags.

Agreed - but it should make sure that this one passes! For any future models or flags we should add new tests.

In terms of tests to add - this relates back to my previous request here. It seems that the PR broke for a specific model type (encoder-decoder). Are there tests, which do not rely on randomness, which we can add that make sure just the API works?

@patrickvonplaten
Copy link
Contributor

Hey @jiqing-feng,

There is sadly still a bug with speculative decoding. The following doesn't work:

from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import AutoModelForCausalLM
from datasets import load_dataset
import time
import torch

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v2"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

assistant_model_id = "distil-whisper/distil-large-v2"
assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]

input_features = processor(sample["array"], return_tensors="pt").input_features.to("cuda").to(torch.float16)

# warm-up
_ = model.generate(input_features, assistant_model=assistant_model)

start_time = time.time()
out = model.generate(input_features, assistant_model=assistant_model)
# out = model.generate(input_features)
print(time.time() - start_time)

@gante
Copy link
Member

gante commented Nov 16, 2023

@amyeroberts There is something odd here. We have a mixin test that should be catching API issues. I'm looking into it to attempt to figure out what's wrong.

@patrickvonplaten
Copy link
Contributor

The following code snippet also needs to work:

- assistant_model_id = "distil-whisper/distil-large-v2"
- assistant_model = AutoModelForCausalLM.from_pretrained(
-    assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
-)
+ assistant_model_id = "openai/whisper-tiny"
+ assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
+    assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
+)

But I think it does already

@jiqing-feng
Copy link
Contributor Author

Hi @patrickvonplaten

I run the following script on my CPU device, and it works well.

from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import AutoModelForCausalLM
from datasets import load_dataset
import time
import torch

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v2"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

assistant_model_id = "openai/whisper-tiny"
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
   assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]

input_features = processor(sample["array"], return_tensors="pt").input_features.to(device).to(torch_dtype)

# warm-up
_ = model.generate(input_features, assistant_model=assistant_model)

start_time = time.time()
out = model.generate(input_features, assistant_model=assistant_model)
# out = model.generate(input_features)
print(time.time() - start_time)

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Nov 16, 2023

Hey @jiqing-feng,

Thanks so much for quickly jumping on fixing the problem here 🙏

It sadly still doesn't fix Whisper distillation as per code snippet above. To make sure distil whisper works again on "main", we have now reverted the PR here: #27523 and also added two slow tests that should be run now everytime we do changes to assisted decoding:

RUN_SLOW=1 pytest tests/models/whisper/test_modeling_whisper.py -k "distil" -sv

It would be amazing if you could maybe try to open a new PR that is rebased to current "main" with all your nice changes and in which all fast tests as well as the slow tests pass:

RUN_SLOW=1 pytest tests/models/whisper/test_modeling_whisper.py -k "distil" -sv

Very sorry about the duplicated work here

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Nov 16, 2023

Hi @patrickvonplaten . There is no need to open a new PR. I have fixed the conflicts.

There might be a mistake (here) that I see that you use distil-whisper/distil-large-v2 as assistant model, and you use WhisperForCausalLM to load a WhisperForConditionalGeneration model. The distil-whisper/distil-large-v2 is an encoder-decoder model, so I use WhisperForConditionalGeneration to load it (it is also the original architectures in the model card). After this change, I can successfully run RUN_SLOW=1 pytest tests/models/whisper/test_modeling_whisper.py -k "distil" -sv on my current changes.

@gante
Copy link
Member

gante commented Nov 16, 2023

Hi @jiqing-feng 👋

I have strengthened the test suite for assisted generation and did a small post mortem on why we didn't caught the issue in our tests in this PR.

Let's merge that PR first and then rebase here, to ensure we don't break CI again 🤗

Again, apologies on our end for not having a robust enough test coverage!

@gante
Copy link
Member

gante commented Nov 16, 2023

@jiqing-feng the improved assisted generation tests were merged 🤗

@jiqing-feng
Copy link
Contributor Author

Hi @gante . I also updated my code base. Would you please help to merge this PR? Thx.

@gante
Copy link
Member

gante commented Nov 17, 2023

Hi @jiqing-feng 👋

I got it working on my end, without the change you added on the Whisper test (which we must revert). It is a non-trivial set of changes, so I'm going to detail the entire diff :)

  1. Remove the self._extend_attention_mask and self._extend_token_type_ids functions from the GenerationMixin
  2. Replace them by the following stand-alone functions, which can be added at the bottom of the file
def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]:
    """Expands or crops the model's mask for decoding purposes, to the defined length"""

    mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
    if mask_key not in model_kwargs:
        return model_kwargs

    mask = model_kwargs[mask_key]
    mask_length_diff = new_length - mask.shape[1]

    if mask_length_diff < 0:
        model_kwargs[mask_key] = mask[:, :mask_length_diff]
    elif mask_length_diff > 0:
        model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
    return model_kwargs


def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
    """Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
    if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
        return model_kwargs

    token_type_ids = model_kwargs["token_type_ids"]
    final_token_type = token_type_ids[:, -1].unsqueeze(-1)
    type_length_diff = new_length - token_type_ids.shape[1]

    if type_length_diff < 0:
        token_type_ids = token_type_ids[:, :type_length_diff]
    elif type_length_diff > 0:
        token_type_copies = final_token_type.repeat(1, type_length_diff)
        model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
    return model_kwargs
  1. Replace the code after # Update assistant_kwargs for the assistant's next round of generations by
            assistant_kwargs = _prepare_attention_mask(
                assistant_kwargs, new_cur_len, assistant_model.config.is_encoder_decoder
            )
            assistant_kwargs = _prepare_token_type_ids(assistant_kwargs, new_cur_len)
  1. Replace the code after # 2.1. Prepare the model inputs by
            candidate_kwargs = copy.copy(model_kwargs)
            candidate_kwargs = _prepare_attention_mask(
                candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
            )
            candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])

            model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
  1. Replace the code after # prepare assistant model's keys of inputs by
        assistant_kwargs = copy.copy(model_kwargs)
        if assistant_model.config.is_encoder_decoder:
            # both are encoder-decoder
            input_ids_key = "decoder_input_ids"
            attention_key = "decoder_attention_mask"
            assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs")
        elif "assistant_encoder_outputs" in assistant_kwargs:
            # special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
            input_ids_key = "input_ids"
            attention_key = "attention_mask"
            assistant_kwargs["attention_mask"] = assistant_kwargs.get(
                "decoder_attention_mask",
                torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
            )
            assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs")
        else:
            # both are decoder-only
            input_ids_key = "input_ids"
            attention_key = "attention_mask"

All these changes will make assisted_generation compatible with all use cases, even the more complex DistilWhisper 🤗

@jiqing-feng
Copy link
Contributor Author

Hi @gante . Thanks for your review, I have updated all the changes you proposed. Would you please help me to check and merge it? Thx!

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thank you for working on the changes 💪

@jiqing-feng if possible, it would be nice to delete the now unused _extend_attention_mask and _extend_token_type_ids functions :)

@amyeroberts I've confirmed on my end that all relevant tests are passing:

  1. RUN_SLOW=1 py.test tests/models/whisper/ -k speculative
  2. py.test tests/ -k test_assisted_decoding_matches_greedy_search
  3. py.test tests/ -k test_assisted_decoding_sample

@jiqing-feng
Copy link
Contributor Author

Perfect, thank you for working on the changes 💪

@jiqing-feng if possible, it would be nice to delete the now unused _extend_attention_mask and _extend_token_type_ids functions :)

@amyeroberts I've confirmed on my end that all relevant tests are passing:

  1. RUN_SLOW=1 py.test tests/models/whisper/ -k speculative
  2. py.test tests/ -k test_assisted_decoding_matches_greedy_search
  3. py.test tests/ -k test_assisted_decoding_sample

Do you mean delete these 2 functions and replace all _extend_xxx functions with our new _prepare_xxx functions?

@gante
Copy link
Member

gante commented Nov 21, 2023

@jiqing-feng yes _extend_attention_mask and _extend_token_type_ids -- are not used anywhere in the code

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

@gante Thanks for running the tests! Do these also cover the tests that were breaking previously for whisper? Happy to merge once we know it's whisper compatible 🤗

@gante
Copy link
Member

gante commented Nov 27, 2023

Do these also cover the tests that were breaking previously for whisper? Happy to merge once we know it's whisper compatible 🤗

Yes, it is RUN_SLOW=1 py.test tests/models/whisper/ -k speculative in the list of tests above :) Merging!

@gante gante merged commit 1d7f406 into huggingface:main Nov 27, 2023
20 checks passed
@gante
Copy link
Member

gante commented Nov 27, 2023

@jiqing-feng thank you for bearing with us 🤗

@amyeroberts
Copy link
Collaborator

@gante D'oh sorry - PR blindness 🤦 Thanks for merging and thanks again @jiqing-feng for all the work iterating on this PR!

@jiqing-feng jiqing-feng deleted the assisted branch December 13, 2023 07:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants