Skip to content

Commit

Permalink
Update number of heads for MQA in Falcon (quic#124)
Browse files Browse the repository at this point in the history
* Update number of heads for MQA

Signed-off-by: Mamta Singh <[email protected]>

* Update number of heads for MQA

Signed-off-by: Mamta Singh <[email protected]>

* Update number of heads for MQA

Signed-off-by: Mamta Singh <[email protected]>

---------

Signed-off-by: Mamta Singh <[email protected]>
  • Loading branch information
quic-mamta authored Sep 24, 2024
1 parent cfb7823 commit 64bb87a
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,12 +243,17 @@ def get_padding_shape_from_config(config, batch_size, seq_len):
elif hasattr(config, "n_heads"): # Check for n_heads and d_model in the config (MPT Model)
n_heads = config.n_heads
d_head = config.d_model // config.n_heads
elif hasattr(config, "multi_query"): # Check for Falcon
multi_query_value = getattr(config, "multi_query")
if multi_query_value:
n_heads = 1 # MQA
else:
elif hasattr(config, "new_decoder_architecture"): # Check for Falcon
new_decoder_architecture = getattr(config, "new_decoder_architecture")
if new_decoder_architecture: # multi_query is ignored when new_decoder_architecture is True
n_heads = config.num_attention_heads
else:
if hasattr(config, "multi_query"):
multi_query_value = getattr(config, "multi_query")
if multi_query_value:
n_heads = 1 # MQA , multi query is true
else:
n_heads = config.num_attention_heads
d_head = config.hidden_size // config.num_attention_heads
else:
raise ValueError("Invalid model configuration: n_head/d_heads or num_key_value_heads not found.")
Expand Down

0 comments on commit 64bb87a

Please sign in to comment.