Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: added initial version of PPOTrainer support #3549

Conversation

davidberenstein1957
Copy link
Member

Description

I added support for the PPOTrainer.

from transformers import pipeline
from trl import PPOConfig

task_mapping = TrainingTask.for_proximal_policy_optimization(text=dataset.field_by_name("text"))
trainer = ArgillaTrainer(
    dataset=dataset,
    task=task_mapping,
    framework="trl",
    fetch_records=False
)
# assuming we have an arbitrarily trained textcat model
reward_model = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb")
trainer.update_config(reward_model=sentiment_pipe) # this is always required but if not done it provides a warning
trainer.train(output_dir="my_awesone_model")

Closes #3522

Type of change

  • New feature (non-breaking change which adds functionality)
  • Improvement (change adding some improvement to an existing functionality)

How Has This Been Tested

  • tests/integration/client/feedback/training/test_trl.py

Checklist

  • I added relevant documentation
  • I followed the style guidelines of this project
  • I did a self-review of my code
  • I made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)

@tomaarsen
Copy link
Contributor

Test failures also exist in feat/integration_trl & fixed in develop. I'll propagate the fix down all branches, i.e. merge develop into feat/integration_trl and then update this branch as well.

@tomaarsen
Copy link
Contributor

  • docs/_source/guides/llms/practical_guides/fine_tune.md
  • docs/_source/_common/tabs/train_update_config.md

@codecov
Copy link

codecov bot commented Aug 17, 2023

Codecov Report

Patch coverage: 89.43% and project coverage change: +0.04% 🎉

Comparison is base (9eb6e20) 89.96% compared to head (f08ff49) 90.00%.

Additional details and impacted files
@@                   Coverage Diff                    @@
##           feat/integration_trl    #3549      +/-   ##
========================================================
+ Coverage                 89.96%   90.00%   +0.04%     
========================================================
  Files                       256      256              
  Lines                     13777    13865      +88     
========================================================
+ Hits                      12394    12479      +85     
- Misses                     1383     1386       +3     
Files Changed Coverage Δ
src/argilla/client/feedback/__init__.py 100.00% <ø> (ø)
src/argilla/client/feedback/training/__init__.py 100.00% <ø> (ø)
...argilla/client/feedback/training/frameworks/trl.py 92.35% <87.14%> (-4.62%) ⬇️
src/argilla/client/feedback/training/schemas.py 89.22% <92.30%> (-0.26%) ⬇️
src/argilla/client/feedback/dataset/base.py 80.98% <100.00%> (ø)

... and 2 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@davidberenstein1957 davidberenstein1957 marked this pull request as ready for review August 17, 2023 08:54
@tomaarsen
Copy link
Contributor

Merged before documentation is complete, to allow @davidberenstein1957 to extend PPO further in feat/integration_trl.

@tomaarsen tomaarsen merged commit d43785b into feat/integration_trl Aug 17, 2023
17 checks passed
@tomaarsen tomaarsen deleted the feat/3522-feature-add-ppotrainer-to-trl-integration-argillatrainer branch August 17, 2023 08:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants