-
Notifications
You must be signed in to change notification settings - Fork 26.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expand inputs in processors for VLMs #30962
Conversation
) | ||
|
||
return BatchFeature(data={**text_inputs, **image_inputs}) | ||
|
||
def _get_number_of_features(self, height: int, width: int) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly copied from TGI with minor changes in calculations for unpadding, otherwise it won't work for low resolution images
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Looking forward to see this expanded to other VLMs! Some might be trickier, PaliGemma incorporates causal mask computation in the merge method for instance (thought about that when reading) but it makes sense that most should belong in the processor, not the modeling |
@amyeroberts I did some clean-up after Arthur's comments. Requesting review, should be ready. If this works I will expand the logic to BLIP and PaliGemma in the next weeks What changed:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this - will be great to have some of this logic unified!
Main comment is about how we set the required arguments for processing in the processor
@amyeroberts addressed the comments and added all VLMs to the PR (excluding Idefics, Fuyu and Kosmos as those already have expansion in processing).
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow - a big piece of work!
Overall looks good to me, just a few comments here and there. I'd like to have a second review from @molbap and a run on the slow tests for all the models touched here
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) | ||
inputs_embeds[special_image_mask] = language_model_inputs.flatten() | ||
else: | ||
logger.warning_once( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make sure the official checkpoints have been updated this way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a biiig piece of work, nicely done, tests and all! I left a few comments on some things I didn't understand well + paligemma masking in particular
logger.warning_once( | ||
"Expanding inputs for image tokens in BLIP-2 should be done in processing. " | ||
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. " | ||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this, or another version number as we're in 4.44 dev version
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44." | |
"Using processors without these attributes in the config is deprecated and will throw an error in a later version" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will update accordingly when we get one step away from merging. I think two-three major versions from current one will work :)
Co-authored-by: Pablo Montalvo <[email protected]>
This should be done, addressed the comments. For the failing test, I have no idea how to skip it after deprecating a property from config. |
Alright cool, taking a look soon! For the config option, a quick&dirty solution could be to do something like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Some minor comments remaining but seems good
final_labels = torch.full( | ||
(batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device | ||
) | ||
final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the labels are not defined in the same way - i.e. not nulled where padding tokens are - it'll break BC for existing FT scripts, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, do we not expect users to ignore padding while preparing the labels? We can bring this back for BC but afaik the general rule is that LLMs don't mask out pad tokens in labels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
returned back the masking, and added a warning that users should mask labels themselves
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! During training padding for uneven batches is definitely masked in labels, iiuc
def test_inputs_embeds_matches_input_ids(self): | ||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
|
||
for model_class in self.all_model_classes: | ||
model = model_class(config) | ||
model.to(torch_device) | ||
model.eval() | ||
|
||
inputs = self._prepare_for_class(inputs_dict, model_class) | ||
input_ids = inputs["input_ids"] | ||
del inputs["input_ids"] | ||
del inputs["pixel_values"] | ||
|
||
inputs_embeds = model.get_input_embeddings()(input_ids) | ||
|
||
with torch.no_grad(): | ||
out_ids = model(input_ids=input_ids, **inputs)[0] | ||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] | ||
self.assertTrue(torch.allclose(out_embeds, out_ids)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, same remark, would be worth having in common or an option that checks needed inputs for a given model to do this del
on-demand? (nit)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I tries but seems like some models require to have both, ids and pixels, while other require only one. Will have to think about unifying Vision2Seq model tests somehow, in the scope of another PR
Shouldn't be big problem to repeat the code in tests
del inputs["input_ids"] | ||
del inputs["pixel_values"] | ||
|
||
wte = model.get_input_embeddings() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same remark for wte and the transformers version warning, to modify before merge!
Co-authored-by: Pablo Montalvo <[email protected]>
Co-authored-by: Pablo Montalvo <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great - thanks for handling all of this!
qformer_config=None, | ||
text_config=None, | ||
num_query_tokens=32, | ||
image_token_index=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index or token_id? Index would indicate a specific location, but the logic in 1776 looks like it's matching token_ids
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's token index same as in all llava models
I'll run slow tests and check everything is okey, will merge some time next week |
* let it be * draft * should not have changed * add warnings * fix & add tests * fix tests * ipnuts embeds cannot be passed with pixels * more updates * paligemma ready! * minor typos * update blip-2 * fix tests & raise error * docstring * add blip2 test * tmp * add image seq length to config * update docstring * delete * fix tests * fix blip * fix paligemma * out-of-place scatter * add llava-next-video * Update src/transformers/models/blip_2/modeling_blip_2.py Co-authored-by: Pablo Montalvo <[email protected]> * remove tmp * codestyle * nits * more nits * remove overriding in tests * comprehension when merging video * fix-copies * revert changes for embeds test * fix tests after making comprehension * Update src/transformers/models/blip_2/processing_blip_2.py Co-authored-by: Pablo Montalvo <[email protected]> * Update src/transformers/models/blip_2/processing_blip_2.py Co-authored-by: Pablo Montalvo <[email protected]> * more updates * fix tests --------- Co-authored-by: Pablo Montalvo <[email protected]>
What does this PR do?
Fixes #30809, This PR moves the
_merge_inputs_with_vision_embeds
to the processing logics, and thus making VLMs more versatile in terms of generation strategies. All models were tested locally with different batch sizes and img resolutions, the generation is same as it was before making changes.The main idea is to get sequence length for image features inside the processing files, and expand input ids by repeating special image token. Same is already done for IDEFICS in
transformers
.