From 67cea75e59a3441ca63664af223099d02a2bdaab Mon Sep 17 00:00:00 2001 From: alaradirik Date: Thu, 20 Apr 2023 16:07:52 +0100 Subject: [PATCH] update Swin MIM output class --- src/transformers/models/swin/modeling_swin.py | 18 +++++++++++++---- .../models/swin/modeling_tf_swin.py | 20 ++++++++++++++----- .../models/swinv2/modeling_swinv2.py | 18 +++++++++++++---- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 520c215d4634bd..2f7cfeb1adbde9 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -17,6 +17,7 @@ import collections.abc import math +import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -139,7 +140,7 @@ class SwinMaskedImageModelingOutput(ModelOutput): Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): Masked image modeling (MLM) loss. - logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Reconstructed pixel values. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of @@ -161,11 +162,20 @@ class SwinMaskedImageModelingOutput(ModelOutput): """ loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None + reconstruction: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction + @dataclass class SwinImageClassifierOutput(ModelOutput): @@ -1094,7 +1104,7 @@ def forward( >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) - >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction >>> list(reconstructed_pixel_values.shape) [1, 3, 192, 192] ```""" @@ -1138,7 +1148,7 @@ def forward( return SwinMaskedImageModelingOutput( loss=masked_im_loss, - logits=reconstructed_pixel_values, + reconstruction=reconstructed_pixel_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, reshaped_hidden_states=outputs.reshaped_hidden_states, diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index e2fdef3813297a..61352843c2f248 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -17,6 +17,7 @@ import collections.abc import math +import warnings from dataclasses import dataclass from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -143,7 +144,7 @@ class TFSwinMaskedImageModelingOutput(ModelOutput): Args: loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): Masked image modeling (MLM) loss. - logits (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): Reconstructed pixel values. hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape @@ -165,11 +166,20 @@ class TFSwinMaskedImageModelingOutput(ModelOutput): """ loss: Optional[tf.Tensor] = None - logits: tf.Tensor = None + reconstruction: tf.Tensor = None hidden_states: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction + @dataclass class TFSwinImageClassifierOutput(ModelOutput): @@ -1340,7 +1350,7 @@ def call( >>> bool_masked_pos = tf.random.uniform((1, num_patches)) >= 0.5 >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) - >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction >>> list(reconstructed_pixel_values.shape) [1, 3, 224, 224] ```""" @@ -1392,7 +1402,7 @@ def call( return TFSwinMaskedImageModelingOutput( loss=masked_im_loss, - logits=reconstructed_pixel_values, + reconstruction=reconstructed_pixel_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, reshaped_hidden_states=outputs.reshaped_hidden_states, @@ -1401,7 +1411,7 @@ def call( def serving_output(self, output: TFSwinMaskedImageModelingOutput) -> TFSwinMaskedImageModelingOutput: # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions return TFSwinMaskedImageModelingOutput( - logits=output.logits, + reconstruction=output.reconstruction, hidden_states=output.hidden_states, attentions=output.attentions, reshaped_hidden_states=output.reshaped_hidden_states, diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index eff3c8d7946ffb..97b460479d6d5d 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -17,6 +17,7 @@ import collections.abc import math +import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -142,7 +143,7 @@ class Swinv2MaskedImageModelingOutput(ModelOutput): Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): Masked image modeling (MLM) loss. - logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Reconstructed pixel values. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of @@ -164,11 +165,20 @@ class Swinv2MaskedImageModelingOutput(ModelOutput): """ loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None + reconstruction: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction + @dataclass # Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->Swinv2 @@ -1175,7 +1185,7 @@ def forward( >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) - >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction >>> list(reconstructed_pixel_values.shape) [1, 3, 256, 256] ```""" @@ -1219,7 +1229,7 @@ def forward( return Swinv2MaskedImageModelingOutput( loss=masked_im_loss, - logits=reconstructed_pixel_values, + reconstruction=reconstructed_pixel_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, reshaped_hidden_states=outputs.reshaped_hidden_states,