Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand inputs in processors for VLMs #30962

Merged
merged 44 commits into from
Aug 13, 2024

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented May 22, 2024

What does this PR do?

Fixes #30809, This PR moves the _merge_inputs_with_vision_embeds to the processing logics, and thus making VLMs more versatile in terms of generation strategies. All models were tested locally with different batch sizes and img resolutions, the generation is same as it was before making changes.

The main idea is to get sequence length for image features inside the processing files, and expand input ids by repeating special image token. Same is already done for IDEFICS in transformers.

)

return BatchFeature(data={**text_inputs, **image_inputs})

def _get_number_of_features(self, height: int, width: int) -> int:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly copied from TGI with minor changes in calculations for unpadding, otherwise it won't work for low resolution images

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@molbap
Copy link
Contributor

molbap commented May 24, 2024

Looking forward to see this expanded to other VLMs! Some might be trickier, PaliGemma incorporates causal mask computation in the merge method for instance (thought about that when reading) but it makes sense that most should belong in the processor, not the modeling

@zucchini-nlp
Copy link
Member Author

@amyeroberts I did some clean-up after Arthur's comments. Requesting review, should be ready. If this works I will expand the logic to BLIP and PaliGemma in the next weeks

What changed:

  • Model can generate from both: old inputs and new-expanded inputs. If it's old inputs, warning is raised, asking to upgrade the processor config.
  • Processor also can return both types. If it has all the necessary parts to calculate image embedding length, the inputs will be expanded. Otherwise, warning is raised and old behavior retained.
  • Old behavior is planned to be totally removed in v4.44 (or better v4.43?)
  • Added tests to check that old vs new inputs generation is identical
  • To actually have llava-based models work in new style, I'll later update all hf-llava configs in the hub. Other models in the hub will continue to work with old behavior

@zucchini-nlp zucchini-nlp marked this pull request as ready for review May 29, 2024 13:34
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this - will be great to have some of this logic unified!

Main comment is about how we set the required arguments for processing in the processor

src/transformers/models/llava/processing_llava.py Outdated Show resolved Hide resolved
src/transformers/models/llava/processing_llava.py Outdated Show resolved Hide resolved
src/transformers/models/llava/processing_llava.py Outdated Show resolved Hide resolved
tests/models/llava/test_modeling_llava.py Outdated Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

zucchini-nlp commented Jun 10, 2024

@amyeroberts addressed the comments and added all VLMs to the PR (excluding Idefics, Fuyu and Kosmos as those already have expansion in processing).

  • warning text is more clear and it's easy for users to add new attributes to Processor class (with processor.patch_size = patch_size)
  • BLIP-2 needed more modifications as it didn't have special image token, lmk if the way I did works
  • Paligemma worked out of the box but needed changes for causal mask. There's also smth weird with "position_ids" which will be fixed by @molbap
  • All models have their "old-new format equivalence" tests and are passing locally. I don't know how to make happy the failing doctest, it's red even after I deprecated the unused attribute

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow - a big piece of work!

Overall looks good to me, just a few comments here and there. I'd like to have a second review from @molbap and a run on the slow tests for all the models touched here

src/transformers/models/vipllava/modeling_vipllava.py Outdated Show resolved Hide resolved
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
else:
logger.warning_once(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make sure the official checkpoints have been updated this way

@molbap molbap self-requested a review July 31, 2024 15:49
Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a biiig piece of work, nicely done, tests and all! I left a few comments on some things I didn't understand well + paligemma masking in particular

logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this, or another version number as we're in 4.44 dev version

Suggested change
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44."
"Using processors without these attributes in the config is deprecated and will throw an error in a later version"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will update accordingly when we get one step away from merging. I think two-three major versions from current one will work :)

src/transformers/models/blip_2/modeling_blip_2.py Outdated Show resolved Hide resolved
src/transformers/models/blip_2/processing_blip_2.py Outdated Show resolved Hide resolved
src/transformers/models/llava/processing_llava.py Outdated Show resolved Hide resolved
temp.py Outdated Show resolved Hide resolved
tests/models/blip_2/test_modeling_blip_2.py Outdated Show resolved Hide resolved
tests/models/llava/test_modeling_llava.py Outdated Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

This should be done, addressed the comments. For the failing test, I have no idea how to skip it after deprecating a property from config.

@molbap
Copy link
Contributor

molbap commented Aug 6, 2024

Alright cool, taking a look soon! For the config option, a quick&dirty solution could be to do something like _ = config.ignore_index in the modeling?

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Some minor comments remaining but seems good

src/transformers/models/blip_2/processing_blip_2.py Outdated Show resolved Hide resolved
src/transformers/models/blip_2/processing_blip_2.py Outdated Show resolved Hide resolved
Comment on lines -347 to -350
final_labels = torch.full(
(batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the labels are not defined in the same way - i.e. not nulled where padding tokens are - it'll break BC for existing FT scripts, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, do we not expect users to ignore padding while preparing the labels? We can bring this back for BC but afaik the general rule is that LLMs don't mask out pad tokens in labels

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returned back the masking, and added a warning that users should mask labels themselves

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! During training padding for uneven batches is definitely masked in labels, iiuc

Comment on lines +212 to +230
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()

inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]

inputs_embeds = model.get_input_embeddings()(input_ids)

with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, same remark, would be worth having in common or an option that checks needed inputs for a given model to do this del on-demand? (nit)

Copy link
Member Author

@zucchini-nlp zucchini-nlp Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I tries but seems like some models require to have both, ids and pixels, while other require only one. Will have to think about unifying Vision2Seq model tests somehow, in the scope of another PR

Shouldn't be big problem to repeat the code in tests

del inputs["input_ids"]
del inputs["pixel_values"]

wte = model.get_input_embeddings()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same remark for wte and the transformers version warning, to modify before merge!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - thanks for handling all of this!

qformer_config=None,
text_config=None,
num_query_tokens=32,
image_token_index=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index or token_id? Index would indicate a specific location, but the logic in 1776 looks like it's matching token_ids

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's token index same as in all llava models

@zucchini-nlp
Copy link
Member Author

I'll run slow tests and check everything is okey, will merge some time next week

@zucchini-nlp zucchini-nlp merged commit a29eabd into huggingface:main Aug 13, 2024
21 checks passed
alaskar-10r pushed a commit to alaskar-10r/transformers that referenced this pull request Aug 13, 2024
* let it be

* draft

* should not have changed

* add warnings

* fix & add tests

* fix tests

* ipnuts embeds cannot be passed with pixels

* more updates

* paligemma ready!

* minor typos

* update blip-2

* fix tests & raise error

* docstring

* add blip2 test

* tmp

* add image seq length to config

* update docstring

* delete

* fix tests

* fix blip

* fix paligemma

* out-of-place scatter

* add llava-next-video

* Update src/transformers/models/blip_2/modeling_blip_2.py

Co-authored-by: Pablo Montalvo <[email protected]>

* remove tmp

* codestyle

* nits

* more nits

* remove overriding in tests

* comprehension when merging video

* fix-copies

* revert changes for embeds test

* fix tests after making comprehension

* Update src/transformers/models/blip_2/processing_blip_2.py

Co-authored-by: Pablo Montalvo <[email protected]>

* Update src/transformers/models/blip_2/processing_blip_2.py

Co-authored-by: Pablo Montalvo <[email protected]>

* more updates

* fix tests

---------

Co-authored-by: Pablo Montalvo <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-slow WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
4 participants