Skip to content

Commit

Permalink
Fix arg in bettertransformer llama attention (#1421)
Browse files Browse the repository at this point in the history
* fix arg in llama attention

* change to kwargs

* add kwargs everwhere

---------

Co-authored-by: younesbelkada <[email protected]>
  • Loading branch information
SunMarc and younesbelkada authored Oct 3, 2023
1 parent 8a0c11d commit dbe70f9
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def opt_forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
raise_on_head_mask(layer_head_mask)

Expand Down Expand Up @@ -336,6 +337,7 @@ def t5_forward(
query_length=None,
use_cache=False,
output_attentions=False,
**kwargs,
):
raise_on_head_mask(layer_head_mask)

Expand Down Expand Up @@ -466,6 +468,7 @@ def bart_forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
raise_on_head_mask(layer_head_mask)
Expand Down Expand Up @@ -583,6 +586,7 @@ def llama_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")
Expand Down Expand Up @@ -768,6 +772,7 @@ def gpt_bigcode_forward(
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")
Expand Down Expand Up @@ -826,6 +831,7 @@ def bloom_forward(
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
raise_on_head_mask(head_mask)

Expand Down Expand Up @@ -910,6 +916,7 @@ def falcon_forward(
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
Expand Down

0 comments on commit dbe70f9

Please sign in to comment.