Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Fix] fix paddle multipy_fwd_func warning message #7818

Merged
merged 1 commit into from
Jan 11, 2024

Conversation

BeingGod
Copy link
Contributor

PR types

Bug fixes

PR changes

Others

Description

由于主框架在 PaddlePaddle/Paddle#59518 中对于multiply输入类型不一致会输出warning日志。
在amp level=O2情况下:
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_statespaddle.rsqrt(variance + self.variance_epsilon)结果为float32,hidden_states 为fp16/bf16。

报错日志
image

复现脚本:

SCRIPT_HOME=$(cd $(dirname $0); pwd)

CARDS="0,1"

task_name="llama_pp2dp4"
rm -rf "$SCRIPT_HOME/output/$task_name/"
rm -rf "$SCRIPT_HOME/output/${task_name}_log"

TP=2
PP=1
SHARDING_STAGE="stage1"


python -u  -m paddle.distributed.launch \
    --devices=$CARDS \
    --log_dir "output/$task_name""_log" \
    run_pretrain.py \
    --model_name_or_path "__internal_testing__/tiny-random-llama" \
    --tokenizer_name_or_path "__internal_testing__/tiny-random-llama" \
    --input_dir "/workspace/dataset/llama_openwebtext2" \
    --output_dir "output/$task_name" \
    --split 949,50,1 \
    --max_seq_length 1024 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --fuse_attention_qkv 1 \
    --fuse_attention_ffn 1 \
    --use_flash_attention 1 \
    --use_fused_rms_norm 0 \
    --fp16 \
    --fp16_opt_level "O2" \
    --scale_loss 1024 \
    --amp_master_grad 1 \
    --max_grad_norm 1.0 \
    --tensor_parallel_degree $TP \
    --pipeline_parallel_degree $PP \
    --sharding $SHARDING_STAGE \
    --learning_rate 5.0e-5 \
    --min_learning_rate 1.0e-9 \
    --lr_scheduler_type "cosine" \
    --max_steps 6000 \
    --save_steps 6000 \
    --weight_decay 0.01 \
    --adam_beta1 0.9 \
    --adam_beta2 0.95 \
    --warmup_ratio 0.1 \
    --logging_steps 1 \
    --dataloader_num_workers 0 \
    --gradient_accumulation_steps 1 \
    --eval_steps 1000 \
    --report_to "visualdl" \
    --disable_tqdm true \
    --continue_training 0 \
    --recompute 0 \
    --do_train \
    --device "gpu" \
    --overwrite_output_dir True

Copy link

paddle-bot bot commented Jan 10, 2024

Thanks for your contribution!

Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link

codecov bot commented Jan 10, 2024

Codecov Report

Attention: 2 lines in your changes are missing coverage. Please review.

Comparison is base (5c2bf81) 0.00% compared to head (b893463) 57.11%.
Report is 1 commits behind head on develop.

Files Patch % Lines
paddlenlp/transformers/llama/modeling.py 50.00% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##           develop    #7818       +/-   ##
============================================
+ Coverage         0   57.11%   +57.11%     
============================================
  Files            0      587      +587     
  Lines            0    88196    +88196     
============================================
+ Hits             0    50377    +50377     
- Misses           0    37819    +37819     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ZHUI ZHUI requested a review from wawltor January 11, 2024 02:50
@wawltor wawltor merged commit abb0d3c into PaddlePaddle:develop Jan 11, 2024
8 of 9 checks passed
guoshengCS added a commit to guoshengCS/PaddleNLP that referenced this pull request Mar 26, 2024
guoshengCS added a commit to guoshengCS/PaddleNLP that referenced this pull request Mar 26, 2024
guoshengCS added a commit to guoshengCS/PaddleNLP that referenced this pull request Mar 26, 2024
wawltor pushed a commit that referenced this pull request Jun 13, 2024
…ferenceModel (#7953)

* Add Pipeline Parallel for PPO training.

* Move new_ppo_trainer.py to ppo_trainer.py

* Fix padding among batches of accumulation steps in _prepare_pipeline_inputs_func.

* Fix hcg using in TP generation

* Try to support generation in PP. And allow extra training args passed from main from_pratrined.

* Support PP generation.

* Fix PP eval by unify prediction_step

* Fix reward value showing error cased by BF16 dtype when eval

* fix all

* Make non-PipelineParallel models use the same loss layer with PipeModel to unify.

* add offload.

* Use create_loss to unify Pipe and non-Pipe usage.

* Add eval mode and offload level.

* merge

* support tp+pp

* fix data split.

* Fix position_ids in generation/eval/train.

* fix data group.

* add tp rank guard

* Support rollout label data both with target length or source+target length.

* Move metric calculation to rl_step to avoid comm.

* fix pad

* fix create group.

* no print

* Suppport inference model generation.

* fix compatible for no eval model.

* fix pp sync.

* remove debug info

* Refacor PPO training using StepTrainer.

* Open PolicyTrainer loss logging postprocess. More StepTrainer docs.

* more timer.

* fix bugs.

* Add EMA and PPOMetric

* add tests

* add unit test for rank guard.

* Fix reshard zero3 and reshard infer.

* Revert #7818 for llama and remove position_ids for gen/train/eval to align.

* Move reload/clean/data_group to comm_utils and use guard to decorate them.

* Offload sync and other data reuse fix.

* Clead code

* Update README

* Update ppo_trainer

* format code

* Fix make_position_ids by 4d causal mask.

* Fix nested_broadcast_tensor_with_empty import

* Update eval with make_attention_mask

---------

Co-authored-by: Zhong Hui <[email protected]>
Co-authored-by: gongenlei <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants