Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct llava mask & fix missing setter for vocab_size #29389

Merged
merged 8 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quite don't understand this comment, why do we need to mask out here because of using the past_key_values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit later in the code (specifically here:

batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
), we check past_key_values value to find unattended tokens. This was not correct before because this added step 6. was missing.

Overall, this would be worth a larger refactor that would avoid regenerating full masks at every forward step in the generate. This PR is just a hotfix.

batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]

final_embedding[batch_indices, indices_to_mask] = 0
Comment on lines +348 to +351
Copy link
Collaborator

@ArthurZucker ArthurZucker Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
new_token_positions = new_token_positions * (input_ids != self.pad_token_id)

should this not work as well? given that we create the new_token_positions with a cumsum we could also do it before (right after the cumsum)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure which one will end up being faster

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe


Comment on lines +347 to +352
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As later batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) is used, this is necessary

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why - could you explain a little bit more? The previous code didn't modify first_layer_past_key_value

if labels is None:
final_labels = None

Expand Down Expand Up @@ -444,11 +450,11 @@ def forward(
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

# Get the target length
target_seqlen = first_layer_past_key_value.shape[-1] + 1
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]

extended_attention_mask = torch.ones(
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
Expand All @@ -463,7 +469,7 @@ def forward(
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1

outputs = self.language_model(
Expand Down
14 changes: 10 additions & 4 deletions src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]

final_embedding[batch_indices, indices_to_mask] = 0

if labels is None:
final_labels = None

Expand Down Expand Up @@ -443,11 +449,11 @@ def forward(
# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-1) == 0)

# Get the target length
target_seqlen = first_layer_past_key_value.shape[-2] + 1
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]

extended_attention_mask = torch.ones(
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
Expand All @@ -462,7 +468,7 @@ def forward(
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1

outputs = self.language_model(
Expand Down
42 changes: 41 additions & 1 deletion tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,50 @@ def test_small_model_integration_test_llama_batched_regression(self):

output = model.generate(**inputs, max_new_tokens=20)

EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
Copy link
Contributor Author

@fxmarty fxmarty Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test itself was wrong. The input_ids is tensor([[32001, 32001, 32001, 1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 526, 278, 2712, 306, 881, 367, 274, 1300, 2738, 1048, 746, 306, 6493, 445, 2058, 29973, 1724, 881, 306, 6963, 411, 592, 29973, 13, 22933, 9047, 13566, 29901], [ 1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 338, 445, 29973, 13, 22933, 9047, 13566, 29901, 7803, 274, 1446, 19214, 373, 263, 6592, 29991, 13, 11889, 29901, 29871, 32000, 29871, 13, 2855, 445, 29973, 13, 22933, 9047, 13566, 29901]], device='cuda:0'), with 32000 being the image tokens. The first sequences visibly see only the first image, so the output should not be different compared to

def test_small_model_integration_test_llama_batched(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT:",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(prompts, images=[image1, image2], return_tensors="pt", padding=True)
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
, should it @younesbelkada?

https://huggingface.co/llava-hf/llava-1.5-7b-hf/blob/main/config.json#L6


self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

@slow
def test_batched_generation(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test does not pass on main (trash is generated instead)

import requests
import torch
from PIL import Image

from transformers import AutoProcessor, LlavaForConditionalGeneration
fxmarty marked this conversation as resolved.
Show resolved Hide resolved

with torch.device(torch_device):
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
fxmarty marked this conversation as resolved.
Show resolved Hide resolved

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

prompt1 = "<image>\n<image>\nUSER: What's the the difference of two images?\nASSISTANT:"
prompt2 = "<image>\nUSER: Describe the image.\nASSISTANT:"
prompt3 = "<image>\nUSER: Describe the image.\nASSISTANT:"
url1 = "https://images.unsplash.com/photo-1552053831-71594a27632d?q=80&w=3062&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
url2 = "https://images.unsplash.com/photo-1617258683320-61900b281ced?q=80&w=3087&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
image1 = Image.open(requests.get(url1, stream=True).raw)
image2 = Image.open(requests.get(url2, stream=True).raw)

inputs = processor(
text=[prompt1, prompt2, prompt3],
images=[image1, image2, image1, image2],
return_tensors="pt",
padding=True,
).to(torch_device)

model = model.eval()

EXPECTED_OUTPUT = [
"\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog holding a flower in one",
"\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding",
"\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama",
]

generate_ids = model.generate(**inputs, max_new_tokens=20)
outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertEqual(outputs, EXPECTED_OUTPUT)

@slow
@require_bitsandbytes
def test_llava_index_error_bug(self):
Expand Down
Loading