Skip to content

Commit

Permalink
Add GQA config to megatron gpt model (#7096)
Browse files Browse the repository at this point in the history
* Add GQA config in gpt config file

Signed-off-by: jasonwan <[email protected]>

* Verify mcore is enabled when using GQA

Signed-off-by: jasonwan <[email protected]>

---------

Signed-off-by: jasonwan <[email protected]>
  • Loading branch information
blahBlahhhJ authored and ericharper committed Jul 25, 2023
1 parent 2b6cbe7 commit 2320d50
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ model:
overlap_p2p_comm: False # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1
batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1
seq_len_interpolation_factor: null # RoPE Interpolation factor for sequence length. This is used to build long-context models with RoPE ex: https://arxiv.org/abs/2306.15595.
num_query_groups: null # Number of query groups for group query attention. If None, normal attention is used.

tokenizer:
library: 'megatron'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ def model_provider_func(self, pre_process, post_process):
rotary_percent=self.cfg.get('rotary_percentage', 1.0),
)
else:
assert (
self.cfg.get('num_query_groups', None) is None
), "Group Query Attention is only supported in Megatron Core. Set 'mcore_gpt' to use GQA."

model = GPTModel(
config=self.model_parallel_config,
vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size),
Expand Down

0 comments on commit 2320d50

Please sign in to comment.