Skip to content

Commit

Permalink
Make the TF dummies even smaller (#24071)
Browse files Browse the repository at this point in the history
* Let's see if we can use the smallest possible dummies

* Make GPT-2's dummies a little longer

* Just use (1,2) as the default shape

* Update other dummies in sync

* Correct imports for Keras 2.13

* Shrink the Wav2Vec2 dummies
  • Loading branch information
Rocketknight1 authored Jun 7, 2023
1 parent 092c14c commit 1fc832b
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 23 deletions.
10 changes: 7 additions & 3 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
if parse(tf.__version__).minor >= 13:
from keras import backend as K
from keras.__internal__ import KerasTensor
from keras.engine.base_layer_utils import call_context
from keras.src.engine.base_layer_utils import call_context
elif parse(tf.__version__).minor >= 11:
from keras import backend as K
from keras.engine.base_layer_utils import call_context
Expand Down Expand Up @@ -1125,15 +1125,19 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
sig = self._prune_signature(self.input_signature)
for key, spec in sig.items():
# 2 is the most correct arbitrary size. I will not be taking questions
dummies[key] = tf.ones(shape=[dim if dim is not None else 2 for dim in spec.shape], dtype=spec.dtype)
dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
if spec.shape[0] is None:
# But let's make the batch size 1 to save memory anyway
dummy_shape[0] = 1
dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype)
if key == "token_type_ids":
# Some models have token_type_ids but with a vocab_size of 1
dummies[key] = tf.zeros_like(dummies[key])
if self.config.add_cross_attention and "encoder_hidden_states" in inspect.signature(self.call).parameters:
if "encoder_hidden_states" not in dummies:
if self.main_input_name == "input_ids":
dummies["encoder_hidden_states"] = tf.ones(
shape=(2, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
shape=(1, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
)
else:
raise NotImplementedError(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/funnel/modeling_tf_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ class TFFunnelPreTrainedModel(TFPreTrainedModel):
@property
def dummy_inputs(self):
# Funnel misbehaves with very small inputs, so we override and make them a bit bigger
return {"input_ids": tf.ones((3, 3), dtype=tf.int32)}
return {"input_ids": tf.ones((1, 3), dtype=tf.int32)}


@dataclass
Expand Down
15 changes: 0 additions & 15 deletions src/transformers/models/sam/modeling_tf_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,21 +1147,6 @@ class TFSamPreTrainedModel(TFPreTrainedModel):
base_model_prefix = "sam"
main_input_name = "pixel_values"

@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
# We override the default dummy inputs here because SAM has some really explosive memory usage in the
# attention layers, so we want to pass the smallest possible batches
VISION_DUMMY_INPUTS = tf.random.uniform(
shape=(
1,
self.config.vision_config.num_channels,
self.config.vision_config.image_size,
self.config.vision_config.image_size,
),
dtype=tf.float32,
)
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}


SAM_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,8 +1194,8 @@ def input_signature(self):
@property
def dummy_inputs(self):
return {
"input_values": tf.random.uniform(shape=(1, 16000), dtype=tf.float32),
"attention_mask": tf.ones(shape=(1, 16000), dtype=tf.float32),
"input_values": tf.random.uniform(shape=(1, 500), dtype=tf.float32),
"attention_mask": tf.ones(shape=(1, 500), dtype=tf.float32),
}

def __init__(self, config, *inputs, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/whisper/modeling_tf_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,9 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
return {
self.main_input_name: tf.random.uniform(
[2, self.config.num_mel_bins, self.config.max_source_positions * 2 - 1], dtype=tf.float32
[1, self.config.num_mel_bins, self.config.max_source_positions * 2 - 1], dtype=tf.float32
),
"decoder_input_ids": tf.constant([[2, 3]], dtype=tf.int32),
"decoder_input_ids": tf.constant([[1, 3]], dtype=tf.int32),
}

@property
Expand Down

0 comments on commit 1fc832b

Please sign in to comment.