Skip to content

Commit

Permalink
[Blip] Fix blip output name (#24889)
Browse files Browse the repository at this point in the history
* fix blip output name

* add property

* oops

* fix failing test
  • Loading branch information
younesbelkada authored Jul 18, 2023
1 parent a9e067a commit 5c5cb4e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
16 changes: 13 additions & 3 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
""" PyTorch BLIP model."""

import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union

Expand Down Expand Up @@ -74,7 +75,7 @@ class BlipForConditionalGenerationModelOutput(ModelOutput):
Args:
loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
Languge modeling loss from the text decoder.
decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
Prediction scores of the language modeling head of the text decoder model.
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*):
The image embeddings obtained after applying the Vision Transformer model to the input image.
Expand All @@ -94,12 +95,21 @@ class BlipForConditionalGenerationModelOutput(ModelOutput):
"""

loss: Optional[Tuple[torch.FloatTensor]] = None
decoder_logits: Optional[Tuple[torch.FloatTensor]] = None
logits: Optional[Tuple[torch.FloatTensor]] = None
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None

@property
def decoder_logits(self):
warnings.warn(
"`decoder_logits` attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the `logits` attribute to retrieve the final output instead.",
FutureWarning,
)
return self.logits


@dataclass
class BlipTextVisionModelOutput(ModelOutput):
Expand Down Expand Up @@ -1011,7 +1021,7 @@ def forward(

return BlipForConditionalGenerationModelOutput(
loss=outputs.loss,
decoder_logits=outputs.logits,
logits=outputs.logits,
image_embeds=image_embeds,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,
Expand Down
16 changes: 13 additions & 3 deletions src/transformers/models/blip/modeling_tf_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union

Expand Down Expand Up @@ -84,7 +85,7 @@ class TFBlipForConditionalGenerationModelOutput(ModelOutput):
Args:
loss (`tf.Tensor`, *optional*, returned when `labels` is provided, `tf.Tensor` of shape `(1,)`):
Languge modeling loss from the text decoder.
decoder_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
Prediction scores of the language modeling head of the text decoder model.
image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)`, *optional*):
The image embeddings obtained after applying the Vision Transformer model to the input image.
Expand All @@ -104,12 +105,21 @@ class TFBlipForConditionalGenerationModelOutput(ModelOutput):
"""

loss: Tuple[tf.Tensor] | None = None
decoder_logits: Tuple[tf.Tensor] | None = None
logits: Tuple[tf.Tensor] | None = None
image_embeds: tf.Tensor | None = None
last_hidden_state: tf.Tensor = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None

@property
def decoder_logits(self):
warnings.warn(
"`decoder_logits` attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the `logits` attribute to retrieve the final output instead.",
FutureWarning,
)
return self.logits


@dataclass
class TFBlipTextVisionModelOutput(ModelOutput):
Expand Down Expand Up @@ -1078,7 +1088,7 @@ def call(

return TFBlipForConditionalGenerationModelOutput(
loss=outputs.loss,
decoder_logits=outputs.logits,
logits=outputs.logits,
image_embeds=image_embeds,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,
Expand Down

0 comments on commit 5c5cb4e

Please sign in to comment.