You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was training to run sft based on Mixtral-8x7B-instruct model with tensor parallel size=4 (sequence parallel=True) and LoRA (target modules =[all]).
It reports that the output dims of original module and the corresponding lora adapter module is not matched so they cannot be added together.
Steps/Code to reproduce bug
I used the recommended docker nvcr.io/nvidia/nemo:24.07 and my scripts is as follow:
...
File "/opt/NeMo-Aligner/nemo_aligner/algorithms/supervised.py", line 145, in train_single_step
loss_mean, metrics = self.model.get_loss_and_metrics(batch=batch, forward_only=False)
File "/opt/NeMo-Aligner/nemo_aligner/models/nlp/gpt/gpt_sft_model.py", line 93, in get_loss_and_metrics
losses_reduced = fwd_bwd_function(
File "/opt/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 439, in forward_backward_no_pipelining
output_tensor, num_tokens = forward_step(
File "/opt/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 264, in forward_step
output_tensor, loss_func = forward_step_func(data_iterator, model)
File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 1273, in fwd_output_and_loss_func
output_tensor = model(**forward_args)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/megatron-lm/megatron/core/models/gpt/gpt_model.py", line 191, in forward
hidden_states = self.decoder(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/megatron-lm/megatron/core/transformer/transformer_block.py", line 411, in forward
hidden_states, context = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/megatron-lm/megatron/core/transformer/transformer_layer.py", line 178, in forward
attention_output_with_bias = self.self_attention(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/NeMo/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py", line 202, in forward
query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
File "/opt/NeMo/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py", line 117, in get_query_key_value_tensors
mixed_qkv = mixed_qkv + lora_mixed_qkv
RuntimeError: The size of tensor a (400) must match the size of tensor b (1600) at non-singleton dimension 0
And I try to fix this by modify the following two .py files:
/opt/NeMo/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py
--- parallel_adapters.py 2024-08-28 02:52:03.000000000 +0000+++ /opt/NeMo/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py 2024-08-28 02:52:04.712871536 +0000@@ -291,7 +291,7 @@
if self.norm_position == 'pre':
x = self.layer_norm(x)
- if self._sequence_parallel and not self.input_is_parallel and self.norm_position=='pre':+ if self._sequence_parallel and not self.input_is_parallel:
# for attention_qkv and linear_fc1
# layernorm before lora is impacted by sequence parallel,
# hence seq dim need to be gathered right before lora linear layers
After the modifications, I can run sft with LoRA and tensor and sequence parallel, but I am not sure it runs correctly. Hope you guys can provide elegant solutions for it.
Expected behavior
LoRA can be used with tensor and sequence parallel.
Environment overview (please complete the following information)
Describe the bug
I was training to run sft based on Mixtral-8x7B-instruct model with tensor parallel size=4 (sequence parallel=True) and LoRA (target modules =[all]).
It reports that the output dims of original module and the corresponding lora adapter module is not matched so they cannot be added together.
Steps/Code to reproduce bug
I used the recommended docker nvcr.io/nvidia/nemo:24.07 and my scripts is as follow:
And it runs into error like:
And I try to fix this by modify the following two .py files:
/opt/NeMo/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py
/opt/NeMo/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py
After the modifications, I can run sft with LoRA and tensor and sequence parallel, but I am not sure it runs correctly. Hope you guys can provide elegant solutions for it.
Expected behavior
LoRA can be used with tensor and sequence parallel.
Environment overview (please complete the following information)
nvcr.io/nvidia/nemo:24.07
docker run
Environment details
I used the default environment of the nemo docker
The text was updated successfully, but these errors were encountered: