-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add score scaling/normalization/clipping #560
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this and adding this new nice feature.
I am ok with this PR in principle that it is backward compatible with the existing setup and works in distributed setting as well (from the code of RunningMoments
& get_global_statistics
)
Would love to hear from @lvwerra & @vwxyzjn to hear their thoughts on this
Can you also run the styling checks?
make precommit
Thanks!
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add few lines in the documentation explaining this feature to users ? The details could go in a dedicated section here: https://github.com/lvwerra/trl/blob/main/docs/source/customization.mdx
Also can you share the behaviour of the env/rewards_mean
and env/rewards_std
?
Thanks!
@@ -45,7 +45,6 @@ class ScriptArguments: | |||
default=1, metadata={"help": "the number of gradient accumulation steps"} | |||
) | |||
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"}) | |||
target_kl: Optional[float] = field(default=6, metadata={"help": "kl target for early stopping"}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This field seems to have been removed by mistake?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Younes,
You will find that target_kl
already exists on L57 with a much smaller value.
I dug deeper and found that PPOConfig
has two configs target
and target_kl
, where target
has a default value of 6. So I assume the first duplicate target_kl
config here was meant to be target
. However, target
is NOT used to populate PPOConfig at L64, so I just removed it.
Regards,
Felix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great point, thank you !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is actually a bug from here: 1620da3
we overloaded the target_kl
term - we should rename it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @edbeeching
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lvwerra as much as I love introducing bugs into trl. I think this time it was @younesbelkada , in the Big refactor of examples and documentation (#509). Here
I agree to rename to early_stop_kl
, or something
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this great work, this looks very nice on my side, let's see what others will say !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean PR, thanks! left a few questions :)
examples/scripts/sentiment_tuning.py
Outdated
@@ -56,6 +55,9 @@ class ScriptArguments: | |||
) | |||
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"}) | |||
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"}) | |||
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"}) | |||
use_score_norm: Optional[bool] = field(default=False, metadata={"help": "Use score normalization"}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should clarify that this only works if use_score_scaling
is also True
otherwise it's actually ignored. we change the logic a bit in general
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Is that really true? inside step we only log |
Hi @lvwerra, Could you elaborate on the performance degradation? In for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
query_tensors = batch["input_ids"]
# Get response from gpt2
response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
batch["response"] = tokenizer.batch_decode(response_tensors)
# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
# Run PPO step
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)
It's not obvious to me whether score scaling/normalization/clipping improves or degrades the performance. It's meant to improve training stability but I guess I haven't run the training long enough to observe possible divergences (well Google Colab would crash on me). In general I observe more smooth curves. I do observe better loss curves on the value head/function and assume that that can be attributed to the more stable and smooth reward scores. In Regardless, the configs are optional and are backward compatible. Regards, Felix |
Hi @zfang I was mainly referring to this plot that you shared: |
Hey @zfang thanks for the PR! Sometimes random seeds could impact the results a lot. E.g., #462 (comment) Could you run the experiment for 10 random seeds? |
Actually I just made a change in |
Update: I do observe consistent patterns of difference in |
After some investigations, I have the root cause. Based on the following code snippet for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
query_tensors = batch["input_ids"]
# Get response from gpt2
response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
batch["response"] = tokenizer.batch_decode(response_tensors)
# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
# Run PPO step
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards) We have the dependencies of In other words, because On a high level I think that makes sense: we normalize the sentiment scores so it's less dominant over the KL loss, and thus we observe that with score normalization the model is less eager to optimize for sentiment scores in comparison to the KL loss. This can be adjusted by using a smaller Hopefully that makes sense to you. Regards, Felix |
Ok, makes sense - since it's optional it's not directly a regression and we can merge. |
Summary
Add score (aka reward) scaling/normalization/clipping to improve PPO training stability based on Section 5.3.1 of Secrets of RLHF in Large Language Models Part I: PPO and https://github.com/OpenLMLab/MOSS-RLHF:
Tests
The following is tested on a Google Colab notebook with an Nvidia T4 GPU. My notebook gets disconnected by itself after a few hours while I was getting 1 iteration per minute, so my runs crashed fairly early.
sentiment-tuning.py
Command for baseline:
Command for score scaling/normalization/clipping:
Screenshots of wandb:
multi_adapter_rl_v2.py
Command for baseline:
Command for score scaling/normalization/clipping:
Screenshots of wandb: