Skip to content

Commit

Permalink
Fix rlhf on trl v0.11.0
Browse files Browse the repository at this point in the history
  • Loading branch information
satyaog committed Sep 23, 2024
1 parent dbebc7c commit 6caac29
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions benchmarks/rlhf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from trl import ModelConfig
from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


class PPOv2TrainerIntrumented(PPOv2Trainer):
Expand Down Expand Up @@ -62,7 +62,7 @@ def main():
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
value_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/rlhf/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from datasets import load_dataset
from trl import ModelConfig
from trl.trainer.ppov2_trainer import PPOv2Config
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


if __name__ == "__main__":
Expand All @@ -30,7 +30,7 @@
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

value_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path,
Expand Down

0 comments on commit 6caac29

Please sign in to comment.