-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expand inputs in processors for VLMs #30962
Merged
Merged
Changes from 36 commits
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
050657f
let it be
zucchini-nlp a67087e
draft
zucchini-nlp 1e2b873
should not have changed
zucchini-nlp 70145d4
add warnings
zucchini-nlp 16a6787
Merge remote-tracking branch 'upstream/main' into vlm_processors
zucchini-nlp 8472035
fix & add tests
zucchini-nlp 13af9e8
fix tests
zucchini-nlp 41d086f
ipnuts embeds cannot be passed with pixels
zucchini-nlp bf59ed6
more updates
zucchini-nlp 020e7ed
paligemma ready!
zucchini-nlp 3e0455c
minor typos
zucchini-nlp 674f16e
update blip-2
zucchini-nlp 42ae646
fix tests & raise error
zucchini-nlp b5259f2
Merge branch 'main' into vlm_processors
zucchini-nlp a6c50de
docstring
zucchini-nlp 4766e2e
add blip2 test
zucchini-nlp d46df90
Merge branch 'main' into vlm_processors
zucchini-nlp f74297b
tmp
zucchini-nlp 5fc8565
add image seq length to config
zucchini-nlp 1b4674a
update docstring
zucchini-nlp c3c130b
Merge branch 'main' into vlm_processors
zucchini-nlp 8438875
delete
zucchini-nlp bf9e637
fix tests
zucchini-nlp db1fa4f
fix blip
zucchini-nlp 246b06a
fix paligemma
zucchini-nlp 222bf9a
merge `main`
zucchini-nlp 5486215
out-of-place scatter
zucchini-nlp 78c4484
add llava-next-video
zucchini-nlp d60624e
Update src/transformers/models/blip_2/modeling_blip_2.py
zucchini-nlp 1973b39
remove tmp
zucchini-nlp a6e380f
merge `main`
zucchini-nlp 8e88d8b
codestyle
zucchini-nlp 689eed9
nits
zucchini-nlp 28e8054
more nits
zucchini-nlp 637e514
remove overriding in tests
zucchini-nlp be939d8
comprehension when merging video
zucchini-nlp 232eb7c
fix-copies
zucchini-nlp 385a617
revert changes for embeds test
zucchini-nlp 4831a7e
fix tests after making comprehension
zucchini-nlp 85fbff9
Update src/transformers/models/blip_2/processing_blip_2.py
zucchini-nlp 119178f
Update src/transformers/models/blip_2/processing_blip_2.py
zucchini-nlp 2451911
more updates
zucchini-nlp 414031e
fix tests
zucchini-nlp 8cfad20
Merge remote-tracking branch 'upstream/main' into vlm_processors
zucchini-nlp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1767,12 +1767,25 @@ def forward( | |
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device | ||
) | ||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) | ||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) | ||
|
||
if attention_mask is None: | ||
attention_mask = torch.ones_like(input_ids) | ||
expected_device = language_model_attention_mask.device | ||
attention_mask = torch.cat([language_model_attention_mask, attention_mask.to(expected_device)], dim=1) | ||
|
||
# if the model already has "image_token_index" then the input is expanded to account for image embeds | ||
# otherwise we expand manually by concating | ||
if hasattr(self.config, "image_token_index"): | ||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) | ||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) | ||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) | ||
else: | ||
logger.warning_once( | ||
"Expanding inputs for image tokens in BLIP-2 should be done in processing. " | ||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. " | ||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44." | ||
) | ||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) | ||
attention_mask = torch.cat( | ||
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1 | ||
) | ||
|
||
if self.config.use_decoder_only_language_model: | ||
outputs = self.language_model( | ||
|
@@ -1876,13 +1889,25 @@ def generate( | |
.repeat(batch_size, 1) | ||
.to(image_embeds.device) | ||
) | ||
inputs_embeds = self.get_input_embeddings()(input_ids) | ||
if attention_mask is None: | ||
attention_mask = torch.ones_like(input_ids) | ||
attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1) | ||
|
||
# concatenate query embeddings with prompt embeddings | ||
inputs_embeds = self.get_input_embeddings()(input_ids) | ||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) | ||
# if the model already has "image_token_index" then the input is expanded to account for image embeds | ||
# otherwise we expand manually by concatenating | ||
if hasattr(self.config, "image_token_index"): | ||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) | ||
inputs_embeds[special_image_mask] = language_model_inputs.flatten() | ||
else: | ||
logger.warning_once( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should make sure the official checkpoints have been updated this way |
||
"Expanding inputs for image tokens in BLIP-2 should be done in processing. " | ||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. " | ||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44." | ||
) | ||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) | ||
attention_mask = torch.cat( | ||
[language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1 | ||
) | ||
|
||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds | ||
# -1 is to account for the prepended BOS after `generate.` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this, or another version number as we're in 4.44 dev version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will update accordingly when we get one step away from merging. I think two-three major versions from current one will work :)