Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand inputs in processors for VLMs #30962

Merged
merged 44 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
050657f
let it be
zucchini-nlp May 20, 2024
a67087e
draft
zucchini-nlp May 22, 2024
1e2b873
should not have changed
zucchini-nlp May 22, 2024
70145d4
add warnings
zucchini-nlp May 29, 2024
16a6787
Merge remote-tracking branch 'upstream/main' into vlm_processors
zucchini-nlp May 29, 2024
8472035
fix & add tests
zucchini-nlp May 29, 2024
13af9e8
fix tests
zucchini-nlp May 29, 2024
41d086f
ipnuts embeds cannot be passed with pixels
zucchini-nlp May 29, 2024
bf59ed6
more updates
zucchini-nlp Jun 7, 2024
020e7ed
paligemma ready!
zucchini-nlp Jun 10, 2024
3e0455c
minor typos
zucchini-nlp Jun 10, 2024
674f16e
update blip-2
zucchini-nlp Jun 10, 2024
42ae646
fix tests & raise error
zucchini-nlp Jun 10, 2024
b5259f2
Merge branch 'main' into vlm_processors
zucchini-nlp Jun 10, 2024
a6c50de
docstring
zucchini-nlp Jun 10, 2024
4766e2e
add blip2 test
zucchini-nlp Jun 10, 2024
d46df90
Merge branch 'main' into vlm_processors
zucchini-nlp Jun 10, 2024
f74297b
tmp
zucchini-nlp Jun 17, 2024
5fc8565
add image seq length to config
zucchini-nlp Jun 18, 2024
1b4674a
update docstring
zucchini-nlp Jun 18, 2024
c3c130b
Merge branch 'main' into vlm_processors
zucchini-nlp Jun 18, 2024
8438875
delete
zucchini-nlp Jun 18, 2024
bf9e637
fix tests
zucchini-nlp Jun 18, 2024
db1fa4f
fix blip
zucchini-nlp Jun 18, 2024
246b06a
fix paligemma
zucchini-nlp Jun 21, 2024
222bf9a
merge `main`
zucchini-nlp Jul 18, 2024
5486215
out-of-place scatter
zucchini-nlp Jul 18, 2024
78c4484
add llava-next-video
zucchini-nlp Jul 18, 2024
d60624e
Update src/transformers/models/blip_2/modeling_blip_2.py
zucchini-nlp Aug 5, 2024
1973b39
remove tmp
zucchini-nlp Aug 5, 2024
a6e380f
merge `main`
zucchini-nlp Aug 5, 2024
8e88d8b
codestyle
zucchini-nlp Aug 5, 2024
689eed9
nits
zucchini-nlp Aug 6, 2024
28e8054
more nits
zucchini-nlp Aug 6, 2024
637e514
remove overriding in tests
zucchini-nlp Aug 6, 2024
be939d8
comprehension when merging video
zucchini-nlp Aug 6, 2024
232eb7c
fix-copies
zucchini-nlp Aug 6, 2024
385a617
revert changes for embeds test
zucchini-nlp Aug 6, 2024
4831a7e
fix tests after making comprehension
zucchini-nlp Aug 6, 2024
85fbff9
Update src/transformers/models/blip_2/processing_blip_2.py
zucchini-nlp Aug 8, 2024
119178f
Update src/transformers/models/blip_2/processing_blip_2.py
zucchini-nlp Aug 8, 2024
2451911
more updates
zucchini-nlp Aug 8, 2024
414031e
fix tests
zucchini-nlp Aug 8, 2024
8cfad20
Merge remote-tracking branch 'upstream/main' into vlm_processors
zucchini-nlp Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1767,12 +1767,25 @@ def forward(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)

if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
expected_device = language_model_attention_mask.device
attention_mask = torch.cat([language_model_attention_mask, attention_mask.to(expected_device)], dim=1)

