Skip to content

Commit

Permalink
Update Swin MIM output class (#22893)
Browse files Browse the repository at this point in the history
Updates Swin MIM output class to match other masked image modeling outputs
  • Loading branch information
alaradirik authored Apr 21, 2023
1 parent 1e1cb6f commit 3db2e40
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 13 deletions.
18 changes: 14 additions & 4 deletions src/transformers/models/swin/modeling_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import collections.abc
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -139,7 +140,7 @@ class SwinMaskedImageModelingOutput(ModelOutput):
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Masked image modeling (MLM) loss.
logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed pixel values.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
Expand All @@ -161,11 +162,20 @@ class SwinMaskedImageModelingOutput(ModelOutput):
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
reconstruction: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None

@property
def logits(self):
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
return self.reconstruction


@dataclass
class SwinImageClassifierOutput(ModelOutput):
Expand Down Expand Up @@ -1094,7 +1104,7 @@ def forward(
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape)
[1, 3, 192, 192]
```"""
Expand Down Expand Up @@ -1138,7 +1148,7 @@ def forward(

return SwinMaskedImageModelingOutput(
loss=masked_im_loss,
logits=reconstructed_pixel_values,
reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
Expand Down
20 changes: 15 additions & 5 deletions src/transformers/models/swin/modeling_tf_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import collections.abc
import math
import warnings
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -143,7 +144,7 @@ class TFSwinMaskedImageModelingOutput(ModelOutput):
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Masked image modeling (MLM) loss.
logits (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed pixel values.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
Expand All @@ -165,11 +166,20 @@ class TFSwinMaskedImageModelingOutput(ModelOutput):
"""

loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
reconstruction: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None

@property
def logits(self):
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
return self.reconstruction


@dataclass
class TFSwinImageClassifierOutput(ModelOutput):
Expand Down Expand Up @@ -1340,7 +1350,7 @@ def call(
>>> bool_masked_pos = tf.random.uniform((1, num_patches)) >= 0.5
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape)
[1, 3, 224, 224]
```"""
Expand Down Expand Up @@ -1392,7 +1402,7 @@ def call(

return TFSwinMaskedImageModelingOutput(
loss=masked_im_loss,
logits=reconstructed_pixel_values,
reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
Expand All @@ -1401,7 +1411,7 @@ def call(
def serving_output(self, output: TFSwinMaskedImageModelingOutput) -> TFSwinMaskedImageModelingOutput:
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFSwinMaskedImageModelingOutput(
logits=output.logits,
reconstruction=output.reconstruction,
hidden_states=output.hidden_states,
attentions=output.attentions,
reshaped_hidden_states=output.reshaped_hidden_states,
Expand Down
18 changes: 14 additions & 4 deletions src/transformers/models/swinv2/modeling_swinv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import collections.abc
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -142,7 +143,7 @@ class Swinv2MaskedImageModelingOutput(ModelOutput):
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Masked image modeling (MLM) loss.
logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed pixel values.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
Expand All @@ -164,11 +165,20 @@ class Swinv2MaskedImageModelingOutput(ModelOutput):
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
reconstruction: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None

@property
def logits(self):
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
return self.reconstruction


@dataclass
# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->Swinv2
Expand Down Expand Up @@ -1175,7 +1185,7 @@ def forward(
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape)
[1, 3, 256, 256]
```"""
Expand Down Expand Up @@ -1219,7 +1229,7 @@ def forward(

return Swinv2MaskedImageModelingOutput(
loss=masked_im_loss,
logits=reconstructed_pixel_values,
reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
Expand Down

0 comments on commit 3db2e40

Please sign in to comment.