Skip to content

Commit

Permalink
feat: Allow exporting data for SFT, Reward Modelling (related to RLHF…
Browse files Browse the repository at this point in the history
…), DPO, rename TrainingTaskMapping (#3467)

Resolves #3379, resolves #3377

Hello!

## Pull Request overview
* Prepare data for SFT, RM, DPO in TRL.
* Rename `TrainingTaskMapping` to `TrainingTask` and `task_mapping` to
`task`.

# Description
## Prepare data
```python
from argilla.feedback import TrainingTask

def formatting_func(sample: Dict[str, Any]):
    ...
    yield template.format(
        prompt=sample["prompt"],
        response=sample["response"],
    )

task = TrainingTask.for_supervised_fine_tuning(formatting_func=formatting_func)
ds = fds_dataset.prepare_for_training(framework="trl", task=task)
# -> ds has "text" and "id" columns
```
Compatible with
[SFTTrainer](https://huggingface.co/docs/trl/main/en/sft_trainer).

```python
task = TrainingTask.for_reward_modelling(chosen_rejected_func=chosen_rejected_func)
ds = fds_dataset.prepare_for_training(framework="trl", task=task)
# -> ds has "chosen" and "rejected" columns
```
Nearly compatible with
[RewardTrainer](https://huggingface.co/docs/trl/main/en/reward_trainer).

```python
task = TrainingTask.for_direct_preference_optimization(prompt_chosen_rejected_func=prompt_chosen_rejected_func)
ds = fds_dataset.prepare_for_training(framework="trl", task=task)
# -> ds has "prompt", "chosen" and "rejected" columns
```
Compatible with
[DPOTrainer](https://huggingface.co/docs/trl/main/en/dpo_trainer).

### Details
I implement this by calling `dataset.format_as("datasets")` and then
passing each sample (a simple dictionary) from this dataset to the
function that the user provides. This user provided function can return
`None`, one sample, a list of samples, or yield samples. This allows
users to export multiple training samples from a single Argilla record,
e.g. when there's multiple annotators that provided useful corrections,
or if the annotated record justifies 3 "chosen", "rejected" pairs
because there's a ranking between 3 texts.

## Rename
`TrainingTaskMapping` is now `TrainingTask` - the "mapping" part is just
unintuitive to the user. Same for `task_mapping` to `task`. **Note:** If
people used `task_mapping=...` before, that will now fail. I can make
this deprecation softer, but then I have to make `task` optional, which
I would rather not do.

## TODO:
- [ ] Add TRL to `ArgillaTrainer`, allowing:
   ```python
   task = TrainingTask.for_supervised_fine_tuning(
       formatting_func=formatting_func
   ) # or any other task from this PR
   trainer = ArgillaTrainer(
       dataset=fds_dataset,
       task=task,
       framework="trl",
   )
   trainer.train()
   ```
- [ ] Consider renaming `FeedbackDataset.prepare_for_training` to
`FeedbackDataset.export`.
- [ ] New tests
- [ ] Add documentation

**Type of change**

- [x] New feature

**How Has This Been Tested**

Not finished yet.

**Checklist**

- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---

- Tom Aarsen

---------

Co-authored-by: Alvaro Bartolome <[email protected]>
Co-authored-by: Alvaro Bartolome <[email protected]>
Co-authored-by: David Berenstein <[email protected]>
Co-authored-by: Daniel Vila Suero <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
6 people authored Aug 28, 2023
1 parent 605ebb1 commit 76d9a4b
Show file tree
Hide file tree
Showing 33 changed files with 2,593 additions and 1,131 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ These are the section headers that we use:

### Added

- Added `ArgillaTrainer` integration with TRL, allowing for easy supervised finetuning, reward modeling, direct preference optimization and proximal policy optimization ([#3467](https://github.com/argilla-io/argilla/pull/3467))
- Added `formatting_func` to `ArgillaTrainer` for `FeedbackDataset` datasets add a custom formatting for the data ([#3599](https://github.com/argilla-io/argilla/pull/3599)).
- Added `login` function in `argilla.client.login` to login into an Argilla server and store the credentials locally ([#3582](https://github.com/argilla-io/argilla/pull/3582)).
- Added `login` command to login into an Argilla server ([#3600](https://github.com/argilla-io/argilla/pull/3600)).
- Added `logout` command to logout from an Argilla server ([#3605](https://github.com/argilla-io/argilla/pull/3605)).
Expand Down
46 changes: 46 additions & 0 deletions docs/_source/_common/dolly_dataset_info.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
The data required for these steps need to be used as comparison data to showcase the preference for the generated prompts. A good example is our [curated Dolly dataset](https://huggingface.co/datasets/argilla/databricks-dolly-15k-curated-en), where we assumed that updated responses get preference over the older ones. Another good example is the [Anthropic RLHF dataset](https://huggingface.co/datasets/Anthropic/hh-rlhf).

```{note}
The Dolly original dataset contained a lot of reference indicators such as "[1]", which causes the model to hallucinate and incorrectly create references.
```

::::{tab-set}

:::{tab-item} Original

```bash
### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand. [2]
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3]
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]
### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
```
:::
:::{tab-item} Corrected
```bash
### Instruction
When did Virgin Australia start operating?
### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
```

:::

::::
7 changes: 7 additions & 0 deletions docs/_source/_common/dolly_dataset_load.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
We will use our [curated Dolly dataset](https://huggingface.co/datasets/argilla/databricks-dolly-15k-curated-en), as introduced in the background-section above..

```python
import argilla as rg

feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en")
```
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
---
title: Text classification
description: When a RatingQuestion, LabelQuestion or MultiLabelQuestion is present in the datasets, we can define a TrainingTaskMappingForTextClassification to use our ArgillaTrainer integration for fine-tuning with openai”, “setfit”, “peft”, “spacy and transformers.
description: When a RatingQuestion, LabelQuestion or MultiLabelQuestion is present in the datasets, we can define a TrainingTaskForTextClassification to use our ArgillaTrainer integration for fine-tuning with "openai", "setfit", "peft", "spacy" and "transformers".
links:
- linkText: Argilla unification docs
linkLink: https://docs.argilla.io/en/latest/guides/llms/practical_guides/collect_responses.html#solve-disagreements
- linkText: Argilla fine-tuning docs
linkLink: https://docs.argilla.io/en/latest/guides/llms/practical_guides/fine_tune_others.html#text-classification
linkLink: https://docs.argilla.io/en/latest/guides/llms/practical_guides/fine_tune.html#text-classification
- linkText: ArgillaTrainer docs
linkLink: https://docs.argilla.io/en/latest/guides/train_a_model.html#the-argillatrainer
---

```python
import argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTaskMapping
import argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask

dataset = FeedbackDataset.from_argilla(
name="<my_dataset_name>",
workspace="<my_workspace_name>"
)
task_mapping = TrainingTaskMapping.for_text_classification(
task = TrainingTask.for_text_classification(
text=dataset.field_by_name("<my_field>"),
label=dataset.question_by_name("<my_question>")
)
trainer = ArgillaTrainer(
dataset=dataset,
task_mapping=task_mapping,
task=task,
framework="<my_framework>",
)
trainer.update_config()
Expand Down
10 changes: 10 additions & 0 deletions docs/_source/_common/tabs/train_prepare_for_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,14 @@ dataset_rg.prepare_for_training(framework="spark-nlp", train_size=1)
```
:::

:::{tab-item} TRL

```python
import argilla as rg

dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="trl", task=..., train_size=1)
```
:::

::::
32 changes: 32 additions & 0 deletions docs/_source/_common/tabs/train_update_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,36 @@ trainer.update_config(
```
:::

:::{tab-item} TRL

```python
# parameters from `trl.RewardTrainer`, `trl.SFTTrainer`, `trl.PPOTrainer` or `trl.DPOTrainer`.
# `transformers.TrainingArguments`
trainer.update_config(
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
gradient_accumulation_steps = 1,
learning_rate = 5e-5,
weight_decay = 0,
adam_beta1 = 0.9,
adam_beta2 = 0.9,
adam_epsilon = 1e-8,
max_grad_norm = 1,
learning_rate = 5e-5,
num_train_epochs = 3,
max_steps = 0,
log_level = "passive",
logging_strategy = "steps",
save_strategy = "steps",
save_steps = 500,
seed = 42,
push_to_hub = False,
hub_model_id = "user_name/output_dir_name",
hub_strategy = "every_save",
hub_token = "1234",
hub_private_repo = False
)
```
:::

::::
2 changes: 1 addition & 1 deletion docs/_source/getting_started/cheatsheet.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ rg.log(records_for_training, name="majority_voter_results")

## Train Models

We love our open-source training libraries as much as you do, so we provide integrations with all of them to limit the time you spend on data preparation and have more fun with actual training. As of now, we support `spacy`, `transformers`, `setfit`, `openai`, `autotrain`, and way more. Want to get to know all support? Train/fine-tune a model from a `FeedbackDataset` as explained [here](/guides/llms/practical_guides/practical_guides.html), or either a `TextClassificationDataset` or a `TokenClassificationDataset`[here](/guides/train_a_model.md).
We love our open-source training libraries as much as you do, so we provide integrations with all of them to limit the time you spend on data preparation and have more fun with actual training. We support `spacy`, `transformers`, `setfit`, `openai`, `autotrain`, and way more. Want to get to know all support? Train/fine-tune a model from a `FeedbackDataset` as explained [here](/guides/llms/practical_guides/practical_guides/fine_tune.html), or either a `TextClassificationDataset` or a `TokenClassificationDataset`[here](/guides/train_a_model.md).

```python
from argilla.training import ArgillaTrainer
Expand Down
Loading

0 comments on commit 76d9a4b

Please sign in to comment.