# if the model already has "image_token_index" then the input is expanded to account for image embeds
# otherwise we expand manually by concating
if hasattr(self.config, "image_token_index"):
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else:
logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this, or another version number as we're in 4.44 dev version

Suggested change
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44."
"Using processors without these attributes in the config is deprecated and will throw an error in a later version"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will update accordingly when we get one step away from merging. I think two-three major versions from current one will work :)

)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
attention_mask = torch.cat(
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
)

if self.config.use_decoder_only_language_model:
outputs = self.language_model(
Expand Down Expand Up @@ -1876,13 +1889,25 @@ def generate(
.repeat(batch_size, 1)
.to(image_embeds.device)
)
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)

# concatenate query embeddings with prompt embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
# if the model already has "image_token_index" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating
if hasattr(self.config, "image_token_index"):
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
else:
logger.warning_once(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make sure the official checkpoints have been updated this way

"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44."
)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
attention_mask = torch.cat(
[language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
)

# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
# -1 is to account for the prepended BOS after `generate.`
Expand Down
58 changes: 50 additions & 8 deletions src/transformers/models/blip_2/processing_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,18 @@

from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from ...tokenization_utils_base import (
AddedToken,
BatchEncoding,
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from ...utils import TensorType, logging


logger = logging.get_logger(__name__)


class Blip2Processor(ProcessorMixin):
Expand All @@ -36,20 +46,25 @@ class Blip2Processor(ProcessorMixin):
An instance of [`BlipImageProcessor`]. The image processor is a required input.
tokenizer (`AutoTokenizer`):
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
num_query_tokens (`int`, *optional*):
MNumber of tokens used by the Qformer as queries, should be same as in model's config.
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
"""

attributes = ["image_processor", "tokenizer"]
valid_kwargs = []
image_processor_class = "BlipImageProcessor"
tokenizer_class = "AutoTokenizer"

# Copied from transformers.models.blip.processing_blip.BlipProcessor.__init__
def __init__(self, image_processor, tokenizer, **kwargs):
def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
tokenizer.return_token_type_ids = False
self.current_processor = image_processor
self.image_token = AddedToken("<image>", normalized=False, special=True)
tokens_to_add = {"additional_special_tokens": [self.image_token]}
tokenizer.add_special_tokens(tokens_to_add)
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
self.num_query_tokens = num_query_tokens

super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor

# Copied from transformers.models.blip.processing_blip.BlipProcessor.__call__
def __call__(
self,
images: ImageInput = None,
Expand Down Expand Up @@ -106,7 +121,13 @@ def __call__(
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)

if text is not None:
text_encoding = self.tokenizer(
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")

text_encoding = {}
_text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
Expand All @@ -121,9 +142,30 @@ def __call__(
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
return_tensors=None, # hardcode "None" here for prepending image tokens
**kwargs,
)

# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
# because BLIP expects image tokens to be at the beginning even before BOS token
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
if self.num_query_tokens is not None:
image_tokens = self.image_token.content * self.num_query_tokens
image_token_encoding = self.tokenizer([image_tokens], add_special_tokens=False, return_tensors=None)
for k in _text_encoding:
text_encoding[k] = [
img_encoding + txt_encoding
for img_encoding, txt_encoding in zip(image_token_encoding[k], _text_encoding[k])
]
else:
text_encoding = _text_encoding
logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44."
)

# cast to desired return tensors type
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
else:
text_encoding = None

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/llava/configuration_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class LlavaConfig(PretrainedConfig):
Can be one of `"default"` or `"full"`.
vision_feature_layer (`int`, *optional*, defaults to -2):
The index of the layer to select the vision feature.
image_seq_length (`int`, *optional*, defaults to 576):
Sequence length of one image embedding.

Example:

Expand Down Expand Up @@ -82,11 +84,13 @@ def __init__(
projector_hidden_act="gelu",
vision_feature_select_strategy="default",
vision_feature_layer=-2,
image_seq_length=576,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.image_seq_length = image_seq_length

if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(
Expand Down
Loading
Loading