Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct llava mask & fix missing setter for vocab_size #29389

Merged
merged 8 commits into from
Mar 22, 2024
Merged

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Mar 1, 2024

This PR fixes llava mask in the generation case.

torch.cat((attention_mask, extended_mask), dim=-1) was very wrong and it is a miracle we were getting meaningful generation for batch generation. The issue is that this disregards the attention_mask used during the first forward pass, that is custom. Wrong past key values are used.

It is also not clear to me why this custom masking is handled in the forward and not in GenerationMixin cc @gante (we are doing unnecessary operations at each forward to retrieve the correct attention_mask, it could just be an output of the model & updated in GenerationMixin)

Fixes #28184

@fxmarty fxmarty requested review from younesbelkada, ArthurZucker and amyeroberts and removed request for younesbelkada March 1, 2024 09:41
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Could you add a test / update test values? Or has this no impact on the current tests?

@gante
Copy link
Member

gante commented Mar 5, 2024

It is also not clear to me why this custom masking is handled in the forward and not in GenerationMixin cc @gante (we are doing unnecessary operations at each forward to retrieve the correct attention_mask, it could just be an output of the model & updated in GenerationMixin)

@fxmarty We do that by default (in this function) 🤔

@zucchini-nlp since you're interested in multimodal models: after this PR gets merged, can you inspect why llava (and related models) need this custom attention mask handling at generation time?

@zucchini-nlp
Copy link
Member

@gante , I've been affected by this also while trying to make speculative decoding work with VLMs. I'll briefly outline what I found.

The custom mask is needed in Llava because it concatenated image embeds and text, which is further used as inputs_embeds for the LLM part. Thus after pre-fill they require attention mask of much larger length than what we got initially. When we update model_kwargs in generate, we assume one token is added to prev attention mask. This is not the case for Llava, since the attention mask holds values for textual part only, so the forward adds masks for the image part every call.

Comparing to other models, I have counted only 3 soft prompt VLMs in transformers (Kosmos-2, Blip, Llava). The other two models do not have this issue, because they have custom generate where embeds are concatenated, new attention masks are prepared manually and only then text_model.generate() is called for further handling. But Llava calls composite_model.generate() directly.

We can either fix Llava to call text_model.generate() for consistency with others or update model_kwargs depending on past_key_values length in generate. I was also thinking that if the second case, we can get rid of using _prepare_token_type_ids and _prepare_attention_mask in speculative decoding.

It's up to you to decide if we need any changes, I have no idea how this can affect library-wide 😄

@gante
Copy link
Member

gante commented Mar 5, 2024

We can either fix Llava to call text_model.generate() for consistency with others or update model_kwargs depending on past_key_values length in generate. I was also thinking that if the second case, we can get rid of using _prepare_token_type_ids and _prepare_attention_mask in speculative decoding.

@zucchini-nlp I'd say to go with "call text_model.generate() for consistency", if that is feasible and results in a clean interface. generate is still quite rigid at the moment, so handling more cases will slow down ongoing projects. After we refactor generate to make it more flexible, we can revisit this decision :)

@@ -307,10 +307,50 @@ def test_small_model_integration_test_llama_batched_regression(self):

output = model.generate(**inputs, max_new_tokens=20)

EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
Copy link
Contributor Author

@fxmarty fxmarty Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test itself was wrong. The input_ids is tensor([[32001, 32001, 32001, 1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 526, 278, 2712, 306, 881, 367, 274, 1300, 2738, 1048, 746, 306, 6493, 445, 2058, 29973, 1724, 881, 306, 6963, 411, 592, 29973, 13, 22933, 9047, 13566, 29901], [ 1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 338, 445, 29973, 13, 22933, 9047, 13566, 29901, 7803, 274, 1446, 19214, 373, 263, 6592, 29991, 13, 11889, 29901, 29871, 32000, 29871, 13, 2855, 445, 29973, 13, 22933, 9047, 13566, 29901]], device='cuda:0'), with 32000 being the image tokens. The first sequences visibly see only the first image, so the output should not be different compared to

def test_small_model_integration_test_llama_batched(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT:",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(prompts, images=[image1, image2], return_tensors="pt", padding=True)
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
, should it @younesbelkada?

https://huggingface.co/llava-hf/llava-1.5-7b-hf/blob/main/config.json#L6


self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

@slow
def test_batched_generation(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test does not pass on main (trash is generated instead)

Comment on lines +343 to +348
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]

final_embedding[batch_indices, indices_to_mask] = 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As later batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) is used, this is necessary

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why - could you explain a little bit more? The previous code didn't modify first_layer_past_key_value

