Skip to content

Commit

Permalink
Remove all the manual overrides for encoder-decoder model signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed May 18, 2023
1 parent 7d412c8 commit 10e742e
Show file tree
Hide file tree
Showing 11 changed files with 7 additions and 86 deletions.
8 changes: 7 additions & 1 deletion src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,13 @@ def input_signature(self) -> Dict[str, tf.TensorSpec]:
text_dims = 3
else:
text_dims = 2
for input_name in ("input_ids", "attention_mask", "token_type_ids", "decoder_input_ids"):
for input_name in (
"input_ids",
"attention_mask",
"token_type_ids",
"decoder_input_ids",
"decoder_attention_mask",
):
if input_name in model_inputs:
sig[input_name] = tf.TensorSpec([None] * text_dims, tf.int32, name=input_name)
if "pixel_values" in model_inputs:
Expand Down
9 changes: 0 additions & 9 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,15 +482,6 @@ class TFBartPretrainedModel(TFPreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"

@property
def input_signature(self):
return {
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}

@property
def dummy_inputs(self):
dummy_inputs = super().dummy_inputs
Expand Down
9 changes: 0 additions & 9 deletions src/transformers/models/blenderbot/modeling_tf_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,15 +461,6 @@ class TFBlenderbotPreTrainedModel(TFPreTrainedModel):
config_class = BlenderbotConfig
base_model_prefix = "model"

@property
def input_signature(self):
return {
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}


BLENDERBOT_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,15 +461,6 @@ class TFBlenderbotSmallPreTrainedModel(TFPreTrainedModel):
config_class = BlenderbotSmallConfig
base_model_prefix = "model"

@property
def input_signature(self):
return {
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}


BLENDERBOT_SMALL_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
Expand Down
6 changes: 0 additions & 6 deletions src/transformers/models/blip/modeling_tf_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,12 +1206,6 @@ def _shift_right(self, input_ids):

return shifted_input_ids

@property
def input_signature(self):
base_sig = super().input_signature
base_sig["decoder_input_ids"] = base_sig["input_ids"]
return base_sig

@unpack_inputs
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBlipTextVisionModelOutput, config_class=BlipVisionConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,6 @@ def __init__(
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
)

@property
def input_signature(self):
return {
"input_ids": tf.TensorSpec([None, None], dtype=tf.int32),
"decoder_input_ids": tf.TensorSpec([None, None], dtype=tf.int32),
}

def get_encoder(self):
return self.encoder

Expand Down
9 changes: 0 additions & 9 deletions src/transformers/models/marian/modeling_tf_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,15 +498,6 @@ class TFMarianPreTrainedModel(TFPreTrainedModel):
config_class = MarianConfig
base_model_prefix = "model"

@property
def input_signature(self):
return {
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}


MARIAN_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
Expand Down
9 changes: 0 additions & 9 deletions src/transformers/models/mbart/modeling_tf_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,15 +465,6 @@ class TFMBartPreTrainedModel(TFPreTrainedModel):
config_class = MBartConfig
base_model_prefix = "model"

@property
def input_signature(self):
return {
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}


MBART_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
Expand Down
9 changes: 0 additions & 9 deletions src/transformers/models/pegasus/modeling_tf_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,15 +500,6 @@ class TFPegasusPreTrainedModel(TFPreTrainedModel):
config_class = PegasusConfig
base_model_prefix = "model"

@property
def input_signature(self):
return {
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}


PEGASUS_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
Expand Down
9 changes: 0 additions & 9 deletions src/transformers/models/t5/modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,15 +862,6 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"decoder\Wblock[\W_0]+layer[\W_1]+EncDecAttention\Wrelative_attention_bias"]

@property
def input_signature(self):
return {
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}

def get_input_embeddings(self):
return self.shared

Expand Down
9 changes: 0 additions & 9 deletions src/transformers/models/xglm/modeling_tf_xglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,15 +617,6 @@ class TFXGLMPreTrainedModel(TFPreTrainedModel):
config_class = XGLMConfig
base_model_prefix = "model"

@property
def input_signature(self):
return {
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}


XGLM_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
Expand Down

0 comments on commit 10e742e

Please sign in to comment.