Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add score scaling/normalization/clipping #560

Merged
merged 9 commits into from
Aug 10, 2023
Merged

Conversation

zfang
Copy link
Contributor

@zfang zfang commented Jul 24, 2023

Summary

Add score (aka reward) scaling/normalization/clipping to improve PPO training stability based on Section 5.3.1 of Secrets of RLHF in Large Language Models Part I: PPO and https://github.com/OpenLMLab/MOSS-RLHF:
Screen Shot 2023-07-23 at 10 29 29 PM
Screen Shot 2023-07-24 at 12 18 39 PM

Tests

The following is tested on a Google Colab notebook with an Nvidia T4 GPU. My notebook gets disconnected by itself after a few hours while I was getting 1 iteration per minute, so my runs crashed fairly early.

sentiment-tuning.py

Command for baseline:

python examples/scripts/sentiment_tuning.py --log_with wandb

Command for score scaling/normalization/clipping:

python examples/scripts/sentiment_tuning.py --log_with wandb --use_score_scaling --use_score_norm --score_clip 0.5

Screenshots of wandb:

Screen Shot 2023-07-23 at 10 20 37 PM
Screen Shot 2023-07-23 at 10 20 51 PM
Screen Shot 2023-07-23 at 10 21 56 PM

multi_adapter_rl_v2.py

Command for baseline:

python examples/scripts/multi_adapter_rl_v2.py --model_name ../llama-7b --log_with wandb --use_safetensors

Command for score scaling/normalization/clipping:

python examples/scripts/multi_adapter_rl_v2.py --model_name ../llama-7b --log_with wandb --use_safetensors --use_score_scaling --use_score_norm --score_clip 0.5

Screenshots of wandb:

Screen Shot 2023-07-25 at 9 28 14 PM
Screen Shot 2023-07-25 at 9 28 28 PM
Screen Shot 2023-07-25 at 9 28 47 PM
Screen Shot 2023-07-25 at 9 29 02 PM
Screen Shot 2023-07-25 at 9 29 11 PM

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this and adding this new nice feature.
I am ok with this PR in principle that it is backward compatible with the existing setup and works in distributed setting as well (from the code of RunningMoments & get_global_statistics)
Would love to hear from @lvwerra & @vwxyzjn to hear their thoughts on this
Can you also run the styling checks?

make precommit

Thanks!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 24, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add few lines in the documentation explaining this feature to users ? The details could go in a dedicated section here: https://github.com/lvwerra/trl/blob/main/docs/source/customization.mdx
Also can you share the behaviour of the env/rewards_mean and env/rewards_std ?
Thanks!

@@ -45,7 +45,6 @@ class ScriptArguments:
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
target_kl: Optional[float] = field(default=6, metadata={"help": "kl target for early stopping"})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This field seems to have been removed by mistake?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Younes,

You will find that target_kl already exists on L57 with a much smaller value.

I dug deeper and found that PPOConfig has two configs target and target_kl, where target has a default value of 6. So I assume the first duplicate target_kl config here was meant to be target. However, target is NOT used to populate PPOConfig at L64, so I just removed it.

Regards,

Felix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great point, thank you !

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is actually a bug from here: 1620da3
we overloaded the target_kl term - we should rename it!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lvwerra as much as I love introducing bugs into trl. I think this time it was @younesbelkada , in the Big refactor of examples and documentation (#509). Here

I agree to rename to early_stop_kl, or something

@zfang
Copy link
Contributor Author

zfang commented Jul 24, 2023

Can you also add few lines in the documentation explaining this feature to users ? The details could go in a dedicated section here: https://github.com/lvwerra/trl/blob/main/docs/source/customization.mdx Also can you share the behaviour of the env/rewards_mean and env/rewards_std ? Thanks!

Screen Shot 2023-07-24 at 11 35 11 AM
The rewards here are actually independent of score scaling/normalization/clipping because they are logged independently:

...
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

    # Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

...

    def log_stats(
        self,
        stats: dict,
        batch: dict,
        rewards: List[torch.FloatTensor],
    ):
            ...

            logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
            logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
            logs["env/reward_dist"] = rewards.cpu().numpy()

            ...

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this great work, this looks very nice on my side, let's see what others will say !

Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean PR, thanks! left a few questions :)

@@ -56,6 +55,9 @@ class ScriptArguments:
)
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
use_score_norm: Optional[bool] = field(default=False, metadata={"help": "Use score normalization"})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should clarify that this only works if use_score_scaling is also True otherwise it's actually ignored. we change the logic a bit in general

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@lvwerra
Copy link
Member

lvwerra commented Jul 26, 2023

The rewards here are actually independent of score scaling/normalization/clipping because they are logged independently:

Is that really true? inside step we only log scores which are normalized with this PR. also if it weren't true then it looks like we have a strong performance degradation.

