From 76d9a4bf6cb4fbb9253a98ee58665cbddadce7cf Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Mon, 28 Aug 2023 17:32:35 +0200 Subject: [PATCH] feat: Allow exporting data for SFT, Reward Modelling (related to RLHF), DPO, rename TrainingTaskMapping (#3467) Resolves #3379, resolves #3377 Hello! ## Pull Request overview * Prepare data for SFT, RM, DPO in TRL. * Rename `TrainingTaskMapping` to `TrainingTask` and `task_mapping` to `task`. # Description ## Prepare data ```python from argilla.feedback import TrainingTask def formatting_func(sample: Dict[str, Any]): ... yield template.format( prompt=sample["prompt"], response=sample["response"], ) task = TrainingTask.for_supervised_fine_tuning(formatting_func=formatting_func) ds = fds_dataset.prepare_for_training(framework="trl", task=task) # -> ds has "text" and "id" columns ``` Compatible with [SFTTrainer](https://huggingface.co/docs/trl/main/en/sft_trainer). ```python task = TrainingTask.for_reward_modelling(chosen_rejected_func=chosen_rejected_func) ds = fds_dataset.prepare_for_training(framework="trl", task=task) # -> ds has "chosen" and "rejected" columns ``` Nearly compatible with [RewardTrainer](https://huggingface.co/docs/trl/main/en/reward_trainer). ```python task = TrainingTask.for_direct_preference_optimization(prompt_chosen_rejected_func=prompt_chosen_rejected_func) ds = fds_dataset.prepare_for_training(framework="trl", task=task) # -> ds has "prompt", "chosen" and "rejected" columns ``` Compatible with [DPOTrainer](https://huggingface.co/docs/trl/main/en/dpo_trainer). ### Details I implement this by calling `dataset.format_as("datasets")` and then passing each sample (a simple dictionary) from this dataset to the function that the user provides. This user provided function can return `None`, one sample, a list of samples, or yield samples. This allows users to export multiple training samples from a single Argilla record, e.g. when there's multiple annotators that provided useful corrections, or if the annotated record justifies 3 "chosen", "rejected" pairs because there's a ranking between 3 texts. ## Rename `TrainingTaskMapping` is now `TrainingTask` - the "mapping" part is just unintuitive to the user. Same for `task_mapping` to `task`. **Note:** If people used `task_mapping=...` before, that will now fail. I can make this deprecation softer, but then I have to make `task` optional, which I would rather not do. ## TODO: - [ ] Add TRL to `ArgillaTrainer`, allowing: ```python task = TrainingTask.for_supervised_fine_tuning( formatting_func=formatting_func ) # or any other task from this PR trainer = ArgillaTrainer( dataset=fds_dataset, task=task, framework="trl", ) trainer.train() ``` - [ ] Consider renaming `FeedbackDataset.prepare_for_training` to `FeedbackDataset.export`. - [ ] New tests - [ ] Add documentation **Type of change** - [x] New feature **How Has This Been Tested** Not finished yet. **Checklist** - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- - Tom Aarsen --------- Co-authored-by: Alvaro Bartolome Co-authored-by: Alvaro Bartolome Co-authored-by: David Berenstein Co-authored-by: Daniel Vila Suero Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 2 + docs/_source/_common/dolly_dataset_info.md | 46 + docs/_source/_common/dolly_dataset_load.md | 7 + .../feedback-task/text-classification.md | 10 +- .../tabs/train_prepare_for_training.md | 10 + .../_common/tabs/train_update_config.md | 32 + docs/_source/getting_started/cheatsheet.md | 2 +- .../examples/train-reward-model-rlhf.ipynb | 424 ++----- .../guides/llms/practical_guides/fine_tune.md | 1007 ++++++++++++----- .../llms/practical_guides/fine_tune_others.md | 204 ---- .../llms/practical_guides/practical_guides.md | 8 +- .../llms/practical_guides/update_dataset.md | 1 + docs/_source/guides/train_a_model.md | 15 +- .../reference/python/python_training.rst | 6 + environment_dev.yml | 4 +- pyproject.toml | 5 +- src/argilla/client/datasets.py | 2 + src/argilla/client/feedback/__init__.py | 26 +- src/argilla/client/feedback/dataset/base.py | 61 +- .../client/feedback/training/__init__.py | 23 +- src/argilla/client/feedback/training/base.py | 110 +- .../training/frameworks/transformers.py | 8 +- .../feedback/training/frameworks/trl.py | 385 +++++++ .../client/feedback/training/schemas.py | 825 ++++++++++++-- src/argilla/client/models.py | 4 + src/argilla/feedback/__init__.py | 14 +- src/argilla/training/base.py | 2 +- src/argilla/training/transformers.py | 9 +- .../client/feedback/test_dataset.py | 6 +- .../client/feedback/training/test_trainer.py | 114 +- .../client/feedback/training/test_trl.py | 258 +++++ tests/integration/training/test_autotrain.py | 78 -- .../client/feedback/training/test_schemas.py | 16 +- 33 files changed, 2593 insertions(+), 1131 deletions(-) create mode 100644 docs/_source/_common/dolly_dataset_info.md create mode 100644 docs/_source/_common/dolly_dataset_load.md delete mode 100644 docs/_source/guides/llms/practical_guides/fine_tune_others.md create mode 100644 src/argilla/client/feedback/training/frameworks/trl.py create mode 100644 tests/integration/client/feedback/training/test_trl.py delete mode 100644 tests/integration/training/test_autotrain.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ea26c5f325..7daed7db80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ These are the section headers that we use: ### Added +- Added `ArgillaTrainer` integration with TRL, allowing for easy supervised finetuning, reward modeling, direct preference optimization and proximal policy optimization ([#3467](https://github.com/argilla-io/argilla/pull/3467)) +- Added `formatting_func` to `ArgillaTrainer` for `FeedbackDataset` datasets add a custom formatting for the data ([#3599](https://github.com/argilla-io/argilla/pull/3599)). - Added `login` function in `argilla.client.login` to login into an Argilla server and store the credentials locally ([#3582](https://github.com/argilla-io/argilla/pull/3582)). - Added `login` command to login into an Argilla server ([#3600](https://github.com/argilla-io/argilla/pull/3600)). - Added `logout` command to logout from an Argilla server ([#3605](https://github.com/argilla-io/argilla/pull/3605)). diff --git a/docs/_source/_common/dolly_dataset_info.md b/docs/_source/_common/dolly_dataset_info.md new file mode 100644 index 0000000000..1430b645d4 --- /dev/null +++ b/docs/_source/_common/dolly_dataset_info.md @@ -0,0 +1,46 @@ +The data required for these steps need to be used as comparison data to showcase the preference for the generated prompts. A good example is our [curated Dolly dataset](https://huggingface.co/datasets/argilla/databricks-dolly-15k-curated-en), where we assumed that updated responses get preference over the older ones. Another good example is the [Anthropic RLHF dataset](https://huggingface.co/datasets/Anthropic/hh-rlhf). + +```{note} +The Dolly original dataset contained a lot of reference indicators such as "[1]", which causes the model to hallucinate and incorrectly create references. +``` + +::::{tab-set} + +:::{tab-item} Original + +```bash +### Instruction +When did Virgin Australia start operating? + +### Context +Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. +It is the largest airline by fleet size to use the Virgin brand. [2] +It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3] +It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. +The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4] + +### Response: +Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. +``` +::: + +:::{tab-item} Corrected + +```bash +### Instruction +When did Virgin Australia start operating? + +### Context +Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. +It is the largest airline by fleet size to use the Virgin brand. +It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. +It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. +The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney. + +### Response: +Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. +``` + +::: + +:::: diff --git a/docs/_source/_common/dolly_dataset_load.md b/docs/_source/_common/dolly_dataset_load.md new file mode 100644 index 0000000000..218320c551 --- /dev/null +++ b/docs/_source/_common/dolly_dataset_load.md @@ -0,0 +1,7 @@ +We will use our [curated Dolly dataset](https://huggingface.co/datasets/argilla/databricks-dolly-15k-curated-en), as introduced in the background-section above.. + +```python +import argilla as rg + +feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en") +``` diff --git a/docs/_source/_common/snippets/training/feedback-task/text-classification.md b/docs/_source/_common/snippets/training/feedback-task/text-classification.md index 3a3158b572..f14db2e3ec 100644 --- a/docs/_source/_common/snippets/training/feedback-task/text-classification.md +++ b/docs/_source/_common/snippets/training/feedback-task/text-classification.md @@ -1,29 +1,29 @@ --- title: Text classification -description: When a RatingQuestion, LabelQuestion or MultiLabelQuestion is present in the datasets, we can define a TrainingTaskMappingForTextClassification to use our ArgillaTrainer integration for fine-tuning with openai”, “setfit”, “peft”, “spacy” and “transformers”. +description: When a RatingQuestion, LabelQuestion or MultiLabelQuestion is present in the datasets, we can define a TrainingTaskForTextClassification to use our ArgillaTrainer integration for fine-tuning with "openai", "setfit", "peft", "spacy" and "transformers". links: - linkText: Argilla unification docs linkLink: https://docs.argilla.io/en/latest/guides/llms/practical_guides/collect_responses.html#solve-disagreements - linkText: Argilla fine-tuning docs - linkLink: https://docs.argilla.io/en/latest/guides/llms/practical_guides/fine_tune_others.html#text-classification + linkLink: https://docs.argilla.io/en/latest/guides/llms/practical_guides/fine_tune.html#text-classification - linkText: ArgillaTrainer docs linkLink: https://docs.argilla.io/en/latest/guides/train_a_model.html#the-argillatrainer --- ```python -import argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTaskMapping +import argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask dataset = FeedbackDataset.from_argilla( name="", workspace="" ) -task_mapping = TrainingTaskMapping.for_text_classification( +task = TrainingTask.for_text_classification( text=dataset.field_by_name(""), label=dataset.question_by_name("") ) trainer = ArgillaTrainer( dataset=dataset, - task_mapping=task_mapping, + task=task, framework="", ) trainer.update_config() diff --git a/docs/_source/_common/tabs/train_prepare_for_training.md b/docs/_source/_common/tabs/train_prepare_for_training.md index 1ddba699ed..e3dc9ee5c8 100644 --- a/docs/_source/_common/tabs/train_prepare_for_training.md +++ b/docs/_source/_common/tabs/train_prepare_for_training.md @@ -91,4 +91,14 @@ dataset_rg.prepare_for_training(framework="spark-nlp", train_size=1) ``` ::: +:::{tab-item} TRL + +```python +import argilla as rg + +dataset_rg = rg.load("") +dataset_rg.prepare_for_training(framework="trl", task=..., train_size=1) +``` +::: + :::: \ No newline at end of file diff --git a/docs/_source/_common/tabs/train_update_config.md b/docs/_source/_common/tabs/train_update_config.md index e97b53f929..66c9111ecb 100644 --- a/docs/_source/_common/tabs/train_update_config.md +++ b/docs/_source/_common/tabs/train_update_config.md @@ -231,4 +231,36 @@ trainer.update_config( ``` ::: +:::{tab-item} TRL + +```python +# parameters from `trl.RewardTrainer`, `trl.SFTTrainer`, `trl.PPOTrainer` or `trl.DPOTrainer`. +# `transformers.TrainingArguments` +trainer.update_config( + per_device_train_batch_size = 8, + per_device_eval_batch_size = 8, + gradient_accumulation_steps = 1, + learning_rate = 5e-5, + weight_decay = 0, + adam_beta1 = 0.9, + adam_beta2 = 0.9, + adam_epsilon = 1e-8, + max_grad_norm = 1, + learning_rate = 5e-5, + num_train_epochs = 3, + max_steps = 0, + log_level = "passive", + logging_strategy = "steps", + save_strategy = "steps", + save_steps = 500, + seed = 42, + push_to_hub = False, + hub_model_id = "user_name/output_dir_name", + hub_strategy = "every_save", + hub_token = "1234", + hub_private_repo = False +) +``` +::: + :::: diff --git a/docs/_source/getting_started/cheatsheet.md b/docs/_source/getting_started/cheatsheet.md index 6d92aeea10..114b323117 100644 --- a/docs/_source/getting_started/cheatsheet.md +++ b/docs/_source/getting_started/cheatsheet.md @@ -222,7 +222,7 @@ rg.log(records_for_training, name="majority_voter_results") ## Train Models -We love our open-source training libraries as much as you do, so we provide integrations with all of them to limit the time you spend on data preparation and have more fun with actual training. As of now, we support `spacy`, `transformers`, `setfit`, `openai`, `autotrain`, and way more. Want to get to know all support? Train/fine-tune a model from a `FeedbackDataset` as explained [here](/guides/llms/practical_guides/practical_guides.html), or either a `TextClassificationDataset` or a `TokenClassificationDataset`[here](/guides/train_a_model.md). +We love our open-source training libraries as much as you do, so we provide integrations with all of them to limit the time you spend on data preparation and have more fun with actual training. We support `spacy`, `transformers`, `setfit`, `openai`, `autotrain`, and way more. Want to get to know all support? Train/fine-tune a model from a `FeedbackDataset` as explained [here](/guides/llms/practical_guides/practical_guides/fine_tune.html), or either a `TextClassificationDataset` or a `TokenClassificationDataset`[here](/guides/train_a_model.md). ```python from argilla.training import ArgillaTrainer diff --git a/docs/_source/guides/llms/examples/train-reward-model-rlhf.ipynb b/docs/_source/guides/llms/examples/train-reward-model-rlhf.ipynb index 9f4477169e..79f12705d1 100644 --- a/docs/_source/guides/llms/examples/train-reward-model-rlhf.ipynb +++ b/docs/_source/guides/llms/examples/train-reward-model-rlhf.ipynb @@ -66,7 +66,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "metadata": { "id": "IbSU2uDhMBYU" }, @@ -108,7 +108,7 @@ "# you can find the Spaces URL under the Embed this space button\n", "# Replace api_key if you configured a custom API key\n", "rg.init(\n", - " api_url=\"http://localhost:6900\", \n", + " api_url=\"http://localhost:6900\",\n", " api_key=\"admin.apikey\"\n", ")" ] @@ -441,7 +441,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -614,12 +614,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The next step is to prepare the dataset in the standard format for training a reward model. In particular, we want to select the chose and rejected response from the user feedback. We do this by utilizing the value of the RatingQuestion's response:" + "The next step is to prepare the dataset in the standard format for training a reward model. In particular, we want to select the chosen and rejected response from the user feedback. We do this by creating a `TrainingTask` instance for reward modelling using a function that returns chosen-rejected tuples." ] }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -628,252 +628,82 @@ "id": "cll7j10swNT2", "outputId": "0e4974e6-74b1-41be-c54c-92bca9732ddc" }, + "outputs": [], + "source": [ + "from typing import Any, Dict\n", + "from argilla.feedback import TrainingTask\n", + "from collections import Counter\n", + "\n", + "def formatting_func(sample: Dict[str, Any]):\n", + " # sample[\"choose-best\"] => [{'user_id': None, 'value': 1, 'status': 'submitted'}, ...]\n", + " values = [\n", + " annotation[\"value\"]\n", + " for annotation in sample[\"choose-best\"]\n", + " if annotation[\"status\"] == \"submitted\"\n", + " ]\n", + " # values => [1]\n", + " winning_response = Counter(values).most_common(1)[0][0]\n", + " if winning_response == 1:\n", + " chosen = sample[\"response-1\"]\n", + " rejected = sample[\"response-2\"]\n", + " else:\n", + " chosen = sample[\"response-2\"]\n", + " rejected = sample[\"response-1\"]\n", + " return chosen, rejected\n", + "\n", + "task = TrainingTask.for_reward_modeling(formatting_func=formatting_func)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If we want, we can observe the resulting dataset by preparing it for training with TRL:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, "outputs": [ { "data": { - "text/html": [ - "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
instructionchosen_responserejected_response
0What is DepreciationDepreciation is the drop in value of an asset ...What is Depreciation – 10 Important Facts to K...
1What do you know about the city of Aberdeen in...Aberdeen is a city located in the North East o...As an AI language model, I don't have personal...
2Describe thunderstorm season in the United Sta...Thunderstorm season in the United States and C...Describe thunderstorm season in the United Sta...
3When did Peloton IPO?\\nOn September 26, 2019, ...Peloton became a public company via an initial...When did Peloton IPO?\\nPeloton IPO'd on May 26...
4What is the best way to answer an interview qu...The first recommended step is to ask clarifyin...Some of the best ways to answer an interview q...
............
7396How do i accept the changeEmbrace the change and see the differenceI's a great opportunity to improve. The only t...
7397Extract the teams that the footballer Sócrates...Brazil, Botafogo-SP, Corinthians, FiorentinaExtract the teams that the footballer Sócrates...
7398Without quoting directly from the text give me...Brendon Small is a stand-up comedian, Creator...Without quoting directly from the text give me...
7399Is Killing is Sin ? Is it tureKilling a human being should not be sin becaus...Is Killing is Sin ? Is it ture?\\nKilling can b...
7400Who was Otto von Bismarck?\\nOtto, Prince of Bi...Otto von Bismarck was a Prussian and German so...Who was Otto von Bismarck?\\nOtto von Bismarck ...
\n", - "

7401 rows × 3 columns

\n", - "
\n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - "
\n", - " " - ], "text/plain": [ - " instruction \\\n", - "0 What is Depreciation \n", - "1 What do you know about the city of Aberdeen in... \n", - "2 Describe thunderstorm season in the United Sta... \n", - "3 When did Peloton IPO?\\nOn September 26, 2019, ... \n", - "4 What is the best way to answer an interview qu... \n", - "... ... \n", - "7396 How do i accept the change \n", - "7397 Extract the teams that the footballer Sócrates... \n", - "7398 Without quoting directly from the text give me... \n", - "7399 Is Killing is Sin ? Is it ture \n", - "7400 Who was Otto von Bismarck?\\nOtto, Prince of Bi... \n", - "\n", - " chosen_response \\\n", - "0 Depreciation is the drop in value of an asset ... \n", - "1 Aberdeen is a city located in the North East o... \n", - "2 Thunderstorm season in the United States and C... \n", - "3 Peloton became a public company via an initial... \n", - "4 The first recommended step is to ask clarifyin... \n", - "... ... \n", - "7396 Embrace the change and see the difference \n", - "7397 Brazil, Botafogo-SP, Corinthians, Fiorentina \n", - "7398 Brendon Small is a stand-up comedian, Creator... \n", - "7399 Killing a human being should not be sin becaus... \n", - "7400 Otto von Bismarck was a Prussian and German so... \n", - "\n", - " rejected_response \n", - "0 What is Depreciation – 10 Important Facts to K... \n", - "1 As an AI language model, I don't have personal... \n", - "2 Describe thunderstorm season in the United Sta... \n", - "3 When did Peloton IPO?\\nPeloton IPO'd on May 26... \n", - "4 Some of the best ways to answer an interview q... \n", - "... ... \n", - "7396 I's a great opportunity to improve. The only t... \n", - "7397 Extract the teams that the footballer Sócrates... \n", - "7398 Without quoting directly from the text give me... \n", - "7399 Is Killing is Sin ? Is it ture?\\nKilling can b... \n", - "7400 Who was Otto von Bismarck?\\nOtto von Bismarck ... \n", - "\n", - "[7401 rows x 3 columns]" + "Dataset({\n", + " features: ['chosen', 'rejected'],\n", + " num_rows: 7401\n", + "})" ] }, - "execution_count": 33, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# build a dataset with chosen and rejected responses\n", - "rows = []\n", - "for record in feedback_dataset.records:\n", - " if record.responses is None or len(record.responses) == 0:\n", - " continue\n", - " # get chosen index from RatingQuestion response\n", - " chosen_id = record.responses[0].values[\"choose-best\"].value\n", - " rejected_id = 2 if chosen_id == 1 else 1\n", - "\n", - " # build rows for rm training\n", - " rows.append({\n", - " \"instruction\": record.fields[\"instruction\"],\n", - " \"chosen_response\": record.fields[f\"response-{chosen_id}\"],\n", - " \"rejected_response\": record.fields[f\"response-{rejected_id}\"]\n", - " })\n", - "\n", - "# build dataset for training\n", - "prepared_dataset = Dataset.from_list(rows)\n", - "prepared_dataset.to_pandas()" + "dataset = feedback_dataset.prepare_for_training(framework=\"trl\", task=task)\n", + "dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'chosen': \"Depreciation is the drop in value of an asset due to wear and tear, age and obsolescence (going out of date) as recorded in an organization's financial records.\",\n", + " 'rejected': 'What is Depreciation – 10 Important Facts to Know?\\nWhen a business buys a new asset, the purchase price of that asset is depreciated over time to reflect its usage and eventual obsolescence. Depreciation expense can be a tax deductible expense and is usually a non-cash expense reported on a company’s income statement and balance sheet. The amount of depreciation expense a company reports each year is the difference between the original purchase price of the asset and what the current value of that asset might be. Here are 10 important facts to know about depreciation:\\n1. Depreciation is a non-cash expense. It is an expense that is reported in a business’s income statement and balance sheet and not a cash flow expense.\\n2. Depreciation is an accounting standard and it is required to be disclosed in a business’s financial statements.\\n3. The amount of depreciation is usually a tax expense and not a cash expense reported on a company’s income statement'}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[0]" ] }, { @@ -915,7 +745,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -941,57 +771,15 @@ { "data": { "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [ 101/1041 06:05 < 57:52, 0.27 it/s, Epoch 0.29/3]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining LossValidation LossAccuracy
200.6701000.5461840.878984
400.4069000.2669480.910319
600.2289000.1993270.927066
800.2803000.1824980.938952

\n", - "

\n", - " \n", - " \n", - " [ 31/232 00:07 < 00:48, 4.15 it/s]\n", - "
\n", - " " + "
[08/08/23 16:36:51] INFO     INFO:ArgillaTRLTrainer:{'eval_loss': 0.1626577377319336, 'eval_accuracy':   trl.py:226\n",
+              "                             0.937204591492235, 'eval_runtime': 6.5907, 'eval_samples_per_second':                 \n",
+              "                             224.709, 'eval_steps_per_second': 28.221, 'epoch': 1.0}                               \n",
+              "
\n" ], "text/plain": [ - "" + "\u001b[2;36m[08/08/23 16:36:51]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m INFO:ArgillaTRLTrainer:\u001b[1m{\u001b[0m\u001b[32m'eval_loss'\u001b[0m: \u001b[1;36m0.1626577377319336\u001b[0m, \u001b[32m'eval_accuracy'\u001b[0m: \u001b]8;id=234053;file://C:\\code\\argilla\\src\\argilla\\client\\feedback\\training\\frameworks\\trl.py\u001b\\\u001b[2mtrl.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=146316;file://C:\\code\\argilla\\src\\argilla\\client\\feedback\\training\\frameworks\\trl.py#226\u001b\\\u001b[2m226\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m0.937204591492235\u001b[0m, \u001b[32m'eval_runtime'\u001b[0m: \u001b[1;36m6.5907\u001b[0m, \u001b[32m'eval_samples_per_second'\u001b[0m: \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m224.709\u001b[0m, \u001b[32m'eval_steps_per_second'\u001b[0m: \u001b[1;36m28.221\u001b[0m, \u001b[32m'epoch'\u001b[0m: \u001b[1;36m1.0\u001b[0m\u001b[1m}\u001b[0m \u001b[2m \u001b[0m\n" ] }, "metadata": {}, @@ -999,50 +787,22 @@ } ], "source": [ - "model_name = \"distilroberta-base\"\n", + "from argilla.feedback import ArgillaTrainer\n", "\n", - "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1)\n", - "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", - "\n", - "if tokenizer.pad_token is None:\n", - " tokenizer.pad_token = tokenizer.eos_token\n", - " model.config.pad_token_id = model.config.eos_token_id\n", - "\n", - "def formatting_func(examples):\n", - " kwargs = {\"padding\": \"max_length\", \"truncation\": True, \"max_length\": 512, \"return_tensors\": \"pt\"}\n", - "\n", - " # Prepend the prompt and a line break to the original_response and response-1 fields.\n", - " prompt_plus_chosen_response = examples[\"instruction\"] + \"\\n\" + examples[\"chosen_response\"]\n", - " prompt_plus_rejected_response = examples[\"instruction\"] + \"\\n\" + examples[\"rejected_response\"]\n", - "\n", - " # Then tokenize these modified fields.\n", - " tokens_chosen = tokenizer.encode_plus(prompt_plus_chosen_response, **kwargs)\n", - " tokens_rejected = tokenizer.encode_plus(prompt_plus_rejected_response, **kwargs)\n", - "\n", - " return {\n", - " \"input_ids_chosen\": tokens_chosen[\"input_ids\"][0], \"attention_mask_chosen\": tokens_chosen[\"attention_mask\"][0],\n", - " \"input_ids_rejected\": tokens_rejected[\"input_ids\"][0], \"attention_mask_rejected\": tokens_rejected[\"attention_mask\"][0]\n", - " }\n", - " \n", - "formatted_dataset = prepared_dataset.map(formatting_func) \n", - "formatted_dataset = formatted_dataset.train_test_split()\n", - "\n", - "training_args = TrainingArguments(\n", - " output_dir=\"./reward_model\",\n", - " per_device_train_batch_size=16,\n", - " evaluation_strategy=\"steps\", \n", - " logging_steps=200, \n", + "model_name = \"distilroberta-base\"\n", + "trainer = ArgillaTrainer(\n", + " dataset=feedback_dataset,\n", + " task=task,\n", + " framework=\"trl\",\n", + " model=model_name,\n", + " train_size=0.8,\n", ")\n", - "\n", - "trainer = RewardTrainer(\n", - " model=model,\n", - " args=training_args,\n", - " tokenizer=tokenizer,\n", - " train_dataset=formatted_dataset[\"train\"],\n", - " eval_dataset=formatted_dataset[\"test\"],\n", + "trainer.update_config(\n", + " per_device_train_batch_size=16,\n", + " evaluation_strategy=\"steps\",\n", + " logging_steps=200,\n", ")\n", - "\n", - "trainer.train()\n" + "trainer.train(\"./reward_model\")\n" ] }, { @@ -1182,7 +942,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.8.17" }, "vscode": { "interpreter": { diff --git a/docs/_source/guides/llms/practical_guides/fine_tune.md b/docs/_source/guides/llms/practical_guides/fine_tune.md index 23137afeab..452e8b09a9 100644 --- a/docs/_source/guides/llms/practical_guides/fine_tune.md +++ b/docs/_source/guides/llms/practical_guides/fine_tune.md @@ -1,16 +1,297 @@ -# Fine-tune an LLM +# Fine-tuning a Feedback Dataset -After [collecting the responses](./collect_responses.html) from our `FeedbackDataset` we can start fine-tuning our LLM. Due to the customizability of the task, this might require setting up a custom post-processing workflow but we will provide some good toy examples for the [classic LLM approaches](../conceptual_guides/rlhf.html): pre-training, supervised fine-tuning, reward modeling, and reinforcement learning. +After [collecting the responses](/guides/llms/practical_guides/collect_responses.html) from our `FeedbackDataset`, we can start fine-tuning our LLMs and other models. Due to the customizability of the task, this might require setting up a custom post-processing workflow, but we will provide some good toy examples for the [LLM approaches](/guides/llms/conceptual_guides/rlhf.html): pre-training, supervised fine-tuning, and reinforcement learning through human feedback (RLHF). However, we also still provide for other NLP tasks like text classification. +## The `ArgillaTrainer` -## Supervised finetuning +The `ArgillaTrainer` is a wrapper around many of our favorite NLP libraries. It provides a very intuitive abstract representation to facilitate simple training workflows using decent default pre-set configurations without having to worry about any data transformations from Argilla. -The goal of Supervised Fine Tuning (SFT) is to optimize this pre-trained model to generate the responses that users are looking for. After pre-training a causal language model, it can generate feasible human text, but it will not be able to have proper `answers` to `question` phrases posed by the user in a conversational or instruction set. Therefore, we need to collect and curate data tailored to this use case to teach the model to mimic this data. We have a section in our docs about [collecting data for this task](../conceptual_guides/sft.html) and there are many good [pre-trained causal language models](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads) available on Hugging Face. +Using the `ArgillaTrainer` is straightforward, but it slightly differs per task. + +1. First, we define a `TrainingTask`. This is done using a custom `formatting_func`. However, tasks like Text Classification can also be defined using default definitions using the `FeedbackDataset` fields and questions. These tasks are then used for retrieving data from a dataset and initializing the training. We also offer some ideas for [unifying data](/guides/llms/practical_guides/collect_responses) out of the box. +2. Next, we initialize the `ArgillaTrainer` and forward the task and training framework. Internally, this uses the `FeedbackData.prepare_for_training`-method to format the data according to the expectations from the framework. Some other interesting methods are: + 1. `ArgillaTrainer.update_config` to change framework specific training parameters. + 2. `ArgillaTrainer.train` to start training. + 3. `ArgillTrainer.predict` to run inference. + +Underneath, you can see the happy flow for using the `ArgillaTrainer`. + +```python +from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask + +dataset = FeedbackDataset.from_huggingface( + repo_id="argilla/emotion" +) +task = TrainingTask.for_text_classification( + text=dataset.field_by_name("text"), + label=dataset.question_by_name("label"), +) +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="setfit" +) +trainer.update_config(num_iterations=1) +trainer.train(output_dir="my_setfit_model") +trainer.predict("This is awesome!") +``` + +### Supported Frameworks + +We plan on adding more support for other tasks and frameworks so feel free to reach out on our Slack or GitHub to help us prioritize each task. + +| Task/Framework | TRL | OpenAI | AutoTrain | SetFit | spaCy | Transformers | PEFT | +|:--------------------------------|:-----|:-------|:----------|:-------|:------|:-------------|:-----| +| Text Classification | | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | +| Supervised Fine-tuning | ✔️ | | | | | | | +| Reward Modeling | ✔️ | | | | | | | +| Proximal Policy Optimization | ✔️ | | | | | | | +| Direct Preference Optimization | ✔️ | | | | | | | + + +```{note} +We also offer support for Token Classification using our `TokenClassifcationDataset` but this is shown in [a section](/guides/train_a_model) about our older dataset-types. +``` + +#### Training Configs + +The trainer also has an `ArgillaTrainer.update_config()` method, which maps a dict with `**kwargs` to the respective framework. So, these can be derived from the underlying framework that was used to initialize the trainer. Underneath, you can find an overview of these variables for the supported frameworks. + +```{note} +Note that you don't need to pass all of them directly and that the values below are their default configurations. +``` -### Data +```{include} /_common/tabs/train_update_config.md +``` + +### The `TrainingTask` + +A `TrainingTask` is used to define how the data should be processed and formatted according to the associated task and framework. Each task has its own `TrainingTask.for_*`-classmethod and the data formatting can always be defined using a custom `formatting_func`. However, simpler tasks like Text Classification can also be defined using default definitions. These directly use the fields and questions from the FeedbackDataset configuration to infer how to prepare the data. Underneath you can find an overview of the `TrainingTask` requirements. + +| Method | Content | `formatting_func` return type | Default | +|:-----------------------------------|:-----------------|:-----------------------------------------------------------|:------------------| +| for_text_classification | `text-label` | `Union[Tuple[str, str], Tuple[str, List[str]]]` | ✔️ | +| for_supervised_fine_tuning | `text` | `Optional[Union[str, Iterator[str]]]` | ✗ | +| for_reward_modeling | `chosen-rejected`| `Optional[Union[Tuple[str, str], Iterator[Tuple[str, str]]]]` | ✗ | +| for_proximal_policy_optimization | `text` | `Optional[Union[str, Iterator[str]]]` | ✗ | +| for_direct_preference_optimization| `prompt-chosen-rejected` | `Optional[Union[Tuple[str, str, str], Iterator[Tuple[str, str, str]]]]` | ✗ | + + +## Tasks + +### Text Classification + +#### Background + +Text classification is a widely used NLP task where labels are assigned to text. Major companies rely on it for various applications. Sentiment analysis, a popular form of text classification, assigns labels like 🙂 positive, 🙁 negative, or 😐 neutral to text. Additionally, we distinguish between single- and multi-label text classification. + +::::{tab-set} + +:::{tab-item} Single-label +Single-label text classification refers to the task of assigning a single category or label to a given text sample. Each text is associated with only one predefined class or category. For example, in sentiment analysis, a single-label text classification task would involve assigning labels such as "positive," "negative," or "neutral" to texts based on their sentiment. + +```batch +"The help for my application of a new card and mortgage was great", "positive" +``` + +::: + +:::{tab-item} Multi-label +Multi-label text classification is generally more complex than single-label classification due to the challenge of determining and predicting multiple relevant labels for each text. It finds applications in various domains, including document tagging, topic labeling, and content recommendation systems. For example, in customer care, a multi-label text classification task would involve assigning topics such as "new_card," "mortgage," or "opening_hours" to texts based on their content. + +```{tip} +For a multi-label scenario it is recommended to add some examples without any labels to improve model performance. +``` + +```batch +"The help for my application of a new card and mortgage was great", ["new_card", "mortgage"] +``` + +::: + +:::: + +We then use either `text-label`-pair to further fine-tune the model. + +#### Training + +Text classification is one of the most widely supported training tasks tasks within NLP. For example purposes we will use our [emotion demo dataset](https://huggingface.co/datasets/argilla/emotion). + +**Data Preparation** + +```python +from argilla.feedback import FeedbackDataset + +dataset = FeedbackDataset.from_huggingface( + repo_id="argilla/emotion" +) +``` + + +For this task, we assume we need a `text-label`-pair or a `formatting_func` for defining the `TrainingTask.for_text_classification`. + +::::{tab-set} + +:::{tab-item} text-label-pair +We offer the option to use default unification strategies and formatting based on a `text-label`-pair. Here we infer formatting information based on a `TextField` and a `LabelQuestion`, `MultiLabelQuestion`, `RatingQuestion` or , `RankingQuestion` from the dataset. This is the easiest way to define a `TrainingTask` for text classification but if you need a custom workflow, you can use `formatting_func`. + +```{note} +An overview of the unifcation measures can be found [here](/guides/llms/practical_guides/collect_responses). The `RatingQuestion` and `RankingQuestion` can be unified using a "majority"-, "min"-, "max"- or "disagreement"-strategy. Both the `LabelQuestion` and `MultiLabelQuestion` can be resolved using a "majority"-, or "disagreement"-strategy. +``` + +```python +from argilla.feedback import FeedbackDataset, TrainingTask + +dataset = FeedbackDataset.from_huggingface( + repo_id="argilla/emotion" +) +task = TrainingTask.for_text_classification( + text=dataset.field_by_name("text"), + label=dataset.question_by_name("label"), + label_strategy=None # defaults presets +) +``` + +::: + +:::{tab-item} formatting_func +We offer the option to provide a `formatting_func` to the `TrainingTask.for_text_classification`. This function is applied to each sample in the dataset and can be used for more advanced preprocessing and data formatting. The function should return a tuple of `(text, label)` as `Tuple[str, str]` or `Tuple[str, List[str]]`. + +```python +from argilla.feedback import FeedbackDataset, TrainingTask + +dataset = FeedbackDataset.from_huggingface( + repo_id="argilla/emotion" +) + +def formatting_func(sample): + text = sample["text"] + # Choose the most common label + values = [resp["value"] for resp in sample["label"]] + counter = Counter(values) + if counter: + most_common = counter.most_common() + max_frequency = most_common[0][1] + most_common_elements = [ + element for element, frequency in most_common if frequency == max_frequency + ] + label = random.choice(most_common_elements) + return (text, label) + else: + return None + +task = TrainingTask.for_text_classification(formatting_func=formatting_func) +``` + +::: + + +:::: + +We can then define our `ArgillaTrainer` for any of [the supported frameworks](fine_tune.md#training-configs) and [customize the training config](#supported-frameworks) using `ArgillaTrainer.update_config`. + +```python +from argilla.feedback import ArgillaTrainer + +trainer = ArgillaTrainer( + dataset=feedback_dataset, + task=task, + framework="spacy", + train_size=0.8, + model="en_core_web_sm", +) + +trainer.train(output_dir="textcat_model") +``` + +### Pre-training + +#### Background + +When talking about pre-training, we generally talk about a simple `prompt-completion` task, where we need the model to pick up on basic statistics of the language it is learning. Given that you are familiar with Spanish cuisine and the prompt sentence, `The base ingredient of paella is ___`, you know that the word in the `___` is much more likely to be `rice` than `apples`. So, you are basically training a causal language model or text generation model. + +```{note} +This is an unsupervised approach hence we only infer training data from a basic sentence like `The base ingredient of paella is rice.` by starting with the word `The`, and from there unwrapping the sentence step by step. +``` + +#### Training + + +Many training datasets for this task can be found online (e.g., [Hugging Face](https://huggingface.co/datasets?task_categories=task_categories:text-generation&sort=downloads)). You can either upload this in the right Argilla format but it might be needed to collect and fine-tune additional data with Argilla. So we, therefore, provide a basic setup underneath which should help you to start gathering or preparing pre-training data. + +```{note} +When it comes to pre-training an LLM, we generally do not need data of highest quality, but it is always smart to use domain-specfic data and to avoid data that might lead to undesired effects like hallucination and bias. +``` + +First, create a `FeedbackDataset` with records. + +```python +import argilla as rg + +# create prompt-completion dataset +dataset = rg.FeedbackDataset( + guidelines="Please, complete the following prompt fields with a brief text answer.", + fields=[ + rg.TextField(name="prompt"), + ], + questions=[ + rg.TextQuestion(name="completion", title="Add a brief text answer."), + ] +) + +# create a Feedback Records +record = rg.FeedbackRecord( + fields={ + "prompt": "The base ingredient of paella is rice." + } +) + +dataset.add_records([record]) +``` + +Then push it to Argilla via `push_to_argilla`. + +::::{tab-set} + +:::{tab-item} Argilla 1.14.0 or higher +```python +remote_dataset = dataset.push_to_argilla(name="pre-training") +``` +::: + +:::{tab-item} Lower than Argilla 1.14.0 +```python +dataset.push_to_argilla(name="pre-training") +``` +::: +:::: + +And, finally, load the `FeedbackDataset` from Argilla. + +```python +import argilla as rg +from datasets import Dataset + +dataset = rg.FeedbackDataset.from_argilla("pre-training") +prompts = {"prompt": [record.fields.get("prompt") for record in dataset.records]} +dataset = Dataset.from_dict(prompts) +dataset +# Dataset({ +# features: ['prompt'], +# num_rows: 1 +# }) +``` + +There are many ways and great packages to deal with this `pre-training` phase, but generally, NLP training frameworks like [KerasNLP](https://keras.io/keras_nlp/) and [Hugging Face](https://huggingface.co/) offer great out-of-the-box methods for training a causal language model. In our guide, we will refer to the great docs of the Hugging Face `transformers` and `datasets` libraries and prepare our training data in the format they require for [training a causal language model](https://huggingface.co/learn/nlp-course/chapter7/6#training-a-causal-language-model-from-scratch). + +### Supervised finetuning + +#### Background + +The goal of Supervised Fine Tuning (SFT) is to optimize this pre-trained model to generate the responses that users are looking for. After pre-training a causal language model, it can generate feasible human text, but it will not be able to have proper `answers` to `question` phrases posed by the user in a conversational or instruction set. Therefore, we need to collect and curate data tailored to this use case to teach the model to mimic this data. We have a section in our docs about [collecting data for this task](../conceptual_guides/sft.html) and there are many good [pre-trained causal language models](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads) available on Hugging Face. Data for the training phase is generally divided into two different types generic for domain-like finetuning or chat for fine-tuning an instruction set. -#### Generic +*Generic* In a generic fine-tuning setting, the aim is to make the model more proficient in generating coherent and contextually appropriate text within a particular domain. For example, if we want the model to generate text related to medical research, we would fine-tune it using a dataset consisting of medical literature, research papers, or related documents. By exposing the model to domain-specific data during training, it becomes more knowledgeable about the terminology, concepts, and writing style prevalent in that domain. This enables the model to generate more accurate and contextually appropriate responses when prompted with queries or tasks related to the specific domain. An example of this format is the [PubMed data](https://huggingface.co/datasets/pubmed), but it might be smart to add some nuance by generic instruction phrases that indicate the scope of the data, like `Generate a medical paper abstract: ...`. @@ -18,7 +299,7 @@ In a generic fine-tuning setting, the aim is to make the model more proficient i # Five distinct ester hydrolases (EC 3-1) have been characterized in guinea-pig epidermis. These are carboxylic esterase, acid phosphatase, pyrophosphatase, and arylsulphatase A and B. Their properties are consistent with those of lysosomal enzymes. ``` -#### Chat +*Chat* On the other hand, instruction-based fine-tuning involves training the model to understand and respond to specific instructions or prompts given by the user. This approach allows for greater control and specificity in the generated output. For example, if we want the model to summarize a given text, we can fine-tune it using a dataset that consists of pairs of text passages and their corresponding summaries. The model can then be instructed to generate a summary based on a given input text. By fine-tuning the model in this manner, it becomes more adept at following instructions and producing output that aligns with the desired task or objective. An example of this format used is our [curated Dolly dataset](https://huggingface.co/datasets/argilla/databricks-dolly-15k-curated-en) with `instruction`, `context` and `response` fields. However, we can also have simpler datasets with only `question` and `answer` fields. @@ -56,78 +337,153 @@ Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two a :::: -Ultimately, the choice between these two approaches depends on the specific requirements of the application and the desired level of control over the model's output. By employing the appropriate fine-tuning strategy, we can enhance the model's performance and make it more suitable for a wide range of applications and use cases. +Ultimately, the choice between these two approaches to be used as `text`-field depends on the specific requirements of the application and the desired level of control over the model's output. By employing the appropriate fine-tuning strategy, we can enhance the model's performance and make it more suitable for a wide range of applications and use cases. -### Training +#### Training -There are many good libraries to help with this step, however, we are a fan of the [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl) package, and the no-code [Hugging Face AutoTrain](https://huggingface.co/spaces/autotrain-projects/autotrain-advanced) for fine-tuning. In both cases, we need a backbone model, obtained from the [pre-training step](#pre-training) and for example purposes we will use our [curated Dolly dataset](https://huggingface.co/datasets/argilla/databricks-dolly-15k-curated-en). +There are many good libraries to help with this step, however, we are a fan of the [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl) package, [Transformer Reinforcement Learning X (TRLX)](https://github.com/CarperAI/trlx),and the no-code [Hugging Face AutoTrain](https://huggingface.co/spaces/autotrain-projects/autotrain-advanced) for fine-tuning. In both cases, we need a backbone model, obtained from the [pre-training step](#pre-training) and for example purposes we will use our [curated Dolly dataset](https://huggingface.co/datasets/argilla/databricks-dolly-15k-curated-en). ```{note} -This dataset only contains a single annotator response per record. We gave some sugggestions on dealing with [responses from multiple annotators](/guides/llms/practical_guides/collect_responses). +This dataset only contains a single annotator response per record. We gave some suggestions on dealing with [responses from multiple annotators](/guides/llms/practical_guides/collect_responses). ``` +::::{tab-set} + +:::{tab-item} TRL + +The [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl) package provides a flexible and customizable framework for fine-tuning models. It allows users to have fine-grained control over the training process, enabling them to define their functions and to further specify the desired behavior of the model. This approach requires a deeper understanding of reinforcement learning concepts and techniques, as well as more careful experimentation. It is best suited for users who have experience in reinforcement learning and want fine-grained control over the training process. Additionally, it directly integrates with [Parameter-Efficient Fine-Tuning](https://huggingface.co/docs/peft/index) (PEFT) decreasing the computational complexity of this step of training an LLM. + +**Data Preparation** + ```python import argilla as rg from datasets import Dataset feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en") +``` -data = {"instruction": [], "context": [], "response": []} -for entry in feedback_dataset: - if entry.responses: - res = entry.responses[0].values - data["instruction"].append(res["new-instruction"].value) - data["context"].append(res["new-context"].value) - data["response"].append(res["new-response"].value) +We offer the option to provide a `formatting_func` to the `TrainingTask.for_supervised_fine_tuning`. This function is applied to each sample in the dataset and can be used for advanced preprocessing and data formatting. The function should return a `text` as `str`. -dataset = Dataset.from_dict(data) -dataset -# Dataset({ -# features: ['instruction', 'context', 'response'], -# num_rows: 15000 -# }) + +```python +from argilla.feedback import TrainingTask +from typing import Dict, Any + +template = """\ +### Instruction: {instruction}\n +### Context: {context}\n +### Response: {response}""" + +def formatting_func(sample: Dict[str, Any]) -> str: + # What `sample` looks like depends a lot on your FeedbackDataset fields and questions + return template.format( + instruction=sample["new-instruction"][0]["value"], + context=sample["new-context"][0]["value"], + response=sample["new-response"][0]["value"], + ) + +task = TrainingTask.for_supervised_fine_tuning(formatting_func=formatting_func) ``` -#### TRL +You can observe the resulting dataset by calling `FeedbackDataset.prepare_for_training`. We can use `"trl"` as the framework for example: -The [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl) package provides a flexible and customizable framework for fine-tuning models. It allows users to have fine-grained control over the training process, enabling them to define their functions and to further specify the desired behavior of the model. This approach requires a deeper understanding of reinforcement learning concepts and techniques, as well as more careful experimentation. It is best suited for users who have experience in reinforcement learning and want fine-grained control over the training process. Additionally, it directly integrates with [Performance Efficient Fine Tuning](https://huggingface.co/docs/peft/index) (PEFT) decreasing the computational complexity of this step of training an LLM. +```python +dataset = feedback_dataset.prepare_for_training( + framework="trl", + task=task +) +""" +>>> dataset +Dataset({ + features: ['id', 'text'], + num_rows: 15015 +}) +>>> dataset[0]["text"] +### Instruction: When did Virgin Australia start operating? + +### Context: Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney. + +### Response: Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. +""" +``` + +**ArgillaTrainer** ```python -from transformers import AutoModelForCausalLM -from datasets import load_dataset -from trl import SFTTrainer +from argilla.feedback import ArgillaTrainer + +trainer = ArgillaTrainer( + dataset=feedback_dataset, + task=task, + framework="trl", + train_size=0.8, + model="gpt2", +) +# e.g. using LoRA: +# from peft import LoraConfig +# trainer.update_config(peft_config=LoraConfig()) +trainer.train(output_dir="sft_model") +``` -dataset = ... +**Inference** -model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") +Let's observe if it worked to train the model to respond within our template. We'll create a quick helper method for this. -def formatting_prompts_func(example): - text = ( - f"### Instruction: {example['instruction']}\n" + - f"### Context: {example['context']}\n" + - f"### Response: {example['response']}" +```python +from transformers import GenerationConfig, AutoTokenizer, GPT2LMHeadModel + + +def generate(model_id: str, instruction: str, context: str = "") -> str: + model = GPT2LMHeadModel.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + inputs = template.format( + instruction=instruction, + context=context, + response="", + ).strip() + + encoding = tokenizer([inputs], return_tensors="pt") + outputs = model.generate( + **encoding, + generation_config=GenerationConfig( + max_new_tokens=32, + min_new_tokens=12, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ), ) - return text - -trainer = SFTTrainer( - model, - train_dataset=dataset, - packing=True, - formatting_func=formatting_prompts_func, - # peft_config=LoraConfig() # from peft import LoraConfig -) + return tokenizer.decode(outputs[0]) +``` + +```python +>>> generate("sft_model", "Is a toad a frog?") +### Instruction: Is a toad a frog? + +### Context: -trainer.train() +### Response: A frog is a small, round, black-eyed, frog with a long, black-winged head. It is a member of the family Pter ``` +Much better! This model follows the template like we want. -#### TRLX +::: + +:::{tab-item} TRLX -The other package is [Transformer Reinforcement Learning X (TRLX)](https://github.com/CarperAI/trlx), which has been heavily inspired by TRL but with an increased focus on incorporating Human Feedback into the training loop. However, out of the box, it also provides intuitive support for supervised `prompt-completion` fine-tuning using a relatively simple SDK, that takes tuples as `(prompt, completion)`. Take a look at the [RLHF section](#rlhf) for the other more feedback-oriented use cases of this library. +The [Transformer Reinforcement Learning X (TRLX)](https://github.com/CarperAI/trlx), which has been heavily inspired by TRL but with an increased focus on incorporating Human Feedback into the training loop. However, out of the box, it also provides intuitive support for supervised `prompt-completion` fine-tuning using a relatively simple SDK, that takes tuples as `(prompt, completion)`. Take a look at the [RLHF section](#rlhf) for the other more feedback-oriented use cases of this library. ```python import trlx -# dataset = ... +# Let's create a Dataset for convenience +data = {"instruction": [], "context": [], "response": []} +for entry in feedback_dataset: + if entry.responses: + res = entry.responses[0].values + data["instruction"].append(res["new-instruction"].value) + data["context"].append(res["new-context"].value) + data["response"].append(res["new-response"].value) +dataset = Dataset.from_dict(data) samples = [ [ @@ -139,337 +495,412 @@ samples = [ trainer = trlx.train('gpt2', samples=samples) ``` -#### AutoTrain +::: -AutoTrain offers an option for users who prefer a simpler and more automated approach. It offers a no-code solution for fine-tuning models wrapped and enabled by a nice [streamlit UI](https://huggingface.co/spaces/autotrain-projects/autotrain-advanced), or by a low-code option with the [AutoTrain Advanced package](https://github.com/huggingface/autotrain-advanced). This tool leverages techniques to automatically optimize the model's performance without requiring users to have extensive knowledge of reinforcement learning or coding skills. It streamlines the fine-tuning process by automatically adjusting the model's parameters and optimizing its performance based on user-provided feedback. +:::: -First, export the data into CSV or any other supported format. +### Reward Modeling -```python -dataset = ... +#### Background -dataset.to_csv("databricks-dolly-15k-curated-en.csv", index=False) +A Reward Model (RM) is used to rate responses in alignment with human preferences and afterwards using this RM to fine-tune the LLM with the associated scores. Fine-tuning using a Reward Model can be done in different ways. We can either get the annotator to rate output completely manually, we can use a simple heuristic or we can use a stochastic preference model. Both [TRL](https://huggingface.co/docs/trl) and [TRLX](https://github.com/CarperAI/trlx) provide decent options for incorporating rewards. The [DeepSpeed library of Microsoft](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) is a worthy mention too but will not be covered in our docs. + +```{include} /_common/dolly_dataset.md ``` -Then, go to the AutoTrain UI for training. +In case of training an RM, we then use the `chosen-rejected`-pairs and train a classifier to distinguish between them. + +#### Training - +```{include} /_common/dolly_dataset_info.md +``` -## RLHF +::::{tab-set} -The last part of the fine-tuning process is the part that contains doing Reinforcement Learning with Human Feedback (RLHf). This is generally done by creating a reward model (RM) to rate responses in alignment with human preferences and afterward using this reward model to fine-tune the LLM with the associated scores. +:::{tab-item} TRL +[TRL](https://huggingface.co/docs/trl) implements reward modeling, which can be used via the `ArgillaTrainer` class. We offer the option to provide a `formatting_func` to the `TrainingTask.for_reward_modeling`. This function is applied to each sample in the dataset and can be used for preprocessing and data formatting. The function should return a tuple of `chosen-rejected`-pairs as `Tuple[str, str]`. To determine which response from the FeedbackDataset is superior, we can use the user annotations. ```{note} -First, create a reward model or heuristic. Second, use this as automated procedure during reinforcment learning to align with human preferences. +The formatting function can also return `None` or a list of tuples. The `None` may be used if the annotations indicate that the text is low quality or harmful, and the latter could be used if multiple annotators provide additional written responses, resulting in multiple good `chosen-rejected` pairs. ``` -### Data - -The data required for these steps need to be used as comparison data to showcase the preference for the generated prompts. Therefore, we need to have a classification dataset with a `better_response` and a `poorer_responses`. These are then used to train a preference classifier. There are several public datasets [available](https://huggingface.co/datasets?search=rlhf) but a good baseline can be found in the one that is the one offered by [Anthropic](https://huggingface.co/datasets/Anthropic/hh-rlhf). We will however showcase how to use our [curated Dolly dataset](https://huggingface.co/datasets/argilla/databricks-dolly-15k-curated-en), where we assumed that updated responses get preference over the older ones. +**Data Preparation** +What the parameter to `formatting_func` looks like depends a lot on your FeedbackDataset fields and questions. +However, fields (i.e. the left side of the Argilla annotation view) are provided as their values, e.g. ```python -import argilla as rg -from datasets import Dataset - -feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en", split="train") - -data = {"instruction": [], "context": [], "poorer_response": [], "better_response": []} -for entry in feedback_dataset: - if entry.responses: - res = entry.responses[0].values - original_input = entry.fields["original-response"] - if original_input != res["new-response"].value: - data["instruction"].append(res["new-instruction"].value) - data["context"].append(res["new-context"].value) - data["poorer_response"].append(original_input) - data["better_response"].append(res["new-response"].value) - -dataset = Dataset.from_dict(data) -dataset -# Dataset({ -# features: ['instruction', 'context', 'poorer_response', 'better_response'], -# num_rows: 475 -# }) +>>> sample +{ + ... + 'original-response': 'Virgin Australia commenced services on 31 August 2000 ' + 'as Virgin Blue, with two aircraft on a single route.', + ... +} +``` +And all questions (i.e. the right side of the Argilla annotation view) are provided like so: +```python +>>> sample +{ + ... + 'new-response': [{'status': 'submitted', + 'value': 'Virgin Australia commenced services on 31 August ' + '2000 as Virgin Blue, with two aircraft on a ' + 'single route.', + 'user-id': ...}], + 'new-response-suggestion': None, + 'new-response-suggestion-metadata': {'agent': None, + 'score': None, + 'type': None}, + ... +} ``` -### Training +We can now define our formatting function, which should return `chosen-rejected`-pairs as tuple. -Fine-tuning using a Reward Model can be done in different ways. We can either get the annotator to rate output completely manually, we can use a simple heuristic or we can use a stochastic preference model. Both TRL and TRLX provide decent options for incorporating rewards. The [DeepSpeed library of Microsoft](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) is a worthy mention too but will not be covered in our docs. +```python +from typing import Any, Dict, Iterator, Tuple +from argilla.feedback import TrainingTask + +template = """\ +### Instruction: {instruction}\n +### Context: {context}\n +### Response: {response}""" + +def formatting_func(sample: Dict[str, Any]) -> Iterator[Tuple[str, str]]: + # Our annotators were asked to provide new responses, which we assume are better than the originals + og_instruction = sample["original-instruction"] + og_context = sample["original-context"] + og_response = sample["original-response"] + rejected = template.format(instruction=og_instruction, context=og_context, response=og_response) + + for instruction, context, response in zip(sample["new-instruction"], sample["new-context"], sample["new-response"]): + if response["status"] == "submitted": + chosen = template.format( + instruction=instruction["value"], + context=context["value"], + response=response["value"], + ) + if chosen != rejected: + yield chosen, rejected + +task = TrainingTask.for_reward_modeling(formatting_func=formatting_func) +``` -#### TRL +You can observe the dataset created using this task by using `FeedbackDataset.prepare_for_training`, for example using the "trl" framework: -[TRL](https://huggingface.co/docs/trl) has a direct reward modeling integration via the `RewardTrainer` class. This trains a classifier to mimic the human evaluation of generated texts. Afterward, we can use the `PPOTrainer` class for the reinforcement learning step in combination with the trained `RewardTrainer`. +```python +dataset = feedback_dataset.prepare_for_training(framework="trl", task=task) +""" +>>> dataset +Dataset({ + features: ['chosen', 'rejected'], + num_rows: 2872 +}) +>>> dataset[2772] +{ + 'chosen': '### Instruction: Answer based on the text: Is Leucascidae a sponge\n\n' + '### Context: Leucascidae is a family of calcareous sponges in the order Clathrinida.\n\n' + '### Response: Yes', + 'rejected': '### Instruction: Is Leucascidae a sponge\n\n' + '### Context: Leucascidae is a family of calcareous sponges in the order Clathrinida.[1]\n\n' + '### Response: Leucascidae is a family of calcareous sponges in the order Clathrinida.'} +""" +``` +Looks great! -::::{tab-set} +**ArgillaTrainer** -:::{tab-item} RewardTrainer -[TRL](https://huggingface.co/docs/trl) has a direct reward modeling integration via the `RewardTrainer` class. This class functions similarly to the SFTTrainer and TransformersTrainer but requires `rejected-accepted` input pairs as training data. These are then used to fine-tune an `AutoModelForSequenceClassification` which we can use as a reward model during the reinforcement learning phase. The entries within the dataset should be `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected` so we should first format them. The [roberta-base-reward-model-falcon-dolly reward model](https://huggingface.co/argilla/roberta-base-reward-model-falcon-dolly) was trained using the code below. +Now let's use the `ArgillaTrainer` to train a reward model with this task. ```python -from transformers import ( - AutoModelForSequenceClassification, - AutoTokenizer, - TrainingArguments, -) -​ -from trl import RewardTrainer -​ -from datasets import load_dataset -​ -dataset = load_dataset("argilla/dolly-curated-comparison-falcon-7b-instruct", split="train") -​ -model_name = "distilroberta-base" -​ -model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1) -tokenizer = AutoTokenizer.from_pretrained(model_name) -​ -if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - model.config.pad_token_id = model.config.eos_token_id -​ -def formatting_func(examples): - kwargs = {"padding": "max_length", "truncation": True, "max_length": 512, "return_tensors": "pt"} -​ - # Assuming original human response is preferred to Falcon's - chosen_response = examples["original_response"] - rejected_response = examples["response-1"] - prompt = examples["prompt"] -​ - tokens_chosen = tokenizer.encode_plus(prompt, chosen_response, **kwargs) - tokens_rejected = tokenizer.encode_plus(prompt, rejected_response, **kwargs) -​ - return { - "input_ids_chosen": tokens_chosen["input_ids"][0], "attention_mask_chosen": tokens_chosen["attention_mask"][0], - "input_ids_rejected": tokens_rejected["input_ids"][0], "attention_mask_rejected": tokens_rejected["attention_mask"][0] - } +from argilla.feedback import ArgillaTrainer -formatted_dataset = dataset.map(formatting_func) -​ -trainer = RewardTrainer( - model=model, - args=TrainingArguments("output_dir"), - tokenizer=tokenizer, - train_dataset=formatted_dataset +trainer = ArgillaTrainer( + dataset=feedback_dataset, + task=task, + framework="trl", + model="distilroberta-base", ) -​ -trainer.train() +trainer.train(output_dir="reward_model") ``` -::: -:::{tab-item} PPOTrainer -The [TRL](https://huggingface.co/docs/trl) `PPOTrainer` allows updating while plugging in any arbitrary model or heuristic to assign `rewards` to the generated output. In the example below, we use the `reward_model` and `reward_tokenizer` to create a transformers text-classification pipeline. This pipeline is then used to create `rewards` which are then passed during the PPO `.step()` to include in the weigh optimization for the next batch. You can choose to use our [roberta-base-reward-model-falcon-dolly reward model](https://huggingface.co/argilla/roberta-base-reward-model-falcon-dolly). +**Inference** + +Let's try out the trained model in practice. ```python +from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch -from transformers import AutoTokenizer, pipeline -from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer -from trl.core import LengthSampler -reward_model = ... # "argilla/roberta-base-reward-model-falcon-dolly" -reward_tokenizer = ... # "argilla/roberta-base-reward-model-falcon-dolly" +model = AutoModelForSequenceClassification.from_pretrained("reward_model") +tokenizer = AutoTokenizer.from_pretrained("reward_model") -config = PPOConfig(model_name="gpt2", batch_size=2) +def get_score(model, tokenizer, text): + # Tokenize the input sequences + inputs = tokenizer(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt") -model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name) -ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name) -tokenizer = AutoTokenizer.from_pretrained(config.model_name) -tokenizer.pad_token = tokenizer.eos_token -reward_pipe = pipeline(model=reward_model, tokenizer=reward_tokenizer) + # Perform forward pass + with torch.no_grad(): + outputs = model(**inputs) -def formatting_func(examples): - kwargs = { - "padding": "max_length", "truncation": True, - "max_length": 512, "return_tensors": "pt" - } - input_size = LengthSampler(min_value=2, max_value=8) - input_text = examples["instruction"] + examples["context"] + examples["response"] - examples["input_ids"] = tokenizer.encode(input_text, **kwargs)[0][: input_size()] - examples["query"] = tokenizer.decode(examples["input_ids"][0]) - return examples - -formatted_dataset = dataset.map(formatting_func, batched=False) -formatted_dataset.set_format(type="torch") - -def collator(data): - return dict((key, [d[key] for d in data]) for key in data[0]) - -ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=formatted_dataset, data_collator=collator) - -output_min_length = 4 -output_max_length = 16 -output_length_sampler = LengthSampler(output_min_length, output_max_length) - -generation_kwargs = { - "min_length": -1, - "top_k": 0.0, - "top_p": 1.0, - "do_sample": True, - "pad_token_id": tokenizer.eos_token_id, -} + # Extract the logits + return outputs.logits[0, 0].item() -for epoch, batch in enumerate(ppo_trainer.dataloader): - query_tensors = batch["input_ids"] +# Example usage +prompt = "Is a toad a frog?" +context = "Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\" Toads are a sub-classification of frogs, meaning that all toads are frogs, but not all frogs are toads." +good_response = "Yes" +bad_response = "Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\"" +example_good = template.format(instruction=prompt, context=context, response=good_response) +example_bad = template.format(instruction=prompt, context=context, response=bad_response) - #### Get response from gpt2 - response_tensors = [] - for query in query_tensors: - gen_len = output_length_sampler() - generation_kwargs["max_new_tokens"] = gen_len - response = ppo_trainer.generate(query, **generation_kwargs) - response_tensors.append(response.squeeze()[-gen_len:]) - batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors] +score = get_score(model, tokenizer, example_good) +print(score) +# >> 5.478324890136719 - #### Compute sentiment score - texts = [q + r for q, r in zip(batch["query"], batch["response"])] - pipe_outputs = reward_pipe(texts, return_all_scores=True) - rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] - - #### Run PPO step - stats = ppo_trainer.step(query_tensors, response_tensors, rewards) - ppo_trainer.log_stats(stats, batch, rewards) +score = get_score(model, tokenizer, example_bad) +print(score) +# >> 2.2948970794677734 ``` - +As expected, the good response has a higher score than the worse response. ::: :::: -#### TRLX +### Proximal Policy Optimization -[TRLX](https://github.com/CarperAI/trlx) gives the option to use a `reward function` or a `reward-labeled` dataset in combination with Proximal Policy Optimization (PPO) for the reinforcement learning step, which can be used by defining a PPO policy configuration. During this step, we infer rewards to mimic the human evaluation of generated texts. Additionally, [Hugging Face Accelerate](https://huggingface.co/docs/accelerate/index) can be used to speed up training or [Ray Tune](https://docs.ray.io/en/latest/tune/index.html) to optimize hyperparameter tuning. +#### Background -```python -from trlx.data.default_configs import default_ppo_config +The [TRL](https://huggingface.co/docs/trl) library implements the last step of RLHF: Proximal Policy Optimization (PPO). It requires prompts, which are then fed through the model being finetuned. Its results are passed through a reward model. Lastly, the prompts, responses and rewards are used to update the model through reinforcement learning. -config = default_ppo_config() -config.model.model_path = 'gpt2' -config.train.batch_size = 16 +```{note} +PPO requires a trained supervised fine-tuned model and reward model to work. Take a look at that task outlines above to train your own models. ``` -::::{tab-set} - -:::{tab-item} reward function +```{include} /_common/dolly_dataset.md +``` -The [TRLX](https://github.com/CarperAI/trlx) `reward_fn` is quite flexible in its set up, however, most commonly you would expect to use a stochastic classification model obtained in a similar manner as the `RewardTrainer` defined above. For demo purposes, we provide an out-of-the-box [roberta-base-reward-model-falcon-dolly reward model](https://huggingface.co/argilla/roberta-base-reward-model-falcon-dolly). +In case of training an PPO, we then use the prompt and context data and correct the generated response from the SFT model by using the reward model. Hence, we will need to format the following `text`. -```python -from transformers import pipeline -import trlx +```bash +### Instruction +When did Virgin Australia start operating? -dataset = ... -config = ... +### Context +Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. +It is the largest airline by fleet size to use the Virgin brand. +It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. +It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. +The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney. -classifier = pipeline("argilla/roberta-base-reward-model-falcon-dolly") +### Response: +{to be generated by SFT model} +``` -def my_reward_function(entry): - return classifier(entry)[0].get("score") +#### Training -trainer = trlx.train( - config=config, - reward_fn=lambda samples, **kwargs: [my_reward_function(sample) for sample in samples] -) +```{include} /_common/dolly_dataset_load.md ``` -::: +**Data Preparation** -:::{tab-item} reward-labeled dataset +As usual, we start with a task with a formatting function. For PPO, the formatting function only returns prompts as `text`, which are formatted according to a template. -In this case, TRLX relies on reward-labeled data to infer the alignment with human preference. This is a good approach but it is not recommended to only collect these labels via human feedback because this is likely too costly to scale. Therefore, we recommend using an automated reward function or creating a reward-labeled dataset using our [roberta-base-reward-model-falcon-dolly model](https://huggingface.co/argilla/roberta-base-reward-model-falcon-dolly). For demo purposes, we now infer the rewards from the corrected response, but we can also set up [specific ranking](../conceptual_guides/rm.html) using the Argilla UI. +```python +from argilla.feedback import TrainingTask +from typing import Dict, Any, Iterator + +template = """\ +### Instruction: {instruction}\n +### Context: {context}\n +### Response: {response}""" + +def formatting_func(sample: Dict[str, Any]) -> Iterator[str]: + for instruction, context in zip(sample["new-instruction"], sample["new-context"]): + if instruction["status"] == "submitted": + yield template.format( + instruction=instruction["value"], + context=context["value"][:500], + response="" + ).strip() + +task = TrainingTask.for_proximal_policy_optimization(formatting_func=formatting_func) +``` + +Like before, we can observe the resulting dataset: ```python -import trlx +dataset = feedback_dataset.prepare_for_training(framework="trl", task=task) +""" +>>> dataset +Dataset({ + features: ['id', 'query'], + num_rows: 15015 +}) +>>> dataset[922] +{'id': 922, 'query': '### Instruction: Is beauty objective or subjective?\n\n### Context: \n\n### Response:'} +""" +``` -dataset = ... -config = ... +**ArgillaTrainer** -samples, rewards = [], [] -for entry in dataset: - samples.append(entry["poorer_response"]) - rewards.append(1) - samples.append(entry["better_response"]) - rewards.append(2) +Instead of using this dataset, we'll use the task directly with our `FeedbackDataset` in the `ArgillaTrainer`. PPO requires us to specify the `reward_model`, and allows us to specify some other useful values as well: +* `reward_model`: A sentiment analysis pipeline with the reward model. This produces a reward for a prompt + response. +* `length_sampler_kwargs`: A dictionary with `min_value` and `max_value` keys, indicating the lower and upper bound on the number of tokens the finetuning model should generate while finetuning. +* `generation_kwargs`: The keyword arguments passed to the `generate` method of the finetuning model. +* `config`: A `trl.PPOConfig` instance with many useful parameters such as `learning_rate` and `batch_size`. -trainer = trlx.train(config=config, samples=samples, rewards=rewards) +```python +from argilla.feedback import ArgillaTrainer +from transformers import pipeline +from trl import PPOConfig + +trainer = ArgillaTrainer( + dataset=feedback_dataset, + task=task, + framework="trl", + model="gpt2", +) +reward_model = pipeline("sentiment-analysis", model="reward_model") +trainer.update_config( + reward_model=reward_model, + length_sampler_kwargs={"min_value": 32, "max_value": 256}, + generation_kwargs={ + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + }, + config=PPOConfig(batch_size=16) +) +trainer.train(output_dir="ppo_model") ``` -::: +**Inference** -:::: +After training, we can load this model and generate with it! -## Pre-training +```python +from transformers import AutoModelForCausalLM, AutoTokenizer -When talking about pre-training, we generally talk about a simple `prompt-completion` task, where we need the model to pick up on basic statistics of the language it is learning. Given that you are familiar with Spanish cuisine and the prompt sentence, `The base ingredient of paella is ___`, you know that the word in the `___` is much more likely to be `rice` than `apples`. So, you are basically training a causal language model or text generation model. +model = AutoModelForCausalLM.from_pretrained("ppo_model") +tokenizer = AutoTokenizer.from_pretrained("ppo_model") +tokenizer.pad_token = tokenizer.eos_token -```{note} -This is an unsupervised approach hence we only infer training data from a basic sentence like `The base ingredient of paella is rice.` by starting with the word `The`, and from there unwrapping the sentence step by step. +inputs = template.format( + instruction="Is a toad a frog?", + context="Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\" Toads are a sub-classification of frogs, meaning that all toads are frogs, but not all frogs are toads.", + response="" +).strip() +encoding = tokenizer([inputs], return_tensors="pt") +outputs = model.generate(**encoding, max_new_tokens=30) +output_text = tokenizer.decode(outputs[0]) +print(output_text) +# Yes it is, toads are a sub-classification of frogs. ``` -### Data +### Direct Preference Optimization -Many training datasets for this task can be found online (e.g., [Hugging Face](https://huggingface.co/datasets?task_categories=task_categories:text-generation&sort=downloads)). You can either upload this in the right Argilla format but it might be needed to collect and fine-tune additional data with Argilla. So we, therefore, provide a basic setup underneath which should help you to start gathering or preparing pre-training data. +#### Background + +The [TRL](https://huggingface.co/docs/trl) library implements and alternative way to incorporate human feedback into an LLM which is called Direct Preference Optimization (DPO). This approach skips the step of training a separate reward model and directly uses the preference data during training as measure for optimization of human feedback. In order to properly use th ```{note} -When it comes to pre-training an LLM, we generally do not need data of highest quality, but it is always smart to use domain-specfic data and to avoid data that might lead to undecired effect like hallucination and bias. +DPO requires a trained supervised fine-tuned model to function. Take a look at that task outline above to train your own model. ``` -First, create a `FeedbackDataset` with records. +```{include} /_common/dolly_dataset_info.md +``` -```python -import argilla as rg +In case of training using PPO, we then use the prompt and context data and correct the generated response from the SFT model by using the reward model. Hence, we will need to format the following `text`. -# create promp-completion dataset -dataset = rg.FeedbackDataset( - guidelines="Please, complete the following prompt fields with a brief text answer.", - fields=[ - rg.TextField(name="prompt"), - ], - questions=[ - rg.TextQuestion(name="completion", title="Add a brief text answer."), - ] -) +```bash +### Instruction +When did Virgin Australia start operating? -# create a Feedback Records -record = rg.FeedbackRecord( - fields={ - "prompt": "The base ingredient of paella is rice." - } -) +### Context +Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. +It is the largest airline by fleet size to use the Virgin brand. +It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. +It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. +The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney. -dataset.add_records([record]) +### Response: +{to be generated by SFT model} ``` -Then push it to Argilla via `push_to_argilla`. +Within the DPO approach we infer the reward from the formatted prompt and the provided preference data as `prompt-chosen-rejected`-pairs. -::::{tab-set} +#### Training -:::{tab-item} Argilla 1.14.0 or higher -```python -remote_dataset = dataset.push_to_argilla(name="pre-training") +```{include} /_common/dolly_dataset_load.md ``` -::: -:::{tab-item} Lower than Argilla 1.14.0 +**Data Preperation** + +We will start with our a basic example of a formatting function. For DPO it should return `prompt-chosen-rejected`-pairs, where the prompt is formatted according to a template. + ```python -dataset.push_to_argilla(name="pre-training") +from argilla.feedback import TrainingTask +from typing import Dict, Any, Iterator + +template = """\ +### Instruction: {instruction}\n +### Context: {context}\n +### Response: {response}""" + +def formatting_func(sample: Dict[str, Any]) -> Iterator[Tuple[str, str]]: + # Our annotators were asked to provide new responses, which we assume are better than the originals + og_instruction = sample["original-instruction"] + og_context = sample["original-context"] + rejected = sample["original-response"] + prompt = template.format(instruction=og_instruction, context=og_context, response="") + + for instruction, context, response in zip(sample["new-instruction"], sample["new-context"], sample["new-response"]): + if response["status"] == "submitted": + chosen = response["value"] + if chosen != rejected: + yield prompt, chosen, rejected + + +task = TrainingTask.for_direct_preference_optimization(formatting_func=formatting_func) ``` -::: -:::: -And, finally, load the `FeedbackDataset` from Argilla. +**ArgillaTrainer** + +We'll use the task directly with our `FeedbackDataset` in the `ArgillaTrainer`. In contrary to PPO, we do not need to specify any reward model, because this preference modeling is inferred internally by the DPO-algorithm. ```python -import argilla as rg -from datasets import Dataset +from argilla.feedback import ArgillaTrainer -dataset = rg.FeedbackDataset.from_argilla("pre-training") -prompts = {"prompt": [record.fields.get("prompt") for record in dataset.records]} -dataset = Dataset.from_dict(prompts) -dataset -# Dataset({ -# features: ['prompt'], -# num_rows: 1 -# }) +trainer = ArgillaTrainer( + dataset=feedback_dataset, + task=task, + framework="trl", + model="gpt2", +) +trainer.train(output_dir="dpo_model") ``` -### Training +**Inference** + +After training, we can load this model and generate with it! -There are many ways and great packages to deal with this `pre-training` phase, but generally, NLP training frameworks like [KerasNLP](https://keras.io/keras_nlp/) and [Hugging Face](https://huggingface.co/) offer great out-of-the-box methods for training a causal language model. In our guide, we will refer to the great docs off using Hugging Face `transformers` and `datasets` library and prepare our training data in the format they require for [training a causal language model](https://huggingface.co/learn/nlp-course/chapter7/6#training-a-causal-language-model-from-scratch). +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("dpo_model") +tokenizer = AutoTokenizer.from_pretrained("dpo_model") +tokenizer.pad_token = tokenizer.eos_token + +inputs = template.format( + instruction="Is a toad a frog?", + context="Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\" Toads are a sub-classification of frogs, meaning that all toads are frogs, but not all frogs are toads.", + response="" +).strip() +encoding = tokenizer([inputs], return_tensors="pt") +outputs = model.generate(**encoding, max_new_tokens=30) +output_text = tokenizer.decode(outputs[0]) +print(output_text) +# Yes it is, toads are a sub-classification of frogs. +``` diff --git a/docs/_source/guides/llms/practical_guides/fine_tune_others.md b/docs/_source/guides/llms/practical_guides/fine_tune_others.md deleted file mode 100644 index 03e211158b..0000000000 --- a/docs/_source/guides/llms/practical_guides/fine_tune_others.md +++ /dev/null @@ -1,204 +0,0 @@ -# Fine-tune other models - -After [collecting the responses](./collect_responses.html) from our `FeedbackDataset` we can start fine-tuning our basic models. Due to the customizability of the `FeedbackDataset`, this might require setting up a custom post-processing workflow but we will provide some good toy examples for the text classification task. We will add additional support for other tasks in the future. - -Generally, this is as easy as one-two-three but does slightly differ per task. - -1. First, we define a unification strategy for responses to `questions` we want to use. -2. Next, we then define a task-mapping. This mapping defines which `fields` and `questions` we want to use from our dataset for the downstream training task. These mappings are then used for retrieving data from a dataset and initializing the training. -3. Lastly, we initialize the `ArgillaTrainer` and forward the task mapping, unification strategies and training framework. - -## Text classification - -### Background - -Text classification is a widely used NLP task where labels are assigned to text. Major companies rely on it for various applications. Sentiment analysis, a popular form of text classification, assigns labels like 🙂 positive, 🙁 negative, or 😐 neutral to text. Additionally, we distinguish between single- and multi-label text classification. - -#### Single-label - -Single-label text classification refers to the task of assigning a single category or label to a given text sample. Each text is associated with only one predefined class or category. For example, in sentiment analysis, a single-label text classification task would involve assigning labels such as "positive," "negative," or "neutral" to individual texts based on their sentiment. - -#### Multi-label - -Multi-label text classification is generally more complex than single-label classification due to the challenge of determining and predicting multiple relevant labels for each text. It finds applications in various domains, including document tagging, topic labeling, and content recommendation systems. - -### Training - -Data for the training text classification using our `FeedbackDataset` is defined by following three easy steps. - -1. We need to define a unification strategy `RatingStrategy`, a `LabelStrategy` or a `MultiLabelStrategy`. - -2. For this task, we assume we need a `text-label`-pair for defining a text classification task. We allow mapping for creating a `TrainingTaskMapping.for_text_classification` by mapping `*Field` to a `text`-value and allow for mapping a `RatingStrategy`, `LabelStrategy` or a `MultiLabelStrategy` to a `label`-value. - -3. We then define an `ArgillaTrainer` instance with support for "openai", "setfit", "peft", "spacy" and "transformers". - -#### Unify responses - -Argilla `*Question`s need to be [unified using a strategy](/guides/llms/practical_guides/collect_responses) and so do `RatingQuestions`s, `LabelQuestion`s and `MultiLabelQuestion`s. Therefore, records need to be unified by using a strategy, which takes one of the questions and one of their associated strategies. Luckily this is integrated within the `TrainingTaskMapping`-step underneath, but you can also do this individually as shown [here](/guides/llms/practical_guides/collect_responses). - -````{note} -A brief shortcut that `RatingQuestion`s can be unified using a "majority"-, "min"-, "max"- or "disagreement"-strategy. Both `LabelQuestion`s and `MultiLabelQuestion`s can be resolved using a "majority"-, or "disagreement"-strategy. -```` - -#### Define a task mapping - -Now we know which unification strategy to apply, we can now define our `TrainingTaskMapping.for_text_classification`. - -::::{tab-set} - -:::{tab-item} LabelQuestion -```python -from argilla.feedback import FeedbackDataset, TrainingTaskMapping - -dataset = FeedbackDataset.from_huggingface( - repo_id="argilla/stackoverflow_feedback_demo" -) -task_mapping = TrainingTaskMapping.for_text_classification( - text=dataset.field_by_name("title"), - label=dataset.question_by_name("title_question_fit"), # LabelQuestion - label_strategy=None # default to "majority", or use "disagreement" -) -``` -::: - -:::{tab-item} MultiLabelQuestion -```python -from argilla.feedback import FeedbackDataset, TrainingTaskMapping - -dataset = FeedbackDataset.from_huggingface( - repo_id="argilla/stackoverflow_feedback_demo" -) -task_mapping = TrainingTaskMapping.for_text_classification( - text=dataset.field_by_name("title"), - label=dataset.question_by_name("tags"), # MultiLabelQuestion - label_strategy=None # default to "majority", or use "disagreement" -) -``` -::: - -:::{tab-item} RatingQuestion -```python -from argilla.feedback import FeedbackDataset, TrainingTaskMapping - -dataset = FeedbackDataset.from_huggingface( - repo_id="argilla/stackoverflow_feedback_demo" -) -task_mapping = TrainingTaskMapping.for_text_classification( - text=dataset.field_by_name("title"), - label=dataset.question_by_name("answer_quality"), # RatingQuestion - label_strategy=None # default to "majority", or use "min", "max", "disagreement" -) -``` -::: - -:::: - -#### Use ArgillaTrainer - -Next, we can use our `FeedbackDataset` and `TrainingTaskMappingForTextClassification` to initialize our `argilla.ArgillaTrainer`. We support the frameworks "openai", "setfit", "peft", "spacy" and "transformers". - -````{note} -This is a newer version and can be imported via `from argilla.feedback import ArgillaTrainer`. The old trainer can be imported via `from argilla.training import ArgillaTrainer`. Our docs, contain some [additional information on usage of the ArgillaTrainer](../../train_a_model.html). -```` - -```python -from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTaskMapping - -dataset = FeedbackDataset.from_huggingface( - repo_id="argilla/stackoverflow_feedback_demo" -) -task_mapping = TrainingTaskMapping.for_text_classification( - text=dataset.field_by_name("title"), - label=dataset.question_by_name("tags") -) -trainer = ArgillaTrainer( - dataset=dataset, - task_mapping=task_mapping, - framework="setfit", - fetch_records=False -) -trainer.update_config(num_train_epochs=2) -trainer.train(output_dir="my_awesone_model") -``` - -````{note} -The `FeedbackDataset` also allows for custom workflows via the `prepare_for_training()` method. -```python -task_mapping = ... -dataset = rg.FeedbackDataset.from_huggingface( - repo_id="argilla/stackoverflow_feedback_demo" -) -dataset.prepare_for_training( - framework="setfit", - task_mapping=task_mapping -) -``` -```` - -### An end-to-end example - -Below you can also find an end-to-end example of how to use the `ArgillaTrainer`. - -```python -from argilla.feedback import ( - ArgillaTrainer, - FeedbackDataset, - FeedbackRecord, - LabelQuestion, - TextField, - TrainingTaskMapping, -) - -dataset = FeedbackDataset( - guidelines="Add some guidelines for the annotation team here.", - fields=[ - TextField(name="text", title="Human prompt"), - ], - questions =[ - LabelQuestion( - name="relevant", - title="Is the response relevant for the given prompt?", - labels=["yes","no"], - required=True, - ) - ] -) -dataset.add_records( - records=[ - FeedbackRecord( - fields={"text": "What is your favorite color?"}, - responses=[{"values": {"relevant": {"value": "no"}}}] - ), - FeedbackRecord( - fields={"text": "What do you think about the new iPhone?"}, - responses=[{"values": {"relevant": {"value": "yes"}}}] - ), - FeedbackRecord( - fields={"text": "What is your feeling about the technology?"}, - responses=[{"values": {"relevant": {"value": "yes"}}}, - {"values": {"relevant": {"value": "no"}}}, - {"values": {"relevant": {"value": "yes"}}}] - ), - FeedbackRecord( - fields={"text": "When do you expect to buy a new phone?"}, - responses=[{"values": {"relevant": {"value": "no"}}}, - {"values": {"relevant": {"value": "yes"}}}] - ) - - ] -) - -task_mapping = TrainingTaskMapping.for_text_classification( - text=dataset.field_by_name("text"), - label=dataset.question_by_name("relevant") -) - -trainer = ArgillaTrainer( - dataset=dataset, - task_mapping=task_mapping, - framework="setfit", - fetch_records=False -) -trainer.update_config(num_train_epochs=2) -trainer.train(output_dir="my_awesone_model") -``` diff --git a/docs/_source/guides/llms/practical_guides/practical_guides.md b/docs/_source/guides/llms/practical_guides/practical_guides.md index ee9331587c..98956cbe1c 100644 --- a/docs/_source/guides/llms/practical_guides/practical_guides.md +++ b/docs/_source/guides/llms/practical_guides/practical_guides.md @@ -47,14 +47,9 @@ Use the Argilla LangChain callback for monitoring, evaluation, and fine-tuning. ```{grid-item-card} Fine-tune LLMs :link: fine_tune.html -Fine-tune an LLM with the feedback collected from Argilla. - +Fine-tune an LLM or other models with the feedback collected from Argilla. ``` -```{grid-item-card} Fine-tune other models -:link: fine_tune_others.html -Fine-tune basic models with feedback collected from Argilla. -``` ```` ![Feedback dataset snapshot](../../../_static/images/llms/snapshot-feedback-demo.png) @@ -70,5 +65,4 @@ collect_responses export_dataset use_argilla_callback_in_langchain fine_tune -fine_tune_others ``` \ No newline at end of file diff --git a/docs/_source/guides/llms/practical_guides/update_dataset.md b/docs/_source/guides/llms/practical_guides/update_dataset.md index 7bb377f276..13971b0884 100644 --- a/docs/_source/guides/llms/practical_guides/update_dataset.md +++ b/docs/_source/guides/llms/practical_guides/update_dataset.md @@ -1,4 +1,5 @@ # Update a Feedback dataset + Oftentimes datasets that we have created previously need modifications or updates. In this section, we will explore some of the most common workflows to change an existing Feedback dataset in Argilla. Remember that you will need to connect to Argilla to perform any of the actions below. diff --git a/docs/_source/guides/train_a_model.md b/docs/_source/guides/train_a_model.md index ed177605df..4f9340aaf4 100644 --- a/docs/_source/guides/train_a_model.md +++ b/docs/_source/guides/train_a_model.md @@ -1,6 +1,6 @@ # 🦾 Train a Model -This guide showcases how to train a model on the `Dataset` classes in the Argilla client. +This guide showcases how to train a model on the `TextClassification`, `TokenClassification` and `Text2TextClassification` classes in the Argilla client. The Dataset classes are lightweight containers for Argilla records. These classes facilitate importing from and exporting to different formats (e.g., `pandas.DataFrame`, `datasets.Dataset`) as well as sharing and versioning Argilla datasets using the Hugging Face Hub. For each record type, there's a corresponding Dataset class called `DatasetFor`. @@ -11,16 +11,15 @@ There are two ways to train custom models on top of your annotated data: 1. Train models using the Argilla training module, which is quick and easy but does not offer specific customization. 2. Train with a custom workflow using the prepare for training methods, which requires some configuration but also offers more flexibility to integrate with your existing training workflows. - -````{note} -For training models with the `FeedbackDataset` take a look [here](/guides/llms/practical_guides/practical_guides). -```` - ## Train directly This is, quick and easy but does not offer specific customizations. -The `ArgillaTrainer` is a wrapper around many of our favorite NLP libraries. It provides a very intuitive abstract workflow to facilitate simple training workflows using decent default pre-set configurations without having to worry about any data transformations from Argilla. We plan on adding more support for other tasks and frameworks so feel free to reach out on our Slack or GitHub. +The `ArgillaTrainer` is a wrapper around many of our favorite NLP libraries. It provides a very intuitive abstract workflow to facilitate simple training workflows using decent default pre-set configurations without having to worry about any data transformations from Argilla. We plan on adding more support for other tasks and frameworks so feel free to reach out on our Slack or GitHub + +````{note} +For training models with the `FeedbackDataset` take a look [here](/guides/llms/practical_guides/fine_tune). +```` | Framework/Task | TextClassification | TokenClassification | Text2Text | Feedback | |-------------------|--------------------|---------------------|-----------|-----------| @@ -32,6 +31,8 @@ The `ArgillaTrainer` is a wrapper around many of our favorite NLP libraries. It | PEFT | ✔️ | ✔️ | | | | SpanMarker | | ✔️ | | | + + ### The `ArgillaTrainer` We can use the `ArgillaTrainer` to train directly using `spacy`, `setfit` and `transformers` as framework variables. diff --git a/docs/_source/reference/python/python_training.rst b/docs/_source/reference/python/python_training.rst index 37e9d3d27d..748bac8d41 100644 --- a/docs/_source/reference/python/python_training.rst +++ b/docs/_source/reference/python/python_training.rst @@ -70,3 +70,9 @@ SpanMarker Trainer .. automodule:: argilla.training.span_marker :members: + +TRL Trainer +------------------ + +.. automodule:: argilla.client.feedback.training.frameworks.trl + :members: diff --git a/environment_dev.yml b/environment_dev.yml index 21656a0aa7..cf1106a9d1 100644 --- a/environment_dev.yml +++ b/environment_dev.yml @@ -46,14 +46,14 @@ dependencies: - spacy==3.5.3 - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0.tar.gz - spacy-transformers>=1.2.5 - - transformers[torch]>=4.19.0 + - transformers[torch]>=4.30.0 # <- required for DPO with TRL - evaluate - seqeval - setfit - span_marker - openai - peft - - autotrain-advanced==0.5.2 + - trl>=0.5.0 - rich!=13.1.0 # install Argilla in editable mode - -e .[server,listeners] diff --git a/pyproject.toml b/pyproject.toml index 2ff8079582..33a74fd431 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,13 +101,14 @@ integrations = [ "snorkel >= 0.9.7", "spacy == 3.5.3", "spacy-transformers >= 1.2.5", - "transformers[torch] >= 4.19.0", + "transformers[torch] >= 4.30.0", "evaluate", "seqeval", "setfit", "span_marker", "openai", - "peft" + "peft", + "trl>=0.5.0" ] tests = [ "pytest", diff --git a/src/argilla/client/datasets.py b/src/argilla/client/datasets.py index 4eaf38f65f..ed3cb54691 100644 --- a/src/argilla/client/datasets.py +++ b/src/argilla/client/datasets.py @@ -358,7 +358,9 @@ def prepare_for_training( "transformers" and "spacy" are currently supported. Default: `transformers` lang: The spacy nlp Language pipeline used to process the dataset. (Only for spacy framework) train_size: The size of the training set. If float, should be between 0.0 and 1.0 and represent the + proportion of the dataset to include in the train split. test_size: The size of the test set. If float, should be between 0.0 and 1.0 and represent the + proportion of the dataset to include in the test split. seed: Random state. Returns: diff --git a/src/argilla/client/feedback/__init__.py b/src/argilla/client/feedback/__init__.py index 1e8a81313c..5b7a2b65b8 100644 --- a/src/argilla/client/feedback/__init__.py +++ b/src/argilla/client/feedback/__init__.py @@ -14,13 +14,24 @@ from argilla.client.feedback.training import ( ArgillaTrainer, - TrainingTaskMapping, - TrainingTaskMappingForTextClassification, + TrainingTask, + TrainingTaskForDPO, + TrainingTaskForPPO, + TrainingTaskForRM, + TrainingTaskForSFT, + TrainingTaskForTextClassification, + TrainingTaskMapping, # <- Deprecated + TrainingTaskMappingForTextClassification, # <- Deprecated ) from argilla.client.feedback.unification import ( LabelQuestionStrategy, + LabelQuestionUnification, MultiLabelQuestionStrategy, + MultiLabelQuestionUnification, + RankingQuestionStrategy, + RankingQuestionUnification, RatingQuestionStrategy, + RatingQuestionUnification, UnifiedValueSchema, ) @@ -29,7 +40,18 @@ "LabelQuestionStrategy", "MultiLabelQuestionStrategy", "RatingQuestionStrategy", + "TrainingTask", + "TrainingTaskForTextClassification", + "TrainingTaskForSFT", + "TrainingTaskForRM", + "TrainingTaskForPPO", + "TrainingTaskForDPO", "TrainingTaskMapping", "TrainingTaskMappingForTextClassification", + "RankingQuestionStrategy", "UnifiedValueSchema", + "LabelQuestionUnification", + "MultiLabelQuestionUnification", + "RatingQuestionUnification", + "RankingQuestionUnification", ] diff --git a/src/argilla/client/feedback/dataset/base.py b/src/argilla/client/feedback/dataset/base.py index b4d5b052a7..2d1d529d50 100644 --- a/src/argilla/client/feedback/dataset/base.py +++ b/src/argilla/client/feedback/dataset/base.py @@ -29,7 +29,14 @@ RatingQuestion, ) from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes -from argilla.client.feedback.training.schemas import TrainingTaskMappingForTextClassification +from argilla.client.feedback.training.schemas import ( + TrainingTaskForDPO, + TrainingTaskForPPO, + TrainingTaskForRM, + TrainingTaskForSFT, + TrainingTaskForTextClassification, + TrainingTaskTypes, +) from argilla.client.feedback.unification import ( LabelQuestionStrategy, MultiLabelQuestionStrategy, @@ -306,14 +313,27 @@ def unify_responses( def prepare_for_training( self, framework: Union[Framework, str], - task_mapping: TrainingTaskMappingForTextClassification, + task: TrainingTaskTypes, train_size: Optional[float] = 1, test_size: Optional[float] = None, seed: Optional[int] = None, lang: Optional[str] = None, fetch_records: Optional[bool] = None, ): - # TODO(davidberenstein1957): add missing docstrings and type annotations + """ + Prepares the dataset for training for a specific training framework and NLP task by splitting the dataset into train and test sets. + + Args: + framework: the framework to use for training. Currently supported frameworks are: `transformers`, `peft`, + `setfit`, `spacy`, `spacy-transformers`, `span_marker`, `spark-nlp`, `openai`, `trl`. + task: the NLP task to use for training. Currently supported tasks are: `TrainingTaskForTextClassification`, + `TrainingTaskForSFT`, `TrainingTaskForRM`, `TrainingTaskForPPO`, `TrainingTaskForDPO`. + train_size: the size of the train set. If `None`, the whole dataset will be used for training. + test_size: the size of the test set. If `None`, the whole dataset will be used for testing. + seed: the seed to use for splitting the dataset into train and test sets. + lang: the spaCy language to use for training. If `None`, the language of the dataset will be used. + fetch_records: whether to fetch the records from Argilla or use the local records instead. If `None`, use local. + """ if fetch_records is not None: warnings.warn( "`fetch_records` is deprecated and will be removed in a future version." @@ -349,22 +369,31 @@ def prepare_for_training( " dataset via the `FeedbackDataset.add_records` method first." ) - if isinstance(task_mapping, TrainingTaskMappingForTextClassification): - self.unify_responses(question=task_mapping.label.question, strategy=task_mapping.label.strategy) - else: - raise ValueError(f"Training data {type(task_mapping)} is not supported yet") - - data = task_mapping._format_data([record for record in self.records]) + if isinstance(task, TrainingTaskForTextClassification): + if task.formatting_func is None: + self.unify_responses(question=task.label.question, strategy=task.label.strategy) + elif not isinstance( + task, + ( + TrainingTaskForSFT, + TrainingTaskForRM, + TrainingTaskForPPO, + TrainingTaskForDPO, + ), + ): + raise ValueError(f"Training data {type(task)} is not supported yet") + + data = task._format_data(self) if framework in [ Framework.TRANSFORMERS, Framework.SETFIT, Framework.SPAN_MARKER, Framework.PEFT, ]: - return task_mapping._prepare_for_training_with_transformers( + return task._prepare_for_training_with_transformers( data=data, train_size=train_size, seed=seed, framework=framework ) - elif framework is Framework.SPACY or framework is Framework.SPACY_TRANSFORMERS: + elif framework in [Framework.SPACY, Framework.SPACY_TRANSFORMERS]: require_version("spacy") import spacy @@ -376,11 +405,15 @@ def prepare_for_training( lang = spacy.blank(lang) else: lang = spacy.load(lang) - return task_mapping._prepare_for_training_with_spacy(data=data, train_size=train_size, seed=seed, lang=lang) + return task._prepare_for_training_with_spacy(data=data, train_size=train_size, seed=seed, lang=lang) elif framework is Framework.SPARK_NLP: - return task_mapping._prepare_for_training_with_spark_nlp(data=data, train_size=train_size, seed=seed) + return task._prepare_for_training_with_spark_nlp(data=data, train_size=train_size, seed=seed) elif framework is Framework.OPENAI: - return task_mapping._prepare_for_training_with_openai(data=data, train_size=train_size, seed=seed) + return task._prepare_for_training_with_openai(data=data, train_size=train_size, seed=seed) + elif framework is Framework.TRL: + return task._prepare_for_training_with_trl(data=data, train_size=train_size, seed=seed) + elif framework is Framework.TRLX: + return task._prepare_for_training_with_trlx(data=data, train_size=train_size, seed=seed) else: raise NotImplementedError( f"Framework {framework} is not supported. Choose from: {[e.value for e in Framework]}" diff --git a/src/argilla/client/feedback/training/__init__.py b/src/argilla/client/feedback/training/__init__.py index 0810f480c3..1369cca084 100644 --- a/src/argilla/client/feedback/training/__init__.py +++ b/src/argilla/client/feedback/training/__init__.py @@ -13,6 +13,25 @@ # limitations under the License. from argilla.client.feedback.training.base import ArgillaTrainer -from argilla.client.feedback.training.schemas import TrainingTaskMapping, TrainingTaskMappingForTextClassification +from argilla.client.feedback.training.schemas import ( + TrainingTask, + TrainingTaskForDPO, + TrainingTaskForPPO, + TrainingTaskForRM, + TrainingTaskForSFT, + TrainingTaskForTextClassification, + TrainingTaskMapping, # <- Deprecated + TrainingTaskMappingForTextClassification, # <- Deprecated +) -__all__ = ["ArgillaTrainer", "TrainingTaskMapping", "TrainingTaskMappingForTextClassification"] +__all__ = [ + "ArgillaTrainer", + "TrainingTask", + "TrainingTaskForTextClassification", + "TrainingTaskForSFT", + "TrainingTaskForRM", + "TrainingTaskForPPO", + "TrainingTaskForDPO", + "TrainingTaskMapping", + "TrainingTaskMappingForTextClassification", +] diff --git a/src/argilla/client/feedback/training/base.py b/src/argilla/client/feedback/training/base.py index 49987d04ef..7511ce2267 100644 --- a/src/argilla/client/feedback/training/base.py +++ b/src/argilla/client/feedback/training/base.py @@ -17,7 +17,8 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, List, Optional, Union -from argilla.client.feedback.training.schemas import TrainingTaskMappingForTextClassification +from argilla.client.feedback.schemas.records import FeedbackRecord +from argilla.client.feedback.training.schemas import TrainingTaskForTextClassification, TrainingTaskTypes from argilla.client.models import Framework, TextClassificationRecord from argilla.training import ArgillaTrainer as ArgillaTrainerV1 @@ -33,7 +34,7 @@ class ArgillaTrainer(ArgillaTrainerV1): def __init__( self, dataset: "FeedbackDataset", - task_mapping: TrainingTaskMappingForTextClassification, + task: TrainingTaskTypes, framework: Framework, lang: Optional["spacy.Language"] = None, model: Optional[str] = None, @@ -41,44 +42,37 @@ def __init__( seed: Optional[int] = None, gpu_id: Optional[int] = -1, framework_kwargs: Optional[dict] = {}, - fetch_records: bool = True, ) -> None: """ Initialize an Argilla Trainer. Args: - dataset (FeedbackDataset): the dataset to be used for training. - task_mapping (TrainingTaskMappingForTextClassification): the training data to be used for training. - framework (str): - the framework to use for training. Currently, only "transformers", "setfit", and "spacy" - are supported. - lang (spacy.Language): - the spaCy language model to use for training, just required when `framework="spacy"`. + dataset: the dataset to be used for training. + task: the training data to be used for training. + framework: the framework to use for training. Currently, "transformers", "setfit", "spacy", "peft", + "openai", "span_marker" and "trl" are supported. + lang: the spaCy language model to use for training, just required when `framework="spacy"`. Defaults to None, but it will be set to `spacy.blank("en")` if not specified. - model (str): - name or path to the baseline model to be used. If not specified will set to a good default + model: name or path to the baseline model to be used. If not specified will set to a good default per framework, if applicable. Defaults to None. - train_size (float): - the size of the training set. If not specified, the entire dataset will be used for training, + train_size: the size of the training set. If not specified, the entire dataset will be used for training, which may be an issue if `framework="spacy"` as it requires a validation set. Defaults to None. - seed (int): the random seed to ensure reproducibility. Defaults to None. - gpu_id (int): - the GPU ID to use when training a SpaCy model. Defaults to -1, which means that the CPU + seed: the random seed to ensure reproducibility. Defaults to None. + gpu_id: the GPU ID to use when training a SpaCy model. Defaults to -1, which means that the CPU will be used by default. GPU IDs start in 0, which stands for the default GPU in the system, if available. - framework_kwargs (dict): arguments for the framework's trainer. + framework_kwargs: arguments for the framework's trainer. **load_kwargs: arguments for the rg.load() function. """ self._dataset = dataset self._train_size = train_size - self._task_mapping = task_mapping + self._task = task self._seed = seed # split is used for train-test-split and should therefore be fixed self._model = model self._prepared_data = self._dataset.prepare_for_training( framework=framework, - task_mapping=task_mapping, - fetch_records=fetch_records, + task=task, train_size=train_size, seed=seed, lang=lang, @@ -88,13 +82,13 @@ def __init__( framework = Framework(framework) if framework is Framework.SETFIT: - if not isinstance(task_mapping, TrainingTaskMappingForTextClassification): + if not isinstance(task, TrainingTaskForTextClassification): raise NotImplementedError(f"{Framework.SETFIT} only supports `TextClassification` tasks.") from argilla.client.feedback.training.frameworks.setfit import ArgillaSetFitTrainer self._trainer = ArgillaSetFitTrainer( - feedback_dataset=self._dataset, - task_mapping=self._task_mapping, + dataset=self._dataset, + task=self._task, prepared_data=self._prepared_data, seed=self._seed, model=self._model, @@ -103,8 +97,8 @@ def __init__( from argilla.client.feedback.training.frameworks.transformers import ArgillaTransformersTrainer self._trainer = ArgillaTransformersTrainer( - feedback_dataset=self._dataset, - task_mapping=self._task_mapping, + dataset=self._dataset, + task=self._task, prepared_data=self._prepared_data, seed=self._seed, model=self._model, @@ -113,8 +107,8 @@ def __init__( from argilla.client.feedback.training.frameworks.peft import ArgillaPeftTrainer self._trainer = ArgillaPeftTrainer( - feedback_dataset=self._dataset, - task_mapping=self._task_mapping, + dataset=self._dataset, + task=self._task, prepared_data=self._prepared_data, seed=self._seed, model=self._model, @@ -123,8 +117,8 @@ def __init__( from argilla.client.feedback.training.frameworks.spacy import ArgillaSpaCyTrainer self._trainer = ArgillaSpaCyTrainer( - feedback_dataset=self._dataset, - task_mapping=self._task_mapping, + dataset=self._dataset, + task=self._task, prepared_data=self._prepared_data, seed=self._seed, model=self._model, @@ -135,8 +129,8 @@ def __init__( from argilla.client.feedback.training.frameworks.spacy import ArgillaSpaCyTransformersTrainer self._trainer = ArgillaSpaCyTransformersTrainer( - feedback_dataset=self._dataset, - task_mapping=self._task_mapping, + dataset=self._dataset, + task=self._task, prepared_data=self._prepared_data, seed=self._seed, model=self._model, @@ -147,8 +141,8 @@ def __init__( from argilla.client.feedback.training.frameworks.openai import ArgillaOpenAITrainer self._trainer = ArgillaOpenAITrainer( - feedback_dataset=self._dataset, - task_mapping=self._task_mapping, + dataset=self._dataset, + task=self._task, prepared_data=self._prepared_data, seed=self._seed, model=self._model, @@ -157,8 +151,18 @@ def __init__( from argilla.client.feedback.training.frameworks.span_marker import ArgillaSpanMarkerTrainer self._trainer = ArgillaSpanMarkerTrainer( - feedback_dataset=self._dataset, - task_mapping=self._task_mapping, + dataset=self._dataset, + task=self._task, + prepared_data=self._prepared_data, + seed=self._seed, + model=self._model, + ) + elif framework is Framework.TRL: + from argilla.client.feedback.training.frameworks.trl import ArgillaTRLTrainer + + self._trainer = ArgillaTRLTrainer( + dataset=self._dataset, + task=self._task, prepared_data=self._prepared_data, seed=self._seed, model=self._model, @@ -167,7 +171,7 @@ def __init__( raise NotImplementedError(f"{framework} is not a valid framework.") self._logger.info(self) - self._track_trainer_usage(framework=framework, task=self._task_mapping.__class__.__name__) + self._track_trainer_usage(framework=framework, task=self._task.__class__.__name__) def __repr__(self) -> str: """ @@ -182,7 +186,7 @@ def __repr__(self) -> str: _________________________________________________________________ These baseline params are fixed: dataset: {self._dataset} - task: {self._task_mapping} + task: {self._task} train_size: {self._train_size} seed: {self._seed} @@ -218,40 +222,42 @@ def predict(self, text: Union[List[str], str], as_argilla_records: bool = True, class ArgillaTrainerSkeleton(ABC): def __init__( self, - feedback_dataset: "FeedbackDataset", - task_mapping: TrainingTaskMappingForTextClassification, + dataset: "FeedbackDataset", + task: TrainingTaskTypes, prepared_data=None, model: str = None, seed: int = None, *arg, **kwargs, ): - self._feedback_dataset = feedback_dataset - self._task_mapping = task_mapping + self._dataset = dataset + self._task = task self._dataset = prepared_data self._model = model self._seed = seed - if isinstance(self._task_mapping, TrainingTaskMappingForTextClassification): - self._multi_label = self._task_mapping.__multi_label__ or False - self._label_list = self._task_mapping.__all_labels__ or None - self._label2id = self._task_mapping.__label2id__ - self._id2label = self._task_mapping.__id2label__ + if isinstance(self._task, TrainingTaskForTextClassification): + self._multi_label = self._task.__multi_label__ or False + self._label_list = self._task.__all_labels__ or None + self._label2id = self._task.__label2id__ + self._id2label = self._task.__id2label__ self._record_class = TextClassificationRecord # TODO: dirty hack to inherit from original trainers + else: + self._record_class = FeedbackRecord @abstractmethod - def init_training_args(self): + def init_training_args(self) -> None: """ Initializes the training arguments. """ @abstractmethod - def init_model(self): + def init_model(self) -> None: """ Initializes a model. """ @abstractmethod - def update_config(self, *args, **kwargs): + def update_config(self, *args, **kwargs) -> None: """ Updates the configuration of the trainer, but the parameters depend on the trainer.subclass. """ @@ -263,13 +269,13 @@ def predict(self, text: Union[List[str], str], as_argilla_records: bool = True, """ @abstractmethod - def train(self, output_dir: str = None): + def train(self, output_dir: Optional[str] = None) -> None: """ Trains the model. """ @abstractmethod - def save(self, output_dir: str): + def save(self, output_dir: str) -> None: """ Saves the model to the specified path. """ diff --git a/src/argilla/client/feedback/training/frameworks/transformers.py b/src/argilla/client/feedback/training/frameworks/transformers.py index e4eb6d922d..73f7a271b1 100644 --- a/src/argilla/client/feedback/training/frameworks/transformers.py +++ b/src/argilla/client/feedback/training/frameworks/transformers.py @@ -16,7 +16,7 @@ from datasets import Dataset, DatasetDict from argilla.client.feedback.training.base import ArgillaTrainerSkeleton -from argilla.client.feedback.training.schemas import TrainingTaskMappingForTextClassification +from argilla.client.feedback.training.schemas import TrainingTaskForTextClassification from argilla.training.transformers import ArgillaTransformersTrainer as ArgillaTransformersTrainerV1 @@ -57,9 +57,11 @@ def __init__(self, *args, **kwargs): else: raise NotImplementedError(f"We do not support {type(self._dataset)} yet.") - if isinstance(self._task_mapping, TrainingTaskMappingForTextClassification): + if isinstance(self._task, TrainingTaskForTextClassification): self._model_class = AutoModelForSequenceClassification else: - raise NotImplementedError(f"ArgillaTransformersTrainer does not support {type(self._task_mapping)} yet.") + raise NotImplementedError( + f"ArgillaTransformersTrainer does not support {self._task.__class__.__name__} yet." + ) self.init_training_args() diff --git a/src/argilla/client/feedback/training/frameworks/trl.py b/src/argilla/client/feedback/training/frameworks/trl.py new file mode 100644 index 0000000000..eff8fef5b9 --- /dev/null +++ b/src/argilla/client/feedback/training/frameworks/trl.py @@ -0,0 +1,385 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Dict, List, Union + +from argilla.client.feedback.training.base import ArgillaTrainerSkeleton +from argilla.client.feedback.training.schemas import ( + TrainingTaskForDPO, + TrainingTaskForPPO, + TrainingTaskForRM, + TrainingTaskForSFT, +) +from argilla.training.utils import filter_allowed_args +from argilla.utils.dependency import require_version + +if TYPE_CHECKING: + import transformers + from trl import PPOConfig + + from argilla.client.feedback.dataset import FeedbackDataset + + +class PPOArgs: + def __init__( + self, + config: "PPOConfig", + reward_model: Union[str, "transformers.pipeline"], + length_sampler_kwargs: dict, + generation_kwargs: dict, + ) -> None: + """ + Additional arguments for PPO training process. + + Args: + reward_model (Union[str, "transformers.pipeline"]): Reward model to use for creating PPO rewards for training. + length_sampler_kwargs (dict): Arguments for the length sampler. + min_value: Minimum length for the generated samples. + max_value: Maximum length for the generated samples. + generation_kwargs (dict): Arguments for the generation process. + min_length: Minimum length for the generated samples. + max_length: Maximum length for the generated samples. + num_beams: Number of beams to use for the generation process. + num_return_sequences: Number of sequences to generate. + temperature: Temperature for the generation process. + top_k: Top k for the generation process. + top_p: Top p for the generation process. + """ + if isinstance(reward_model, str): + from transformers import pipeline + + reward_model = pipeline("text-classifcation", model=reward_model) + self.config = config + self.reward_model = reward_model + self.length_sampler_kwargs = length_sampler_kwargs + self.generation_kwargs = generation_kwargs + + +class ArgillaTRLTrainer(ArgillaTrainerSkeleton): + _logger = logging.getLogger("ArgillaTRLTrainer") + _logger.setLevel(logging.INFO) + + require_version("transformers") + require_version("torch") + require_version("trl>=0.5.0") + + def __init__( + self, + dataset: "FeedbackDataset", + task: Union[ + TrainingTaskForSFT, + TrainingTaskForRM, + TrainingTaskForPPO, + TrainingTaskForDPO, + ], + prepared_data=None, + model: str = None, + seed: int = None, + ) -> None: + super().__init__(dataset=dataset, task=task, prepared_data=prepared_data, model=model, seed=seed) + import torch + from datasets import DatasetDict + from transformers import set_seed + + self._transformers_model = None + self._transformers_tokenizer = None + self.device = "cpu" + if torch.backends.mps.is_available(): + self.device = "mps" + elif torch.cuda.is_available(): + self.device = "cuda" + + if self._seed is None: + self._seed = 42 + set_seed(self._seed) + + if self._model is None: + self._model = "gpt2-medium" + + if isinstance(self._dataset, DatasetDict): + self._train_dataset = self._dataset["train"] + self._eval_dataset = self._dataset["test"] + else: + self._train_dataset = self._dataset + self._eval_dataset = None + + if not isinstance( + self._task, + ( + TrainingTaskForSFT, + TrainingTaskForRM, + TrainingTaskForPPO, + TrainingTaskForDPO, + ), + ): + raise NotImplementedError(f"Task {self._task} not supported in TRL.") + + from trl import DPOTrainer, PPOTrainer, RewardTrainer, SFTTrainer + + self.trainer_mapping = { + TrainingTaskForSFT: SFTTrainer, + TrainingTaskForRM: RewardTrainer, + TrainingTaskForPPO: PPOTrainer, + TrainingTaskForDPO: DPOTrainer, + } + self.trainer_cls = self.trainer_mapping[type(self._task)] + + self.init_training_args() + + def init_training_args(self) -> None: + """ + Initializes the training arguments. + """ + self.training_args_kwargs = {} + self.trainer_kwargs = {} + + if isinstance(self._task, TrainingTaskForPPO): + from trl import PPOConfig + + self._logger.warning( + "The PPOTrainer must be initialized by passing `reward_model`, `length_sampler_kwargs`, `generation_kwargs` as kwargs to the `update_config()`-method." + ) + self.trainer_kwargs["config"] = PPOConfig() + self.training_args_kwargs["reward_model"] = None + self.training_args_kwargs["length_sampler_kwargs"] = {"min_value": 1, "max_value": 10} + self.training_args_kwargs["generation_kwargs"] = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + } + else: + self.training_args_kwargs["evaluation_strategy"] = "no" if self._eval_dataset is None else "epoch" + self.training_args_kwargs["logging_steps"] = 30 + self.training_args_kwargs["logging_steps"] = 1 + self.training_args_kwargs["num_train_epochs"] = 1 + + def init_model(self, new: bool = False) -> None: + """ + Initializes a model. + """ + from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, + ) + from trl import AutoModelForCausalLMWithValueHead, create_reference_model + + if isinstance(self._task, (TrainingTaskForSFT, TrainingTaskForDPO)): + auto_model_class = AutoModelForCausalLM + elif isinstance(self._task, TrainingTaskForPPO): + auto_model_class = AutoModelForCausalLMWithValueHead + elif isinstance(self._task, TrainingTaskForRM): + auto_model_class = AutoModelForSequenceClassification + + self._transformers_model: PreTrainedModel = auto_model_class.from_pretrained(self._model) + self._transformers_tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self._model) + self._transformers_tokenizer.pad_token = self._transformers_tokenizer.eos_token + self._transformers_model.config.pad_token_id = self._transformers_tokenizer.pad_token_id + + if isinstance(self._task, (TrainingTaskForPPO, TrainingTaskForDPO)): + self._transformers_ref_model: PreTrainedModel = create_reference_model(self._transformers_model) + if new: + self._transformers_model.to(self.device) + + def update_config(self, **kwargs) -> None: + """ + Updates the configuration of the trainer, but the parameters depend on the trainer.subclass. + """ + from transformers import TrainingArguments + + self.training_args_kwargs.update(filter_allowed_args(TrainingArguments.__init__, **kwargs)) + self.training_args_kwargs.update(filter_allowed_args(PPOArgs.__init__, **kwargs)) + self.trainer_kwargs.update(filter_allowed_args(self.trainer_cls.__init__, **kwargs)) + + def predict(self, text: Union[List[str], str], as_argilla_records: bool = True, **kwargs) -> None: + """ + Predicts the label of the text. + """ + raise NotImplementedError("Models trained with TRL cannot be used for label predictions.") + + def train(self, output_dir: str) -> None: + """ + Trains the model. + """ + if isinstance(self._task, TrainingTaskForPPO): + if not all( + x in self.training_args_kwargs for x in ["length_sampler_kwargs", "generation_kwargs", "reward_model"] + ): + raise ValueError( + "To train a PPO model, you need to specify the following arguments via `trainer.update_config`: length_sampler_kwargs, generation_kwargs, reward_model." + ) + + from transformers import TrainingArguments + + # check required path argument + self.training_args_kwargs["output_dir"] = output_dir + + self.init_model(new=True) + + if isinstance(self._task, TrainingTaskForSFT): + self._training_args = TrainingArguments(**self.training_args_kwargs) + self._trainer = self.trainer_cls( + self._transformers_model, + args=self._training_args, + train_dataset=self._train_dataset, + eval_dataset=self._eval_dataset, + dataset_text_field="text", + tokenizer=self._transformers_tokenizer, + **self.trainer_kwargs, + ) + + elif isinstance(self._task, TrainingTaskForRM): + + def preprocess_function(examples) -> Dict[str, List]: + new_examples = { + "input_ids_chosen": [], + "attention_mask_chosen": [], + "input_ids_rejected": [], + "attention_mask_rejected": [], + } + for chosen, rejected in zip(examples["chosen"], examples["rejected"]): + tokenized_j = self._transformers_tokenizer(chosen, truncation=True) + tokenized_k = self._transformers_tokenizer(rejected, truncation=True) + + new_examples["input_ids_chosen"].append(tokenized_j["input_ids"]) + new_examples["attention_mask_chosen"].append(tokenized_j["attention_mask"]) + new_examples["input_ids_rejected"].append(tokenized_k["input_ids"]) + new_examples["attention_mask_rejected"].append(tokenized_k["attention_mask"]) + + return new_examples + + self._training_args = TrainingArguments(**self.training_args_kwargs) + train_dataset = self._train_dataset.map(preprocess_function, batched=True) + eval_dataset = None + if self._eval_dataset: + eval_dataset = self._eval_dataset.map(preprocess_function, batched=True) + + self._trainer = self.trainer_cls( + self._transformers_model, + args=self._training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=self._transformers_tokenizer, + **self.trainer_kwargs, + ) + elif isinstance(self._task, TrainingTaskForPPO): + from datasets import concatenate_datasets + + dataset = concatenate_datasets([x for x in [self._train_dataset, self._eval_dataset] if x is not None]) + + def tokenize(sample): + sample["input_ids"] = self._transformers_tokenizer.encode(sample["query"], truncation=True) + return sample + + def data_collator(data): + return dict((key, [d[key] for d in data]) for key in data[0]) + + def remove_truncated(sample): + return len(sample) < self._transformers_tokenizer.model_max_length + + dataset = dataset.map(tokenize, batched=False) + size_before = len(dataset) + dataset = dataset.filter(remove_truncated, batched=False) + size_after = len(dataset) + if size_after != size_before: + self._logger.info( + f"Removed {size_before - size_after} samples ({1 - (size_after / size_before):%}), " + "as these samples were longer than the maximum model length even before the generation." + ) + dataset.set_format(type="torch") + self._trainer = self.trainer_cls( + model=self._transformers_model, + ref_model=self._transformers_ref_model, + tokenizer=self._transformers_tokenizer, + dataset=dataset, + data_collator=data_collator, + **self.trainer_kwargs, + ) + elif isinstance(self._task, TrainingTaskForDPO): + self._training_args = TrainingArguments(**self.training_args_kwargs) + self._trainer = self.trainer_cls( + model=self._transformers_model, + ref_model=self._transformers_ref_model, + args=self._training_args, + train_dataset=self._train_dataset, + eval_dataset=self._eval_dataset, + tokenizer=self._transformers_tokenizer, + **self.trainer_kwargs, + ) + + # train + if isinstance(self._task, TrainingTaskForPPO): + import torch + from tqdm import tqdm + from trl.core import LengthSampler + + output_length_sampler = LengthSampler(**self.training_args_kwargs["length_sampler_kwargs"]) + generation_kwargs = self.training_args_kwargs["generation_kwargs"] + generation_kwargs["pad_token_id"] = self._transformers_tokenizer.eos_token_id + reward_model = self.training_args_kwargs["reward_model"] + + for batch in tqdm(self._trainer.dataloader): + query_tensors = batch["input_ids"] + + #### Get response from SFT + response_tensors = self._trainer.generate( + query_tensors, + return_prompt=False, + length_sampler=output_length_sampler, + **generation_kwargs, + ) + batch["response"] = self._transformers_tokenizer.batch_decode( + response_tensors, skip_special_tokens=True + ) + + #### Compute rewards scores + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = reward_model(texts, top_k=None, truncation=True) + rewards = [torch.tensor(output[-1]["score"] * len(output)) for output in pipe_outputs] + + #### Run PPO step + stats = self._trainer.step(query_tensors, response_tensors, rewards) + self._trainer.log_stats(stats, batch, rewards) + else: + self._trainer.train() + if self._trainer.eval_dataset: + self._metrics = self._trainer.evaluate() + self._logger.info(self._metrics) + else: + self._metrics = None + + self.save(output_dir) + + def save(self, output_dir: str) -> None: + """ + Saves the model to the specified path. + """ + self._transformers_model.save_pretrained(output_dir) + self._transformers_tokenizer.save_pretrained(output_dir) + + def __repr__(self) -> str: + formatted_string = [] + arg_dict = { + repr(self.trainer_cls.__name__): self.trainer_kwargs, + "'TrainingArguments'": self.training_args_kwargs, + } + for arg_dict_key, arg_dict_single in arg_dict.items(): + formatted_string.append(arg_dict_key) + for key, val in arg_dict_single.items(): + formatted_string.append(f"{key}: {val}") + return "\n".join(formatted_string) diff --git a/src/argilla/client/feedback/training/schemas.py b/src/argilla/client/feedback/training/schemas.py index 5776351f66..713f62db99 100644 --- a/src/argilla/client/feedback/training/schemas.py +++ b/src/argilla/client/feedback/training/schemas.py @@ -13,8 +13,9 @@ # limitations under the License. import logging -from abc import ABC, abstractmethod -from typing import List, Tuple, Union +import warnings +from abc import ABC +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import pandas as pd from pydantic import BaseModel @@ -39,28 +40,71 @@ _LOGGER = logging.getLogger(__name__) +if TYPE_CHECKING: + import datasets + import spacy + + from argilla.client.feedback.dataset import FeedbackDataset + + +TASK_STRUCTURE = { + "text_classification": { + "field": (TextField), + "question": ( + LabelQuestion, + MultiLabelQuestion, + RatingQuestion, + RankingQuestion, + ), + "unification": ( + LabelQuestionUnification, + MultiLabelQuestionUnification, + RatingQuestionUnification, + RankingQuestionUnification, + ), + } +} + class TrainingData(ABC): - def _format_data(self, records): + _formatting_func_return_types = None + + def _test_output_formatting_func(self, sample: Any): + """ + Test if the formatting function returns the expected format. + """ + try: + if not type(sample) == iter: + self._formatting_func_return_types(format=sample) + return True + except Exception: + raise ValueError( + f"formatting_func must return {self._formatting_func_return_types.__annotations__['format']}, not {type(sample)}" + ) + + def _format_data(self, dataset: "FeedbackDataset") -> List[Dict[str, Any]]: formatted_data = [] explode_columns = set() - for record in records: + for record in dataset.records: data = {} for pydantic_field in self: - pydantic_field_name, pydantic_field_value = pydantic_field - if isinstance(pydantic_field_value, (TextField,)): - data[pydantic_field_name] = record.fields[pydantic_field_value.name] - else: - if pydantic_field_value.question.name not in record._unified_responses: - continue + # with default and formatting_func either one can be None + if pydantic_field[-1] is not None: + pydantic_field_name, pydantic_field_value = pydantic_field + if isinstance(pydantic_field_value, (TextField,)): + data[pydantic_field_name] = record.fields[pydantic_field_value.name] else: - data[pydantic_field_name] = [ - resp.value for resp in record._unified_responses[pydantic_field_value.question.name] - ] - explode_columns.add(pydantic_field_name) + if pydantic_field_value.question.name not in record._unified_responses: + continue + else: + data[pydantic_field_name] = [ + resp.value for resp in record._unified_responses[pydantic_field_value.question.name] + ] + explode_columns.add(pydantic_field_name) formatted_data.append(data) df = pd.DataFrame(formatted_data) - df = df.explode(list(explode_columns)) + if explode_columns: + df = df.explode(list(explode_columns)) # In cases of MultiLabel datasets the label column contains a list, # which is unhashable, so for those cases we transform the rows in the # dataframe to tuples to drop duplicates and reconstruct the original @@ -71,6 +115,7 @@ def _format_data(self, records): df = df.drop_duplicates() df = df.dropna(how="any") + return df.to_dict(orient="records") @property @@ -83,134 +128,310 @@ def test_framework_support(self, framework: Union[str, Framework]): if framework not in self.supported_frameworks: raise NotImplementedError(f"Framework {framework} is not supported for this {self.__class__}.") - @abstractmethod def _train_test_split(self, data: List[dict], train_size: float, seed: int) -> Tuple[List[dict], List[dict]]: """Overwritten by subclasses""" - @abstractmethod def _prepare_for_training_with_transformers( - self, data: List[dict], train_size, seed: int + self, data: List[dict], train_size, seed: int, framework: Union[str, Framework] ) -> Union["datasets.Dataset", "datasets.DatasetDict"]: - """Overwritten by subclasses""" + raise ValueError(f"{self.__class__.__name__} does not support the {framework} framework.") - @abstractmethod def _prepare_for_training_with_spacy( self, data: List[dict], train_size, seed: int, lang: str ) -> Union["spacy.token.DocBin", Tuple["spacy.token.DocBin", "spacy.token.DocBin"]]: - """Overwritten by subclasses""" + raise ValueError(f"{self.__class__.__name__} does not support the spaCy framework.") - @abstractmethod def _prepare_for_training_with_spark_nlp( self, data: List[dict], train_size, seed: int ) -> Union["pd.DataFrame", Tuple["pd.DataFrame", "pd.DataFrame"]]: - """Overwritten by subclasses""" + raise ValueError(f"{self.__class__.__name__} does not support the Spark NLP framework.") - @abstractmethod def _prepare_for_training_with_openai( self, data: List[dict], train_size, seed: int ) -> Union[List[dict], Tuple[List[dict], List[dict]]]: - """Overwritten by subclasses""" + raise ValueError(f"{self.__class__.__name__} does not support the OpenAI framework.") + def _prepare_for_training_with_trl( + self, data: List[dict], train_size, seed: int + ) -> Union[List[dict], Tuple[List[dict], List[dict]]]: + raise ValueError(f"{self.__class__.__name__} does not support the TRL framework.") + + def _prepare_for_training_with_trlx( + self, data: List[dict], train_size, seed: int + ) -> Union[List[dict], Tuple[List[dict], List[dict]]]: + raise ValueError(f"{self.__class__.__name__} does not support the TRLX framework.") -class TrainingTaskMapping: + +class TrainingTask: @classmethod def for_text_classification( cls, - text: TextField, - label: Union[ - RatingQuestion, - LabelQuestion, - RankingQuestion, - MultiLabelQuestion, - RatingQuestionUnification, - LabelQuestionUnification, - MultiLabelQuestionUnification, - RankingQuestionUnification, - ], + formatting_func: Callable[[Dict[str, Any]], Union[None, str, List[str], Iterator[str]]] = None, + text: Optional[TextField] = None, + label: Optional[ + Union[ + RatingQuestion, + LabelQuestion, + RankingQuestion, + MultiLabelQuestion, + RatingQuestionUnification, + LabelQuestionUnification, + MultiLabelQuestionUnification, + RankingQuestionUnification, + ] + ] = None, label_strategy: str = None, - ) -> "TrainingTaskMappingForTextClassification": + ) -> "TrainingTaskForTextClassification": """ - _summary_ + Define a task configuration for text classification. It takes default values for `text` and `label` using datasets Fields and Questions or a custom `formatting_func` as Callable. See Examples underneath for more details. Args: - text (TextField): The TextField to use for training. - label (Union[RatingQuestion, LabelQuestion, MultiLabelQuestion, RatingQuestionUnification, LabelQuestionUnification, MultiLabelQuestionUnification]): _description_ - label_strategy (str, optional): A strategy to unify responses. Defaults to None. This means it will initialize the default strategy for the label type. + formatting_func: A formatting function. Defaults to None. + text: The TextField to use for training. Defaults to None. + label: The *Question to use for training. Defaults to None. + label_strategy: A strategy to unify responses. Defaults to None. This means it will initialize the default strategy for the label type. Defaults to None. Raises: ValueError: if label is not a valid type with the question type. - ValueError: if label_strategy is defined and label is alraedy a Unification class. + ValueError: if label_strategy is defined and label is already a Unification class. Returns: - TrainingTaskMappingForTextClassification: _description_ + TrainingTaskForTextClassification: A task mapping instance to be used in `FeedbackDataset.prepare_for_training()` Examples: - >>> from argilla import LabelQuestion, TrainingTaskMapping - >>> dataset = rg.FeedbackDataset.from_argilla(argilla_id="...") - >>> training_data = TrainingTaskMapping.for_text_classification( + >>> from argilla import LabelQuestion, TrainingTask + >>> dataset = rg.FeedbackDataset.from_argilla(name="...") + >>> task = TrainingTask.for_text_classification( ... text=dataset.fields[0], ... label=dataset.questions[0] ... ) - >>> dataset.prepare_training_data(training_data=training_data) + >>> dataset.prepare_for_training(framework="...", task=task) + >>> from argilla import LabelQuestion, TrainingTask + >>> from collections import Counter + >>> import random + >>> def formatting_func(sample: Dict[str, Any]) -> Union[Tuple[str, str], Tuple[str, List[str]]]: + ... text = sample["text"] + ... values = [annotation["value"] for annotation in sample["label"]] + ... counter = Counter(values) + ... if counter: + ... most_common = counter.most_common() + ... max_frequency = most_common[0][1] + ... most_common_elements = [element for element, frequency in most_common if frequency == max_frequency] + ... label = random.choice(most_common_elements) + ... return (text, label) + ... else: + ... return None + >>> task = TrainingTask.for_text_classification(formatting_func=formatting_func) + >>> dataset.prepare_for_training(framework="...", task=task) """ - if isinstance( - label, - ( - LabelQuestionUnification, - MultiLabelQuestionUnification, - RatingQuestionUnification, - RankingQuestionUnification, - ), - ): - if label_strategy is not None: - raise ValueError("label_strategy is already defined via Unification class.") + if (text and label) and formatting_func is not None: + raise ValueError("You must provide either `text` and `label`, or a `formatting_func`, not both.") + + if formatting_func is not None: + if text or label: + raise ValueError("`formatting_func` is already defined, so you cannot define `text` and `label`.") + return TrainingTaskForTextClassification(formatting_func=formatting_func) else: - unification_kwargs = {"question": label} - if label_strategy is not None: - unification_kwargs["strategy"] = label_strategy + if isinstance(label, TASK_STRUCTURE["text_classification"]["unification"]): + if label_strategy is not None: + raise ValueError("label_strategy is already defined via Unification class.") else: - _LOGGER.info(f"No label strategy defined. Using default strategy for {type(label)}.") - if isinstance(label, RatingQuestion): - label = RatingQuestionUnification(**unification_kwargs) - elif isinstance(label, MultiLabelQuestion): - label = MultiLabelQuestionUnification(**unification_kwargs) - elif isinstance(label, LabelQuestion): - label = LabelQuestionUnification(**unification_kwargs) - elif isinstance(label, RankingQuestion): - label = RankingQuestionUnification(**unification_kwargs) - else: - raise ValueError(f"Label type {type(label)} is not supported.") - return TrainingTaskMappingForTextClassification( - text=text, - label=label, - label_strategy=label_strategy, - ) + unification_kwargs = {"question": label} + if label_strategy is not None: + unification_kwargs["strategy"] = label_strategy + else: + _LOGGER.info(f"No label strategy defined. Using default strategy for {type(label)}.") + if isinstance(label, RatingQuestion): + label = RatingQuestionUnification(**unification_kwargs) + elif isinstance(label, MultiLabelQuestion): + label = MultiLabelQuestionUnification(**unification_kwargs) + elif isinstance(label, LabelQuestion): + label = LabelQuestionUnification(**unification_kwargs) + elif isinstance(label, RankingQuestion): + label = RankingQuestionUnification(**unification_kwargs) + else: + raise ValueError(f"Label type {type(label)} is not supported.") + return TrainingTaskForTextClassification(text=text, label=label) + @classmethod + def for_supervised_fine_tuning( + cls, + formatting_func: Callable[[Dict[str, Any]], Union[None, str, List[str], Iterator[str]]], + ) -> "TrainingTaskForSFT": + """ + Return a task that can be used in `FeedbackDataset.prepare_for_training(framework="...", task)` + to extract data from the Feedback Dataset in an immediately useful format. + + Args: + formatting_func: A formatting function converting a dictionary of records into zero, + one or more text strings. -class TrainingTaskMappingForTextClassification(BaseModel, TrainingData): + Returns: + TrainingTaskForSFT: A task mapping instance to be used in `FeedbackDataset.prepare_for_training()` + + Examples: + >>> from argilla import TrainingTask + >>> dataset = rg.FeedbackDataset.from_argilla(name="...") + >>> def formatting_func(sample: Dict[str, Any]): + ... annotations = sample["good] + ... if annotations and annotations[0]["value"] == "Bad": + ... return + ... return template.format(prompt=sample["prompt"][0]["value"], response=sample["response"][0]["value"]) + >>> task = TrainingTask.for_supervised_fine_tuning(formatting_func=formatting_func) + >>> dataset.prepare_for_training(framework="...", task=task) + + """ + return TrainingTaskForSFT(formatting_func=formatting_func) + + @classmethod + def for_reward_modeling( + cls, + formatting_func: Callable[ + [Dict[str, Any]], Union[None, Tuple[str, str], List[Tuple[str, str]], Iterator[Tuple[str, str]]] + ], + ) -> "TrainingTaskForRM": + """ + Return a task that can be used in `FeedbackDataset.prepare_for_training(framework="...", task)` + to extract data from the Feedback Dataset in an immediately useful format. + + Args: + formatting_func: A formatting function converting a dictionary of records into zero, + one or more chosen-rejected text tuples. + + Returns: + TrainingTaskForRM: A task mapping instance to be used in `FeedbackDataset.prepare_for_training()` + + Examples: + >>> from argilla import TrainingTask + >>> dataset = rg.FeedbackDataset.from_argilla(name="...") + >>> def formatting_func(sample: Dict[str, Any]): + ... values = [annotation["value"] for annotation in sample["ranking"]] + ... if values.count("1") >= values.count("2"): + ... chosen = sample["response-1"] + ... rejected = sample["response-2"] + ... else: + ... chosen = sample["response-2"] + ... rejected = sample["response-1"] + ... return chosen, rejected + >>> task = TrainingTask.for_reward_modeling(formatting_func=formatting_func) + >>> dataset.prepare_for_training(framework="...", task=task) + + """ + return TrainingTaskForRM(formatting_func=formatting_func) + + @classmethod + def for_proximal_policy_optimization( + cls, formatting_func: Callable[[Dict[str, Any]], Union[None, str, Iterator[str]]] + ) -> "TrainingTaskForPPO": + """ + Return a task that can be used in `FeedbackDataset.prepare_for_training(text: TextField)` + to extract data from the Feedback Dataset in an immediately useful format. + + Args: + formatting_func: A formatting function converting a dictionary of records into zero, + one or more prompts. + + Returns: + TrainingTaskForPPO: A task mapping instance to be used in `FeedbackDataset.prepare_for_training()` + """ + return TrainingTaskForPPO(formatting_func=formatting_func) + + @classmethod + def for_direct_preference_optimization( + cls, + formatting_func: Callable[[Dict[str, Any]], Union[None, Tuple[str, str, str], Iterator[Tuple[str, str, str]]]], + ) -> "TrainingTaskForDPO": + """ + Provide `TrainingTask.for_direct_preference_optimization(formatting_func: Callable)` + Return a task that can be used in `FeedbackDataset.prepare_for_training(framework="...", task)` + to extract data from the Feedback Dataset in an immediately useful format. + + Args: + formatting_func: A formatting function converting a dictionary of records into zero, + one or more prompt-chosen-rejected text tuples. + + Returns: + TrainingTaskForDPO: A task mapping instance to be used in `FeedbackDataset.prepare_for_training()` + + Examples: + >>> from argilla import TrainingTask + >>> dataset = rg.FeedbackDataset.from_argilla(name="...") + >>> def formatting_func(sample: Dict[str, Any]): + ... values = [annotation["value"] for annotation in sample["ranking"]] + ... if values.count("1") >= values.count("2"): + ... chosen = sample["response-1"] + ... rejected = sample["response-2"] + ... else: + ... chosen = sample["response-2"] + ... rejected = sample["response-1"] + ... return sample["prompt"], chosen, rejected + >>> task = TrainingTask.for_direct_preference_optimization(formatting_func=formatting_func) + >>> dataset.prepare_for_training(framework="...", task=task) + + """ + return TrainingTaskForDPO(formatting_func=formatting_func) + + +class TrainingTaskForTextClassificationFormat(BaseModel): + """ + Union[ + Tuple[str, str], Tuple[str, List[str]], + List[Tuple[str, str]], List[Tuple[str, List[str]]] + ] + """ + + format: Union[Tuple[str, str], Tuple[str, List[str]], List[Tuple[str, str]], List[Tuple[str, List[str]]]] + + +class TrainingTaskForTextClassification(BaseModel, TrainingData): """Training data for text classification Args: - text: TextField - label: Union[RatingUnification, LabelUnification, MultiLabelUnification] + formatting_func: A formatting function returning the text to classify. Either a formatting function or + the text and label parameters are provided. Defaults to None. + text: The text field to take as the text to classify. + label: The question denoting the label of the text to classify. Examples: - >>> from argilla import LabelQuestion, TrainingTaskMappingForTextClassification - >>> dataset = rg.FeedbackDataset.from_argilla(argilla_id="...") - >>> label = RatingQuestionUnification(question=dataset.questions[0], strategy="mean") - >>> training_data = TrainingTaskMappingForTextClassification( + >>> from argilla import LabelQuestion, TrainingTask + >>> dataset = rg.FeedbackDataset.from_argilla(name="...") + >>> task = TrainingTask.for_text_classification( ... text=dataset.fields[0], - ... label=label + ... label=dataset.questions[0] ... ) - >>> dataset.prepare_training_data(training_data=training_data) + >>> dataset.prepare_for_training(framework="...", task=task) + + >>> from argilla import LabelQuestion, TrainingTask + >>> from collections import Counter + >>> def formatting_func(sample: Dict[str, Any]) -> Union[Tuple[str, str], Tuple[str, List[str]]]: + ... text = sample["text"] + ... values = [annotation["value"] for annotation in sample["label"]] + ... counter = Counter(values) + ... if counter: + ... most_common = counter.most_common() + ... max_frequency = most_common[0][1] + ... most_common_elements = [element for element, frequency in most_common if frequency == max_frequency] + ... label = random.choice(most_common_elements) + ... return (text, label) + ... else: + ... return None + >>> task = TrainingTask.for_text_classification(formatting_func=formatting_func) + >>> dataset.prepare_for_training(framework="...", task=task) """ - text: TextField - label: Union[ - RatingQuestionUnification, LabelQuestionUnification, MultiLabelQuestionUnification, RankingQuestionUnification - ] + formatting_func: Optional[Callable[[Dict[str, Any]], Union[None, str, List[str], Iterator[str]]]] = None + _formatting_func_return_types = TrainingTaskForTextClassificationFormat + text: Optional[TextField] = None + label: Optional[ + Union[ + RatingQuestionUnification, + LabelQuestionUnification, + MultiLabelQuestionUnification, + RankingQuestionUnification, + ] + ] = None @property def supported_frameworks(self): @@ -233,6 +454,46 @@ def __label2id__(self): def __id2label__(self): return self.label.question.__id2label__ + def _format_data(self, dataset: "FeedbackDataset") -> List[Dict[str, Any]]: + if self.formatting_func is not None: + output = set() + + for sample in dataset.format_as("datasets"): + text_label = self.formatting_func(sample) + if text_label is None: + continue + + self._test_output_formatting_func(text_label) + + if isinstance(text_label, tuple): + text_label = {text_label} + + output |= set(text_label) + + data = [] + _all_labels = set() + for text, label in output: + data.append({"text": text, "label": label}) + if isinstance(label, list): + _multi_label = True + _all_labels |= set(label) + else: + _all_labels.add(label) + _multi_label = False + + # infer label type from output custom formatting function + if _multi_label: + self.label = MultiLabelQuestionUnification( + question=MultiLabelQuestion(name="custom_func", labels=list(_all_labels)) + ) + else: + self.label = LabelQuestionUnification( + question=LabelQuestion(name="custom_func", labels=list(_all_labels)) + ) + return data + else: + return super()._format_data(dataset) + def unify_responses(self, responses: List[FeedbackRecord]): self.label.strategy.unify_responses(responses=responses, field=self.label.question) @@ -251,7 +512,7 @@ def _train_test_split(self, data: List[dict], train_size: float, seed: int) -> T def __repr__(self) -> str: return ( - "TrainingTaskMappingForTextClassification" + f"{self.__class__.__name__}" f"\n\t text={self.text.name}" f"\n\t label={self.label.question.name}" f"\n\t multi_label={self.__multi_label__}" @@ -265,18 +526,18 @@ def _prepare_for_training_with_transformers( self.test_framework_support(framework) import datasets - multi_label = isinstance(self.label.question, MultiLabelQuestion) + multi_label = self.__multi_label__ datasets_dict = {"id": [], "text": [], "label": []} - for entry in data: - datasets_dict["id"].append("None") + for index, entry in enumerate(data): + datasets_dict["id"].append(index) datasets_dict["text"].append(entry["text"]) datasets_dict["label"].append(entry["label"]) all_labels = self.label.question.__all_labels__ class_label = datasets.ClassLabel(names=all_labels) feature_dict = { - "id": datasets.Value("string"), + "id": datasets.Value(dtype="int32"), "text": datasets.Value("string"), "label": [class_label] if multi_label else class_label, } @@ -311,14 +572,13 @@ def _prepare_for_training_with_spacy( ) -> Union["spacy.token.DocBin", Tuple["spacy.token.DocBin", "spacy.token.DocBin"]]: from spacy.tokens import DocBin - all_labels = self.label.question.__all_labels__ + all_labels = self.__all_labels__ def _prepare(data): db = DocBin(store_user_data=True) # Creating the DocBin object as in https://spacy.io/usage/training#training-data for entry in data: doc = lang.make_doc(entry["text"]) - # doc.user_data["id"] = record.id cats = dict.fromkeys(all_labels, 0) if isinstance(entry["label"], list): @@ -397,3 +657,372 @@ def _prepare(data): return _prepare(train_data), _prepare(test_data) else: return _prepare(data) + + +class TrainingTaskForSFTFormat(BaseModel): + """ + Union[str, List[str]] + """ + + format: Union[str, List[str]] + + +class TrainingTaskForSFT(BaseModel, TrainingData): + """Training data for supervised finetuning + + Args: + formatting_func: A formatting function converting a dictionary of records into zero, + one or more text strings. + + Examples: + >>> from argilla import TrainingTaskForSFT + >>> dataset = rg.FeedbackDataset.from_argilla(name="...") + >>> def formatting_func(sample: Dict[str, Any]): + ... annotations = sample["good] + ... if annotations and annotations[0]["value"] == "Bad": + ... return + ... return template.format(prompt=sample["prompt"][0]["value"], response=sample["response"][0]["value"]) + >>> task = TrainingTaskForSFT(formatting_func=formatting_func) + >>> dataset.prepare_for_training(framework="...", task=task) + + """ + + _formatting_func_return_types = TrainingTaskForSFTFormat + formatting_func: Callable[[Dict[str, Any]], Union[None, str, List[str], Iterator[str]]] + + def _format_data(self, dataset: "FeedbackDataset") -> List[Dict[str, str]]: + formatted_texts = set() + for sample in dataset.format_as("datasets"): + if texts := self.formatting_func(sample): + if texts is None: + continue + + self._test_output_formatting_func(texts) + + if isinstance(texts, str): + texts = {texts} + + formatted_texts |= set(texts) + return [{"text": text} for text in formatted_texts] + + @property + def supported_frameworks(self): + names = ["trl"] + return [Framework(name) for name in names] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}\n\t formatting_func={self.formatting_func}" + + @requires_version("datasets>1.17.0") + def _prepare_for_training_with_trl( + self, data: List[dict], train_size: float, seed: int + ) -> Union["datasets.Dataset", "datasets.DatasetDict"]: + import datasets + + datasets_dict = {"id": [], "text": []} + for index, sample in enumerate(data): + datasets_dict["id"].append(index) + datasets_dict["text"].append(sample["text"]) + + feature_dict = { + "id": datasets.Value(dtype="int32"), + "text": datasets.Value("string"), + } + + ds = datasets.Dataset.from_dict(datasets_dict, features=datasets.Features(feature_dict)) + if train_size != 1: + ds = ds.train_test_split(train_size=train_size, test_size=1 - train_size, seed=seed) + + return ds + + +class TrainingTaskForRMFormat(BaseModel): + """ + Union[ + Tuple[str, str], Tuple[str, List[str]], + List[Tuple[str, str]], List[Tuple[str, List[str]]] + ] + """ + + format: Union[Tuple[str, str], Tuple[str, List[str]], List[Tuple[str, str]], List[Tuple[str, List[str]]]] + + +class TrainingTaskForRM(BaseModel, TrainingData): + """Training data for reward modeling + + Args: + formatting_func: A formatting function converting a dictionary of records into zero, + one or more chosen-rejected text tuples. + + Examples: + >>> from argilla import TrainingTaskForRM + >>> dataset = rg.FeedbackDataset.from_argilla(name="...") + >>> def formatting_func(sample: Dict[str, Any]): + ... values = [annotation["value"] for annotation in sample["ranking"]] + ... if values.count("1") >= values.count("2"): + ... chosen = sample["response-1"] + ... rejected = sample["response-2"] + ... else: + ... chosen = sample["response-2"] + ... rejected = sample["response-1"] + ... return chosen, rejected + >>> task = TrainingTaskForRM(formatting_func=formatting_func) + >>> dataset.prepare_for_training(framework="...", task=task) + """ + + _formatting_func_return_types = TrainingTaskForRMFormat + formatting_func: Callable[ + [Dict[str, Any]], Union[None, Tuple[str, str], List[Tuple[str, str]], Iterator[Tuple[str, str]]] + ] + + def _format_data(self, dataset: "FeedbackDataset") -> List[Dict[str, str]]: + output = set() + for sample in dataset.format_as("datasets"): + chosen_rejecteds = self.formatting_func(sample) + if chosen_rejecteds is None: + continue + + self._test_output_formatting_func(chosen_rejecteds) + + if isinstance(chosen_rejecteds, tuple): + chosen_rejecteds = {chosen_rejecteds} + + output |= set(chosen_rejecteds) + return [{"chosen": chosen, "rejected": rejected} for chosen, rejected in output] + + @property + def supported_frameworks(self): + names = ["trl"] + return [Framework(name) for name in names] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}\n\t formatting_func={self.formatting_func}" + + @requires_version("datasets>1.17.0") + def _prepare_for_training_with_trl( + self, data: List[dict], train_size: float, seed: int + ) -> Union["datasets.Dataset", "datasets.DatasetDict"]: + import datasets + + datasets_dict = {"chosen": [], "rejected": []} + for sample in data: + datasets_dict["chosen"].append(sample["chosen"]) + datasets_dict["rejected"].append(sample["rejected"]) + + feature_dict = { + "rejected": datasets.Value("string"), + "chosen": datasets.Value("string"), + } + + ds = datasets.Dataset.from_dict(datasets_dict, features=datasets.Features(feature_dict)) + if train_size != 1: + ds = ds.train_test_split(train_size=train_size, test_size=1 - train_size, seed=seed) + + return ds + + +class TrainingTaskForPPOFormat(BaseModel): + """ + Union[str, List[str]] + """ + + format: Union[str, List[str]] + + +class TrainingTaskForPPO(BaseModel, TrainingData): + """Training data for proximal policy optimization + + Args: + text: The TextField to use for training. + + Examples: + >>> from argilla import TrainingTaskForPPO + >>> dataset = rg.FeedbackDataset.from_argilla(name="...") + >>> task = TrainingTaskForPPO(text=dataset.fields[0],) + >>> dataset.prepare_for_training(framework="...", task=task) + """ + + _formatting_func_return_types = TrainingTaskForPPOFormat + formatting_func: Callable[[Dict[str, Any]], Union[None, str, Iterator[str]]] + + def _format_data(self, dataset: "FeedbackDataset") -> List[Dict[str, str]]: + formatted_texts = set() + for sample in dataset.format_as("datasets"): + if texts := self.formatting_func(sample): + if texts is None: + continue + + self._test_output_formatting_func(texts) + + if isinstance(texts, str): + texts = {texts} + formatted_texts |= set(texts) + return [{"query": text} for text in formatted_texts] + + @property + def supported_frameworks(self): + names = ["trl"] + return [Framework(name) for name in names] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}\n\t formatting_func={self.formatting_func}" + + @requires_version("datasets>1.17.0") + def _prepare_for_training_with_trl( + self, data: List[dict], train_size: float, seed: int + ) -> Union["datasets.Dataset", "datasets.DatasetDict"]: + import datasets + + datasets_dict = {"id": [], "query": []} + for index, entry in enumerate(data): + datasets_dict["id"].append(index) + datasets_dict["query"].append(entry["query"]) + + feature_dict = { + "id": datasets.Value(dtype="int32"), + "query": datasets.Value("string"), + } + + ds = datasets.Dataset.from_dict(datasets_dict, features=datasets.Features(feature_dict)) + + if train_size != 1: + ds = ds.train_test_split(train_size=train_size, test_size=1 - train_size, seed=seed) + + return ds + + +class TrainingTaskForDPOFormat(BaseModel): + """ + Union[Tuple[str, str, str], List[Tuple[str, str, str]]] + """ + + format: Union[Tuple[str, str, str], List[Tuple[str, str, str]]] + + +class TrainingTaskForDPO(BaseModel, TrainingData): + """Training data for direct preference optimization + + Args: + formatting_func: A formatting function converting a dictionary of records into zero, + one or more prompt-chosen-rejected text tuples. + + Examples: + >>> from argilla import TrainingTaskForDPO + >>> dataset = rg.FeedbackDataset.from_argilla(name="...") + >>> def formatting_func(sample: Dict[str, Any]): + ... values = [annotation["value"] for annotation in sample["ranking"]] + ... if values.count("1") >= values.count("2"): + ... chosen = sample["response-1"] + ... rejected = sample["response-2"] + ... else: + ... chosen = sample["response-2"] + ... rejected = sample["response-1"] + ... return sample["prompt"], chosen, rejected + >>> task = TrainingTaskForDPO(formatting_func=formatting_func) + >>> dataset.prepare_for_training(framework="...", task=task) + """ + + _formatting_func_return_types = TrainingTaskForDPOFormat + formatting_func: Callable[[Dict[str, Any]], Union[None, Tuple[str, str, str], Iterator[Tuple[str, str, str]]]] + + def _format_data(self, dataset: "FeedbackDataset") -> List[Dict[str, str]]: + output = set() + for sample in dataset.format_as("datasets"): + prompt_chosen_rejecteds = self.formatting_func(sample) + if prompt_chosen_rejecteds is None: + continue + + self._test_output_formatting_func(prompt_chosen_rejecteds) + + if isinstance(prompt_chosen_rejecteds, tuple): + prompt_chosen_rejecteds = {prompt_chosen_rejecteds} + + output |= set(prompt_chosen_rejecteds) + return [{"prompt": prompt, "chosen": chosen, "rejected": rejected} for prompt, chosen, rejected in output] + + @property + def supported_frameworks(self): + names = ["trl"] + return [Framework(name) for name in names] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}\n\t formatting_func={self.formatting_func}" + + @requires_version("datasets>1.17.0") + def _prepare_for_training_with_trl( + self, data: List[dict], train_size: float, seed: int + ) -> Union["datasets.Dataset", "datasets.DatasetDict"]: + import datasets + + datasets_dict = {"prompt": [], "chosen": [], "rejected": []} + for sample in data: + datasets_dict["prompt"].append(sample["prompt"]) + datasets_dict["chosen"].append(sample["chosen"]) + datasets_dict["rejected"].append(sample["rejected"]) + + feature_dict = { + "prompt": datasets.Value("string"), + "rejected": datasets.Value("string"), + "chosen": datasets.Value("string"), + } + + ds = datasets.Dataset.from_dict(datasets_dict, features=datasets.Features(feature_dict)) + if train_size != 1: + ds = ds.train_test_split(train_size=train_size, test_size=1 - train_size, seed=seed) + + return ds + + +TrainingTaskTypes = Union[ + TrainingTaskForTextClassification, + TrainingTaskForSFT, + TrainingTaskForRM, + TrainingTaskForPPO, + TrainingTaskForDPO, +] + + +# Old, deprecated variants. +class RenamedDeprecationMixin: + @classmethod + def warn(cls) -> None: + this_class_name = cls.__name__ + first_subclass_name = cls.__mro__[1].__name__ + warnings.warn( + (f"`{this_class_name}` has been renamed to `{first_subclass_name}`, please use the latter."), + DeprecationWarning, + stacklevel=3, + ) + + +class TrainingTaskMapping(TrainingTask, RenamedDeprecationMixin): + @classmethod + def for_text_classification(cls, *args, **kwargs) -> TrainingTaskForTextClassification: + cls.warn() + return super().for_text_classification(*args, **kwargs) + + @classmethod + def for_supervised_fine_tuning(cls, *args, **kwargs) -> TrainingTaskForSFT: + cls.warn() + return super().for_supervised_fine_tuning(*args, **kwargs) + + @classmethod + def for_reward_modeling(cls, *args, **kwargs) -> TrainingTaskForRM: + cls.warn() + return super().for_reward_modeling(*args, **kwargs) + + @classmethod + def for_proximal_policy_optimization(cls, *args, **kwargs) -> TrainingTaskForPPO: + cls.warn() + return super().for_proximal_policy_optimization(cls, *args, **kwargs) + + @classmethod + def for_direct_preference_optimization(cls, *args, **kwargs) -> TrainingTaskForDPO: + cls.warn() + return super().for_direct_preference_optimization(*args, **kwargs) + + +class TrainingTaskMappingForTextClassification(TrainingTaskForTextClassification, RenamedDeprecationMixin): + def __init__(self, *args, **kwargs) -> None: + self.warn() + return super().__init__(*args, **kwargs) diff --git a/src/argilla/client/models.py b/src/argilla/client/models.py index a20d43ae37..003850e6b4 100644 --- a/src/argilla/client/models.py +++ b/src/argilla/client/models.py @@ -48,6 +48,8 @@ class Framework(Enum): span_marker: SpanMarker Tom Aarsen library spark-nlp: Spark NLP John Snow Labs library openai: OpenAI LLMs + trl: Transformer Reinforcement Learning + trlx: Transformer Reinforcement Learning X """ TRANSFORMERS = "transformers" @@ -58,6 +60,8 @@ class Framework(Enum): SPAN_MARKER = "span_marker" SPARK_NLP = "spark-nlp" OPENAI = "openai" + TRL = "trl" + TRLX = "trlx" # AUTOTRAIN = "autotrain" @classmethod diff --git a/src/argilla/feedback/__init__.py b/src/argilla/feedback/__init__.py index a43259b4b7..57f967604b 100644 --- a/src/argilla/feedback/__init__.py +++ b/src/argilla/feedback/__init__.py @@ -15,9 +15,15 @@ from argilla.client.feedback import ( ArgillaTrainer, LabelQuestionStrategy, + LabelQuestionUnification, MultiLabelQuestionStrategy, + MultiLabelQuestionUnification, + RankingQuestionStrategy, + RankingQuestionUnification, RatingQuestionStrategy, - TrainingTaskMapping, + RatingQuestionUnification, + TrainingTask, + TrainingTaskMapping, # <- Deprecated ) from argilla.client.feedback.dataset import FeedbackDataset from argilla.client.feedback.schemas import ( @@ -38,6 +44,12 @@ "LabelQuestionStrategy", "MultiLabelQuestionStrategy", "RatingQuestionStrategy", + "RankingQuestionStrategy", + "LabelQuestionUnification", + "MultiLabelQuestionUnification", + "RatingQuestionUnification", + "RankingQuestionUnification", + "TrainingTask", "TrainingTaskMapping", "FeedbackDataset", "FeedbackRecord", diff --git a/src/argilla/training/base.py b/src/argilla/training/base.py index f501bb7122..5ef1a2d2d0 100644 --- a/src/argilla/training/base.py +++ b/src/argilla/training/base.py @@ -281,7 +281,7 @@ def predict(self, text: Union[List[str], str], as_argilla_records: bool = True, """ return self._trainer.predict(text=text, as_argilla_records=as_argilla_records, **kwargs) - def train(self, output_dir: str = None): + def train(self, output_dir: str): """ `train` takes in a path to a file and trains the model. If a path is provided, the model is saved to that path. diff --git a/src/argilla/training/transformers.py b/src/argilla/training/transformers.py index 71fa324bec..77ed5e1306 100644 --- a/src/argilla/training/transformers.py +++ b/src/argilla/training/transformers.py @@ -326,14 +326,7 @@ def compute_metrics(p): return func def train(self, output_dir: str): - """ - We create a SetFitModel object from a pretrained model, then create a SetFitTrainer object with - the model, and then train the model - """ - from transformers import ( - Trainer, - TrainingArguments, - ) + from transformers import Trainer, TrainingArguments # check required path argument self.trainer_kwargs["output_dir"] = output_dir diff --git a/tests/integration/client/feedback/test_dataset.py b/tests/integration/client/feedback/test_dataset.py index 4d93828e3e..3fc68f6564 100644 --- a/tests/integration/client/feedback/test_dataset.py +++ b/tests/integration/client/feedback/test_dataset.py @@ -27,7 +27,7 @@ TextField, TextQuestion, ) -from argilla.client.feedback.training.schemas import TrainingTaskMapping +from argilla.client.feedback.training.schemas import TrainingTask from argilla.client.models import Framework if TYPE_CHECKING: @@ -690,6 +690,6 @@ def test_prepare_for_training_text_classification( ) dataset.add_records(feedback_dataset_records) label = dataset.question_by_name(question) - task_mapping = TrainingTaskMapping.for_text_classification(text=dataset.fields[0], label=label) + task = TrainingTask.for_text_classification(text=dataset.fields[0], label=label) - dataset.prepare_for_training(framework=framework, task_mapping=task_mapping, fetch_records=False) + dataset.prepare_for_training(framework=framework, task=task) diff --git a/tests/integration/client/feedback/training/test_trainer.py b/tests/integration/client/feedback/training/test_trainer.py index 2bf876ed5a..21657d1948 100644 --- a/tests/integration/client/feedback/training/test_trainer.py +++ b/tests/integration/client/feedback/training/test_trainer.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Union +import random +from collections import Counter +from typing import TYPE_CHECKING, Callable, List, Union import pytest if TYPE_CHECKING: from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes +import re import shutil import sys from pathlib import Path @@ -31,8 +34,12 @@ ) from argilla.client.feedback.training import ArgillaTrainer from argilla.client.feedback.training.schemas import ( + TrainingTask, + TrainingTaskForTextClassification, + TrainingTaskForTextClassificationFormat, TrainingTaskMapping, TrainingTaskMappingForTextClassification, + TrainingTaskTypes, ) from argilla.client.feedback.unification import LabelQuestionUnification from argilla.client.models import Framework @@ -59,7 +66,7 @@ "feedback_dataset_questions", "feedback_dataset_records", ) -def test_prepare_for_training_text_classification( +def test_prepare_for_training_text_classification_with_defaults( framework: Union[Framework, str], feedback_dataset_guidelines: str, feedback_dataset_fields: List["AllowedFieldTypes"], @@ -77,28 +84,22 @@ def test_prepare_for_training_text_classification( question for question in dataset.questions if isinstance(question, (LabelQuestion, MultiLabelQuestion)) ] label = LabelQuestionUnification(question=questions[0]) - task_mapping = TrainingTaskMapping.for_text_classification(text=dataset.fields[0], label=label) + task = TrainingTask.for_text_classification(text=dataset.fields[0], label=label) if framework == Framework("span_marker"): with pytest.raises( NotImplementedError, - match=f"Framework {framework} is not supported for this {TrainingTaskMappingForTextClassification}.", + match=f"Framework {framework} is not supported for this {TrainingTaskForTextClassification}.", ): - trainer = ArgillaTrainer( - dataset=dataset, task_mapping=task_mapping, framework=framework, fetch_records=False - ) + trainer = ArgillaTrainer(dataset=dataset, task=task, framework=framework) elif framework == Framework("spark-nlp"): with pytest.raises(NotImplementedError, match=f"{framework} is not a valid framework."): - trainer = ArgillaTrainer( - dataset=dataset, task_mapping=task_mapping, framework=framework, fetch_records=False - ) + trainer = ArgillaTrainer(dataset=dataset, task=task, framework=framework) else: if framework in [Framework("peft")] and sys.version_info < (3, 9): pass else: - trainer = ArgillaTrainer( - dataset=dataset, task_mapping=task_mapping, framework=framework, fetch_records=False - ) + trainer = ArgillaTrainer(dataset=dataset, task=task, framework=framework) if framework in [Framework("spacy"), Framework("spacy-transformers")]: trainer.update_config(max_steps=1) elif framework in [Framework("transformers"), Framework("setfit")]: @@ -107,3 +108,90 @@ def test_prepare_for_training_text_classification( if Path(__OUTPUT_DIR__).exists(): shutil.rmtree(__OUTPUT_DIR__) + + +@pytest.mark.usefixtures( + "feedback_dataset_guidelines", + "feedback_dataset_fields", + "feedback_dataset_questions", + "feedback_dataset_records", +) +def test_prepare_for_training_text_classification_with_formatting_func( + feedback_dataset_guidelines: str, + feedback_dataset_fields: List["AllowedFieldTypes"], + feedback_dataset_questions: List["AllowedQuestionTypes"], + feedback_dataset_records: List[FeedbackRecord], +): + dataset = FeedbackDataset( + guidelines=feedback_dataset_guidelines, + fields=feedback_dataset_fields, + questions=feedback_dataset_questions, + ) + dataset.add_records(records=feedback_dataset_records * 5) + framework = Framework("setfit") + + def wrong_formatting_func(sample): + text = sample["text"] + values = [resp["value"] for resp in sample["question-3"]] + counter = Counter(values) + if counter: + most_common = counter.most_common() + max_frequency = most_common[0][1] + most_common_elements = [element for element, frequency in most_common if frequency == max_frequency] + label = random.choice(most_common_elements) + return {"text": text, "label": label} + else: + return None + + def correct_formatting_func(sample): + data = wrong_formatting_func(sample) + if data: + return (data["text"], data["label"]) + else: + return None + + with pytest.raises( + ValueError, + match=re.escape( + f"formatting_func must return {TrainingTaskForTextClassificationFormat.__annotations__['format']}, not " + ), + ): + task = TrainingTask.for_text_classification(wrong_formatting_func) + trainer = ArgillaTrainer(dataset=dataset, task=task, framework=framework) + trainer.update_config(num_iterations=1) + trainer.train(__OUTPUT_DIR__) + + task = TrainingTask.for_text_classification(correct_formatting_func) + trainer = ArgillaTrainer(dataset=dataset, task=task, framework=framework) + trainer.update_config(num_iterations=1) + trainer.train(__OUTPUT_DIR__) + + +@pytest.mark.parametrize( + "callable", + ( + lambda: TrainingTaskMapping.for_text_classification(None, None), + lambda: TrainingTaskMapping.for_direct_preference_optimization(None), + lambda: TrainingTaskMapping.for_reward_modeling(None), + lambda: TrainingTaskMapping.for_supervised_fine_tuning(None), + ), +) +def test_deprecations(callable: Callable[[], TrainingTaskTypes]) -> None: + with pytest.warns(DeprecationWarning, match="`TrainingTaskMapping` has been renamed to `TrainingTask`"): + # This'll crash because we're passing None, but we only test the warning + try: + callable() + except Exception: + pass + + +def test_deprecations_for_text_classification(): + with pytest.warns( + DeprecationWarning, + match="`TrainingTaskMappingForTextClassification` has been renamed to `TrainingTaskForTextClassification`", + ): + # This'll crash because we're passing None, but we only test the warning + try: + TrainingTaskMappingForTextClassification(None) + except Exception: + pass diff --git a/tests/integration/client/feedback/training/test_trl.py b/tests/integration/client/feedback/training/test_trl.py new file mode 100644 index 0000000000..5e4bf2b3c2 --- /dev/null +++ b/tests/integration/client/feedback/training/test_trl.py @@ -0,0 +1,258 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from collections import Counter +from typing import TYPE_CHECKING, Any, Dict, Iterator, List + +import pytest +from argilla.client.feedback.dataset import FeedbackDataset +from argilla.client.feedback.schemas.records import FeedbackRecord +from argilla.client.feedback.training.base import ArgillaTrainer +from argilla.client.feedback.training.schemas import ( + TrainingTask, + TrainingTaskForDPOFormat, + TrainingTaskForPPOFormat, + TrainingTaskForRMFormat, + TrainingTaskForSFTFormat, +) +from datasets import Dataset, DatasetDict + +from tests.integration.training.helpers import train_with_cleanup + +if TYPE_CHECKING: + from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes + +__OUTPUT_DIR__ = "tmp" +__FRAMWORK__ = "trl" + + +def try_wrong_format(dataset, task, format_func: Any) -> None: + task = task(lambda _: {"test": "test"}) + with pytest.raises( + ValueError, + match=re.escape(f"formatting_func must return {format_func.__annotations__['format']}, not "), + ): + trainer = ArgillaTrainer(dataset=dataset, task=task, framework=__FRAMWORK__) + trainer.train(__OUTPUT_DIR__) + + +def test_prepare_for_training_sft( + feedback_dataset_guidelines: str, + feedback_dataset_fields: List["AllowedFieldTypes"], + feedback_dataset_questions: List["AllowedQuestionTypes"], + feedback_dataset_records: List[FeedbackRecord], +) -> None: + dataset = FeedbackDataset( + guidelines=feedback_dataset_guidelines, + fields=feedback_dataset_fields, + questions=feedback_dataset_questions, + ) + dataset.add_records(records=feedback_dataset_records * 2) + + def formatting_func(sample: Dict[str, Any]) -> Iterator[str]: + # For example, the sample must be most frequently rated as "1" in question-2 and + # label "b" from "question-3" must have not been set by any annotator + ratings = [ + annotation["value"] + for annotation in sample["question-2"] + if annotation["status"] == "submitted" and annotation["value"] is not None + ] + labels = [ + annotation["value"] + for annotation in sample["question-3"] + if annotation["status"] == "submitted" and annotation["value"] is not None + ] + if ratings and Counter(ratings).most_common(1)[0][0] == 1 and "b" not in labels: + return f"### Text\n{sample['text']}" + return None + + try_wrong_format( + dataset=dataset, task=TrainingTask.for_supervised_fine_tuning, format_func=TrainingTaskForSFTFormat + ) + + task = TrainingTask.for_supervised_fine_tuning(formatting_func) + train_dataset = dataset.prepare_for_training(framework=__FRAMWORK__, task=task) + assert isinstance(train_dataset, Dataset) + assert len(train_dataset) == 2 + train_dataset_dict = dataset.prepare_for_training(framework=__FRAMWORK__, task=task, train_size=0.5) + assert isinstance(train_dataset_dict, DatasetDict) + assert tuple(train_dataset_dict.keys()) == ("train", "test") + assert len(train_dataset_dict["train"]) == 1 + + trainer = ArgillaTrainer(dataset, task, framework=__FRAMWORK__, model="sshleifer/tiny-gpt2") + trainer.update_config(max_steps=3) + assert trainer._trainer.training_args_kwargs["max_steps"] == 3 + trainer.update_config(max_steps=1) + assert trainer._trainer.training_args_kwargs["max_steps"] == 1 + train_with_cleanup(trainer, __OUTPUT_DIR__) + + eval_trainer = ArgillaTrainer(dataset, task, framework=__FRAMWORK__, model="sshleifer/tiny-gpt2", train_size=0.5) + eval_trainer.update_config(max_steps=1) + train_with_cleanup(eval_trainer, __OUTPUT_DIR__) + + +def test_prepare_for_training_rm( + feedback_dataset_guidelines: str, + feedback_dataset_fields: List["AllowedFieldTypes"], + feedback_dataset_questions: List["AllowedQuestionTypes"], + feedback_dataset_records: List[FeedbackRecord], +) -> None: + dataset = FeedbackDataset( + guidelines=feedback_dataset_guidelines, + fields=feedback_dataset_fields, + questions=feedback_dataset_questions, + ) + dataset.add_records(records=feedback_dataset_records * 2) + + def formatting_func(sample: Dict[str, Any]): + # The FeedbackDataset isn't really set up for RM, so we'll just use an arbitrary example here + labels = [ + annotation["value"] + for annotation in sample["question-3"] + if annotation["status"] == "submitted" and annotation["value"] is not None + ] + if labels: + # Three cases for the tests: None, one tuple and yielding multiple tuples + if labels[0] == "a": + return None + elif labels[0] == "b": + return sample["text"], sample["text"][:5] + elif labels[0] == "c": + return [(sample["text"], sample["text"][5:10]), (sample["text"], sample["text"][:5])] + + try_wrong_format(dataset=dataset, task=TrainingTask.for_reward_modeling, format_func=TrainingTaskForRMFormat) + + task = TrainingTask.for_reward_modeling(formatting_func) + train_dataset = dataset.prepare_for_training(framework=__FRAMWORK__, task=task) + assert isinstance(train_dataset, Dataset) + assert len(train_dataset) == 2 + train_dataset_dict = dataset.prepare_for_training(framework=__FRAMWORK__, task=task, train_size=0.5) + assert isinstance(train_dataset_dict, DatasetDict) + assert tuple(train_dataset_dict.keys()) == ("train", "test") + assert len(train_dataset_dict["train"]) == 1 + + trainer = ArgillaTrainer(dataset, task, framework=__FRAMWORK__, model="sshleifer/tiny-gpt2") + trainer.update_config(max_steps=3) + assert trainer._trainer.training_args_kwargs["max_steps"] == 3 + trainer.update_config(max_steps=1) + assert trainer._trainer.training_args_kwargs["max_steps"] == 1 + train_with_cleanup(trainer, __OUTPUT_DIR__) + + eval_trainer = ArgillaTrainer(dataset, task, framework=__FRAMWORK__, model="sshleifer/tiny-gpt2", train_size=0.5) + eval_trainer.update_config(max_steps=1) + train_with_cleanup(eval_trainer, __OUTPUT_DIR__) + + +def test_prepare_for_training_ppo( + feedback_dataset_guidelines: str, + feedback_dataset_fields: List["AllowedFieldTypes"], + feedback_dataset_questions: List["AllowedQuestionTypes"], + feedback_dataset_records: List[FeedbackRecord], +) -> None: + from transformers import pipeline + from trl import PPOConfig + + reward_model = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb") + dataset = FeedbackDataset( + guidelines=feedback_dataset_guidelines, + fields=feedback_dataset_fields, + questions=feedback_dataset_questions, + ) + dataset.add_records(records=feedback_dataset_records * 2) + + def formatting_func(sample: Dict[str, Any]): + return sample["text"] + + try_wrong_format( + dataset=dataset, task=TrainingTask.for_proximal_policy_optimization, format_func=TrainingTaskForPPOFormat + ) + + task = TrainingTask.for_proximal_policy_optimization(formatting_func=formatting_func) + train_dataset = dataset.prepare_for_training(framework=__FRAMWORK__, task=task) + assert isinstance(train_dataset, Dataset) + assert len(train_dataset) == 2 + train_dataset_dict = dataset.prepare_for_training(framework=__FRAMWORK__, task=task, train_size=0.5) + assert isinstance(train_dataset_dict, DatasetDict) + assert tuple(train_dataset_dict.keys()) == ("train", "test") + assert len(train_dataset_dict["train"]) == 1 + + trainer = ArgillaTrainer(dataset, task, framework=__FRAMWORK__, model="sshleifer/tiny-gpt2") + trainer.update_config(config=PPOConfig(batch_size=1, ppo_epochs=1), reward_model=reward_model) + assert trainer._trainer.trainer_kwargs["config"].batch_size == 1 + trainer.update_config(generation_kwargs={"top_k": 0.0, "top_p": 1.0, "do_sample": True}) + assert trainer._trainer.training_args_kwargs["generation_kwargs"]["top_p"] == 1.0 + train_with_cleanup(trainer, __OUTPUT_DIR__) + + eval_trainer = ArgillaTrainer(dataset, task, framework=__FRAMWORK__, model="sshleifer/tiny-gpt2", train_size=0.5) + eval_trainer.update_config(config=PPOConfig(batch_size=1, ppo_epochs=1), reward_model=reward_model) + eval_trainer.update_config(max_steps=1) + train_with_cleanup(eval_trainer, __OUTPUT_DIR__) + + +def test_prepare_for_training_dpo( + feedback_dataset_guidelines: str, + feedback_dataset_fields: List["AllowedFieldTypes"], + feedback_dataset_questions: List["AllowedQuestionTypes"], + feedback_dataset_records: List[FeedbackRecord], +) -> None: + dataset = FeedbackDataset( + guidelines=feedback_dataset_guidelines, + fields=feedback_dataset_fields, + questions=feedback_dataset_questions, + ) + dataset.add_records(records=feedback_dataset_records * 2) + + def formatting_func(sample: Dict[str, Any]): + # The FeedbackDataset isn't really set up for DPO, so we'll just use an arbitrary example here + labels = [ + annotation["value"] + for annotation in sample["question-3"] + if annotation["status"] == "submitted" and annotation["value"] is not None + ] + if labels: + # Three cases for the tests: None, one tuple and yielding multiple tuples + if labels[0] == "a": + return None + elif labels[0] == "b": + return sample["text"][::-1], sample["text"], sample["text"][:5] + elif labels[0] == "c": + return [ + (sample["text"], sample["text"][::-1], sample["text"][:5]), + (sample["text"][::-1], sample["text"], sample["text"][:5]), + ] + + try_wrong_format( + dataset=dataset, task=TrainingTask.for_direct_preference_optimization, format_func=TrainingTaskForDPOFormat + ) + + task = TrainingTask.for_direct_preference_optimization(formatting_func) + train_dataset = dataset.prepare_for_training(framework=__FRAMWORK__, task=task) + assert isinstance(train_dataset, Dataset) + assert len(train_dataset) == 2 + train_dataset_dict = dataset.prepare_for_training(framework=__FRAMWORK__, task=task, train_size=0.5) + assert isinstance(train_dataset_dict, DatasetDict) + assert tuple(train_dataset_dict.keys()) == ("train", "test") + assert len(train_dataset_dict["train"]) == 1 + + trainer = ArgillaTrainer(dataset, task, framework=__FRAMWORK__, model="sshleifer/tiny-gpt2") + trainer.update_config(max_steps=3) + assert trainer._trainer.training_args_kwargs["max_steps"] == 3 + trainer.update_config(max_steps=1) + assert trainer._trainer.training_args_kwargs["max_steps"] == 1 + train_with_cleanup(trainer, __OUTPUT_DIR__) + + eval_trainer = ArgillaTrainer(dataset, task, framework=__FRAMWORK__, model="sshleifer/tiny-gpt2", train_size=0.5) + eval_trainer.update_config(max_steps=1) + train_with_cleanup(eval_trainer, __OUTPUT_DIR__) diff --git a/tests/integration/training/test_autotrain.py b/tests/integration/training/test_autotrain.py deleted file mode 100644 index e4da70c872..0000000000 --- a/tests/integration/training/test_autotrain.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# import os - -# import pytest -# from argilla.training import ArgillaTrainer - -# FRAMEWORK = "autotrain" -# MODELS = ["prajjwal1/bert-tiny", "autotrain"] -# _HF_HUB_ACCESS_TOKEN = os.environ.get("HF_AUTH_TOKEN") or os.environ.get("HF_HUB_ACCESS_TOKEN") - - -# @pytest.mark.skipif( -# _HF_HUB_ACCESS_TOKEN is None, -# reason="You need a HF Hub access token to test the push_to_hub feature", -# ) -# @pytest.mark.parametrize("model", MODELS) -# def test_update_config(dataset_text_classification, model): -# trainer = ArgillaTrainer( -# name=dataset_text_classification, model=model, train_size=0.8, limit=10, framework=FRAMEWORK -# ) -# trainer.update_config(autotrain=[{"num_models": 1}]) -# assert trainer._trainer.trainer_kwargs["autotrain"][0]["num_models"] == 1 -# trainer.update_config(hub_model=[{"epochs": 1}]) -# assert trainer._trainer.trainer_kwargs["hub_model"][0]["epochs"] == 1 -# trainer.train() - - -# @pytest.mark.skipif( -# _HF_HUB_ACCESS_TOKEN is None, -# reason="You need a HF Hub access token to test the push_to_hub feature", -# ) -# @pytest.mark.parametrize("model", MODELS) -# def test_passed_functions(dataset_text_classification, model): -# trainer = ArgillaTrainer(name=dataset_text_classification, model=model, limit=10, framework=FRAMEWORK) -# trainer._trainer.init_model() -# trainer._trainer.init_pipeline() -# trainer._trainer.predict("useless") -# trainer._trainer.save(output_dir="useless") - - -# @pytest.mark.skipif( -# _HF_HUB_ACCESS_TOKEN is None, -# reason="You need a HF Hub access token to test the push_to_hub feature", -# ) -# def test_autotrain_train_multi_label(dataset_text_classification_multi_label): -# with pytest.raises(NotImplementedError): -# ArgillaTrainer(name=dataset_text_classification_multi_label, model=MODELS[0], limit=10, framework=FRAMEWORK) - - -# @pytest.mark.skipif( -# _HF_HUB_ACCESS_TOKEN is None, -# reason="You need a HF Hub access token to test the push_to_hub feature", -# ) -# def test_autotrain_train_token(dataset_token_classification): -# with pytest.raises(NotImplementedError): -# ArgillaTrainer(name=dataset_token_classification, model=MODELS[0], limit=10, framework=FRAMEWORK) - - -# @pytest.mark.skipif( -# _HF_HUB_ACCESS_TOKEN is None, -# reason="You need a HF Hub access token to test the push_to_hub feature", -# ) -# def test_autotrain_train_text2text(dataset_text2text): -# with pytest.raises(NotImplementedError): -# ArgillaTrainer(name=dataset_text2text, model=MODELS[0], limit=10, framework=FRAMEWORK) diff --git a/tests/unit/client/feedback/training/test_schemas.py b/tests/unit/client/feedback/training/test_schemas.py index e5486825e2..eda9ef4f7a 100644 --- a/tests/unit/client/feedback/training/test_schemas.py +++ b/tests/unit/client/feedback/training/test_schemas.py @@ -22,7 +22,7 @@ RatingQuestion, TextField, ) -from argilla.client.feedback.training.schemas import TrainingTaskMapping +from argilla.client.feedback.training.schemas import TrainingTask from argilla.client.feedback.unification import ( LabelQuestionUnification, MultiLabelQuestionUnification, @@ -307,7 +307,7 @@ ), ], ) -def test_task_mapping_for_text_classification( +def test_task_for_text_classification( framework, label, train_size, @@ -327,19 +327,19 @@ def test_task_mapping_for_text_classification( label = MultiLabelQuestionUnification(question=MultiLabelQuestion(**label_question_payload)) data = [{"text": "This is a text", "label": "1"}, {"text": "This is a text", "label": "2"}] field = TextField(name="text") - task_mapping = TrainingTaskMapping.for_text_classification(text=field, label=label) + task = TrainingTask.for_text_classification(text=field, label=label) if framework == Framework.SPACY or framework == Framework.SPACY_TRANSFORMERS: - data = task_mapping._prepare_for_training_with_spacy( + data = task._prepare_for_training_with_spacy( data=data, train_size=train_size, seed=seed, lang=spacy.blank("en") ) elif framework == Framework.OPENAI: - data = task_mapping._prepare_for_training_with_openai(data=data, train_size=train_size, seed=seed) + data = task._prepare_for_training_with_openai(data=data, train_size=train_size, seed=seed) elif framework == Framework.TRANSFORMERS: - data = task_mapping._prepare_for_training_with_transformers( + data = task._prepare_for_training_with_transformers( data=data, train_size=train_size, seed=seed, framework=Framework.TRANSFORMERS ) elif framework == Framework.SPARK_NLP: - data = task_mapping._prepare_for_training_with_spark_nlp(data=data, train_size=train_size, seed=seed) + data = task._prepare_for_training_with_spark_nlp(data=data, train_size=train_size, seed=seed) else: raise ValueError(f"Framework {framework} not supported") if isinstance(data, tuple): @@ -352,5 +352,5 @@ def test_task_mapping_for_text_classification( def test_training_task_repr(label_question_payload): field = TextField(name="text") label = LabelQuestion(**label_question_payload) - task_mapping = TrainingTaskMapping.for_text_classification(text=field, label=label) + task_mapping = TrainingTask.for_text_classification(text=field, label=label) assert isinstance(repr(task_mapping), str)