Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Allow exporting data for SFT, Reward Modelling (related to RLHF…
…), DPO, rename TrainingTaskMapping (#3467) Resolves #3379, resolves #3377 Hello! ## Pull Request overview * Prepare data for SFT, RM, DPO in TRL. * Rename `TrainingTaskMapping` to `TrainingTask` and `task_mapping` to `task`. # Description ## Prepare data ```python from argilla.feedback import TrainingTask def formatting_func(sample: Dict[str, Any]): ... yield template.format( prompt=sample["prompt"], response=sample["response"], ) task = TrainingTask.for_supervised_fine_tuning(formatting_func=formatting_func) ds = fds_dataset.prepare_for_training(framework="trl", task=task) # -> ds has "text" and "id" columns ``` Compatible with [SFTTrainer](https://huggingface.co/docs/trl/main/en/sft_trainer). ```python task = TrainingTask.for_reward_modelling(chosen_rejected_func=chosen_rejected_func) ds = fds_dataset.prepare_for_training(framework="trl", task=task) # -> ds has "chosen" and "rejected" columns ``` Nearly compatible with [RewardTrainer](https://huggingface.co/docs/trl/main/en/reward_trainer). ```python task = TrainingTask.for_direct_preference_optimization(prompt_chosen_rejected_func=prompt_chosen_rejected_func) ds = fds_dataset.prepare_for_training(framework="trl", task=task) # -> ds has "prompt", "chosen" and "rejected" columns ``` Compatible with [DPOTrainer](https://huggingface.co/docs/trl/main/en/dpo_trainer). ### Details I implement this by calling `dataset.format_as("datasets")` and then passing each sample (a simple dictionary) from this dataset to the function that the user provides. This user provided function can return `None`, one sample, a list of samples, or yield samples. This allows users to export multiple training samples from a single Argilla record, e.g. when there's multiple annotators that provided useful corrections, or if the annotated record justifies 3 "chosen", "rejected" pairs because there's a ranking between 3 texts. ## Rename `TrainingTaskMapping` is now `TrainingTask` - the "mapping" part is just unintuitive to the user. Same for `task_mapping` to `task`. **Note:** If people used `task_mapping=...` before, that will now fail. I can make this deprecation softer, but then I have to make `task` optional, which I would rather not do. ## TODO: - [ ] Add TRL to `ArgillaTrainer`, allowing: ```python task = TrainingTask.for_supervised_fine_tuning( formatting_func=formatting_func ) # or any other task from this PR trainer = ArgillaTrainer( dataset=fds_dataset, task=task, framework="trl", ) trainer.train() ``` - [ ] Consider renaming `FeedbackDataset.prepare_for_training` to `FeedbackDataset.export`. - [ ] New tests - [ ] Add documentation **Type of change** - [x] New feature **How Has This Been Tested** Not finished yet. **Checklist** - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- - Tom Aarsen --------- Co-authored-by: Alvaro Bartolome <[email protected]> Co-authored-by: Alvaro Bartolome <[email protected]> Co-authored-by: David Berenstein <[email protected]> Co-authored-by: Daniel Vila Suero <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information