diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 0ac45643217bcb..c66609e91b50b9 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1287,10 +1287,10 @@ def serving_output(self, output): and getattr(self.config, "add_cross_attention", False) ): output[key] = None - if output[key] is not None: + if isinstance(output[key], (tuple, list)): try: output[key] = tf.convert_to_tensor(output[key]) - except ValueError: + except (ValueError, tf.errors.InvalidArgumentError): pass # Layers may not have the same dimensions return output diff --git a/src/transformers/models/funnel/modeling_tf_funnel.py b/src/transformers/models/funnel/modeling_tf_funnel.py index f8ce0bd53f77fc..44456f3c6c4261 100644 --- a/src/transformers/models/funnel/modeling_tf_funnel.py +++ b/src/transformers/models/funnel/modeling_tf_funnel.py @@ -1129,6 +1129,15 @@ def call( training=training, ) + def serving_output(self, output): + # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of + # different dimensions + return TFBaseModelOutput( + last_hidden_state=output.last_hidden_state, + hidden_states=output.hidden_states, + attentions=output.attentions, + ) + @add_start_docstrings( "The bare Funnel Transformer Model transformer outputting raw hidden-states without any specific head on top.", @@ -1168,6 +1177,15 @@ def call( training=training, ) + def serving_output(self, output): + # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of + # different dimensions + return TFBaseModelOutput( + last_hidden_state=output.last_hidden_state, + hidden_states=output.hidden_states, + attentions=output.attentions, + ) + @add_start_docstrings( """ @@ -1234,6 +1252,13 @@ def call( attentions=discriminator_hidden_states.attentions, ) + def serving_output(self, output): + # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of + # different dimensions + return TFFunnelForPreTrainingOutput( + logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions + ) + @add_start_docstrings("""Funnel Model with a `language modeling` head on top.""", FUNNEL_START_DOCSTRING) class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss): @@ -1301,6 +1326,11 @@ def call( attentions=outputs.attentions, ) + def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput: + # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of + # different dimensions + return TFMaskedLMOutput(logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions) + @add_start_docstrings( """ @@ -1369,6 +1399,13 @@ def call( attentions=outputs.attentions, ) + def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput: + # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of + # different dimensions + return TFSequenceClassifierOutput( + logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions + ) + @add_start_docstrings( """ @@ -1453,6 +1490,13 @@ def call( attentions=outputs.attentions, ) + def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput: + # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of + # different dimensions + return TFMultipleChoiceModelOutput( + logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions + ) + @add_start_docstrings( """ @@ -1523,6 +1567,13 @@ def call( attentions=outputs.attentions, ) + def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput: + # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of + # different dimensions + return TFTokenClassifierOutput( + logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions + ) + @add_start_docstrings( """ @@ -1605,3 +1656,13 @@ def call( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput: + # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of + # different dimensions + return TFQuestionAnsweringModelOutput( + start_logits=output.start_logits, + end_logits=output.end_logits, + hidden_states=output.hidden_states, + attentions=output.attentions, + )