From b56e8b327733baa81c3ef0d6508f08e1b3e33939 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 23 Jun 2023 09:04:24 -0400 Subject: [PATCH] Improve stabiliy: change default hyperparamers --- examples/sentiment/scripts/gpt2-sentiment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/sentiment/scripts/gpt2-sentiment.py b/examples/sentiment/scripts/gpt2-sentiment.py index d9d6de2b44..52546d29bd 100644 --- a/examples/sentiment/scripts/gpt2-sentiment.py +++ b/examples/sentiment/scripts/gpt2-sentiment.py @@ -60,13 +60,13 @@ class ScriptArguments: model_name: Optional[str] = field(default="lvwerra/gpt2-imdb", metadata={"help": "the model name"}) log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) - mini_batch_size: Optional[int] = field(default=16, metadata={"help": "the PPO minibatch size"}) - batch_size: Optional[int] = field(default=256, metadata={"help": "the batch size"}) + mini_batch_size: Optional[int] = field(default=128, metadata={"help": "the PPO minibatch size"}) + batch_size: Optional[int] = field(default=128, metadata={"help": "the batch size"}) gradient_accumulation_steps: Optional[int] = field( default=1, metadata={"help": "the number of gradient accumulation steps"} ) early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"}) - target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"}) + target_kl: Optional[float] = field(default=6, metadata={"help": "kl target for early stopping"}) parser = HfArgumentParser(ScriptArguments)