diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 8a9e3d1c..3cf78029 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -243,12 +243,17 @@ def get_padding_shape_from_config(config, batch_size, seq_len): elif hasattr(config, "n_heads"): # Check for n_heads and d_model in the config (MPT Model) n_heads = config.n_heads d_head = config.d_model // config.n_heads - elif hasattr(config, "multi_query"): # Check for Falcon - multi_query_value = getattr(config, "multi_query") - if multi_query_value: - n_heads = 1 # MQA - else: + elif hasattr(config, "new_decoder_architecture"): # Check for Falcon + new_decoder_architecture = getattr(config, "new_decoder_architecture") + if new_decoder_architecture: # multi_query is ignored when new_decoder_architecture is True n_heads = config.num_attention_heads + else: + if hasattr(config, "multi_query"): + multi_query_value = getattr(config, "multi_query") + if multi_query_value: + n_heads = 1 # MQA , multi query is true + else: + n_heads = config.num_attention_heads d_head = config.hidden_size // config.num_attention_heads else: raise ValueError("Invalid model configuration: n_head/d_heads or num_key_value_heads not found.")