@fxmarty fxmarty requested review from ArthurZucker, amyeroberts and younesbelkada and removed request for younesbelkada and amyeroberts March 18, 2024 10:29
@fxmarty
Copy link
Contributor Author

fxmarty commented Mar 18, 2024

Llava slow tests all pass (apart from tests/models/llava/test_modeling_llava.py::LlavaForConditionalGenerationModelTest::test_cpu_offload which does not pass on main either).

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good to me 👍

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the fix and the deep investigation ! LGTM since the slow tests pass

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, might be a simpler way to fix this

Comment on lines +344 to +347
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]

final_embedding[batch_indices, indices_to_mask] = 0
Copy link
Collaborator

@ArthurZucker ArthurZucker Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
new_token_positions = new_token_positions * (input_ids != self.pad_token_id)

should this not work as well? given that we create the new_token_positions with a cumsum we could also do it before (right after the cumsum)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure which one will end up being faster

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

@fxmarty fxmarty mentioned this pull request Mar 22, 2024
1 task
@fxmarty fxmarty changed the title Correct llava mask Correct llava mask & fix missing setter for vocab_size Mar 22, 2024
@fxmarty
Copy link
Contributor Author

fxmarty commented Mar 22, 2024

I added a setter to fix the regression in #29586 (comment) (RUN_SLOW=1 CUDA_VISIBLE_DEVICES=1 pytest tests/models/llava -s -vvvvv -k "test_small_model_integration_test" failing on main)

I wonder if llava is working on the release version? maybe not

Comment on lines +150 to +153
@vocab_size.setter
def vocab_size(self, value):
self._vocab_size = value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this good @NielsRogge @amyeroberts ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we remove the @property annotator above? That fixed #29789 for me. Isn't property for immutable things (whereas the vocabulary size can change? cc @amyeroberts

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Properties aren't for immutable things (this is the value of a setter). The property enables us to emit a deprecation warning if the property is accessed as an attribute.

Removing the property annotation would mean that previous usage of config.vocab_size would break, and people would have to use config.vocab_size()

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for digging into this and fixing!

Overall the changes look OK, I'm just a bit unclear on some of the reasoning for the logic changes

tests/models/llava/test_modeling_llava.py Outdated Show resolved Hide resolved
tests/models/llava/test_modeling_llava.py Outdated Show resolved Hide resolved
@@ -344,6 +344,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quite don't understand this comment, why do we need to mask out here because of using the past_key_values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit later in the code (specifically here:

batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
), we check past_key_values value to find unattended tokens. This was not correct before because this added step 6. was missing.

Overall, this would be worth a larger refactor that would avoid regenerating full masks at every forward step in the generate. This PR is just a hotfix.

Comment on lines +343 to +348
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]

final_embedding[batch_indices, indices_to_mask] = 0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why - could you explain a little bit more? The previous code didn't modify first_layer_past_key_value

@fxmarty
Copy link
Contributor Author

fxmarty commented Mar 22, 2024

@amyeroberts We need to set final_embedding to zero for

batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

to do its job! Up to now it did not do what was expected.

What I meant in

It is also not clear to me why this custom masking is handled in the forward and not in GenerationMixin cc @gante (we are doing unnecessary operations at each forward to retrieve the correct attention_mask, it could just be an output of the model & updated in GenerationMixin)

is that this (i.e. the attention_mask reconstruction in the decode) could be avoided altogether if the mask was rather handled in generate, which is currently not the case in the implementation

edit: slow tests for llava pass

@fxmarty fxmarty merged commit 13b2370 into main Mar 22, 2024
18 checks passed
@fxmarty fxmarty deleted the fix-llava-mask branch March 22, 2024 11:57
amyeroberts pushed a commit that referenced this pull request Mar 22, 2024
* correct llava mask

* fix vipllava as wlel

* mask out embedding for padding tokens

* add test

* fix style

* add setter

* fix test on suggestion
itazap pushed a commit that referenced this pull request May 14, 2024
* correct llava mask

* fix vipllava as wlel

* mask out embedding for padding tokens

* add test

* fix style

* add setter

* fix test on suggestion
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LLaVa Left Padding Got Weird Results
8 participants