Skip to content

Commit

Permalink
Add type hints for several pytorch models (batch-4) (huggingface#25749)
Browse files Browse the repository at this point in the history
* Add type hints for MGP STR model

* Add missing type hints for plbart model

* Add type hints for Pix2struct model

* Add missing type hints to Rag model and tweak the docstring

* Add missing type hints to Sam model

* Add missing type hints to Swin2sr model

* Fix a type hint for Pix2StructTextModel

Co-authored-by: Matt <[email protected]>

* Fix typo on Rag model docstring

Co-authored-by: Matt <[email protected]>

* Fix linter

---------

Co-authored-by: Matt <[email protected]>
  • Loading branch information
2 people authored and parambharat committed Sep 26, 2023
1 parent b15d4d6 commit 6745567
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 36 deletions.
20 changes: 13 additions & 7 deletions src/transformers/models/mgp_str/modeling_mgp_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,13 @@ def get_input_embeddings(self) -> nn.Module:
return self.embeddings.proj

@add_start_docstrings_to_model_forward(MGP_STR_INPUTS_DOCSTRING)
def forward(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
def forward(
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -437,12 +443,12 @@ def __init__(self, config: MgpstrConfig) -> None:
@replace_return_docstrings(output_type=MgpstrModelOutput, config_class=MgpstrConfig)
def forward(
self,
pixel_values,
output_attentions=None,
output_a3_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_a3_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], MgpstrModelOutput]:
r"""
output_a3_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of a3 modules. See `a3_attentions` under returned tensors
Expand Down
28 changes: 14 additions & 14 deletions src/transformers/models/pix2struct/modeling_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,21 +1387,21 @@ def set_output_embeddings(self, new_embeddings):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
) -> Union[Tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]:
r"""
Returns:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/plbart/modeling_plbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,7 @@ def forward(
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
Expand Down Expand Up @@ -1302,7 +1302,7 @@ def forward(
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand Down
18 changes: 7 additions & 11 deletions src/transformers/models/rag/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,16 +462,12 @@ def from_pretrained_question_encoder_generator(
`question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
retriever.
If the model has is not initialized with a `retriever` ``context_input_ids` has to be provided to the
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. context_attention_mask
(`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*,
returned when *output_retrieved=True*): Attention mask post-processed from the retrieved documents and the
question encoder `input_ids` by the retriever.
If the model has is not initialized with a `retriever` `context_attention_mask` has to be provided to the
forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
use_cache (`bool`, *optional*, defaults to `True`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Expand Down Expand Up @@ -545,7 +541,7 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
doc_scores: Optional[torch.FloatTensor] = None,
context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask=None,
context_attention_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/sam/modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,7 @@ def forward(
target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict=None,
return_dict: Optional[bool] = None,
**kwargs,
) -> List[Dict[str, torch.Tensor]]:
r"""
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/swin2sr/modeling_swin2sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ def pad_and_normalize(self, pixel_values):
)
def forward(
self,
pixel_values,
pixel_values: torch.FloatTensor,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
Expand Down

0 comments on commit 6745567

Please sign in to comment.