@zfang
Copy link
Contributor Author

zfang commented Jul 26, 2023

The rewards here are actually independent of score scaling/normalization/clipping because they are logged independently:

Is that really true? inside step we only log scores which are normalized with this PR. also if it weren't true then it looks like we have a strong performance degradation.

Hi @lvwerra,

Could you elaborate on the performance degradation?

In sentiment-tuning.py (and similarly multi_adapter_rl_v2.py), we have

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    # Get response from gpt2
    response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
    batch["response"] = tokenizer.batch_decode(response_tensors)

    # Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

    # Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

env/reward_mean and env/reward_std are logged inside of ppo_trainer.log_stats, which is based on the raw rewards from sentiment_pipe. I think the reason we observe different reward curves is randomness. Even in ScriptArguments we provide the seed config, it is only used in PPOConfig but not LengthSampler which can impact the input. I do observe near identical reward curves with multi_adapter_rl_v2.py
Screen Shot 2023-07-26 at 9 15 38 AM

ppo/mean_scores and ppo/std_scores are the per-batch (normalized) score stats that are logged inside of ppo_trainer.step. From the wandb screenshots you can see that the scores have a mean close to zero with a 0.5 std (clipping value).

It's not obvious to me whether score scaling/normalization/clipping improves or degrades the performance. It's meant to improve training stability but I guess I haven't run the training long enough to observe possible divergences (well Google Colab would crash on me). In general I observe more smooth curves.

I do observe better loss curves on the value head/function and assume that that can be attributed to the more stable and smooth reward scores. In sentiment-tuning.py I also observe better KL divergence and thus better non-score rewards assume that this is because normalized score rewards make them less dominant over the non-score rewards. In my opinion this makes it easier to configure the KL coefficient because we know what the range of score rewards are to expect.

Regardless, the configs are optional and are backward compatible.

Regards,

Felix

@lvwerra
Copy link
Member

lvwerra commented Aug 8, 2023

Hi @zfang I was mainly referring to this plot that you shared:
Screenshot 2023-08-08 at 14 05 54
It appears that the rewards are considerably lower and was wondering if that's due to scaling. Curious to hear your thoughts.

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Aug 8, 2023

Hey @zfang thanks for the PR! Sometimes random seeds could impact the results a lot. E.g., #462 (comment) Could you run the experiment for 10 random seeds?

@zfang
Copy link
Contributor Author

zfang commented Aug 8, 2023

Hey @zfang thanks for the PR! Sometimes random seeds could impact the results a lot. E.g., #462 (comment) Could you run the experiment for 10 random seeds?

Actually I just made a change in sentiment_tuning.py to move set_seed before calling build_dataset. Re-running sentiment-tuning.py now.

@zfang
Copy link
Contributor Author

zfang commented Aug 8, 2023

Hi @zfang I was mainly referring to this plot that you shared: Screenshot 2023-08-08 at 14 05 54 It appears that the rewards are considerably lower and was wondering if that's due to scaling. Curious to hear your thoughts.

Update: I do observe consistent patterns of difference in env/reward_std and env/reward_mean with sentiment_tuning.py, but not multi_adapter_rl_v2.py

@zfang
Copy link
Contributor Author

zfang commented Aug 8, 2023

Hi @zfang I was mainly referring to this plot that you shared: Screenshot 2023-08-08 at 14 05 54 It appears that the rewards are considerably lower and was wondering if that's due to scaling. Curious to hear your thoughts.

Hi @lvwerra and @vwxyzjn,

After some investigations, I have the root cause.

Based on the following code snippet

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    # Get response from gpt2
    response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
    batch["response"] = tokenizer.batch_decode(response_tensors)

    # Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

    # Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

We have the dependencies of rewards <- pipe_outputs <- texts <- batch["response"] <- response_tensors <- ppo_trainer. Because ppo_trainer.step updates ppo_trainer differently between baseline and score normalization, we start to get different batch["response"] and thus different rewards.

In other words, because gpt2-imdb generates slightly different responses (perhaps less overly "positive" at the expense of higher KL loss) after PPO with score normalization, we start to also see different sentiment scores from distilbert-imdb.

On a high level I think that makes sense: we normalize the sentiment scores so it's less dominant over the KL loss, and thus we observe that with score normalization the model is less eager to optimize for sentiment scores in comparison to the KL loss. This can be adjusted by using a smaller init_kl_coef.

Screenshot 2023-08-08 at 4 18 44 PM
Screenshot 2023-08-08 at 4 18 54 PM

Hopefully that makes sense to you.

Regards,

Felix

@lvwerra
Copy link
Member

lvwerra commented Aug 10, 2023

Ok, makes sense - since it's optional it's not directly a regression and we can merge.

@lvwerra lvwerra merged commit 3b2c820 into huggingface:main Aug 10, 2023
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants