Skip to content

Commit

Permalink
NeMo-Mistral to HF converter bugfix. (#8353)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa committed Feb 19, 2024
1 parent df5a395 commit 1da9751
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def convert(in_file, precision=None, cpu_only=True) -> None:
embed_weights_base_name = f'model.language_model.embedding.word_embeddings.weight'
state_dict[hf_embed_weight_name] = param_to_weights(ckpt[embed_weights_base_name])

head_num = model.cfg.num_attention_heads
if nemo_config.num_query_groups is None or nemo_config.num_query_groups == head_num:
num_query_groups = head_num
else:
Expand All @@ -123,7 +124,6 @@ def convert(in_file, precision=None, cpu_only=True) -> None:
assert nemo_config.activation.startswith('fast-'), 'mcore only supports fast version of gated linear unit.'

hidden_size = model.cfg.hidden_size
head_num = model.cfg.num_attention_heads
num_layers = model.cfg.num_layers
num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B

Expand Down Expand Up @@ -191,7 +191,7 @@ def convert(in_file, precision=None, cpu_only=True) -> None:

hf_post_attn_ln_weight_name = f'model.layers.{l}.post_attention_layernorm.weight'
if mcore_gpt:
post_attn_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight'
post_attn_ln_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight'
else:
post_attn_ln_base_name = f'model.language_model.encoder.layers.{l}.post_attention_layernorm.weight'
state_dict[hf_post_attn_ln_weight_name] = param_to_weights(ckpt[post_attn_ln_base_name])
Expand Down

0 comments on commit 1da9751

Please sign in to comment.