Skip to content

Commit

Permalink
Fix BLIP and LED
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed May 17, 2023
1 parent c118d0a commit bb3d9e5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/transformers/models/blip/modeling_tf_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,12 @@ def _shift_right(self, input_ids):

return shifted_input_ids

@property
def input_signature(self):
base_sig = super().input_signature
base_sig["decoder_input_ids"] = base_sig["input_ids"]
return base_sig

@unpack_inputs
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBlipTextVisionModelOutput, config_class=BlipVisionConfig)
Expand Down Expand Up @@ -1239,7 +1245,7 @@ def call(
```"""
if labels is None and decoder_input_ids is None:
raise ValueError(
"Either `decoder_input_ids` or `labels` should be passed when calling `forward` with"
"Either `decoder_input_ids` or `labels` should be passed when calling"
" `TFBlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you"
" are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`"
)
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/models/blip/modeling_tf_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,19 @@ def call(
cross_attentions=encoder_outputs.cross_attentions,
)

def serving_output(
self, output: TFBaseModelOutputWithPoolingAndCrossAttentions
) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=hs,
attentions=attns,
)


# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811
class TFBlipTextLMHeadModel(TFBlipTextPreTrainedModel):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/led/modeling_tf_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,7 @@ def input_signature(self):
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
"global_attention_mask": tf.TensorSpec((None, None), tf.int32, name="global_attention_mask"),
}


Expand Down

0 comments on commit bb3d9e5

Please sign in to comment.