Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] add PPOTrainer to trl integration ArgillaTrainer #3522

Closed
davidberenstein1957 opened this issue Aug 7, 2023 · 0 comments · Fixed by #3467
Closed

[FEATURE] add PPOTrainer to trl integration ArgillaTrainer #3522

davidberenstein1957 opened this issue Aug 7, 2023 · 0 comments · Fixed by #3467
Assignees
Labels
area: trainer Indicates that an issue or pull request is related to the Argilla Trainer type: enhancement Indicates new feature requests

Comments

@davidberenstein1957
Copy link
Member

davidberenstein1957 commented Aug 7, 2023

Is your feature request related to a problem? Please describe.
We missed the PPOTrainer within our first integration leap and it would be best to add this too.
https://github.com/lvwerra/trl#ppotrainer

image

Describe the solution you'd like

task = TrainingTask.for_reward_modelling(...)
trainer = ArgillaTrainer(
   dataset=fds_dataset,
   task=task,
   framework="trl",
)
trainer.train()
trainer.save("reward_model")

# And then you can use this "reward_model" with PPO
task = TrainingTask.for_proximal_policy_optimization(...)
trainer = ArgillaTrainer(
   dataset=fds_dataset,
   task=task,
   framework="trl",
)
trainer.train(model=reward_model, generation_args)

Describe alternatives you've considered
N.A.

Additional context
N.A.

@davidberenstein1957 davidberenstein1957 added the type: enhancement Indicates new feature requests label Aug 7, 2023
tomaarsen added a commit that referenced this issue Aug 17, 2023
<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

I added support for the PPOTrainer.

```python
from transformers import pipeline
from trl import PPOConfig

task_mapping = TrainingTask.for_proximal_policy_optimization(text=dataset.field_by_name("text"))
trainer = ArgillaTrainer(
    dataset=dataset,
    task=task_mapping,
    framework="trl",
    fetch_records=False
)
# assuming we have an arbitrarily trained textcat model
reward_model = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb")
trainer.update_config(reward_model=sentiment_pipe) # this is always required but if not done it provides a warning
trainer.train(output_dir="my_awesone_model")
```

Closes #3522 

**Type of change**

- [X] New feature (non-breaking change which adds functionality)
- [X] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

- [ ] `tests/integration/client/feedback/training/test_trl.py`

**Checklist**

- [ ] I added relevant documentation
- [X] I followed the style guidelines of this project
- [X] I did a self-review of my code
- [X] I made corresponding changes to the documentation
- [X] My changes generate no new warnings
- [X] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: Tom Aarsen <[email protected]>
Co-authored-by: Tom Aarsen <[email protected]>
@davidberenstein1957 davidberenstein1957 added the area: trainer Indicates that an issue or pull request is related to the Argilla Trainer label Aug 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area: trainer Indicates that an issue or pull request is related to the Argilla Trainer type: enhancement Indicates new feature requests
Projects
None yet
2 participants