From 79fd5fb63fd46357d0af842682969703e6b7d2ae Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Tue, 17 Oct 2023 14:23:24 +0200 Subject: [PATCH 01/12] chore: Apply changes for merge from develop property --- CHANGELOG.md | 4 + docker/scripts/load_data.py | 2 +- docs/_source/practical_guides/fine_tune.md | 35 + src/argilla/cli/datasets/__main__.py | 2 +- src/argilla/cli/datasets/list.py | 2 +- src/argilla/cli/datasets/push.py | 2 +- src/argilla/client/apis/datasets.py | 18 +- .../client/feedback/dataset/__init__.py | 2 +- src/argilla/client/feedback/dataset/base.py | 163 +---- .../dataset/{local.py => local/dataset.py} | 278 ++++++-- .../feedback/dataset/{ => local}/mixins.py | 54 +- .../client/feedback/dataset/remote/dataset.py | 113 ++- .../integrations/huggingface/card/__init__.py | 5 +- .../huggingface/model_card/__init__.py | 39 + .../model_card/argilla_model_template.md | 170 +++++ .../huggingface/model_card/model_card.py | 664 ++++++++++++++++++ src/argilla/client/feedback/training/base.py | 57 +- .../feedback/training/frameworks/openai.py | 23 + .../feedback/training/frameworks/peft.py | 23 + .../frameworks/sentence_transformers.py | 22 + .../feedback/training/frameworks/setfit.py | 23 + .../feedback/training/frameworks/spacy.py | 54 +- .../training/frameworks/span_marker.py | 10 + .../training/frameworks/transformers.py | 31 + .../feedback/training/frameworks/trl.py | 32 + .../client/feedback/training/schemas.py | 17 +- src/argilla/client/feedback/utils.py | 2 +- src/argilla/client/models.py | 15 + src/argilla/client/sdk/datasets/api.py | 6 +- src/argilla/datasets/__init__.py | 10 +- tests/integration/client/feedback/conftest.py | 556 +++++++++++++++ .../dataset/{ => local}/test_dataset.py | 47 +- .../feedback/dataset/remote/test_dataset.py | 25 + .../dataset/remote/test_filter_and_sorting.py | 2 +- .../huggingface/test_model_card.py | 402 +++++++++++ tests/integration/test_datasets_settings.py | 69 +- tests/unit/cli/datasets/test_delete.py | 4 +- tests/unit/cli/datasets/test_list.py | 8 +- tests/unit/cli/datasets/test_push.py | 4 +- .../unit/client/feedback/dataset/test_base.py | 19 + .../client/feedback/dataset/test_local.py | 8 +- .../integrations/huggingface/__init__.py | 13 + .../huggingface/test_model_card.py | 61 ++ tests/unit/server/api/v1/test_records.py | 2 +- 44 files changed, 2827 insertions(+), 271 deletions(-) rename src/argilla/client/feedback/dataset/{local.py => local/dataset.py} (53%) rename src/argilla/client/feedback/dataset/{ => local}/mixins.py (90%) create mode 100644 src/argilla/client/feedback/integrations/huggingface/model_card/__init__.py create mode 100644 src/argilla/client/feedback/integrations/huggingface/model_card/argilla_model_template.md create mode 100644 src/argilla/client/feedback/integrations/huggingface/model_card/model_card.py rename tests/integration/client/feedback/dataset/{ => local}/test_dataset.py (94%) create mode 100644 tests/integration/client/feedback/integrations/huggingface/test_model_card.py create mode 100644 tests/unit/client/feedback/integrations/huggingface/__init__.py create mode 100644 tests/unit/client/feedback/integrations/huggingface/test_model_card.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2791f2c024..d62a7ef7d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ These are the section headers that we use: - Add support for update records (`metadata` and `external_id`) from Python SDK ([#3946](https://github.com/argilla-io/argilla/pull/3946)). - Added `delete_metadata_properties` and `Remote{Terms,Integer,Float}MetadataProperty.delete` methods to delete metadata properties ([#3932](https://github.com/argilla-io/argilla/pull/3932)). - New `PATCH /api/v1/metadata-properties/:metadata_property_id` endpoint allowing the update of a specific metadata property. ([#3952](https://github.com/argilla-io/argilla/pull/3952)). +- Added automatic model card generation through `ArgillaTrainer.save` ([#3857](https://github.com/argilla-io/argilla/pull/3857)). ### Changed @@ -42,11 +43,14 @@ These are the section headers that we use: - Updated `FilteredRemoteFeedbackRecords.__len__` method to return the number of records matching the provided filters ([#3916](https://github.com/argilla-io/argilla/pull/3916)). - Increase the default max result window for Elasticsearch created for Feedback datasets ([#3929](https://github.com/argilla-io/argilla/pull/)). - Force elastic index refresh after records creation ([#3929](https://github.com/argilla-io/argilla/pull/)). +- FeedbackDataset API methods have been aligned to be accessible through the several implementations ([#3937](https://github.com/argilla-io/argilla/pull/3937)). +- The `unify_responses` support for remote datasets ([#3937](https://github.com/argilla-io/argilla/pull/3937)). ### Fixed - Updated active learning for text classification notebooks to pass ids of type int to `TextClassificationRecord` ([#3831](https://github.com/argilla-io/argilla/pull/3831)). - Fixed record fields validation that was preventing from logging records with optional fields (i.e. `required=True`) when the field value was `None` ([#3846](https://github.com/argilla-io/argilla/pull/3846)). +- Fixed `configure_dataset_settings` when providing the workspace via the arg `workspace` ([#3887](https://github.com/argilla-io/argilla/pull/3887)). - Fixed response schemas to allow `values` to be `None` i.e. when a record is discarded the `response.values` are set to `None` ([#3926](https://github.com/argilla-io/argilla/pull/3926)). - The `inserted_at` and `updated_at` attributes are create using the `utcnow` factory to avoid unexpected race conditions on timestamp creation ([#3945](https://github.com/argilla-io/argilla/pull/3945)) - Fixed saving of models trained with `ArgillaTrainer` with a `peft_config` parameter ([#3795](https://github.com/argilla-io/argilla/pull/3795)). diff --git a/docker/scripts/load_data.py b/docker/scripts/load_data.py index d3980cef5c..a169f269cf 100644 --- a/docker/scripts/load_data.py +++ b/docker/scripts/load_data.py @@ -271,7 +271,7 @@ def load_error_analysis_textcat_version(): load_datasets = LOAD_DATASETS.lower().strip() while True: try: - response = requests.get("http://0.0.0.0:6900/") + response = requests.get("http://0.0.0.0:6900") if response.status_code == 200: ld = LoadDatasets(API_KEY) ld.load_error_analysis(with_metadata_property_options=False) diff --git a/docs/_source/practical_guides/fine_tune.md b/docs/_source/practical_guides/fine_tune.md index d82ceed945..886a856722 100644 --- a/docs/_source/practical_guides/fine_tune.md +++ b/docs/_source/practical_guides/fine_tune.md @@ -82,6 +82,41 @@ A `TrainingTask` is used to define how the data should be processed and formatte | for_direct_preference_optimization | `prompt-chosen-rejected` | `Union[Tuple[str, str, str], Iterator[Tuple[str, str, str]]]` | ✗ | | for_chat_completion | `chat-turn-role-content` | `Union[Tuple[str, str, str, str], Iterator[Tuple[str, str, str, str]]]`| ✗ | +#### Model card generation + +The `ArgillaTrainer` automatically generates a [model card](https://huggingface.co/docs/hub/model-cards) when saving the model. After calling `trainer.train(output_dir="my_model")`, you should see the model card under the same output dir you passed through the train method: `./my_model/README.md`. Most of the fields in the card are automatically generated when possible, but the following fields can be (optionally) updated via the `framework_kwargs` variable of the `ArgillaTrainer` like so: + +```python +model_card_kwargs = { + "language": ["en", "es"], + "license": "Apache-2.0", + "model_id": "all-MiniLM-L6-v2", + "dataset_name": "argilla/emotion", + "tags": ["nlp", "few-shot-learning", "argilla", "setfit"], + "model_summary": "Small summary of what the model does", + "model_description": "An extended explanation of the model", + "model_type": "A 1.3B parameter embedding model fine-tuned on an awesome dataset", + "finetuned_from": "all-MiniLM-L6-v2", + "repo": "https://github.com/..." + "developers": "", + "shared_by": "", +} + +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="setfit", + framework_kwargs={"model_card_kwargs": model_card_kwargs} +) +trainer.train(output_dir="my_model") +``` + +Even though its generated internally, you can get the card by calling the `generate_model_card` method: + +```python +argilla_model_card = trainer.generate_model_card("my_model") +``` + ### Tasks #### Text Classification diff --git a/src/argilla/cli/datasets/__main__.py b/src/argilla/cli/datasets/__main__.py index e6b4b60086..c6c022609f 100644 --- a/src/argilla/cli/datasets/__main__.py +++ b/src/argilla/cli/datasets/__main__.py @@ -32,7 +32,7 @@ def callback( init_callback() from argilla.cli.rich import echo_in_panel - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset if ctx.invoked_subcommand not in _COMMANDS_REQUIRING_DATASET: return diff --git a/src/argilla/cli/datasets/list.py b/src/argilla/cli/datasets/list.py index 8b9b804138..88bd75c209 100644 --- a/src/argilla/cli/datasets/list.py +++ b/src/argilla/cli/datasets/list.py @@ -31,7 +31,7 @@ def list_datasets( from argilla.cli.rich import echo_in_panel, get_argilla_themed_table from argilla.client.api import list_datasets as list_datasets_api - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset from argilla.client.workspaces import Workspace console = Console() diff --git a/src/argilla/cli/datasets/push.py b/src/argilla/cli/datasets/push.py index 93a64fd54e..5b1c386a25 100644 --- a/src/argilla/cli/datasets/push.py +++ b/src/argilla/cli/datasets/push.py @@ -29,7 +29,7 @@ def push_to_huggingface( from rich.spinner import Spinner from argilla.cli.rich import echo_in_panel - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset dataset: "FeedbackDataset" = ctx.obj diff --git a/src/argilla/client/apis/datasets.py b/src/argilla/client/apis/datasets.py index cc58513654..93d8682ee9 100644 --- a/src/argilla/client/apis/datasets.py +++ b/src/argilla/client/apis/datasets.py @@ -131,8 +131,8 @@ class _DatasetApiModel(BaseModel): class _SettingsApiModel(BaseModel): label_schema: Dict[str, Any] - def find_by_name(self, name: str) -> _DatasetApiModel: - dataset = get_dataset(self.http_client, name=name).parsed + def find_by_name(self, name: str, workspace: Optional[str] = None) -> _DatasetApiModel: + dataset = get_dataset(self.http_client, name=name, workspace=workspace).parsed return self._DatasetApiModel.parse_obj(dataset) def create(self, name: str, task: TaskType, workspace: str) -> _DatasetApiModel: @@ -163,7 +163,7 @@ def configure(self, name: str, workspace: str, settings: Settings): ) ds = self.create(name=name, task=task, workspace=workspace) except AlreadyExistsApiError: - ds = self.find_by_name(name) + ds = self.find_by_name(name, workspace=workspace) self._save_settings(dataset=ds, settings=settings) def scan( @@ -322,7 +322,7 @@ def _save_settings(self, dataset: _DatasetApiModel, settings: Settings): try: with api_compatibility(self, min_version="1.4"): self.http_client.patch( - f"{self._API_PREFIX}/{dataset.name}/{dataset.task.value}/settings", + f"{self._API_PREFIX}/{dataset.name}/{dataset.task.value}/settings?workspace={dataset.workspace}", json=settings_.dict(), ) except ApiCompatibilityError: @@ -332,20 +332,24 @@ def _save_settings(self, dataset: _DatasetApiModel, settings: Settings): json=settings_.dict(), ) - def load_settings(self, name: str) -> Optional[Settings]: + def load_settings(self, name: str, workspace: Optional[str] = None) -> Optional[Settings]: """ Load the dataset settings Args: name: The dataset name + workspace: The workspace name where the dataset belongs to Returns: Settings defined for the dataset """ - dataset = self.find_by_name(name) + dataset = self.find_by_name(name, workspace=workspace) try: with api_compatibility(self, min_version="1.0"): - response = self.http_client.get(f"{self._API_PREFIX}/{dataset.name}/{dataset.task.value}/settings") + params = {"workspace": dataset.workspace} if dataset.workspace else {} + response = self.http_client.get( + f"{self._API_PREFIX}/{dataset.name}/{dataset.task.value}/settings", params=params + ) return __TASK_TO_SETTINGS__.get(dataset.task).from_dict(response) except NotFoundApiError: return None diff --git a/src/argilla/client/feedback/dataset/__init__.py b/src/argilla/client/feedback/dataset/__init__.py index 564e19262f..0f9646d49b 100644 --- a/src/argilla/client/feedback/dataset/__init__.py +++ b/src/argilla/client/feedback/dataset/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from argilla.client.feedback.dataset.local import FeedbackDataset +from argilla.client.feedback.dataset.local.dataset import FeedbackDataset __all__ = ["FeedbackDataset"] diff --git a/src/argilla/client/feedback/dataset/base.py b/src/argilla/client/feedback/dataset/base.py index 1dd375954a..fc9c684056 100644 --- a/src/argilla/client/feedback/dataset/base.py +++ b/src/argilla/client/feedback/dataset/base.py @@ -23,23 +23,9 @@ FeedbackRecord, SortBy, ) -from argilla.client.feedback.schemas.enums import ResponseStatusFilter -from argilla.client.feedback.schemas.metadata import MetadataFilters from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedMetadataPropertyTypes, AllowedQuestionTypes -from argilla.client.feedback.training.schemas import ( - TrainingTaskForChatCompletion, - TrainingTaskForDPO, - TrainingTaskForPPO, - TrainingTaskForQuestionAnswering, - TrainingTaskForRM, - TrainingTaskForSentenceSimilarity, - TrainingTaskForSFT, - TrainingTaskForTextClassification, - TrainingTaskTypes, -) from argilla.client.feedback.utils import generate_pydantic_schema_for_fields, generate_pydantic_schema_for_metadata -from argilla.client.models import Framework -from argilla.utils.dependency import require_dependencies, requires_dependencies +from argilla.utils.dependency import requires_dependencies if TYPE_CHECKING: from datasets import Dataset @@ -274,16 +260,6 @@ def sort_by(self, sort: List[SortBy]) -> "FeedbackDatasetBase": """Sorts the records in the dataset by the given field.""" pass - @abstractmethod - def filter_by( - self, - *, - response_status: Optional[Union[ResponseStatusFilter, List[ResponseStatusFilter]]] = None, - metadata_filters: Optional[Union[MetadataFilters, List[MetadataFilters]]] = None, - ) -> "FeedbackDatasetBase": - """Filters the records in the dataset by the given filters.""" - pass - def _build_fields_schema(self) -> Type[BaseModel]: """Returns the fields schema of the dataset.""" return generate_pydantic_schema_for_fields(self.fields) @@ -420,113 +396,40 @@ def format_as(self, format: Literal["datasets"]) -> "Dataset": return HuggingFaceDatasetMixin._huggingface_format(self) raise ValueError(f"Unsupported format '{format}'.") - # TODO(alvarobartt,davidberenstein1957): we should consider having something like - # `export(..., training=True)` to export the dataset records in any format, replacing - # both `format_as` and `prepare_for_training` - def prepare_for_training( - self, - framework: Union[Framework, str], - task: TrainingTaskTypes, - train_size: Optional[float] = 1, - test_size: Optional[float] = None, - seed: Optional[int] = None, - lang: Optional[str] = None, - ) -> Any: - """ - Prepares the dataset for training for a specific training framework and NLP task by splitting the dataset into train and test sets. + @abstractmethod + def add_records(self, *args, **kwargs) -> None: + """Adds the given records to the `FeedbackDataset`.""" + pass - Args: - framework: the framework to use for training. Currently supported frameworks are: `transformers`, `peft`, - `setfit`, `spacy`, `spacy-transformers`, `span_marker`, `spark-nlp`, `openai`, `trl`, `sentence-transformers`. - task: the NLP task to use for training. Currently supported tasks are: `TrainingTaskForTextClassification`, - `TrainingTaskForSFT`, `TrainingTaskForRM`, `TrainingTaskForPPO`, `TrainingTaskForDPO`, `TrainingTaskForSentenceSimilarity`. - train_size: the size of the train set. If `None`, the whole dataset will be used for training. - test_size: the size of the test set. If `None`, the whole dataset will be used for testing. - seed: the seed to use for splitting the dataset into train and test sets. - lang: the spaCy language to use for training. If `None`, the language of the dataset will be used. - """ - if isinstance(framework, str): - framework = Framework(framework) - - # validate train and test sizes - if train_size is None: - train_size = 1 - if test_size is None: - test_size = 1 - train_size - - # check if all numbers are larger than 0 - if not [abs(train_size), abs(test_size)] == [train_size, test_size]: - raise ValueError("`train_size` and `test_size` must be larger than 0.") - # check if train sizes sum up to 1 - if not (train_size + test_size) == 1: - raise ValueError("`train_size` and `test_size` must sum to 1.") - - if test_size == 0: - test_size = None - - if len(self.records) < 1: - raise ValueError( - "No records found in the dataset. Make sure you add records to the" - " dataset via the `FeedbackDataset.add_records` method first." - ) + @abstractmethod + def pull(self): + """Pulls the dataset from Argilla and returns a local instance of it.""" + pass - if isinstance(task, (TrainingTaskForTextClassification, TrainingTaskForSentenceSimilarity)): - if task.formatting_func is None: - # in sentence-transformer models we can train without labels - if task.label: - self.unify_responses(question=task.label.question, strategy=task.label.strategy) - elif isinstance(task, TrainingTaskForQuestionAnswering): - if task.formatting_func is None: - self.unify_responses(question=task.answer.name, strategy="disagreement") - elif not isinstance( - task, - ( - TrainingTaskForSFT, - TrainingTaskForRM, - TrainingTaskForPPO, - TrainingTaskForDPO, - TrainingTaskForChatCompletion, - ), - ): - raise ValueError(f"Training data {type(task)} is not supported yet") - - data = task._format_data(self) - if framework in [ - Framework.TRANSFORMERS, - Framework.SETFIT, - Framework.SPAN_MARKER, - Framework.PEFT, - ]: - return task._prepare_for_training_with_transformers( - data=data, train_size=train_size, seed=seed, framework=framework - ) - elif framework in [Framework.SPACY, Framework.SPACY_TRANSFORMERS]: - require_dependencies("spacy") - import spacy - - if lang is None: - _LOGGER.warning("spaCy `lang` is not provided. Using `en`(English) as default language.") - lang = spacy.blank("en") - elif lang.isinstance(str): - if len(lang) == 2: - lang = spacy.blank(lang) - else: - lang = spacy.load(lang) - return task._prepare_for_training_with_spacy(data=data, train_size=train_size, seed=seed, lang=lang) - elif framework is Framework.SPARK_NLP: - return task._prepare_for_training_with_spark_nlp(data=data, train_size=train_size, seed=seed) - elif framework is Framework.OPENAI: - return task._prepare_for_training_with_openai(data=data, train_size=train_size, seed=seed) - elif framework is Framework.TRL: - return task._prepare_for_training_with_trl(data=data, train_size=train_size, seed=seed) - elif framework is Framework.TRLX: - return task._prepare_for_training_with_trlx(data=data, train_size=train_size, seed=seed) - elif framework is Framework.SENTENCE_TRANSFORMERS: - return task._prepare_for_training_with_sentence_transformers(data=data, train_size=train_size, seed=seed) - else: - raise NotImplementedError( - f"Framework {framework} is not supported. Choose from: {[e.value for e in Framework]}" - ) + @abstractmethod + def filter_by(self, *args, **kwargs): + """Filters the current `FeedbackDataset`.""" + pass + + @abstractmethod + def delete(self): + """Deletes the `FeedbackDataset` from Argilla.""" + pass + + @abstractmethod + def prepare_for_training(self, *args, **kwargs) -> Any: + """Prepares the `FeedbackDataset` for training by creating the training.""" + pass + + @abstractmethod + def push_to_argilla(self, *args, **kwargs) -> "FeedbackDatasetBase": + """Pushes the `FeedbackDataset` to Argilla.""" + pass + + @abstractmethod + def unify_responses(self, *args, **kwargs): + """Unifies the responses for a given question.""" + pass @abstractmethod def add_metadata_property(self, metadata_property): diff --git a/src/argilla/client/feedback/dataset/local.py b/src/argilla/client/feedback/dataset/local/dataset.py similarity index 53% rename from src/argilla/client/feedback/dataset/local.py rename to src/argilla/client/feedback/dataset/local/dataset.py index 365d8ad8f5..01e2b9c0e8 100644 --- a/src/argilla/client/feedback/dataset/local.py +++ b/src/argilla/client/feedback/dataset/local/dataset.py @@ -11,16 +11,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import logging +import textwrap import warnings -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Union from argilla.client.feedback.constants import FETCHING_BATCH_SIZE from argilla.client.feedback.dataset.base import FeedbackDatasetBase -from argilla.client.feedback.dataset.mixins import ArgillaMixin, UnificationMixin -from argilla.client.feedback.schemas.enums import RecordSortField, ResponseStatusFilter, SortOrder -from argilla.client.feedback.schemas.metadata import MetadataFilters +from argilla.client.feedback.dataset.local.mixins import ArgillaMixin +from argilla.client.feedback.schemas.enums import RecordSortField, SortOrder +from argilla.client.feedback.schemas.questions import ( + LabelQuestion, + MultiLabelQuestion, + RankingQuestion, + RatingQuestion, + TextQuestion, +) from argilla.client.feedback.schemas.types import AllowedQuestionTypes +from argilla.client.feedback.training.schemas import ( + TrainingTaskForChatCompletion, + TrainingTaskForDPO, + TrainingTaskForPPO, + TrainingTaskForQuestionAnswering, + TrainingTaskForRM, + TrainingTaskForSFT, + TrainingTaskForSentenceSimilarity, + TrainingTaskForTextClassification, + TrainingTaskTypes, +) +from argilla.client.feedback.unification import ( + LabelQuestionStrategy, + MultiLabelQuestionStrategy, + RankingQuestionStrategy, + RatingQuestionStrategy, + TextQuestionStrategy, +) +from argilla.client.models import Framework +from argilla.utils.dependency import require_dependencies if TYPE_CHECKING: from argilla.client.feedback.schemas.records import FeedbackRecord @@ -31,7 +58,10 @@ ) -class FeedbackDataset(FeedbackDatasetBase["FeedbackRecord"], ArgillaMixin, UnificationMixin): +_LOGGER = logging.getLogger(__name__) + + +class FeedbackDataset(ArgillaMixin, FeedbackDatasetBase["FeedbackRecord"]): def __init__( self, *, @@ -128,15 +158,15 @@ def records(self) -> List["FeedbackRecord"]: """Returns the records in the dataset.""" return self._records - def update_records(self, records: Union["FeedbackRecord", List["FeedbackRecord"]]) -> None: - warnings.warn( - "`update_records` method only works for `FeedbackDataset` pushed to Argilla. " - "If your are working with local data, you can just iterate over the records and update them." - ) - def __repr__(self) -> str: """Returns a string representation of the dataset.""" - return f"" + return ( + "FeedbackDataset(" + + textwrap.indent( + f"\nfields={self.fields}\nquestions={self.questions}\nguidelines={self.guidelines})", " " + ) + + "\n)" + ) def __len__(self) -> int: """Returns the number of records in the dataset.""" @@ -169,8 +199,7 @@ def iter(self, batch_size: Optional[int] = FETCHING_BATCH_SIZE) -> Iterator[List yield self._records[i : i + batch_size] def add_records( - self, - records: Union["FeedbackRecord", Dict[str, Any], List[Union["FeedbackRecord", Dict[str, Any]]]], + self, records: Union["FeedbackRecord", Dict[str, Any], List[Union["FeedbackRecord", Dict[str, Any]]]] ) -> None: """Adds the given records to the dataset, and stores them locally. If you are planning to push those to Argilla, you will need to call `push_to_argilla` afterwards, @@ -197,31 +226,6 @@ def add_records( else: self._records = records - def sort_by( - self, field: Union[str, RecordSortField], order: Union[str, SortOrder] = SortOrder.asc - ) -> "FeedbackDataset": - warnings.warn( - "`sort_by` method only works for `FeedbackDataset` pushed to Argilla. " - "Use `sorted` with dataset.records instead.", - UserWarning, - stacklevel=1, - ) - return self - - def filter_by( - self, - *, - response_status: Optional[Union[ResponseStatusFilter, List[ResponseStatusFilter]]] = None, - metadata_filters: Optional[Union[MetadataFilters, List[MetadataFilters]]] = None, - ) -> "FeedbackDataset": - warnings.warn( - "`filter_by` method only works for `FeedbackDataset` pushed to Argilla. " - "Use `filter` with dataset.records instead.", - UserWarning, - stacklevel=1, - ) - return self - def add_metadata_property( self, metadata_property: "AllowedMetadataPropertyTypes" ) -> "AllowedMetadataPropertyTypes": @@ -282,3 +286,197 @@ def delete_metadata_properties( deleted_metadata_properties.append(metadata_properties_mapping.pop(metadata_property)) self._metadata_properties = list(metadata_properties_mapping.values()) return deleted_metadata_properties if len(deleted_metadata_properties) > 1 else deleted_metadata_properties[0] + + def unify_responses( + self: "FeedbackDatasetBase", + question: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion], + strategy: Union[ + str, LabelQuestionStrategy, MultiLabelQuestionStrategy, RatingQuestionStrategy, RankingQuestionStrategy + ], + ) -> "FeedbackDataset": + """ + The `unify_responses` function takes a question and a strategy as input and applies the strategy + to unify the responses for that question. + + Args: + question The `question` parameter can be either a string representing the name of the + question, or an instance of one of the question classes (`LabelQuestion`, `MultiLabelQuestion`, + `RatingQuestion`, `RankingQuestion`). + strategy The `strategy` parameter is used to specify the strategy to be used for unifying + responses for a given question. It can be either a string or an instance of a strategy class. + """ + if isinstance(question, str): + question = self.question_by_name(question) + + if isinstance(strategy, str): + if isinstance(question, LabelQuestion): + strategy = LabelQuestionStrategy(strategy) + elif isinstance(question, MultiLabelQuestion): + strategy = MultiLabelQuestionStrategy(strategy) + elif isinstance(question, RatingQuestion): + strategy = RatingQuestionStrategy(strategy) + elif isinstance(question, RankingQuestion): + strategy = RankingQuestionStrategy(strategy) + elif isinstance(question, TextQuestion): + strategy = TextQuestionStrategy(strategy) + else: + raise ValueError(f"Question {question} is not supported yet") + + strategy.unify_responses(self.records, question) + return self + + # TODO(alvarobartt,davidberenstein1957): we should consider having something like + # `export(..., training=True)` to export the dataset records in any format, replacing + # both `format_as` and `prepare_for_training` + def prepare_for_training( + self, + framework: Union[Framework, str], + task: TrainingTaskTypes, + train_size: Optional[float] = 1, + test_size: Optional[float] = None, + seed: Optional[int] = None, + lang: Optional[str] = None, + ) -> Any: + """ + Prepares the dataset for training for a specific training framework and NLP task by splitting the dataset into train and test sets. + + Args: + framework: the framework to use for training. Currently supported frameworks are: `transformers`, `peft`, + `setfit`, `spacy`, `spacy-transformers`, `span_marker`, `spark-nlp`, `openai`, `trl`, `sentence-transformers`. + task: the NLP task to use for training. Currently supported tasks are: `TrainingTaskForTextClassification`, + `TrainingTaskForSFT`, `TrainingTaskForRM`, `TrainingTaskForPPO`, `TrainingTaskForDPO`, `TrainingTaskForSentenceSimilarity`. + train_size: the size of the train set. If `None`, the whole dataset will be used for training. + test_size: the size of the test set. If `None`, the whole dataset will be used for testing. + seed: the seed to use for splitting the dataset into train and test sets. + lang: the spaCy language to use for training. If `None`, the language of the dataset will be used. + """ + if isinstance(framework, str): + framework = Framework(framework) + + # validate train and test sizes + if train_size is None: + train_size = 1 + if test_size is None: + test_size = 1 - train_size + + # check if all numbers are larger than 0 + if not [abs(train_size), abs(test_size)] == [train_size, test_size]: + raise ValueError("`train_size` and `test_size` must be larger than 0.") + # check if train sizes sum up to 1 + if not (train_size + test_size) == 1: + raise ValueError("`train_size` and `test_size` must sum to 1.") + + if test_size == 0: + test_size = None + + if len(self.records) < 1: + raise ValueError( + "No records found in the dataset. Make sure you add records to the" + " dataset via the `FeedbackDataset.add_records()` method first." + ) + + local_dataset = self.pull() + if isinstance(task, (TrainingTaskForTextClassification, TrainingTaskForSentenceSimilarity)): + if task.formatting_func is None: + # in sentence-transformer models we can train without labels + if task.label: + local_dataset = local_dataset.unify_responses( + question=task.label.question, strategy=task.label.strategy + ) + elif isinstance(task, TrainingTaskForQuestionAnswering): + if task.formatting_func is None: + local_dataset = self.unify_responses(question=task.answer.name, strategy="disagreement") + elif not isinstance( + task, + ( + TrainingTaskForSFT, + TrainingTaskForRM, + TrainingTaskForPPO, + TrainingTaskForDPO, + TrainingTaskForChatCompletion, + ), + ): + raise ValueError(f"Training data {type(task)} is not supported yet") + + data = task._format_data(local_dataset) + if framework in [ + Framework.TRANSFORMERS, + Framework.SETFIT, + Framework.SPAN_MARKER, + Framework.PEFT, + ]: + return task._prepare_for_training_with_transformers( + data=data, train_size=train_size, seed=seed, framework=framework + ) + elif framework in [Framework.SPACY, Framework.SPACY_TRANSFORMERS]: + require_dependencies("spacy") + import spacy + + if lang is None: + _LOGGER.warning("spaCy `lang` is not provided. Using `en`(English) as default language.") + lang = spacy.blank("en") + elif lang.isinstance(str): + if len(lang) == 2: + lang = spacy.blank(lang) + else: + lang = spacy.load(lang) + return task._prepare_for_training_with_spacy(data=data, train_size=train_size, seed=seed, lang=lang) + elif framework is Framework.SPARK_NLP: + return task._prepare_for_training_with_spark_nlp(data=data, train_size=train_size, seed=seed) + elif framework is Framework.OPENAI: + return task._prepare_for_training_with_openai(data=data, train_size=train_size, seed=seed) + elif framework is Framework.TRL: + return task._prepare_for_training_with_trl(data=data, train_size=train_size, seed=seed) + elif framework is Framework.TRLX: + return task._prepare_for_training_with_trlx(data=data, train_size=train_size, seed=seed) + elif framework is Framework.SENTENCE_TRANSFORMERS: + return task._prepare_for_training_with_sentence_transformers(data=data, train_size=train_size, seed=seed) + else: + raise NotImplementedError( + f"Framework {framework} is not supported. Choose from: {[e.value for e in Framework]}" + ) + + def update_records(self, records: Union["FeedbackRecord", List["FeedbackRecord"]]) -> None: + warnings.warn( + "`update_records` method only works for `FeedbackDataset` pushed to Argilla. " + "If your are working with local data, you can just iterate over the records and update them." + ) + + def sort_by( + self, field: Union[str, RecordSortField], order: Union[str, SortOrder] = SortOrder.asc + ) -> "FeedbackDataset": + warnings.warn( + "`sort_by` method is not supported for local datasets and won't take any effect. " + "First, you need to push the dataset to Argilla with `FeedbackDataset.push_to_argilla()`. " + "After, use `FeedbackDataset.from_argilla(...).sort_by()`.", + UserWarning, + stacklevel=1, + ) + return self + + def pull(self) -> "FeedbackDataset": + warnings.warn( + "`pull` method is not supported for local datasets and won't take any effect." + "First, you need to push the dataset to Argilla with `FeedbackDataset.push_to_argilla()`. " + "After, use `FeedbackDataset.from_argilla(...).pull()`.", + UserWarning, + ) + return self + + def filter_by(self, *args, **kwargs) -> "FeedbackDataset": + warnings.warn( + "`filter_by` method is not supported for local datasets and won't take any effect. " + "First, you need to push the dataset to Argilla with `FeedbackDataset.push_to_argilla()`. " + "After, use `FeedbackDataset.from_argilla(...).filter_by()`.", + UserWarning, + ) + return self + + def delete(self): + warnings.warn( + "`delete` method is not supported for local datasets and won't take any effect. " + "First, you need to push the dataset to Argilla with `FeedbackDataset.push_to_argilla`. " + "After, use `FeedbackDataset.from_argilla(...).delete()`", + UserWarning, + ) + return self diff --git a/src/argilla/client/feedback/dataset/mixins.py b/src/argilla/client/feedback/dataset/local/mixins.py similarity index 90% rename from src/argilla/client/feedback/dataset/mixins.py rename to src/argilla/client/feedback/dataset/local/mixins.py index 5eb7bd8c45..154857f1d3 100644 --- a/src/argilla/client/feedback/dataset/mixins.py +++ b/src/argilla/client/feedback/dataset/local/mixins.py @@ -15,8 +15,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union from uuid import UUID -from tqdm import trange - from argilla.client.api import ArgillaSingleton from argilla.client.feedback.constants import PUSHING_BATCH_SIZE from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset @@ -28,6 +26,7 @@ RatingQuestion, TextQuestion, ) +from argilla.client.feedback.schemas.enums import FieldTypes, QuestionTypes from argilla.client.feedback.schemas.remote.fields import RemoteTextField from argilla.client.feedback.schemas.remote.metadata import ( RemoteFloatMetadataProperty, @@ -41,21 +40,17 @@ RemoteRatingQuestion, RemoteTextQuestion, ) -from argilla.client.feedback.unification import ( - LabelQuestionStrategy, - MultiLabelQuestionStrategy, - RankingQuestionStrategy, - RatingQuestionStrategy, - TextQuestionStrategy, -) from argilla.client.feedback.utils import feedback_dataset_in_argilla from argilla.client.sdk.v1.datasets import api as datasets_api_v1 from argilla.client.workspaces import Workspace +from tqdm import trange if TYPE_CHECKING: import httpx - from argilla.client.client import Argilla as ArgillaClient + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset + from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes + from argilla.client.sdk.v1.datasets.models import FeedbackDatasetModel from argilla.client.feedback.dataset.local import FeedbackDataset from argilla.client.feedback.schemas.records import FeedbackRecord from argilla.client.feedback.schemas.types import ( @@ -448,42 +443,3 @@ def list(cls: Type["FeedbackDataset"], workspace: Optional[str] = None) -> List[ ) for dataset in datasets ] - - -class UnificationMixin: - def unify_responses( - self: "FeedbackDataset", - question: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion], - strategy: Union[ - str, LabelQuestionStrategy, MultiLabelQuestionStrategy, RatingQuestionStrategy, RankingQuestionStrategy - ], - ) -> None: - """ - The `unify_responses` function takes a question and a strategy as input and applies the strategy - to unify the responses for that question. - - Args: - question The `question` parameter can be either a string representing the name of the - question, or an instance of one of the question classes (`LabelQuestion`, `MultiLabelQuestion`, - `RatingQuestion`, `RankingQuestion`). - strategy The `strategy` parameter is used to specify the strategy to be used for unifying - responses for a given question. It can be either a string or an instance of a strategy class. - """ - if isinstance(question, str): - question = self.question_by_name(question) - - if isinstance(strategy, str): - if isinstance(question, LabelQuestion): - strategy = LabelQuestionStrategy(strategy) - elif isinstance(question, MultiLabelQuestion): - strategy = MultiLabelQuestionStrategy(strategy) - elif isinstance(question, RatingQuestion): - strategy = RatingQuestionStrategy(strategy) - elif isinstance(question, RankingQuestion): - strategy = RankingQuestionStrategy(strategy) - elif isinstance(question, TextQuestion): - strategy = TextQuestionStrategy(strategy) - else: - raise ValueError(f"Question {question} is not supported yet") - - strategy.unify_responses(self.records, question) diff --git a/src/argilla/client/feedback/dataset/remote/dataset.py b/src/argilla/client/feedback/dataset/remote/dataset.py index dffb60c5fe..c3534f5d8e 100644 --- a/src/argilla/client/feedback/dataset/remote/dataset.py +++ b/src/argilla/client/feedback/dataset/remote/dataset.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import textwrap import warnings from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union @@ -23,8 +24,23 @@ from argilla.client.feedback.dataset.remote.mixins import ArgillaRecordsMixin from argilla.client.feedback.mixins import ArgillaMetadataPropertiesMixin from argilla.client.feedback.schemas.enums import ResponseStatusFilter +from argilla.client.feedback.schemas.questions import ( + LabelQuestion, + MultiLabelQuestion, + RatingQuestion, +) from argilla.client.feedback.schemas.records import FeedbackRecord from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord +from argilla.client.feedback.training.schemas import ( + TrainingTaskTypes, +) +from argilla.client.feedback.unification import ( + LabelQuestionStrategy, + MultiLabelQuestionStrategy, + RankingQuestionStrategy, + RatingQuestionStrategy, +) +from argilla.client.models import Framework from argilla.client.sdk.users.models import UserRole from argilla.client.sdk.v1.datasets import api as datasets_api_v1 from argilla.client.utils import allowed_for_roles @@ -43,6 +59,10 @@ AllowedRemoteQuestionTypes, ) from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel, FeedbackResponseStatusFilter + from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.schemas.enums import ResponseStatusFilter + from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes + from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel from argilla.client.workspaces import Workspace @@ -371,10 +391,17 @@ def _question_name_to_id(self) -> Dict[str, "UUID"]: def __repr__(self) -> str: """Returns a string representation of the dataset.""" + indent = " " return ( - f"" + "RemoteFeedbackDataset(" + + textwrap.indent(f"\nid={self.id}", indent) + + textwrap.indent(f"\nname={self.name}", indent) + + textwrap.indent(f"\nworkspace={self.workspace}", indent) + + textwrap.indent(f"\nurl={self.url}", indent) + + textwrap.indent(f"\nfields={self.fields}", indent) + + textwrap.indent(f"\nquestions={self.questions}", indent) + + textwrap.indent(f"\nguidelines={self.guidelines}", indent) + + ")" ) def __len__(self) -> int: @@ -444,7 +471,7 @@ def pull(self) -> "FeedbackDataset": A local instance of the dataset which is a `FeedbackDataset` object. """ # Importing here to avoid circular imports - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset instance = FeedbackDataset( fields=self.fields, @@ -509,7 +536,7 @@ def add_metadata_property( ) from e # TODO(alvarobartt): structure better the mixins to be able to easily reuse those, here to avoid circular imports - from argilla.client.feedback.dataset.mixins import ArgillaMixin + from argilla.client.feedback.dataset.local.mixins import ArgillaMixin return ArgillaMixin._parse_to_remote_metadata_property(metadata_property=metadata_property, client=self._client) @@ -628,3 +655,79 @@ def _create_from_dataset(cls, dataset: "RemoteFeedbackDataset") -> "RemoteFeedba new_dataset._records = dataset.records return new_dataset + + def unify_responses( + self, + question: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion], + strategy: Union[ + str, LabelQuestionStrategy, MultiLabelQuestionStrategy, RatingQuestionStrategy, RankingQuestionStrategy + ], + ) -> "FeedbackDataset": + """ + The `unify_responses` function takes a question and a strategy as input and applies the strategy + to unify the responses for that question. + + Args: + question The `question` parameter can be either a string representing the name of the + question, or an instance of one of the question classes (`LabelQuestion`, `MultiLabelQuestion`, + `RatingQuestion`, `RankingQuestion`). + strategy The `strategy` parameter is used to specify the strategy to be used for unifying + responses for a given question. It can be either a string or an instance of a strategy class. + """ + warnings.warn( + "A local `FeedbackDataset` returned because " + "`unify_responses` is not supported for `RemoteFeedbackDataset`. " + "`RemoteFeedbackDataset`.pull().unify_responses(*args, **kwargs)` is applied.", + UserWarning, + ) + local = self.pull() + return local.unify_responses(question=question, strategy=strategy) + + def prepare_for_training( + self, + framework: Union[Framework, str], + task: TrainingTaskTypes, + train_size: Optional[float] = 1, + test_size: Optional[float] = None, + seed: Optional[int] = None, + lang: Optional[str] = None, + ) -> Any: + """ + Prepares the dataset for training for a specific training framework and NLP task by splitting the dataset into train and test sets. + + Args: + framework: the framework to use for training. Currently supported frameworks are: `transformers`, `peft`, + `setfit`, `spacy`, `spacy-transformers`, `span_marker`, `spark-nlp`, `openai`, `trl`, `sentence-transformers`. + task: the NLP task to use for training. Currently supported tasks are: `TrainingTaskForTextClassification`, + `TrainingTaskForSFT`, `TrainingTaskForRM`, `TrainingTaskForPPO`, `TrainingTaskForDPO`, `TrainingTaskForSentenceSimilarity`. + train_size: the size of the train set. If `None`, the whole dataset will be used for training. + test_size: the size of the test set. If `None`, the whole dataset will be used for testing. + seed: the seed to use for splitting the dataset into train and test sets. + lang: the spaCy language to use for training. If `None`, the language of the dataset will be used. + """ + warnings.warn( + ( + "A local `FeedbackDataset` returned because " + "`prepare_for_training` is not supported for `RemoteFeedbackDataset`. " + "`RemoteFeedbackDataset`.pull().prepare_for_training(*args, **kwargs)` is applied." + ), + UserWarning, + ) + local = self.pull() + return local.prepare_for_training( + framework=framework, + task=task, + train_size=train_size, + test_size=test_size, + seed=seed, + lang=lang, + ) + + def push_to_argilla( + self, name: str, workspace: Optional[Union[str, "Workspace"]] = None, show_progress: bool = False + ) -> "RemoteFeedbackDataset": + warnings.warn( + "Already pushed datasets cannot be pushed to Argilla again because they are synced automatically.", + UserWarning, + ) + return self diff --git a/src/argilla/client/feedback/integrations/huggingface/card/__init__.py b/src/argilla/client/feedback/integrations/huggingface/card/__init__.py index fbd10d4f5f..f50dddbb1d 100644 --- a/src/argilla/client/feedback/integrations/huggingface/card/__init__.py +++ b/src/argilla/client/feedback/integrations/huggingface/card/__init__.py @@ -15,4 +15,7 @@ from ._dataset_card import ArgillaDatasetCard from ._parser import size_categories_parser -__all__ = ["ArgillaDatasetCard", "size_categories_parser"] +__all__ = [ + "ArgillaDatasetCard", + "size_categories_parser", +] diff --git a/src/argilla/client/feedback/integrations/huggingface/model_card/__init__.py b/src/argilla/client/feedback/integrations/huggingface/model_card/__init__.py new file mode 100644 index 0000000000..caebcdc818 --- /dev/null +++ b/src/argilla/client/feedback/integrations/huggingface/model_card/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .model_card import ( + ArgillaModelCard, + FrameworkCardData, + OpenAIModelCardData, + PeftModelCardData, + SentenceTransformerCardData, + SetFitModelCardData, + SpacyModelCardData, + SpacyTransformersModelCardData, + TransformersModelCardData, + TRLModelCardData, +) + +__all__ = [ + "ArgillaModelCard", + "FrameworkCardData", + "SentenceTransformerCardData", + "TransformersModelCardData", + "SetFitModelCardData", + "PeftModelCardData", + "SpacyModelCardData", + "SpacyTransformersModelCardData", + "OpenAIModelCardData", + "TRLModelCardData", +] diff --git a/src/argilla/client/feedback/integrations/huggingface/model_card/argilla_model_template.md b/src/argilla/client/feedback/integrations/huggingface/model_card/argilla_model_template.md new file mode 100644 index 0000000000..128694f62a --- /dev/null +++ b/src/argilla/client/feedback/integrations/huggingface/model_card/argilla_model_template.md @@ -0,0 +1,170 @@ +--- +# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 +# Doc / guide: https://huggingface.co/docs/hub/model-cards +{{ card_data }} +--- + + + +# Model Card for *{{ model_name | default("Model ID", true) }}* + +This model has been created with [Argilla](https://docs.argilla.io), trained with *{{ library_name }}*. + + + +{{ model_summary | default("", true) }} + +## Model training + +Training the model using the `ArgillaTrainer`: + +```python +# Load the dataset: +dataset = FeedbackDataset.from_{% if _is_on_huggingface %}huggingface("{{ dataset_name }}"){% else %}argilla({% if dataset_name %}"{{ dataset_name }}"{% else %}"..."{% endif %}){% endif %} + +# Create the training task: +{{ trainer_task_call }} + +# Create the ArgillaTrainer: +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="{{ framework }}", + {%- if lang %}{{ "\n " }}lang="{{ lang }}",{% endif %} + {%- if model_id %}{{ "\n " }}model="{{ model_id }}",{% endif %} + {%- if tokenizer %}{{ "\n " }}tokenizer={{ tokenizer }},{% endif %} + {%- if train_size %}{{ "\n " }}train_size={{ train_size }},{% endif %} + {%- if seed %}{{ "\n " }}seed={{ seed }},{% endif %} + {%- if gpu_id %}{{ "\n " }}gpu_id={{ gpu_id }},{% endif %} + {%- if framework_kwargs %}{{ "\n " }}framework_kwargs={{ framework_kwargs }},{% endif %} +) +{% if update_config_call %}{{ update_config_call }}{% endif %} +trainer.train(output_dir={{ output_dir }}) +``` + +You can test the type of predictions of this model like so: + +```python +{{ predict_call }} +``` + +## Model Details + +### Model Description + + + +{{ model_description | default("", true) }} + +- **Developed by:** {{ developers | default("[More Information Needed]", true)}} +- **Shared by [optional]:** {{ shared_by | default("[More Information Needed]", true)}} +- **Model type:** {{ model_type | default("[More Information Needed]", true)}} +- **Language(s) (NLP):** {{ language | default("[More Information Needed]", true)}} +- **License:** {{ license | default("[More Information Needed]", true)}} +- **Finetuned from model [optional]:** {{ finetuned_from | default("[More Information Needed]", true)}} + +{%- if repo %} +### Model Sources [optional] + + + +- **Repository:** {{ repo }} +{% endif %} + + + + + + + + + + + + + + + + + + + +## Technical Specifications [optional] + +### Framework Versions + +- Python: {{ version["python"] }} +- Argilla: {{ version["argilla"] }} + + + + + + + + diff --git a/src/argilla/client/feedback/integrations/huggingface/model_card/model_card.py b/src/argilla/client/feedback/integrations/huggingface/model_card/model_card.py new file mode 100644 index 0000000000..ce1dc7097c --- /dev/null +++ b/src/argilla/client/feedback/integrations/huggingface/model_card/model_card.py @@ -0,0 +1,664 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from dataclasses import dataclass, field, fields +from inspect import getsource +from pathlib import Path +from platform import python_version +from textwrap import dedent +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union + +from huggingface_hub import CardData, ModelCard, dataset_info, model_info +from huggingface_hub.utils import yaml_dump + +from argilla._version import version +from argilla.client.feedback.training.schemas import TRAINING_TASK_MAPPING, TrainingTaskTypes +from argilla.client.models import FRAMEWORK_TO_NAME_MAPPING, Framework +from argilla.training.utils import get_default_args + +if TYPE_CHECKING: + import spacy + from transformers import PreTrainedTokenizer + + +TEMPLATE_ARGILLA_MODEL_CARD_PATH = Path(__file__).parent / "argilla_model_template.md" + + +TEMPLATE_TASK_CALL = "task = TrainingTask.{task_type}({training_task_args})" + + +YAML_FIELDS = ["language", "license", "tags", "dataset_name", "library_name"] + + +class ArgillaModelCard(ModelCard): + """`ArgillaModelCard` has been created similarly to `ModelCard` from + `huggingface_hub` but with a different template. The template is located at + `argilla/client/feedback/integrations/huggingface/model_card/argilla_model_template.md`. + """ + + default_template_path = TEMPLATE_ARGILLA_MODEL_CARD_PATH + + +@dataclass +class FrameworkCardData(CardData): + """Parent class to generate the variables to add to the ModelCard. + + Each framework will inherit from here and update accordingly. + """ + + # User provided + language: Optional[Union[str, List[str]]] = None + license: Optional[str] = None + model_name: Optional[str] = None + model_id: Optional[str] = None + dataset_name: Optional[str] = None + dataset_id: Optional[str] = None + tags: Optional[List[str]] = field(default_factory=lambda: ["argilla"]) + model_summary: Optional[str] = None + model_description: Optional[str] = None + developers: Optional[str] = None + shared_by: Optional[str] = None + model_type: Optional[str] = None + finetuned_from: Optional[str] = None + repo: Optional[str] = None + + # Control variables for the templates + _is_on_huggingface: bool = field(default=False) + + # Obtained internally from each trainer + framework: Optional[Framework] = None + train_size: Optional[float] = None + seed: Optional[int] = None + framework_kwargs: Dict[str, Any] = field(default_factory=dict) + task: Optional[TrainingTaskTypes] = None + output_dir: Optional[str] = None + version: Dict[str, str] = field( + default_factory=lambda: { + "python": python_version(), + "argilla": version, + }, + init=False, + ) + library_name: Optional[str] = None + # Used to store the arguments passed through `update_config` method. In the case + # of transformers for example, this corresponds to the `trainer_kwargs`. + update_config_kwargs: Dict[str, Any] = field(default_factory=lambda: {}) + + def __post_init__(self): + # To decide whether the dataset is loaded from from_huggingface or from_argilla + if self.dataset_name: + if is_on_huggingface(self.dataset_name, is_model=False): + self._is_on_huggingface = True + self.library_name = FRAMEWORK_TO_NAME_MAPPING[self.framework.value] + self.task_type = TRAINING_TASK_MAPPING[type(self.task)] + + def _trainer_task__repr__(self) -> str: + """Generates the creation of the `TrainingTask*` call. + + Returns: + Representation of the training task creation as a str. + """ + pass + + def _predict__repr__(self) -> str: + """Generates the call to the `predict` method, for the models that implement it, or + the underlying library implementation. + + Returns: + A sample call to the predict method according to the type of model. + """ + pass + + def _update_config__repr__(self) -> str: + """Generates the call to the `update_config` method, for the models that implement it. + The arguments passed by the user to `update_config` are the difference between what + the model contains in the `trainer_kwargs`, which are different across frameworks, (not + only the internal variables but the the attribute name), and what it's internally + generated by the `init_training_args` method. + + Returns: + A sample call to the predict method according to the type of model. + """ + pass + + def _to_dict(self) -> Dict[str, str]: + """Write this method to insert variables pertaining to a special framework only.""" + return {} + + def to_dict(self) -> Dict[str, Any]: + """Main method to generate the variables that will be written in the model card.""" + default_kwargs = {field.name: getattr(self, field.name) for field in fields(self)} + + kwargs = { + "framework": self.framework.value, + "trainer_task_call": self._trainer_task__repr__(), + "predict_call": self._predict__repr__(), + } + + if self.framework_kwargs: + kwargs["framework_kwargs"] = str(self.framework_kwargs) + + if update_config_kwargs := self._update_config__repr__(): + kwargs["update_config_call"] = update_config_kwargs + + if extra_kwargs := self._to_dict(): + kwargs.update(**extra_kwargs) + + return {**default_kwargs, **kwargs} + + def to_yaml(self, line_break=None) -> str: + return yaml_dump( + {key: value for key, value in self.to_dict().items() if key in YAML_FIELDS and value is not None}, + sort_keys=False, + line_break=line_break, + ).strip() + + +@dataclass +class SpacyModelCardDataBase(FrameworkCardData): + lang: Optional["spacy.Language"] = None + gpu_id: Optional[int] = -1 + optimize: Literal["efficiency", "accuracy"] = "efficiency" + pipeline: List[str] = field(default_factory=lambda: ["ner"]) + + def _trainer_task__repr__(self) -> str: + task_call = "" + if formatting_func := self.task.formatting_func: + task_call += getsource(formatting_func) + "\n" + training_task_args = "formatting_func=formatting_func" + else: + text = f'dataset.field_by_name("{self.task.text.name}")' + training_task_args = f'text={text}, label=dataset.question_by_name("{self.task.label.question.name}")' + return task_call + TEMPLATE_TASK_CALL.format(task_type=self.task_type, training_task_args=training_task_args) + + def _predict__repr__(self) -> str: + return 'trainer.predict("This is awesome!")' + + def _to_dict(self) -> Dict[str, str]: + return {"gpu_id": self.gpu_id, "lang": self.lang, "optimize": self.optimize} + + +@dataclass +class SpacyModelCardData(SpacyModelCardDataBase): + framework: Framework = Framework("spacy") + freeze_tok2vec: bool = False + + def _to_dict(self) -> Dict[str, str]: + kwargs = super()._to_dict() + # Only add this variable if is different from the default + if freeze_tok2vec := self.freeze_tok2vec: + kwargs.update({"freeze_tok2vec": freeze_tok2vec}) + return kwargs + + def _update_config__repr__(self) -> Optional[str]: + return + + +@dataclass +class SpacyTransformersModelCardData(SpacyModelCardDataBase): + framework: Framework = Framework("spacy-transformers") + update_transformer: bool = True + + def _to_dict(self) -> Dict[str, str]: + kwargs = super()._to_dict() + if update_transformer := not self.update_transformer: + kwargs.update({"update_transformer": update_transformer}) + return kwargs + + def _update_config__repr__(self) -> str: + return + + +@dataclass +class TransformersModelCardDataBase(FrameworkCardData): + tokenizer: "PreTrainedTokenizer" = "" + + def _trainer_task__repr__(self) -> str: + task_call = "" + if formatting_func := self.task.formatting_func: + task_call += getsource(formatting_func) + "\n" + training_task_args = "formatting_func=formatting_func" + else: + text = f'dataset.field_by_name("{self.task.text.name}")' + training_task_args = f'text={text}, label=dataset.question_by_name("{self.task.label.question.name}")' + return task_call + TEMPLATE_TASK_CALL.format(task_type=self.task_type, training_task_args=training_task_args) + + def _predict__repr__(self) -> str: + return 'trainer.predict("This is awesome!")' + + def _to_dict(self) -> Dict[str, str]: + return {"tokenizer": self.tokenizer} + + +@dataclass +class TransformersModelCardData(TransformersModelCardDataBase): + framework: Framework = Framework("transformers") + + def _trainer_task__repr__(self) -> str: + task_call = "" + if formatting_func := self.task.formatting_func: + task_call += getsource(formatting_func) + "\n" + training_task_args = "formatting_func=formatting_func" + else: + if self.task_type == "for_text_classification": + training_task_args = ( + f'text=dataset.field_by_name("{self.task.text.name}"), ' + f'label=dataset.question_by_name("{self.task.label.question.name}")' + ) + elif self.task_type == "for_question_answering": + training_task_args = ( + f'question=dataset.field_by_name("{self.task.question.name}"), ' + f'context=dataset.field_by_name("{self.task.context.name}"), ' + f'answer=dataset.question_by_name("{self.task.answer.name}")' + ) + else: + raise NotImplementedError(f"Transformer doesn't have this `task_type` implemented: `{self.task_type}`") + + return task_call + TEMPLATE_TASK_CALL.format(task_type=self.task_type, training_task_args=training_task_args) + + def _predict__repr__(self) -> str: + if self.task_type == "for_text_classification": + return super()._predict__repr__() + elif self.task_type == "for_question_answering": + return dedent( + f"""\ + # This type of model has no `predict` method implemented from argilla, but can be done using the underlying library + + from transformers import pipeline + + qa_model = pipeline("question-answering", model="{self.output_dir}") + question = "Where do I live?" + context = "My name is Merve and I live in İstanbul." + qa_model(question = question, context = context)""" + ) + else: + raise NotImplementedError(f"`task_type` not implemented: `{self.task_type}`") + + def _update_config__repr__(self) -> Optional[str]: + from transformers import TrainingArguments + + base_kwargs = get_default_args(TrainingArguments.__init__) + + if updated_args := _updated_arguments(base_kwargs, self.update_config_kwargs): + return _update_config__repr__(updated_args) + + +@dataclass +class SetFitModelCardData(TransformersModelCardDataBase): + framework: Framework = Framework("setfit") + tags: Optional[List[str]] = field(default_factory=lambda: ["text-classification", "setfit", "argilla"]) + + def _update_config__repr__(self) -> Optional[str]: + from setfit import SetFitModel, SetFitTrainer + + setfit_model_kwargs = get_default_args(SetFitModel._from_pretrained) + setfit_model_kwargs.update(get_default_args(SetFitModel.from_pretrained)) + + setfit_trainer_kwargs = get_default_args(SetFitTrainer.__init__) + + # The following arguments are set by default internally, don't need to be shown. + self.update_config_kwargs.pop("model_id", None) + self.update_config_kwargs.pop("revision", None) + self.update_config_kwargs.pop("pretrained_model_name_or_path", None) + self.update_config_kwargs.pop("multi_target_strategy", None) + self.update_config_kwargs.pop("device", None) + self.update_config_kwargs.pop("column_mapping", None) + self.update_config_kwargs.pop("train_dataset", None) + self.update_config_kwargs.pop("eval_dataset", None) + self.update_config_kwargs.pop("model", None) + + base_kwargs = {**setfit_model_kwargs, **setfit_trainer_kwargs} + + if updated_args := _updated_arguments(base_kwargs, self.update_config_kwargs): + return _update_config__repr__(updated_args) + + +@dataclass +class SpanMarkerModelCardData(TransformersModelCardDataBase): + # Not implemented, once we have the FeedbackDataset ready for this model it should be aligned. + framework: Framework = Framework("span_marker") + tags: Optional[List[str]] = field( + default_factory=lambda: [ + "span-marker", + "token-classification", + "ner", + "named-entity-recognition", + ] + ) + + +@dataclass +class PeftModelCardData(TransformersModelCardDataBase): + framework: Framework = Framework("peft") + + def _update_config__repr__(self) -> Optional[str]: + base_kwargs = { + "r": 8, + "target_modules": None, + "lora_alpha": 16, + "lora_dropout": 0.1, + "fan_in_fan_out": False, + "bias": "none", + "inference_mode": False, + "modules_to_save": None, + "init_lora_weights": True, + } + + self.update_config_kwargs.pop("task_type", None) + + if updated_args := _updated_arguments(base_kwargs, self.update_config_kwargs): + return _update_config__repr__(updated_args) + + +@dataclass +class OpenAIModelCardData(FrameworkCardData): + framework: Framework = Framework("openai") + + def _trainer_task__repr__(self) -> str: + return _formatting_func_call(self.task.formatting_func, self.task_type) + + def _predict__repr__(self) -> str: + return dedent( + """\ + # After training we can use the model from the openai framework, you can take a look at their docs in order to use the model + import openai + + completion = openai.ChatCompletion.create( + model="ft:gpt-3.5-turbo:my-org:custom_suffix:id", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"} + ] + ) + """ + ) + + +@dataclass +class TRLModelCardData(FrameworkCardData): + framework: Framework = Framework("trl") + + def _trainer_task__repr__(self) -> str: + return _formatting_func_call(self.task.formatting_func, self.task_type) + + def _predict__repr__(self) -> str: + predict_call = "# This type of model has no `predict` method implemented from argilla, but can be done using the underlying library\n" + if self.task_type == "for_supervised_fine_tuning": + return predict_call + dedent( + f"""\ + from transformers import GenerationConfig, AutoTokenizer, GPT2LMHeadModel + + def generate(model_id: str, instruction: str, context: str = "") -> str: + model = GPT2LMHeadModel.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + inputs = template.format( + instruction=instruction, + context=context, + response="", + ).strip() + + encoding = tokenizer([inputs], return_tensors="pt") + outputs = model.generate( + **encoding, + generation_config=GenerationConfig( + max_new_tokens=32, + min_new_tokens=12, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ), + ) + return tokenizer.decode(outputs[0]) + + generate("{self.output_dir.replace('"', '')}", "Is a toad a frog?")""" + ) + elif self.task_type == "for_reward_modeling": + return predict_call + dedent( + f"""\ + from transformers import AutoTokenizer, AutoModelForSequenceClassification + import torch + + model = AutoModelForSequenceClassification.from_pretrained("{self.output_dir.replace('"', "")}") + tokenizer = AutoTokenizer.from_pretrained("{self.output_dir.replace('"', "")}") + + def get_score(model, tokenizer, text): + # Tokenize the input sequences + inputs = tokenizer(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt") + + # Perform forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # Extract the logits + return outputs.logits[0, 0].item() + + # Example usage + example = template.format(instruction="your prompt", context="your context", response="response") + + score = get_score(model, tokenizer, example) + print(score)""" + ) + elif (self.task_type == "for_proximal_policy_optimization") or ( + self.task_type == "for_direct_preference_optimization" + ): + return predict_call + dedent( + f"""\ + from transformers import AutoModelForCausalLM, AutoTokenizer + + model = AutoModelForCausalLM.from_pretrained("{self.output_dir.replace('"', "")}") + tokenizer = AutoTokenizer.from_pretrained("{self.output_dir.replace('"', "")}") + tokenizer.pad_token = tokenizer.eos_token + + inputs = template.format( + instruction="your prompt", + context="your context", + response="" + ).strip() + encoding = tokenizer([inputs], return_tensors="pt") + outputs = model.generate(**encoding, max_new_tokens=30) + output_text = tokenizer.decode(outputs[0]) + print(output_text)""" + ) + else: + raise NotImplementedError(f"Transformer doesn't have this `task_type` implemented: `{self.task_type}`") + + def _update_config__repr__(self) -> Optional[str]: + if self.task_type == "for_proximal_policy_optimization": + # Similar to what happens with spacy, the current implementation + # doesnt' render appropriately, for the moment let for the user to be written by hand. + return + + else: + base_kwargs = { + # Let evaluation_strategy as if the eval dataset was passed by default + "evaluation_strategy": "epoch", + "logging_steps": 1, + "num_train_epochs": 1, + } + + if updated_args := _updated_arguments(base_kwargs, self.update_config_kwargs): + return _update_config__repr__(updated_args) + + +@dataclass +class SentenceTransformerCardData(FrameworkCardData): + framework: Framework = Framework("sentence-transformers") + tags: Optional[List[str]] = field( + default_factory=lambda: ["sentence-similarity", "sentence-transformers", "argilla"] + ) + cross_encoder: bool = False + # Used to gather internally the arguments passed to `update_config` + trainer_cls: Optional[Callable] = None + + def _trainer_task__repr__(self) -> str: + task_call = "" + if formatting_func := self.task.formatting_func: + task_call += getsource(formatting_func) + "\n" + training_task_args = "formatting_func=formatting_func" + else: + texts = ", ".join([f'dataset.field_by_name("{text.name}")' for text in self.task.texts]) + training_task_args = f"texts=[{texts}]{f', label=dataset.question_by_name({self.task.label.question.name})' if self.task.label else ''}" + return task_call + TEMPLATE_TASK_CALL.format(task_type=self.task_type, training_task_args=training_task_args) + + def _predict__repr__(self) -> str: + return dedent( + """\ + trainer.predict( + [ + ["Machine learning is so easy.", "Deep learning is so straightforward."], + ["Machine learning is so easy.", "This is so difficult, like rocket science."], + ["Machine learning is so easy.", "I can't believe how much I struggled with this."] + ] + )""" + ) + + def _to_dict(self) -> Dict[str, str]: + if cross_encoder := self.cross_encoder: + return {"cross_encoder": cross_encoder} + return {} + + def _update_config__repr__(self) -> Optional[str]: + base_kwargs = { + # model_kwargs + **get_default_args(self.trainer_cls.__init__), + # trainer_kwargs + **get_default_args(self.trainer_cls.fit), + # data_kwargs + "batch_size": 32, + "dataset_type": None, + } + + if updated_args := _updated_arguments(base_kwargs, self.update_config_kwargs): + return _update_config__repr__(updated_args) + + +def _formatting_func_call(formatting_func: Callable, task_type: str) -> str: + """Helper function to extract the code for the task call. + + Args: + formatting_func: Function used to prepare the dataset for training. + task_type: Method called to prepare the dataset for training. + + Returns: + formatting_func_call + """ + task_call = getsource(formatting_func) + "\n" + training_task_args = "formatting_func=formatting_func" + return task_call + TEMPLATE_TASK_CALL.format(task_type=task_type, training_task_args=training_task_args) + + +def _updated_arguments(base_kwargs: Dict[str, Any], current_kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Helper function to determine the arguments the user has given through the `update_config` method. + + It does so by obtaining the difference between the `current_kwargs` and the `base_kwargs` + (the one used by default in the model). + + The arguments can contain nested dicts (which are unhashable). Nested dicts (only one level of depth) are first + transformed to tuples of their items to check for differences in the values, and then transformed back. + Instantiated classes and type objects (a class not instantiated for example) are transformed to their name. + + Args: + base_kwargs: default arguments. + current_kwargs: arguments registered in the model after training. + + Returns: + user_kwargs: User provided kwargs. + """ + + base_kwargs_ = _prepare_dict_for_comparison(base_kwargs) + current_kwargs_ = _prepare_dict_for_comparison(current_kwargs) + + set1 = set(base_kwargs_.items()) + set2 = set(current_kwargs_.items()) + user_kwargs = dict(set2.difference(set1)) + + return _prepare_dict_for_return(user_kwargs) + + +def _prepare_dict_for_comparison(d): + # Transforms the nested lists/dicts to tuples and prepares + # them to be recovered after finding possible differences. + # If a "type" is found, it returns its string representation. + new_dict = {} + for k, v in d.items(): + if isinstance(v, dict): + v = list(v.items()) + v.append("__dict") # Placeholder to create it back + v = tuple(v) + elif isinstance(v, list): + v = [v if isinstance(v, (int, float)) else str(v) for v in v] + v.append("__list") + v = tuple(v) + elif isinstance(v, (int, float, bool)): + # Left these values as they are + pass + else: + # The remaining cases are either strings (the general case) + # an instance of a class (in which case we use the str representation) + # or a class (in the best case in the classes' cases there exists a nice __repr__ method as to be + # valid for `update_config`, otherwise a better logic must be implemented). + if isinstance(v, type): + v = v.__name__ + else: + v = str(v) + + new_dict[k] = v + return new_dict + + +def _prepare_dict_for_return(d): + # Prepares the dict with the original lists/dicts + new_dict = {} + for k, v in d.items(): + if isinstance(v, tuple): + if "__dict" in v: + v = list(v) + v.remove("__dict") + v = dict(v) + elif "__list" in v: + v = list(v) + v.remove("__list") + + new_dict[k] = v + return new_dict + + +def _update_config__repr__(keyword_arguments: Dict[str, Any]) -> str: + """Creates the call to `update_config` on the model. + + Args: + keyword_arguments: Arguments given by the user. + + Returns: + trainer.update_config(...) call. + """ + return f"\ntrainer.update_config({json.dumps(keyword_arguments, sort_keys=True, indent=4)})\n" + + +def is_on_huggingface(repo_id: str, is_model: bool = True) -> bool: + # NOTE: kindly copied from https://github.com/tomaarsen/SpanMarkerNER/blob/main/span_marker/model_card.py + # Models with more than two 'sections' certainly are not public models + if len(repo_id.split("/")) > 2: + return False + + try: + if is_model: + model_info(repo_id) + else: + dataset_info(repo_id) + return True + except: + # Fetching models can fail for many reasons: Repository not existing, no internet access, HF down, etc. + return False diff --git a/src/argilla/client/feedback/training/base.py b/src/argilla/client/feedback/training/base.py index 243b3f44f4..981406f179 100644 --- a/src/argilla/client/feedback/training/base.py +++ b/src/argilla/client/feedback/training/base.py @@ -16,7 +16,8 @@ import textwrap import warnings from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Optional, Union +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from argilla.client.feedback.schemas.records import FeedbackRecord from argilla.client.feedback.training.schemas import TrainingTaskForTextClassification, TrainingTaskTypes @@ -30,6 +31,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer from argilla.client.feedback.dataset import FeedbackDataset + from argilla.client.feedback.integrations.huggingface.model_card import ArgillaModelCard, FrameworkCardData class ArgillaTrainer(ArgillaTrainerV1): @@ -67,7 +69,8 @@ def __init__( gpu_id: the GPU ID to use when training a SpaCy model. Defaults to -1, which means that the CPU will be used by default. GPU IDs start in 0, which stands for the default GPU in the system, if available. - framework_kwargs: arguments for the framework's trainer. + framework_kwargs: arguments for the framework's trainer. A special key (model_card_kwargs) is reserved + for the arguments that can be passed to the model card. **load_kwargs: arguments for the rg.load() function. """ self._dataset = dataset @@ -93,6 +96,12 @@ def __init__( f"Passing a tokenizer is not supported for the {framework} framework.", UserWarning, stacklevel=2 ) + # Save the model_card arguments if given by the user + if model_card_kwargs := framework_kwargs.pop("model_card_kwargs", None): + self.model_card_kwargs = model_card_kwargs + else: + self.model_card_kwargs = {} + if framework is Framework.SETFIT: if not isinstance(task, TrainingTaskForTextClassification): raise NotImplementedError(f"{Framework.SETFIT} only supports `TextClassification` tasks.") @@ -246,6 +255,44 @@ def predict(self, text: Union[List[str], str], as_argilla_records: bool = True, """ return self._trainer.predict(text=text, as_argilla_records=False, **kwargs) + def save(self, output_dir: str, generate_card: bool = True) -> None: + """ + Saves the model to the specified path and optionally generates a `ModelCard` at the same `output_dir`. + + Args: + output_dir: The path to the directory where the model will be saved. + generate_card: Whether to generate a model card of the `ArgillaTrainer` for the HuggingFace Hub. Defaults + to `True`. + """ + super().save(output_dir) + + if generate_card: + self.generate_model_card(output_dir) + + def generate_model_card(self, output_dir: str) -> "ArgillaModelCard": + """Generate and return a model card based on the model card data. + + Args: + output_dir: Folder where the model card will be written. + + Returns: + model_card: The model card. + """ + from argilla.client.feedback.integrations.huggingface.model_card import ArgillaModelCard + + if not self.model_card_kwargs.get("output_dir"): + self.model_card_kwargs.update({"output_dir": f'"{output_dir}"'}) + + model_card = ArgillaModelCard.from_template( + card_data=self._trainer.get_model_card_data(**self.model_card_kwargs), + template_path=ArgillaModelCard.default_template_path, + ) + + model_card_path = Path(output_dir) / "README.md" + model_card.save(model_card_path) + self._logger.info(f"Model card generated at: {model_card_path}") + return model_card + class ArgillaTrainerSkeleton(ABC): def __init__( @@ -307,3 +354,9 @@ def save(self, output_dir: str) -> None: """ Saves the model to the specified path. """ + + @abstractmethod + def get_model_card_data(self, card_data_kwargs: Dict[str, Any]) -> "FrameworkCardData": + """ + Generates a `FrameworkCardData` instance to generate a model card from. + """ diff --git a/src/argilla/client/feedback/training/frameworks/openai.py b/src/argilla/client/feedback/training/frameworks/openai.py index 7543ff8e5d..0bbe7a1d17 100644 --- a/src/argilla/client/feedback/training/frameworks/openai.py +++ b/src/argilla/client/feedback/training/frameworks/openai.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from argilla.client.feedback.training.base import ArgillaTrainerSkeleton from argilla.training.openai import ArgillaOpenAITrainer as ArgillaOpenAITrainerV1 from argilla.utils.dependency import require_dependencies +if TYPE_CHECKING: + from argilla.client.feedback.integrations.huggingface.model_card import OpenAIModelCardData + class ArgillaOpenAITrainer(ArgillaOpenAITrainerV1, ArgillaTrainerSkeleton): def __init__(self, *args, **kwargs) -> None: @@ -44,3 +49,21 @@ def __init__(self, *args, **kwargs) -> None: raise NotImplementedError("Legacy models are not supported for OpenAI with the FeedbackDataset.") self.init_training_args(model=self._model) + + def get_model_card_data(self, **card_data_kwargs) -> "OpenAIModelCardData": + """ + Generate the card data to be used for the `ArgillaModelCard`. + + Args: + card_data_kwargs: Extra arguments provided by the user when creating the `ArgillaTrainer`. + + Returns: + OpenAIModelCardData: Container for the data to be written on the `ArgillaModelCard`. + """ + from argilla.client.feedback.integrations.huggingface.model_card import OpenAIModelCardData + + return OpenAIModelCardData( + model_name=self._model, + task=self._task, + **card_data_kwargs, + ) diff --git a/src/argilla/client/feedback/training/frameworks/peft.py b/src/argilla/client/feedback/training/frameworks/peft.py index 253cbadefc..6a6b76de23 100644 --- a/src/argilla/client/feedback/training/frameworks/peft.py +++ b/src/argilla/client/feedback/training/frameworks/peft.py @@ -12,11 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING from argilla.client.feedback.training.frameworks.transformers import ArgillaTransformersTrainer from argilla.training.peft import ArgillaPeftTrainer as ArgillaPeftTrainerV1 +if TYPE_CHECKING: + from argilla.client.feedback.integrations.huggingface.model_card import PeftModelCardData + class ArgillaPeftTrainer(ArgillaPeftTrainerV1, ArgillaTransformersTrainer): def __init__(self, *args, **kwargs): ArgillaTransformersTrainer.__init__(self, *args, **kwargs) + + def get_model_card_data(self, **card_data_kwargs) -> "PeftModelCardData": + """ + Generate the card data to be used for the `ArgillaModelCard`. + + Args: + card_data_kwargs: Extra arguments provided by the user when creating the `ArgillaTrainer`. + + Returns: + PeftModelCardData: Container for the data to be written on the `ArgillaModelCard`. + """ + from argilla.client.feedback.integrations.huggingface.model_card import PeftModelCardData + + return PeftModelCardData( + model_id=self._model, + task=self._task, + update_config_kwargs=self.lora_kwargs, + **card_data_kwargs, + ) diff --git a/src/argilla/client/feedback/training/frameworks/sentence_transformers.py b/src/argilla/client/feedback/training/frameworks/sentence_transformers.py index cc020478b1..610b15506b 100644 --- a/src/argilla/client/feedback/training/frameworks/sentence_transformers.py +++ b/src/argilla/client/feedback/training/frameworks/sentence_transformers.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from argilla.client.feedback.dataset import FeedbackDataset + from argilla.client.feedback.integrations.huggingface.model_card import SentenceTransformerCardData class ArgillaSentenceTransformersTrainer(ArgillaTrainerSkeleton): @@ -345,3 +346,24 @@ def save(self, output_dir: str) -> None: # dataset for example should be done taking the extra information from the argilla # dataset instead of the defaults self._trainer.save(output_dir, model_name=None, create_model_card=False, train_datasets=None) + + def get_model_card_data(self, **card_data_kwargs) -> "SentenceTransformerCardData": + """ + Generate the card data to be used for the `ArgillaModelCard`. + + Args: + card_data_kwargs: Extra arguments provided by the user when creating the `ArgillaTrainer`. + + Returns: + SentenceTransformerCardData: Container for the data to be written on the `ArgillaModelCard`. + """ + from argilla.client.feedback.integrations.huggingface.model_card import SentenceTransformerCardData + + return SentenceTransformerCardData( + model_id=self._model, + task=self._task, + framework_kwargs={"cross_encoder": self._cross_encoder}, + update_config_kwargs={**self.trainer_kwargs, **self.model_kwargs, **self.data_kwargs}, + trainer_cls=self._trainer_cls, + **card_data_kwargs, + ) diff --git a/src/argilla/client/feedback/training/frameworks/setfit.py b/src/argilla/client/feedback/training/frameworks/setfit.py index 90bab567f9..265a2ceb13 100644 --- a/src/argilla/client/feedback/training/frameworks/setfit.py +++ b/src/argilla/client/feedback/training/frameworks/setfit.py @@ -13,12 +13,16 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from argilla.client.feedback.training.frameworks.transformers import ArgillaTransformersTrainer from argilla.client.models import TextClassificationRecord from argilla.training.setfit import ArgillaSetFitTrainer as ArgillaSetFitTrainerV1 from argilla.utils.dependency import require_dependencies +if TYPE_CHECKING: + from argilla.client.feedback.integrations.huggingface.model_card import SetFitModelCardData + class ArgillaSetFitTrainer(ArgillaSetFitTrainerV1, ArgillaTransformersTrainer): _logger = logging.getLogger("ArgillaSetFitTrainer") @@ -43,3 +47,22 @@ def __init__(self, *args, **kwargs): self.multi_target_strategy = None self._column_mapping = {"text": "text", "label": "label"} self.init_training_args() + + def get_model_card_data(self, **card_data_kwargs) -> "SetFitModelCardData": + """ + Generate the card data to be used for the `ArgillaModelCard`. + + Args: + card_data_kwargs: Extra arguments provided by the user when creating the `ArgillaTrainer`. + + Returns: + SetFitModelCardData: Container for the data to be written on the `ArgillaModelCard`. + """ + from argilla.client.feedback.integrations.huggingface.model_card import SetFitModelCardData + + return SetFitModelCardData( + model_id=self._model, + task=self._task, + update_config_kwargs={**self.setfit_model_kwargs, **self.setfit_trainer_kwargs}, + **card_data_kwargs, + ) diff --git a/src/argilla/client/feedback/training/frameworks/spacy.py b/src/argilla/client/feedback/training/frameworks/spacy.py index 1665facdd4..0833a006a4 100644 --- a/src/argilla/client/feedback/training/frameworks/spacy.py +++ b/src/argilla/client/feedback/training/frameworks/spacy.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional from typing_extensions import Literal @@ -24,6 +24,12 @@ from argilla.training.spacy import _ArgillaSpaCyTrainerBase as _ArgillaSpaCyTrainerBaseV1 from argilla.utils.dependency import require_dependencies +if TYPE_CHECKING: + from argilla.client.feedback.integrations.huggingface.model_card import ( + SpacyModelCardData, + SpacyTransformersModelCardData, + ) + class _ArgillaSpaCyTrainerBase(_ArgillaSpaCyTrainerBaseV1, ArgillaTrainerSkeleton): _logger = logging.getLogger("ArgillaSpaCyTrainer") @@ -150,6 +156,29 @@ def __init__(self, freeze_tok2vec: bool = False, **kwargs) -> None: self.freeze_tok2vec = freeze_tok2vec _ArgillaSpaCyTrainerBase.__init__(self, **kwargs) + def get_model_card_data(self, **card_data_kwargs) -> "SpacyModelCardData": + """ + Generate the card data to be used for the `ArgillaModelCard`. + + Args: + card_data_kwargs: Extra arguments provided by the user when creating the `ArgillaTrainer`. + + Returns: + SpacyModelCardData: Container for the data to be written on the `ArgillaModelCard`. + """ + from argilla.client.feedback.integrations.huggingface.model_card import SpacyModelCardData + + return SpacyModelCardData( + model_id=self._model, + task=self._task, + lang=self.language, + gpu_id=self.gpu_id, + framework_kwargs={"optimize": self.optimize, "freeze_tok2vec": self.freeze_tok2vec}, + pipeline=self._pipeline, # Used only to keep track for the config arguments + update_config_kwargs=self.config["training"], + **card_data_kwargs, + ) + class ArgillaSpaCyTransformersTrainer(ArgillaSpaCyTransformersTrainerV1, _ArgillaSpaCyTrainerBase): def __init__(self, update_transformer: bool = True, **kwargs) -> None: @@ -162,3 +191,26 @@ def __init__(self, update_transformer: bool = True, **kwargs) -> None: """ self.update_transformer = update_transformer _ArgillaSpaCyTrainerBase.__init__(self, **kwargs) + + def get_model_card_data(self, **card_data_kwargs) -> "SpacyTransformersModelCardData": + """ + Generate the card data to be used for the `ArgillaModelCard`. + + Args: + card_data_kwargs: Extra arguments provided by the user when creating the `ArgillaTrainer`. + + Returns: + SpacyTransformersModelCardData: Container for the data to be written on the `ArgillaModelCard`. + """ + from argilla.client.feedback.integrations.huggingface.model_card import SpacyTransformersModelCardData + + return SpacyTransformersModelCardData( + model_id=self._model, + task=self._task, + lang=self.language, + gpu_id=self.gpu_id, + framework_kwargs={"optimize": self.optimize, "update_transformer": self.update_transformer}, + pipeline=self._pipeline, # Used only to keep track for the config arguments + update_config_kwargs=self.config["training"], + **card_data_kwargs, + ) diff --git a/src/argilla/client/feedback/training/frameworks/span_marker.py b/src/argilla/client/feedback/training/frameworks/span_marker.py index 9513579d1b..46e687706c 100644 --- a/src/argilla/client/feedback/training/frameworks/span_marker.py +++ b/src/argilla/client/feedback/training/frameworks/span_marker.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from datasets import DatasetDict from argilla.client.feedback.training.base import ArgillaTrainerSkeleton from argilla.client.models import TokenClassificationRecord from argilla.training.span_marker import ArgillaSpanMarkerTrainer as ArgillaSpanMarkerTrainerV1 +if TYPE_CHECKING: + from argilla.client.feedback.integrations.huggingface.model_card import FrameworkCardData + class ArgillaSpanMarkerTrainer(ArgillaSpanMarkerTrainerV1, ArgillaTrainerSkeleton): def __init__(self, *args, **kwargs) -> None: @@ -56,3 +61,8 @@ def __init__(self, *args, **kwargs) -> None: raise NotImplementedError("Text2TextRecord and TextClassification are not supported.") self.init_training_args() + + def get_model_card_data(self, **card_data_kwargs) -> "FrameworkCardData": + raise NotImplementedError( + "This method has to be implemented after `FeedbackDataset` allows for token classification." + ) diff --git a/src/argilla/client/feedback/training/frameworks/transformers.py b/src/argilla/client/feedback/training/frameworks/transformers.py index 7918ac6d22..ae26f115c6 100644 --- a/src/argilla/client/feedback/training/frameworks/transformers.py +++ b/src/argilla/client/feedback/training/frameworks/transformers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING from datasets import Dataset, DatasetDict @@ -19,6 +20,9 @@ from argilla.client.feedback.training.schemas import TrainingTaskForQuestionAnswering, TrainingTaskForTextClassification from argilla.training.transformers import ArgillaTransformersTrainer as ArgillaTransformersTrainerV1 +if TYPE_CHECKING: + from argilla.client.feedback.integrations.huggingface.model_card import TransformersModelCardData + class ArgillaTransformersTrainer(ArgillaTransformersTrainerV1, ArgillaTrainerSkeleton): def __init__(self, *args, **kwargs): @@ -69,3 +73,30 @@ def __init__(self, *args, **kwargs): ) self.init_training_args() + + def get_model_card_data(self, **card_data_kwargs) -> "TransformersModelCardData": + """ + Generate the card data to be used for the `ArgillaModelCard`. + + Args: + card_data_kwargs: Extra arguments provided by the user when creating the `ArgillaTrainer`. + + Returns: + TransformersModelCardData: Container for the data to be written on the `ArgillaModelCard`. + """ + from argilla.client.feedback.integrations.huggingface.model_card import TransformersModelCardData + + if not card_data_kwargs.get("tags"): + if isinstance(self._task, TrainingTaskForTextClassification): + tags = ["text-classification"] + else: + tags = ["question-answering"] + + card_data_kwargs.update({"tags": tags + ["transformers", "argilla"]}) + + return TransformersModelCardData( + model_id=self._model, + task=self._task, + update_config_kwargs=self.trainer_kwargs, + **card_data_kwargs, + ) diff --git a/src/argilla/client/feedback/training/frameworks/trl.py b/src/argilla/client/feedback/training/frameworks/trl.py index d1a7973d2b..52c6af9ed0 100644 --- a/src/argilla/client/feedback/training/frameworks/trl.py +++ b/src/argilla/client/feedback/training/frameworks/trl.py @@ -31,6 +31,7 @@ from trl import PPOConfig from argilla.client.feedback.dataset import FeedbackDataset + from argilla.client.feedback.integrations.huggingface.model_card import TRLModelCardData class PPOArgs: @@ -387,3 +388,34 @@ def __repr__(self) -> str: for key, val in arg_dict_single.items(): formatted_string.append(f"{key}: {val}") return "\n".join(formatted_string) + + def get_model_card_data(self, **card_data_kwargs) -> "TRLModelCardData": + """ + Generate the card data to be used for the `ArgillaModelCard`. + + Args: + card_data_kwargs: Extra arguments provided by the user when creating the `ArgillaTrainer`. + + Returns: + TRLModelCardData: Container for the data to be written on the `ArgillaModelCard`. + """ + from argilla.client.feedback.integrations.huggingface.model_card import TRLModelCardData + + if not card_data_kwargs.get("tags"): + if isinstance(self._task, TrainingTaskForSFT): + tags = ["supervised-fine-tuning", "sft"] + elif isinstance(self._task, TrainingTaskForRM): + tags = ["reward-modeling", "rm"] + elif isinstance(self._task, TrainingTaskForPPO): + tags = ["proximal-policy-optimization", "ppo"] + elif isinstance(self._task, TrainingTaskForDPO): + tags = ["direct-preference-optimization", "dpo"] + + card_data_kwargs.update({"tags": tags + ["TRL", "argilla"]}) + + return TRLModelCardData( + model_id=self._model, + task=self._task, + update_config_kwargs={**self.training_args_kwargs, **self.trainer_kwargs}, + **card_data_kwargs, + ) diff --git a/src/argilla/client/feedback/training/schemas.py b/src/argilla/client/feedback/training/schemas.py index b9db2be3ef..0aebcaf06f 100644 --- a/src/argilla/client/feedback/training/schemas.py +++ b/src/argilla/client/feedback/training/schemas.py @@ -1485,9 +1485,8 @@ def __id2label__(self): def __repr__(self) -> str: return ( f"{self.__class__.__name__}" - f"\n\t texts={self.text.name}" - f"\n\t label={self.label.question.name}" - f"\n\t multi_label={self.__multi_label__}" + f"\n\t texts={self.texts.name if self.texts else None}" + f"\n\t label={self.label.question.name if self.label else None}" f"\n\t all_labels={self.__all_labels__}" f"\n\t formatting_funct={self.formatting_func}" ) @@ -1633,6 +1632,18 @@ def dataset_fields(sample): TrainingTaskForSentenceSimilarity, ] +# Helper map fr the creation of the model cards. +TRAINING_TASK_MAPPING = { + TrainingTaskForTextClassification: "for_text_classification", + TrainingTaskForSFT: "for_supervised_fine_tuning", + TrainingTaskForRM: "for_reward_modeling", + TrainingTaskForPPO: "for_proximal_policy_optimization", + TrainingTaskForDPO: "for_direct_preference_optimization", + TrainingTaskForChatCompletion: "for_chat_completion", + TrainingTaskForQuestionAnswering: "for_question_answering", + TrainingTaskForSentenceSimilarity: "for_sentence_similarity", +} + # Old, deprecated variants. class RenamedDeprecationMixin: diff --git a/src/argilla/client/feedback/utils.py b/src/argilla/client/feedback/utils.py index 47875c6be4..927e33e5b5 100644 --- a/src/argilla/client/feedback/utils.py +++ b/src/argilla/client/feedback/utils.py @@ -109,7 +109,7 @@ def generate_pydantic_schema_for_metadata( """ metadata_fields, metadata_validators = {}, {} - for metadata_property in metadata_properties: + for metadata_property in metadata_properties or []: if metadata_property.type not in MetadataPropertyTypes: raise ValueError( f"Metadata property {metadata_property.name} has an unsupported type: {metadata_property.type}, for the moment only the" diff --git a/src/argilla/client/models.py b/src/argilla/client/models.py index 23ef11c50b..44ba760339 100644 --- a/src/argilla/client/models.py +++ b/src/argilla/client/models.py @@ -36,6 +36,21 @@ Vectors = Dict[str, List[float]] +FRAMEWORK_TO_NAME_MAPPING = { + "transformers": "Transformers", + "peft": "PEFT Transformers library", + "setfit": "SetFit Transformers library", + "spacy": "Spacy Explosion", + "spacy-transformers": "Spacy Transformers Explosion library", + "span_marker": "SpanMarker Tom Aarsen library", + "spark-nlp": "Spark NLP John Snow Labs library", + "openai": "OpenAI LLMs", + "trl": "Transformer Reinforcement Learning", + "trlx": "Transformer Reinforcement Learning X", + "sentence-transformers": "Sentence Transformers library", +} + + class Framework(Enum): """Frameworks supported by Argilla diff --git a/src/argilla/client/sdk/datasets/api.py b/src/argilla/client/sdk/datasets/api.py index 15422e8aa3..ba73d765e8 100644 --- a/src/argilla/client/sdk/datasets/api.py +++ b/src/argilla/client/sdk/datasets/api.py @@ -26,11 +26,14 @@ @lru_cache(maxsize=None) -def get_dataset(client: AuthenticatedClient, name: str) -> Response[Dataset]: +def get_dataset(client: AuthenticatedClient, name: str, workspace: Optional[str] = None) -> Response[Dataset]: url = f"{client.base_url}/api/datasets/{name}" + params = {"workspace": workspace} if workspace else None + response = httpx.get( url=url, + params=params, headers=client.get_headers(), cookies=client.get_cookies(), timeout=client.get_timeout(), @@ -40,6 +43,7 @@ def get_dataset(client: AuthenticatedClient, name: str) -> Response[Dataset]: response_obj = Response.from_httpx_response(response) response_obj.parsed = Dataset(**response.json()) return response_obj + handle_response_error(response) diff --git a/src/argilla/datasets/__init__.py b/src/argilla/datasets/__init__.py index 5ddd40c5af..74af249bf0 100644 --- a/src/argilla/datasets/__init__.py +++ b/src/argilla/datasets/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import warnings from typing import Optional from argilla.client import api @@ -22,7 +23,7 @@ _LOGGER = logging.getLogger(__name__) -def load_dataset_settings(name: str, workspace: Optional[str] = None) -> Settings: +def load_dataset_settings(name: str, workspace: Optional[str] = None) -> Optional[Settings]: """ Loads the settings of a dataset @@ -34,10 +35,9 @@ def load_dataset_settings(name: str, workspace: Optional[str] = None) -> Setting The dataset settings """ active_api = api.active_api() - if workspace is not None: - active_api.set_workspace(workspace) datasets = active_api.datasets - settings = datasets.load_settings(name) + + settings = datasets.load_settings(name, workspace=workspace) if settings is None: return None else: @@ -73,5 +73,5 @@ def configure_dataset(name: str, settings: Settings, workspace: Optional[str] = settings: The dataset settings workspace: The workspace name where the dataset will belongs to """ - _LOGGER.warning("This method is deprecated. Use configure_dataset_settings instead.") + warnings.warn("This method is deprecated. Use configure_dataset_settings instead.", DeprecationWarning) return configure_dataset_settings(name, settings, workspace) diff --git a/tests/integration/client/feedback/conftest.py b/tests/integration/client/feedback/conftest.py index 0982bba3f4..4e05641710 100644 --- a/tests/integration/client/feedback/conftest.py +++ b/tests/integration/client/feedback/conftest.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, Any + import pytest +from argilla.client.models import Framework +from argilla.feedback import TrainingTask + +if TYPE_CHECKING: + from pytest_mock import MockerFixture @pytest.fixture @@ -43,3 +50,552 @@ def ranking_question_payload(): "required": True, "values": ["1", "2"], } + + +@pytest.fixture +def model_card_pattern() -> str: + def inner(framework: Framework, training_task: Any): + # TODO(plaguss): properly annotate training_task argument + if framework == Framework("transformers"): + if training_task == TrainingTask.for_text_classification: + return TRANSFORMERS_CODE_SNIPPET + elif training_task == TrainingTask.for_question_answering: + return TRANSFORMERS_QA_CODE_SNIPPET + elif framework == Framework("setfit"): + return SETFIT_CODE_SNIPPET + elif framework == Framework("peft"): + return PEFT_CODE_SNIPPET + elif framework == Framework("spacy"): + return SPACY_CODE_SNIPPET + elif framework == Framework("spacy-transformers"): + return SPACY_TRANSFORMERS_CODE_SNIPPET + elif framework == Framework("sentence-transformers"): + return SENTENCE_TRANSFORMERS_CODE_SNIPPET + elif framework == Framework("trl"): + if training_task == TrainingTask.for_supervised_fine_tuning: + return TR_SFT_CODE_SNIPPET + elif training_task == TrainingTask.for_reward_modeling: + return TR_RM_CODE_SNIPPET + elif training_task == TrainingTask.for_proximal_policy_optimization: + return TR_PPO_CODE_SNIPPET + elif training_task == TrainingTask.for_direct_preference_optimization: + return TR_DPO_CODE_SNIPPET + elif framework == Framework("openai"): + return OPENAI_CODE_SNIPPET + else: + raise ValueError(f"Framework undefined: {framework}") + + return inner + + +SENTENCE_TRANSFORMERS_CODE_SNIPPET = """\ +```python +# Load the dataset: +dataset = FeedbackDataset.from_huggingface("argilla/emotion") + +# Create the training task: + def formatting_func(sample): + labels = [ + annotation["value"] + for annotation in sample["question-3"] + if annotation["status"] == "submitted" and annotation["value"] is not None + ] + if labels: + # Three cases for the tests: None, one tuple and yielding multiple tuples + if labels[0] == "a": + return None + elif labels[0] == "b": + return {"sentence-1": sample["text"], "sentence-2": sample["text"], "label": 1} + elif labels[0] == "c": + return [ + {"sentence-1": sample["text"], "sentence-2": sample["text"], "label": 1}, + {"sentence-1": sample["text"], "sentence-2": sample["text"], "label": 0}, + ] + +task = TrainingTask.for_sentence_similarity(formatting_func=formatting_func) + +# Create the ArgillaTrainer: +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="sentence-transformers", + model="sentence-transformers/all-MiniLM-L6-v2", + framework_kwargs={'cross_encoder': False}, +) + +trainer.update_config({ + "batch_size": 3 +}) + +trainer.train(output_dir="sentence_similarity_model") +``` + +You can test the type of predictions of this model like so: + +```python +trainer.predict( + [ + ["Machine learning is so easy.", "Deep learning is so straightforward."], + ["Machine learning is so easy.", "This is so difficult, like rocket science."], + ["Machine learning is so easy.", "I can't believe how much I struggled with this."] + ] +) +``` +""" + + +TRANSFORMERS_CODE_SNIPPET = """\ +```python +# Load the dataset: +dataset = FeedbackDataset.from_huggingface("argilla/emotion") + +# Create the training task: +task = TrainingTask.for_text_classification(text=dataset.field_by_name("text"), label=dataset.question_by_name("question-3")) + +# Create the ArgillaTrainer: +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="transformers", + model="prajjwal1/bert-tiny", +) + +trainer.update_config({ + "logging_steps": 1, + "num_train_epochs": 1 +}) + +trainer.train(output_dir="text_classification_model") +``` +""" + + +TRANSFORMERS_QA_CODE_SNIPPET = """\ +```python +# Load the dataset: +dataset = FeedbackDataset.from_huggingface("argilla/emotion") + +# Create the training task: +task = TrainingTask.for_question_answering(question=dataset.field_by_name("label"), context=dataset.field_by_name("text"), answer=dataset.question_by_name("question-1")) + +# Create the ArgillaTrainer: +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="transformers", + model="prajjwal1/bert-tiny", +) + +trainer.update_config({ + "logging_steps": 1, + "num_train_epochs": 1 +}) + +trainer.train(output_dir="question_answering_model") +``` +""" + + +SETFIT_CODE_SNIPPET = """\ +```python +# Load the dataset: +dataset = FeedbackDataset.from_huggingface("argilla/emotion") + +# Create the training task: +task = TrainingTask.for_text_classification(text=dataset.field_by_name("text"), label=dataset.question_by_name("question-3")) + +# Create the ArgillaTrainer: +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="setfit", + model="all-MiniLM-L6-v2", +) + +trainer.update_config({ + "num_iterations": 1 +}) + +trainer.train(output_dir="text_classification_model") +``` +""" + + +PEFT_CODE_SNIPPET = """\ +```python +# Load the dataset: +dataset = FeedbackDataset.from_huggingface("argilla/emotion") + +# Create the training task: +task = TrainingTask.for_text_classification(text=dataset.field_by_name("text"), label=dataset.question_by_name("question-3")) + +# Create the ArgillaTrainer: +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="peft", + model="prajjwal1/bert-tiny", +) + +trainer.train(output_dir="text_classification_model") +""" + + +SPACY_CODE_SNIPPET = """\ +```python +# Load the dataset: +dataset = FeedbackDataset.from_huggingface("argilla/emotion") + +# Create the training task: +task = TrainingTask.for_text_classification(text=dataset.field_by_name("text"), label=dataset.question_by_name("question-3")) + +# Create the ArgillaTrainer: +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="spacy", + lang="en", + model="en_core_web_sm", + gpu_id=-1, + framework_kwargs={'optimize': 'efficiency', 'freeze_tok2vec': False}, +) + +trainer.train(output_dir="text_classification_model") +``` +""" + + +SPACY_TRANSFORMERS_CODE_SNIPPET = """\ +```python +# Load the dataset: +dataset = FeedbackDataset.from_huggingface("argilla/emotion") + +# Create the training task: +task = TrainingTask.for_text_classification(text=dataset.field_by_name("text"), label=dataset.question_by_name("question-3")) + +# Create the ArgillaTrainer: +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="spacy-transformers", + lang="en", + model="prajjwal1/bert-tiny", + gpu_id=-1, + framework_kwargs={'optimize': 'efficiency', 'update_transformer': True}, +) + +trainer.train(output_dir="text_classification_model") +``` +""" + + +OPENAI_CODE_SNIPPET = """\ +```python +# Load the dataset: +dataset = FeedbackDataset.from_huggingface("argilla/emotion") + +# Create the training task: + def formatting_func(sample: dict): + from uuid import uuid4 + + if sample["response"]: + chat = str(uuid4()) + user_message = user_message_prompt.format(context_str=sample["context"], query_str=sample["user-message"]) + return [ + (chat, "0", "system", system_prompt), + (chat, "1", "user", user_message), + (chat, "2", "assistant", sample["response"][0]["value"]), + ] + else: + return None + +task = TrainingTask.for_chat_completion(formatting_func=formatting_func) + +# Create the ArgillaTrainer: +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="openai", +) + +trainer.train(output_dir="chat_completion_model") +``` + +You can test the type of predictions of this model like so: + +```python +# After training we can use the model from the openai framework, you can take a look at their docs in order to use the model +import openai + +completion = openai.ChatCompletion.create( + model="ft:gpt-3.5-turbo:my-org:custom_suffix:id", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"} + ] +) + +``` +""" + + +TR_SFT_CODE_SNIPPET = """\ +```python +# Load the dataset: +dataset = FeedbackDataset.from_huggingface("argilla/emotion") + +# Create the training task: +def formatting_func_sft(sample: Dict[str, Any]) -> Iterator[str]: + # For example, the sample must be most frequently rated as "1" in question-2 and + # label "b" from "question-3" must have not been set by any annotator + ratings = [ + annotation["value"] + for annotation in sample["question-2"] + if annotation["status"] == "submitted" and annotation["value"] is not None + ] + labels = [ + annotation["value"] + for annotation in sample["question-3"] + if annotation["status"] == "submitted" and annotation["value"] is not None + ] + if ratings and Counter(ratings).most_common(1)[0][0] == 1 and "b" not in labels: + return f"### Text\\n{sample['text']}" + return None + +task = TrainingTask.for_supervised_fine_tuning(formatting_func=formatting_func) + +# Create the ArgillaTrainer: +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="trl", + model="sshleifer/tiny-gpt2", +) + +trainer.update_config({ + "evaluation_strategy": "no", + "max_steps": 1 +}) + +trainer.train(output_dir="sft_model") +``` + +You can test the type of predictions of this model like so: + +```python +# This type of model has no `predict` method implemented from argilla, but can be done using the underlying library +from transformers import GenerationConfig, AutoTokenizer, GPT2LMHeadModel + +def generate(model_id: str, instruction: str, context: str = "") -> str: + model = GPT2LMHeadModel.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + inputs = template.format( + instruction=instruction, + context=context, + response="", + ).strip() + + encoding = tokenizer([inputs], return_tensors="pt") + outputs = model.generate( + **encoding, + generation_config=GenerationConfig( + max_new_tokens=32, + min_new_tokens=12, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ), + ) + return tokenizer.decode(outputs[0]) + +generate("sft_model", "Is a toad a frog?") +``` +""" + + +TR_RM_CODE_SNIPPET = """\ +```python +# Load the dataset: +dataset = FeedbackDataset.from_huggingface("argilla/emotion") + +# Create the training task: +def formatting_func_rm(sample: Dict[str, Any]): + # The FeedbackDataset isn't really set up for RM, so we'll just use an arbitrary example here + labels = [ + annotation["value"] + for annotation in sample["question-3"] + if annotation["status"] == "submitted" and annotation["value"] is not None + ] + if labels: + # Three cases for the tests: None, one tuple and yielding multiple tuples + if labels[0] == "a": + return None + elif labels[0] == "b": + return sample["text"], sample["text"][:5] + elif labels[0] == "c": + return [(sample["text"], sample["text"][5:10]), (sample["text"], sample["text"][:5])] + +task = TrainingTask.for_reward_modeling(formatting_func=formatting_func) + +# Create the ArgillaTrainer: +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="trl", + model="sshleifer/tiny-gpt2", +) + +trainer.update_config({ + "evaluation_strategy": "no", + "max_steps": 1 +}) + +trainer.train(output_dir="rm_model") +``` + +You can test the type of predictions of this model like so: + +```python +# This type of model has no `predict` method implemented from argilla, but can be done using the underlying library +from transformers import AutoTokenizer, AutoModelForSequenceClassification +import torch + +model = AutoModelForSequenceClassification.from_pretrained("rm_model") +tokenizer = AutoTokenizer.from_pretrained("rm_model") + +def get_score(model, tokenizer, text): + # Tokenize the input sequences + inputs = tokenizer(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt") + + # Perform forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # Extract the logits + return outputs.logits[0, 0].item() + +# Example usage +example = template.format(instruction="your prompt", context="your context", response="response") + +score = get_score(model, tokenizer, example) +print(score) +``` +""" + + +TR_PPO_CODE_SNIPPET = """\ +```python +# Load the dataset: +dataset = FeedbackDataset.from_huggingface("argilla/emotion") + +# Create the training task: +def formatting_func_ppo(sample: Dict[str, Any]): + return sample["text"] + +task = TrainingTask.for_proximal_policy_optimization(formatting_func=formatting_func) + +# Create the ArgillaTrainer: +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="trl", + model="sshleifer/tiny-gpt2", +) + +trainer.train(output_dir="ppo_model") +``` + +You can test the type of predictions of this model like so: + +```python +# This type of model has no `predict` method implemented from argilla, but can be done using the underlying library +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("ppo_model") +tokenizer = AutoTokenizer.from_pretrained("ppo_model") +tokenizer.pad_token = tokenizer.eos_token + +inputs = template.format( + instruction="your prompt", + context="your context", + response="" +).strip() +encoding = tokenizer([inputs], return_tensors="pt") +outputs = model.generate(**encoding, max_new_tokens=30) +output_text = tokenizer.decode(outputs[0]) +print(output_text) +``` +""" + + +TR_DPO_CODE_SNIPPET = """\ +```python +# Load the dataset: +dataset = FeedbackDataset.from_huggingface("argilla/emotion") + +# Create the training task: +def formatting_func_dpo(sample: Dict[str, Any]): + # The FeedbackDataset isn't really set up for DPO, so we'll just use an arbitrary example here + labels = [ + annotation["value"] + for annotation in sample["question-3"] + if annotation["status"] == "submitted" and annotation["value"] is not None + ] + if labels: + # Three cases for the tests: None, one tuple and yielding multiple tuples + if labels[0] == "a": + return None + elif labels[0] == "b": + return sample["text"][::-1], sample["text"], sample["text"][:5] + elif labels[0] == "c": + return [ + (sample["text"], sample["text"][::-1], sample["text"][:5]), + (sample["text"][::-1], sample["text"], sample["text"][:5]), + ] + +task = TrainingTask.for_direct_preference_optimization(formatting_func=formatting_func) + +# Create the ArgillaTrainer: +trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="trl", + model="sshleifer/tiny-gpt2", +) + +trainer.update_config({ + "evaluation_strategy": "no", + "max_steps": 1 +}) + +trainer.train(output_dir="dpo_model") +``` + +You can test the type of predictions of this model like so: + +```python +# This type of model has no `predict` method implemented from argilla, but can be done using the underlying library +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("dpo_model") +tokenizer = AutoTokenizer.from_pretrained("dpo_model") +tokenizer.pad_token = tokenizer.eos_token + +inputs = template.format( + instruction="your prompt", + context="your context", + response="" +).strip() +encoding = tokenizer([inputs], return_tensors="pt") +outputs = model.generate(**encoding, max_new_tokens=30) +output_text = tokenizer.decode(outputs[0]) +print(output_text) +``` +""" + + +@pytest.fixture +def mocked_is_on_huggingface(mocker: "MockerFixture") -> bool: + mocker.patch( + "argilla.client.feedback.integrations.huggingface.model_card.model_card.is_on_huggingface", return_value=True + ) diff --git a/tests/integration/client/feedback/dataset/test_dataset.py b/tests/integration/client/feedback/dataset/local/test_dataset.py similarity index 94% rename from tests/integration/client/feedback/dataset/test_dataset.py rename to tests/integration/client/feedback/dataset/local/test_dataset.py index d45e047680..b677c75652 100644 --- a/tests/integration/client/feedback/dataset/test_dataset.py +++ b/tests/integration/client/feedback/dataset/local/test_dataset.py @@ -17,6 +17,7 @@ import datasets import pytest +from argilla import Workspace from argilla.client import api from argilla.client.feedback.config import DatasetConfig from argilla.client.feedback.dataset import FeedbackDataset @@ -596,6 +597,7 @@ def test_push_to_huggingface_and_from_huggingface( "feedback_dataset_records", ) def test_prepare_for_training_text_classification( + owner: "ServerUser", framework: Union[Framework, str], question: str, feedback_dataset_guidelines: str, @@ -609,7 +611,48 @@ def test_prepare_for_training_text_classification( questions=feedback_dataset_questions, ) dataset.add_records(feedback_dataset_records) - label = dataset.question_by_name(question) + + api.init(api_key=owner.api_key) + ws = Workspace.create(name="test-workspace") + + remote = dataset.push_to_argilla(name="test-dataset", workspace=ws) + + label = remote.question_by_name(question) task = TrainingTask.for_text_classification(text=dataset.fields[0], label=label) - dataset.prepare_for_training(framework=framework, task=task) + data = remote.prepare_for_training(framework=framework, task=task) + assert data is not None + + +@pytest.mark.usefixtures( + "feedback_dataset_guidelines", + "feedback_dataset_fields", + "feedback_dataset_questions", + "feedback_dataset_records", +) +def test_warning_remote_dataset_methods( + feedback_dataset_guidelines: str, + feedback_dataset_fields: List["AllowedFieldTypes"], + feedback_dataset_questions: List["AllowedQuestionTypes"], + feedback_dataset_records: List[FeedbackRecord], +): + dataset = FeedbackDataset( + guidelines=feedback_dataset_guidelines, + fields=feedback_dataset_fields, + questions=feedback_dataset_questions, + ) + + with pytest.warns( + UserWarning, match="`pull` method is not supported for local datasets and won't take any effect." + ): + dataset.pull() + + with pytest.warns( + UserWarning, match="`filter_by` method is not supported for local datasets and won't take any effect." + ): + dataset.filter_by() + + with pytest.warns( + UserWarning, match="`delete` method is not supported for local datasets and won't take any effect." + ): + dataset.delete() diff --git a/tests/integration/client/feedback/dataset/remote/test_dataset.py b/tests/integration/client/feedback/dataset/remote/test_dataset.py index f57a2ae91c..ef085bf057 100644 --- a/tests/integration/client/feedback/dataset/remote/test_dataset.py +++ b/tests/integration/client/feedback/dataset/remote/test_dataset.py @@ -517,3 +517,28 @@ async def test_pull_without_results( assert local_copy is not None assert local_copy.records == [] + + @pytest.mark.parametrize("role", [UserRole.owner, UserRole.admin]) + async def test_warning_local_methods(self, role: UserRole) -> None: + dataset = await DatasetFactory.create() + await TextFieldFactory.create(dataset=dataset, required=True) + await TextQuestionFactory.create(dataset=dataset, required=True) + await RecordFactory.create_batch(dataset=dataset, size=10) + user = await UserFactory.create(role=role, workspaces=[dataset.workspace]) + + api.init(api_key=user.api_key) + ds = FeedbackDataset.from_argilla(id=dataset.id) + + with pytest.raises(ValueError, match="`FeedbackRecord.fields` does not match the expected schema"): + with pytest.warns( + UserWarning, + match="A local `FeedbackDataset` returned because `unify_responses` is not supported for `RemoteFeedbackDataset`. ", + ): + ds.unify_responses(question=None, strategy=None) + + with pytest.raises(ValueError, match="`FeedbackRecord.fields` does not match the expected schema"): + with pytest.warns( + UserWarning, + match="A local `FeedbackDataset` returned because `prepare_for_training` is not supported for `RemoteFeedbackDataset`. ", + ): + ds.prepare_for_training(framework=None, task=None) diff --git a/tests/integration/client/feedback/dataset/remote/test_filter_and_sorting.py b/tests/integration/client/feedback/dataset/remote/test_filter_and_sorting.py index ce50938db6..5f2a2e57f0 100644 --- a/tests/integration/client/feedback/dataset/remote/test_filter_and_sorting.py +++ b/tests/integration/client/feedback/dataset/remote/test_filter_and_sorting.py @@ -19,7 +19,7 @@ import pytest from argilla import SortBy, TextField, TextQuestion from argilla.client import api -from argilla.client.feedback.dataset.local import FeedbackDataset +from argilla.client.feedback.dataset.local.dataset import FeedbackDataset from argilla.client.feedback.schemas.enums import ResponseStatusFilter from argilla.client.feedback.schemas.metadata import ( FloatMetadataFilter, diff --git a/tests/integration/client/feedback/integrations/huggingface/test_model_card.py b/tests/integration/client/feedback/integrations/huggingface/test_model_card.py new file mode 100644 index 0000000000..f661d0a7f9 --- /dev/null +++ b/tests/integration/client/feedback/integrations/huggingface/test_model_card.py @@ -0,0 +1,402 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Only a subset of the possibilities is tested for speed. +- spacy, spacy-transformers, transformers (text-classification and QA), setfit, and peft with +default dataset fields. +- sentence-transformers and trl with formatting_func. +""" + +import shutil +from collections import Counter +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Union + +import pytest +from argilla.client.feedback.schemas import ( + FeedbackRecord, + LabelQuestion, + MultiLabelQuestion, +) +from argilla.client.feedback.unification import LabelQuestionUnification +from argilla.client.models import Framework +from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask + +if TYPE_CHECKING: + from argilla.client.feedback.schemas import FeedbackRecord + from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes + + +DATASET_NAME = "argilla/emotion" +MODEL_CARD_NAME = "README.md" + + +@pytest.mark.parametrize( + "framework, training_task", + [ + (Framework("spacy"), TrainingTask.for_text_classification), + (Framework("spacy-transformers"), TrainingTask.for_text_classification), + (Framework("transformers"), TrainingTask.for_text_classification), + (Framework("transformers"), TrainingTask.for_question_answering), + (Framework("setfit"), TrainingTask.for_text_classification), + (Framework("peft"), TrainingTask.for_text_classification), + (Framework("span_marker"), TrainingTask.for_text_classification), + ], +) +@pytest.mark.usefixtures( + "feedback_dataset_guidelines", + "feedback_dataset_fields", + "feedback_dataset_questions", + "feedback_dataset_records", + "model_card_pattern", +) +def test_model_card_with_defaults( + framework: Union[Framework, str], + training_task: str, + feedback_dataset_guidelines: str, + feedback_dataset_fields: List["AllowedFieldTypes"], + feedback_dataset_questions: List["AllowedQuestionTypes"], + feedback_dataset_records: List[FeedbackRecord], + model_card_pattern: str, + mocked_is_on_huggingface, +) -> None: + # This test is almost a copy from the one in `test_trainer.py`, it's separated for + # simplicity, but for speed we should test this at the same trainer. + + dataset = FeedbackDataset( + guidelines=feedback_dataset_guidelines, + fields=feedback_dataset_fields, + questions=feedback_dataset_questions, + ) + dataset.add_records(records=feedback_dataset_records * 2) + + questions = [ + question for question in dataset.questions if isinstance(question, (LabelQuestion, MultiLabelQuestion)) + ] + label = LabelQuestionUnification(question=questions[0]) + + if training_task == TrainingTask.for_question_answering: + task = TrainingTask.for_question_answering( + question=dataset.field_by_name("label"), + context=dataset.field_by_name("text"), + answer=dataset.question_by_name("question-1"), + ) + output_dir = '"question_answering_model"' + elif training_task == TrainingTask.for_text_classification: + task = TrainingTask.for_text_classification(text=dataset.fields[0], label=label) + output_dir = '"text_classification_model"' + + if framework == Framework("spacy"): + model = "en_core_web_sm" + elif framework == Framework("setfit"): + model = "all-MiniLM-L6-v2" + else: + model = "prajjwal1/bert-tiny" + + if framework == Framework("span_marker"): + with pytest.raises(NotImplementedError, match=f"^Framework span_marker is not supported for this"): + trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework=framework, + model=model, + framework_kwargs={ + "model_card_kwargs": {"license": "mit", "language": ["en", "es"], "dataset_name": DATASET_NAME} + }, + ) + return + else: + trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework=framework, + model=model, + framework_kwargs={ + "model_card_kwargs": { + "license": "mit", + "language": ["en", "es"], + "dataset_name": DATASET_NAME, + "output_dir": output_dir, + }, + }, + ) + + if framework in [Framework("spacy"), Framework("spacy-transformers")]: + trainer.update_config(max_steps=1) + elif framework in [Framework("transformers"), Framework("setfit")]: + trainer.update_config(num_iterations=1) + + with TemporaryDirectory() as tmpdirname: + model_card = trainer.generate_model_card(tmpdirname) + assert (Path(tmpdirname) / MODEL_CARD_NAME).exists() + pattern = model_card_pattern(framework, training_task) + assert model_card.content.find(pattern) > -1 + + +@pytest.mark.usefixtures( + "feedback_dataset_fields", + "feedback_dataset_questions", + "feedback_dataset_guidelines", + "feedback_dataset_records", + "model_card_pattern", +) +def test_model_card_sentence_transformers( + feedback_dataset_fields: List["AllowedFieldTypes"], + feedback_dataset_questions: List["AllowedQuestionTypes"], + feedback_dataset_guidelines: str, + feedback_dataset_records: List["FeedbackRecord"], + model_card_pattern: str, + mocked_is_on_huggingface, +) -> None: + dataset = FeedbackDataset( + guidelines=feedback_dataset_guidelines, + fields=feedback_dataset_fields, + questions=feedback_dataset_questions, + ) + dataset.add_records(records=feedback_dataset_records * 2) + + def formatting_func(sample): + labels = [ + annotation["value"] + for annotation in sample["question-3"] + if annotation["status"] == "submitted" and annotation["value"] is not None + ] + if labels: + # Three cases for the tests: None, one tuple and yielding multiple tuples + if labels[0] == "a": + return None + elif labels[0] == "b": + return {"sentence-1": sample["text"], "sentence-2": sample["text"], "label": 1} + elif labels[0] == "c": + return [ + {"sentence-1": sample["text"], "sentence-2": sample["text"], "label": 1}, + {"sentence-1": sample["text"], "sentence-2": sample["text"], "label": 0}, + ] + + task = TrainingTask.for_sentence_similarity(formatting_func=formatting_func) + + trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="sentence-transformers", + framework_kwargs={ + "cross_encoder": False, + "model_card_kwargs": { + "license": "mit", + "language": ["en", "es"], + "dataset_name": DATASET_NAME, + "output_dir": '"sentence_similarity_model"', + }, + }, + ) + trainer.update_config(epochs=1, batch_size=3) + + with TemporaryDirectory() as tmpdirname: + model_card = trainer.generate_model_card(tmpdirname) + assert (Path(tmpdirname) / MODEL_CARD_NAME).exists() + pattern = model_card_pattern(Framework("sentence-transformers"), TrainingTask.for_sentence_similarity) + assert model_card.content.find(pattern) > -1 + + +@pytest.mark.usefixtures( + "model_card_pattern", +) +def test_model_card_openai(model_card_pattern: str, mocked_openai, mocked_is_on_huggingface): + dataset = FeedbackDataset.from_huggingface("argilla/customer_assistant") + # adapation from LlamaIndex's TEXT_QA_PROMPT_TMPL_MSGS[1].content + user_message_prompt = """Context information is below. + --------------------- + {context_str} + --------------------- + Given the context information and not prior knowledge but keeping your Argilla Cloud assistant style, answer the query. + Query: {query_str} + Answer: + """ + # adapation from LlamaIndex's TEXT_QA_SYSTEM_PROMPT + system_prompt = """You are an expert customer service assistant for the Argilla Cloud product that is trusted around the world.""" + + def formatting_func(sample: dict): + from uuid import uuid4 + + if sample["response"]: + chat = str(uuid4()) + user_message = user_message_prompt.format(context_str=sample["context"], query_str=sample["user-message"]) + return [ + (chat, "0", "system", system_prompt), + (chat, "1", "user", user_message), + (chat, "2", "assistant", sample["response"][0]["value"]), + ] + else: + return None + + task = TrainingTask.for_chat_completion(formatting_func=formatting_func) + trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="openai", + ) + + trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="openai", + model="gpt-3.5-turbo-0613", + framework_kwargs={ + "model_card_kwargs": { + "license": "mit", + "language": ["en", "es"], + "dataset_name": DATASET_NAME, + "output_dir": '"chat_completion_model"', + } + }, + ) + + with TemporaryDirectory() as tmpdirname: + model_card = trainer.generate_model_card(tmpdirname) + assert (Path(tmpdirname) / MODEL_CARD_NAME).exists() + pattern = model_card_pattern(Framework("openai"), TrainingTask.for_chat_completion) + assert model_card.content.find(pattern) > -1 + + +def formatting_func_sft(sample: Dict[str, Any]) -> Iterator[str]: + # For example, the sample must be most frequently rated as "1" in question-2 and + # label "b" from "question-3" must have not been set by any annotator + ratings = [ + annotation["value"] + for annotation in sample["question-2"] + if annotation["status"] == "submitted" and annotation["value"] is not None + ] + labels = [ + annotation["value"] + for annotation in sample["question-3"] + if annotation["status"] == "submitted" and annotation["value"] is not None + ] + if ratings and Counter(ratings).most_common(1)[0][0] == 1 and "b" not in labels: + return f"### Text\n{sample['text']}" + return None + + +def formatting_func_rm(sample: Dict[str, Any]): + # The FeedbackDataset isn't really set up for RM, so we'll just use an arbitrary example here + labels = [ + annotation["value"] + for annotation in sample["question-3"] + if annotation["status"] == "submitted" and annotation["value"] is not None + ] + if labels: + # Three cases for the tests: None, one tuple and yielding multiple tuples + if labels[0] == "a": + return None + elif labels[0] == "b": + return sample["text"], sample["text"][:5] + elif labels[0] == "c": + return [(sample["text"], sample["text"][5:10]), (sample["text"], sample["text"][:5])] + + +def formatting_func_ppo(sample: Dict[str, Any]): + return sample["text"] + + +def formatting_func_dpo(sample: Dict[str, Any]): + # The FeedbackDataset isn't really set up for DPO, so we'll just use an arbitrary example here + labels = [ + annotation["value"] + for annotation in sample["question-3"] + if annotation["status"] == "submitted" and annotation["value"] is not None + ] + if labels: + # Three cases for the tests: None, one tuple and yielding multiple tuples + if labels[0] == "a": + return None + elif labels[0] == "b": + return sample["text"][::-1], sample["text"], sample["text"][:5] + elif labels[0] == "c": + return [ + (sample["text"], sample["text"][::-1], sample["text"][:5]), + (sample["text"][::-1], sample["text"], sample["text"][:5]), + ] + + +@pytest.mark.parametrize( + "formatting_func, training_task", + ( + (formatting_func_sft, TrainingTask.for_supervised_fine_tuning), + (formatting_func_rm, TrainingTask.for_reward_modeling), + (formatting_func_ppo, TrainingTask.for_proximal_policy_optimization), + (formatting_func_dpo, TrainingTask.for_direct_preference_optimization), + ), +) +@pytest.mark.usefixtures( + "feedback_dataset_fields", + "feedback_dataset_questions", + "feedback_dataset_guidelines", + "feedback_dataset_records", + "model_card_pattern", +) +def test_model_card_trl( + formatting_func: Callable, + training_task: Callable, + feedback_dataset_guidelines: str, + feedback_dataset_fields: List["AllowedFieldTypes"], + feedback_dataset_questions: List["AllowedQuestionTypes"], + feedback_dataset_records: List[FeedbackRecord], + model_card_pattern: str, + mocked_is_on_huggingface, +) -> None: + dataset = FeedbackDataset( + guidelines=feedback_dataset_guidelines, + fields=feedback_dataset_fields, + questions=feedback_dataset_questions, + ) + dataset.add_records(records=feedback_dataset_records * 2) + task = training_task(formatting_func) + model_id = "sshleifer/tiny-gpt2" + + if training_task == TrainingTask.for_supervised_fine_tuning: + output_dir = '"sft_model"' + elif training_task == TrainingTask.for_reward_modeling: + output_dir = '"rm_model"' + elif training_task == TrainingTask.for_proximal_policy_optimization: + output_dir = '"ppo_model"' + else: + output_dir = '"dpo_model"' + + trainer = ArgillaTrainer( + dataset=dataset, + task=task, + framework="trl", + model=model_id, + framework_kwargs={ + "model_card_kwargs": { + "license": "mit", + "language": ["en", "es"], + "dataset_name": DATASET_NAME, + "output_dir": output_dir, + } + }, + ) + if training_task == TrainingTask.for_proximal_policy_optimization: + from transformers import pipeline + from trl import PPOConfig + + reward_model = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb") + trainer.update_config(config=PPOConfig(batch_size=1, ppo_epochs=2), reward_model=reward_model) + else: + trainer.update_config(max_steps=1) + + with TemporaryDirectory() as tmpdirname: + model_card = trainer.generate_model_card(tmpdirname) + assert (Path(tmpdirname) / MODEL_CARD_NAME).exists() + pattern = model_card_pattern(Framework("trl"), training_task) + assert model_card.content.find(pattern) > -1 diff --git a/tests/integration/test_datasets_settings.py b/tests/integration/test_datasets_settings.py index 654abe257e..6839d5efc4 100644 --- a/tests/integration/test_datasets_settings.py +++ b/tests/integration/test_datasets_settings.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Union +from uuid import uuid4 -import argilla as rg import pytest +from argilla import Workspace from argilla.client import api from argilla.client.api import delete, get_workspace, init from argilla.client.client import Argilla from argilla.client.sdk.commons.errors import ForbiddenApiError -from argilla.datasets import TextClassificationSettings, TokenClassificationSettings -from argilla.datasets.__init__ import configure_dataset +from argilla.datasets import ( + TextClassificationSettings, + TokenClassificationSettings, + configure_dataset, + configure_dataset_settings, + load_dataset_settings, +) from argilla.server.contexts import accounts from argilla.server.security.model import WorkspaceUserCreate @@ -30,6 +36,8 @@ from argilla.server.models import User from sqlalchemy.ext.asyncio import AsyncSession + from .helpers import SecuredClient + @pytest.mark.parametrize( ("settings_", "wrong_settings"), @@ -75,6 +83,59 @@ def test_settings_workflow( configure_dataset(dataset, wrong_settings, workspace=workspace) +@pytest.mark.parametrize( + "settings, workspace", + [ + (TextClassificationSettings(label_schema={"A", "B"}), None), + (TextClassificationSettings(label_schema={"D", "E"}), "admin"), + (TokenClassificationSettings(label_schema={"PER", "ORG"}), None), + (TokenClassificationSettings(label_schema={"CAT", "DOG"}), "admin"), + ], +) +def test_configure_dataset_settings_twice( + owner: "User", + argilla_user: "User", + settings: Union[TextClassificationSettings, TokenClassificationSettings], + workspace: Optional[str], +) -> None: + if not workspace: + workspace_name = argilla_user.username + else: + init(api_key=owner.api_key) + workspace = Workspace.create(name=workspace) + workspace.add_user(argilla_user.id) + workspace_name = workspace.name + + init(api_key=argilla_user.api_key, workspace=argilla_user.username) + dataset_name = f"test-dataset-{uuid4()}" + # This will create the dataset + configure_dataset_settings(dataset_name, settings=settings, workspace=workspace_name) + # This will update the dataset and what describes the issue https://github.com/argilla-io/argilla/issues/3505 + configure_dataset_settings(dataset_name, settings=settings, workspace=workspace_name) + + found_settings = load_dataset_settings(dataset_name, workspace_name) + assert {label for label in found_settings.label_schema} == {str(label) for label in settings.label_schema} + + +@pytest.mark.parametrize( + "settings", + [ + TextClassificationSettings(label_schema={"A", "B"}), + TokenClassificationSettings(label_schema={"PER", "ORG"}), + ], +) +def test_configure_dataset_deprecation_warning( + argilla_user: "User", settings: Union[TextClassificationSettings, TokenClassificationSettings] +) -> None: + init(api_key=argilla_user.api_key, workspace=argilla_user.username) + + dataset_name = f"test-dataset-{uuid4()}" + workspace_name = get_workspace() + + with pytest.warns(DeprecationWarning, match="This method is deprecated. Use configure_dataset_settings instead."): + configure_dataset(dataset_name, settings=settings, workspace=workspace_name) + + def test_list_dataset(mocked_client: "SecuredClient"): from argilla.client.api import active_client diff --git a/tests/unit/cli/datasets/test_delete.py b/tests/unit/cli/datasets/test_delete.py index 0b9b819f87..25b58c367e 100644 --- a/tests/unit/cli/datasets/test_delete.py +++ b/tests/unit/cli/datasets/test_delete.py @@ -33,7 +33,7 @@ def test_delete_dataset( remote_feedback_dataset: "RemoteFeedbackDataset", ) -> None: dataset_from_argilla_mock = mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.from_argilla", + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.from_argilla", return_value=remote_feedback_dataset, ) remote_feedback_dataset_delete_mock = mocker.patch( @@ -55,7 +55,7 @@ def test_delete_dataset_runtime_error( remote_feedback_dataset: "RemoteFeedbackDataset", ) -> None: dataset_from_argilla_mock = mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.from_argilla", + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.from_argilla", return_value=remote_feedback_dataset, ) remote_feedback_dataset_delete_mock = mocker.patch( diff --git a/tests/unit/cli/datasets/test_list.py b/tests/unit/cli/datasets/test_list.py index 9c0ba99d9d..c2b33680aa 100644 --- a/tests/unit/cli/datasets/test_list.py +++ b/tests/unit/cli/datasets/test_list.py @@ -38,7 +38,7 @@ def test_list_datasets( ) -> None: add_row_spy = mocker.spy(Table, "add_row") feedback_dataset_list_mock = mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.list", return_value=[remote_feedback_dataset] + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.list", return_value=[remote_feedback_dataset] ) list_datasets_mock = mocker.patch("argilla.client.api.list_datasets", return_value=[dataset]) @@ -74,7 +74,7 @@ def test_list_datasets( def test_list_datasets_with_workspace(self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture") -> None: workspace_from_name_mock = mocker.patch("argilla.client.workspaces.Workspace.from_name") - feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.dataset.FeedbackDataset.list") list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") result = cli_runner.invoke(cli, "datasets list --workspace unit-test") @@ -98,7 +98,7 @@ def test_list_datasets_with_non_existing_workspace( def test_list_datasets_using_type_feedback_filter( self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture" ) -> None: - feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.dataset.FeedbackDataset.list") list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") result = cli_runner.invoke(cli, "datasets list --type feedback") @@ -110,7 +110,7 @@ def test_list_datasets_using_type_feedback_filter( def test_list_datasets_using_type_other_filter( self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture" ) -> None: - feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.dataset.FeedbackDataset.list") list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") result = cli_runner.invoke(cli, "datasets list --type other") diff --git a/tests/unit/cli/datasets/test_push.py b/tests/unit/cli/datasets/test_push.py index 30179b5924..02fdfbe173 100644 --- a/tests/unit/cli/datasets/test_push.py +++ b/tests/unit/cli/datasets/test_push.py @@ -33,7 +33,7 @@ def test_push_to_huggingface( remote_feedback_dataset: "RemoteFeedbackDataset", ) -> None: dataset_from_argilla_mock = mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.from_argilla", + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.from_argilla", return_value=remote_feedback_dataset, ) push_to_huggingface_mock = mocker.patch( @@ -58,7 +58,7 @@ def test_push_to_huggingface_missing_repo_id_arg( remote_feedback_dataset: "RemoteFeedbackDataset", ) -> None: mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.from_argilla", + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.from_argilla", return_value=remote_feedback_dataset, ) diff --git a/tests/unit/client/feedback/dataset/test_base.py b/tests/unit/client/feedback/dataset/test_base.py index 8c28cd08e3..bbf34bcb86 100644 --- a/tests/unit/client/feedback/dataset/test_base.py +++ b/tests/unit/client/feedback/dataset/test_base.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union +from typing import TYPE_CHECKING, Any, List import pytest from argilla.client.feedback.dataset.base import FeedbackDatasetBase @@ -37,6 +38,24 @@ class TestFeedbackDataset(FeedbackDatasetBase): + def add_records(self, *args, **kwargs) -> None: + pass + + def pull(self): + pass + + def delete(self): + pass + + def prepare_for_training(self, *args, **kwargs) -> Any: + pass + + def push_to_argilla(self, *args, **kwargs) -> "FeedbackDatasetBase": + pass + + def unify_responses(self, *args, **kwargs): + pass + def update_records(self, **kwargs: Dict[str, Any]) -> None: pass diff --git a/tests/unit/client/feedback/dataset/test_local.py b/tests/unit/client/feedback/dataset/test_local.py index cddb2b641b..de46cf1fd2 100644 --- a/tests/unit/client/feedback/dataset/test_local.py +++ b/tests/unit/client/feedback/dataset/test_local.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, List import pytest -from argilla.client.feedback.dataset.local import FeedbackDataset +from argilla.client.feedback.dataset.local.dataset import FeedbackDataset from argilla.client.feedback.schemas.fields import TextField from argilla.client.feedback.schemas.metadata import ( FloatMetadataProperty, @@ -252,14 +252,12 @@ def test_not_implemented_methods(): with pytest.warns( UserWarning, - match="`sort_by` method only works for `FeedbackDataset` pushed to Argilla." - " Use `sorted` with dataset.records instead.", + match="`sort_by` method is not supported for local datasets and won't take any effect. " ): assert dataset.sort_by("field") == dataset with pytest.warns( UserWarning, - match="`filter_by` method only works for `FeedbackDataset` pushed to Argilla." - " Use `filter` with dataset.records instead.", + match="`filter_by` method is not supported for local datasets and won't take any effect. " ): assert dataset.filter_by() == dataset diff --git a/tests/unit/client/feedback/integrations/huggingface/__init__.py b/tests/unit/client/feedback/integrations/huggingface/__init__.py new file mode 100644 index 0000000000..55be41799b --- /dev/null +++ b/tests/unit/client/feedback/integrations/huggingface/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/client/feedback/integrations/huggingface/test_model_card.py b/tests/unit/client/feedback/integrations/huggingface/test_model_card.py new file mode 100644 index 0000000000..54a598b279 --- /dev/null +++ b/tests/unit/client/feedback/integrations/huggingface/test_model_card.py @@ -0,0 +1,61 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict + +import pytest +from argilla.client.feedback.integrations.huggingface.model_card.model_card import ( + _prepare_dict_for_comparison, + _updated_arguments, +) +from argilla.training.utils import get_default_args +from transformers import TrainingArguments + +default_transformer_args = get_default_args(TrainingArguments.__init__) +default_transformer_args_1 = default_transformer_args.copy() +default_transformer_args_1.update({"output_dir": None, "warmup_steps": 100}) +default_transformer_args_2 = default_transformer_args.copy() +default_transformer_args_2.update({"output_dir": {"nested_name": "test"}}) +default_transformer_args_3 = default_transformer_args.copy() +default_transformer_args_3.update({"output_dir": [1.2, 3, "value"]}) + + +@dataclass +class Dummy: + # Test a random class, it could be a loss function passed as a callable, or an instance + # of one for example. + pass + + +default_transformer_args_4 = default_transformer_args.copy() +default_transformer_args_4.update({"output_dir": Dummy, "other": Dummy()}) + + +@pytest.mark.parametrize( + "current_kwargs, new_kwargs", + ( + (default_transformer_args_1, {"warmup_steps": 100}), + (default_transformer_args_2, {"output_dir": {"nested_name": "test"}}), + (default_transformer_args_3, {"output_dir": [1.2, 3, "value"]}), + (default_transformer_args_4, {"output_dir": Dummy, "other": Dummy()}), + ), +) +def test_updated_kwargs(current_kwargs: Dict[str, Any], new_kwargs: Dict[str, Any]): + # Using only the Transformer's TrainingArguments as an example, no need to check if the arguments are correct + + new_arguments = _updated_arguments(default_transformer_args, current_kwargs) + assert set(_prepare_dict_for_comparison(new_arguments).items()) == set( + _prepare_dict_for_comparison(new_kwargs).items() + ) diff --git a/tests/unit/server/api/v1/test_records.py b/tests/unit/server/api/v1/test_records.py index 3c43388b9a..78d8408588 100644 --- a/tests/unit/server/api/v1/test_records.py +++ b/tests/unit/server/api/v1/test_records.py @@ -199,7 +199,7 @@ async def test_update_record_with_no_metadata( FloatMetadataPropertyFactory, 13.3, "wrong-float", - "'name' metadata property validation failed because 'wrong-float' is not an float.", + "'name' metadata property validation failed because 'wrong-float' is not a float.", ), ], ) From 1b840de633a5e52b02a88554e88732dbf2dca118 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Oct 2023 12:33:19 +0000 Subject: [PATCH 02/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/argilla/client/feedback/dataset/local/dataset.py | 4 ++-- src/argilla/client/feedback/dataset/local/mixins.py | 5 +---- src/argilla/client/feedback/dataset/remote/dataset.py | 6 ++---- tests/unit/client/feedback/dataset/test_base.py | 1 - tests/unit/client/feedback/dataset/test_local.py | 6 ++---- 5 files changed, 7 insertions(+), 15 deletions(-) diff --git a/src/argilla/client/feedback/dataset/local/dataset.py b/src/argilla/client/feedback/dataset/local/dataset.py index 01e2b9c0e8..8e1438ed5a 100644 --- a/src/argilla/client/feedback/dataset/local/dataset.py +++ b/src/argilla/client/feedback/dataset/local/dataset.py @@ -14,7 +14,7 @@ import logging import textwrap import warnings -from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union from argilla.client.feedback.constants import FETCHING_BATCH_SIZE from argilla.client.feedback.dataset.base import FeedbackDatasetBase @@ -34,8 +34,8 @@ TrainingTaskForPPO, TrainingTaskForQuestionAnswering, TrainingTaskForRM, - TrainingTaskForSFT, TrainingTaskForSentenceSimilarity, + TrainingTaskForSFT, TrainingTaskForTextClassification, TrainingTaskTypes, ) diff --git a/src/argilla/client/feedback/dataset/local/mixins.py b/src/argilla/client/feedback/dataset/local/mixins.py index 154857f1d3..bb4d8b5ece 100644 --- a/src/argilla/client/feedback/dataset/local/mixins.py +++ b/src/argilla/client/feedback/dataset/local/mixins.py @@ -26,7 +26,6 @@ RatingQuestion, TextQuestion, ) -from argilla.client.feedback.schemas.enums import FieldTypes, QuestionTypes from argilla.client.feedback.schemas.remote.fields import RemoteTextField from argilla.client.feedback.schemas.remote.metadata import ( RemoteFloatMetadataProperty, @@ -48,10 +47,8 @@ if TYPE_CHECKING: import httpx from argilla.client.client import Argilla as ArgillaClient - from argilla.client.feedback.dataset.local.dataset import FeedbackDataset - from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes - from argilla.client.sdk.v1.datasets.models import FeedbackDatasetModel from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset from argilla.client.feedback.schemas.records import FeedbackRecord from argilla.client.feedback.schemas.types import ( AllowedFieldTypes, diff --git a/src/argilla/client/feedback/dataset/remote/dataset.py b/src/argilla/client/feedback/dataset/remote/dataset.py index c3534f5d8e..96f624a5e7 100644 --- a/src/argilla/client/feedback/dataset/remote/dataset.py +++ b/src/argilla/client/feedback/dataset/remote/dataset.py @@ -51,6 +51,8 @@ import httpx from argilla.client.feedback.dataset import FeedbackDataset + from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.schemas.enums import ResponseStatusFilter from argilla.client.feedback.schemas.metadata import MetadataFilters from argilla.client.feedback.schemas.types import ( AllowedMetadataPropertyTypes, @@ -59,10 +61,6 @@ AllowedRemoteQuestionTypes, ) from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel, FeedbackResponseStatusFilter - from argilla.client.feedback.dataset.local import FeedbackDataset - from argilla.client.feedback.schemas.enums import ResponseStatusFilter - from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes - from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel from argilla.client.workspaces import Workspace diff --git a/tests/unit/client/feedback/dataset/test_base.py b/tests/unit/client/feedback/dataset/test_base.py index bbf34bcb86..1eb3b60376 100644 --- a/tests/unit/client/feedback/dataset/test_base.py +++ b/tests/unit/client/feedback/dataset/test_base.py @@ -14,7 +14,6 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union -from typing import TYPE_CHECKING, Any, List import pytest from argilla.client.feedback.dataset.base import FeedbackDatasetBase diff --git a/tests/unit/client/feedback/dataset/test_local.py b/tests/unit/client/feedback/dataset/test_local.py index de46cf1fd2..4902265ff0 100644 --- a/tests/unit/client/feedback/dataset/test_local.py +++ b/tests/unit/client/feedback/dataset/test_local.py @@ -251,13 +251,11 @@ def test_not_implemented_methods(): ) with pytest.warns( - UserWarning, - match="`sort_by` method is not supported for local datasets and won't take any effect. " + UserWarning, match="`sort_by` method is not supported for local datasets and won't take any effect. " ): assert dataset.sort_by("field") == dataset with pytest.warns( - UserWarning, - match="`filter_by` method is not supported for local datasets and won't take any effect. " + UserWarning, match="`filter_by` method is not supported for local datasets and won't take any effect. " ): assert dataset.filter_by() == dataset From 04ad5837bdbb560a2eaa86b67527344b90377532 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 17 Oct 2023 15:43:12 +0200 Subject: [PATCH 03/12] bug: resolved [BUG]` __repr__` problem for TrainingTask #3971 --- src/argilla/client/feedback/training/schemas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/argilla/client/feedback/training/schemas.py b/src/argilla/client/feedback/training/schemas.py index 0aebcaf06f..c6170949ff 100644 --- a/src/argilla/client/feedback/training/schemas.py +++ b/src/argilla/client/feedback/training/schemas.py @@ -1266,9 +1266,9 @@ def __repr__(self) -> str: else: return ( f"{self.__class__.__name__}" - f"\n\t question={self.text.name}" + f"\n\t question={self.question.name}" f"\n\t context={self.context.name}" - f"\n\t answer={self.__multi_label__}" + f"\n\t answer={self.answer.name}" ) @requires_dependencies("transformers") From 14c5dc8be22742e7b5f7d3608b20f1eac3cb4bd3 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Tue, 17 Oct 2023 15:46:29 +0200 Subject: [PATCH 04/12] ci: Check failing test --- .../client/feedback/dataset/remote/test_dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/integration/client/feedback/dataset/remote/test_dataset.py b/tests/integration/client/feedback/dataset/remote/test_dataset.py index ef085bf057..4b61911ccd 100644 --- a/tests/integration/client/feedback/dataset/remote/test_dataset.py +++ b/tests/integration/client/feedback/dataset/remote/test_dataset.py @@ -122,7 +122,7 @@ async def test_add_records(self, owner: "User", test_dataset: FeedbackDataset, r assert len(remote_dataset.records) == 1 - async def test_update_records(self, owner: "User", test_dataset: FeedbackDataset): + def test_update_records(self, owner: "User", test_dataset: FeedbackDataset): rg.init(api_key=owner.api_key) ws = rg.Workspace.create(name="test-workspace") @@ -140,12 +140,13 @@ async def test_update_records(self, owner: "User", test_dataset: FeedbackDataset first_record.metadata.update({"terms-metadata": "a"}) remote.update_records(first_record) + print(first_record) assert first_record == remote[0] first_record = remote[0] - assert first_record.external_id == "new-external-id" assert first_record.metadata["terms-metadata"] == "a" + assert first_record.external_id == "new-external-id" async def test_update_records_with_suggestions(self, owner: "User", test_dataset: FeedbackDataset): rg.init(api_key=owner.api_key) From 0461efcac59af373d1d4a510167709a2e469d1d0 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 17 Oct 2023 15:47:37 +0200 Subject: [PATCH 05/12] chore: updated change log --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d62a7ef7d2..cc1ec54efd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -58,6 +58,8 @@ These are the section headers that we use: - Fixed allow pull datasets without records ([#3851](https://github.com/argilla-io/argilla/pull/3851)) - Updated active learning for text classification notebooks to pass ids of type int to `TextClassificationRecord` ([#3831](https://github.com/argilla-io/argilla/pull/3831)). - Fixed record fields validation that was preventing from logging records with optional fields (i.e. `required=True`) when the field value was `None` ([#3846](https://github.com/argilla-io/argilla/pull/3846)). +- Fixed wrong `__repr__` problem for `TrainingTask` ([#3969](https://github.com/argilla-io/argilla/pull/3969)). + ## [1.16.0](https://github.com/argilla-io/argilla/compare/v1.15.1...v1.16.0) From faf9aa8c92cafa6e064a529346846c04438e0b50 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Tue, 17 Oct 2023 15:48:41 +0200 Subject: [PATCH 06/12] Apply suggestions from code review @alvarobart Co-authored-by: Alvaro Bartolome --- CHANGELOG.md | 4 ++-- src/argilla/client/feedback/dataset/local/dataset.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cc1ec54efd..e40fe60002 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,8 +43,8 @@ These are the section headers that we use: - Updated `FilteredRemoteFeedbackRecords.__len__` method to return the number of records matching the provided filters ([#3916](https://github.com/argilla-io/argilla/pull/3916)). - Increase the default max result window for Elasticsearch created for Feedback datasets ([#3929](https://github.com/argilla-io/argilla/pull/)). - Force elastic index refresh after records creation ([#3929](https://github.com/argilla-io/argilla/pull/)). -- FeedbackDataset API methods have been aligned to be accessible through the several implementations ([#3937](https://github.com/argilla-io/argilla/pull/3937)). -- The `unify_responses` support for remote datasets ([#3937](https://github.com/argilla-io/argilla/pull/3937)). +- `FeedbackDataset` methods have been aligned to be accessible through the several implementations ([#3937](https://github.com/argilla-io/argilla/pull/3937)). +- Support on `RemoteFeedbackDataset` for `unify_responses` method ([#3937](https://github.com/argilla-io/argilla/pull/3937)). ### Fixed diff --git a/src/argilla/client/feedback/dataset/local/dataset.py b/src/argilla/client/feedback/dataset/local/dataset.py index 8e1438ed5a..255ec5b847 100644 --- a/src/argilla/client/feedback/dataset/local/dataset.py +++ b/src/argilla/client/feedback/dataset/local/dataset.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging import textwrap import warnings From a95c002245445e87c85d09b35050c2daa1373203 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 17 Oct 2023 15:58:07 +0200 Subject: [PATCH 07/12] fix: resolved empty value issues within default [BUG] ArgillaTrainer returns key not found error with custom dataset #3970 --- .../client/feedback/training/schemas.py | 31 +++++++++++++++++-- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/src/argilla/client/feedback/training/schemas.py b/src/argilla/client/feedback/training/schemas.py index c6170949ff..24ff6061ea 100644 --- a/src/argilla/client/feedback/training/schemas.py +++ b/src/argilla/client/feedback/training/schemas.py @@ -743,6 +743,9 @@ def _prepare_for_training_with_transformers( datasets_dict = {"id": [], "text": [], "label": []} for index, entry in enumerate(data): + if any([entry.get("label") is None, entry.get("text") is None]): + warnings.warn(f"Skipping entry {entry} because it has no label or text.") + continue datasets_dict["id"].append(index) datasets_dict["text"].append(entry["text"]) datasets_dict["label"].append(entry["label"]) @@ -791,6 +794,9 @@ def _prepare(data): db = DocBin(store_user_data=True) # Creating the DocBin object as in https://spacy.io/usage/training#training-data for entry in data: + if any([entry.get("label") is None, entry.get("text") is None]): + warnings.warn(f"Skipping entry {entry} because it has no label or text.") + continue doc = lang.make_doc(entry["text"]) cats = dict.fromkeys(all_labels, 0) @@ -840,6 +846,9 @@ def _prepare_for_training_with_openai( def _prepare(data): jsonl = [] for entry in data: + if any([entry.get("label") is None, entry.get("text") is None]): + warnings.warn(f"Skipping entry {entry} because it has no label or text.") + continue prompt = entry["text"] prompt += separator # needed for better performance @@ -934,6 +943,9 @@ def _prepare_for_training_with_trl( datasets_dict = {"id": [], "text": []} for index, sample in enumerate(data): + if any([sample.get("text") is None]): + warnings.warn(f"Skipping entry {sample} because it has no text.") + continue datasets_dict["id"].append(index) datasets_dict["text"].append(sample["text"]) @@ -1019,6 +1031,9 @@ def _prepare_for_training_with_trl( datasets_dict = {"chosen": [], "rejected": []} for sample in data: + if any([sample.get("chosen") is None, sample.get("rejected") is None]): + warnings.warn(f"Skipping entry {sample} because it has no chosen or rejected.") + continue datasets_dict["chosen"].append(sample["chosen"]) datasets_dict["rejected"].append(sample["rejected"]) @@ -1088,6 +1103,9 @@ def _prepare_for_training_with_trl( datasets_dict = {"id": [], "query": []} for index, entry in enumerate(data): + if entry.get("query") is None: + warnings.warn(f"Skipping entry {entry} because it has no query.") + continue datasets_dict["id"].append(index) datasets_dict["query"].append(entry["query"]) @@ -1169,6 +1187,9 @@ def _prepare_for_training_with_trl( datasets_dict = {"prompt": [], "chosen": [], "rejected": []} for sample in data: + if any([sample.get("prompt") is None, sample.get("chosen") is None, sample.get("rejected") is None]): + warnings.warn(f"Skipping entry {sample} because it has no prompt, chosen or rejected.") + continue datasets_dict["prompt"].append(sample["prompt"]) datasets_dict["chosen"].append(sample["chosen"]) datasets_dict["rejected"].append(sample["rejected"]) @@ -1283,10 +1304,11 @@ def _prepare_for_training_with_transformers( "answer": [], } for entry in data: - if any([entry["question"] is None, entry["context"] is None, entry["answer"] is None]): + if any([entry.get("question") is None, entry.get("context") is None, entry.get("answer") is None]): + warnings.warn(f"Skipping entry {entry} because it has no question, context or answer.") continue - if entry["answer"] not in entry["context"]: - warnings.warn("This is extractive QnA but the answer is not in the context.") + if entry.get("answer") not in entry.get("context"): + warnings.warn(f"Skipping entry {entry} because answer is not in context.") continue # get index of answer in context answer_start = entry["context"].index(entry["answer"]) @@ -1392,6 +1414,9 @@ def _dict_to_format(ds: datasets.Dataset) -> List[Dict[str, List[Dict[str, str]] datasets_dict = {"chat": [], "turn": [], "role": [], "content": []} for entry in data: + if any([entry.get("prompt") is None, entry.get("response") is None]): + warnings.warn(f"Skipping entry {entry} because it has no prompt or response.") + continue if entry["role"] not in ["system", "user", "assistant"]: raise ValueError("Role must be one of 'system', 'user', 'assistant'") datasets_dict["chat"].append(entry["chat"]) From dd7fc9b9387da85e35940aefb8ee4db25a029dc1 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 17 Oct 2023 15:58:28 +0200 Subject: [PATCH 08/12] chore: updated changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e40fe60002..902e8247f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -58,7 +58,8 @@ These are the section headers that we use: - Fixed allow pull datasets without records ([#3851](https://github.com/argilla-io/argilla/pull/3851)) - Updated active learning for text classification notebooks to pass ids of type int to `TextClassificationRecord` ([#3831](https://github.com/argilla-io/argilla/pull/3831)). - Fixed record fields validation that was preventing from logging records with optional fields (i.e. `required=True`) when the field value was `None` ([#3846](https://github.com/argilla-io/argilla/pull/3846)). -- Fixed wrong `__repr__` problem for `TrainingTask` ([#3969](https://github.com/argilla-io/argilla/pull/3969)). +- Fixed wrong `__repr__` problem for `TrainingTask` ([#3969](https://github.com/argilla-io/argilla/pull/3969)). +- Fixed wrong key return error `prepare_for_training_with_*` for `TrainingTask` ([#3969](https://github.com/argilla-io/argilla/pull/3969)). ## [1.16.0](https://github.com/argilla-io/argilla/compare/v1.15.1...v1.16.0) From 03295bfd41fc1dac55da468219ff3924e880f12d Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 17 Oct 2023 16:04:21 +0200 Subject: [PATCH 09/12] docs: added task template info --- docs/_source/_common/tabs/task_templates.md | 241 ++++++++++++++++++ .../practical_guides/choose_dataset.md | 11 +- .../practical_guides/create_dataset.md | 21 +- 3 files changed, 266 insertions(+), 7 deletions(-) create mode 100644 docs/_source/_common/tabs/task_templates.md diff --git a/docs/_source/_common/tabs/task_templates.md b/docs/_source/_common/tabs/task_templates.md new file mode 100644 index 0000000000..9aae40026e --- /dev/null +++ b/docs/_source/_common/tabs/task_templates.md @@ -0,0 +1,241 @@ +::::{tab-set} + +:::{tab-item} Text classification +```python +import argilla as rg + +ds = rg.FeedbackDataset.for_text_classification( + labels=["positive", "negative"], + multi_label=False, + use_markdown=True, + guidelines=None, +) +ds +# FeedbackDataset( +# fields=[TextField(name="text", use_markdown=True)], +# questions=[LabelQuestion(name="label", labels=["positive", "negative"])] +# guidelines="", +# ) +``` +::: + +:::{tab-item} Summarization +```python +import argilla as rg + +ds = rg.FeedbackDataset.for_summarization( + use_markdown=True, + guidelines=None, +) +ds +# FeedbackDataset( +# fields=[TextField(name="text", use_markdown=True)], +# questions=[TextQuestion(name="summary", use_markdown=True)] +# guidelines="", +# ) +``` +::: + +:::{tab-item} Translation +```python +import argilla as rg + +ds = rg.FeedbackDataset.for_translation( + use_markdown=True, + guidelines=None, +) +ds +# FeedbackDataset( +# fields=[TextField(name="source", use_markdown=True)], +# questions=[TextQuestion(name="target", use_markdown=True)] +# guidelines="", +# ) +``` +::: + +:::{tab-item} Natural Language Inference (NLI) +```python +import argilla as rg + +ds = rg.FeedbackDataset.for_natural_language_inference( + labels=None + use_markdown=True, + guidelines=None, +) +ds +# FeedbackDataset( +# fields=[ +# TextField(name="premise", use_markdown=True), +# TextField(name="hypothesis", use_markdown=True) +# ], +# questions=[ +# LabelQuestion( +# name="label", labels=["entailment", "neutral", "contradiction"] +# ) +# ] +# guidelines="", +# ) +``` +::: + +:::{tab-item} Sentence Similarity +```python +import argilla as rg + +ds = rg.FeedbackDataset.for_sentence_similarity( + rating_scale=10, + use_markdown=True, + guidelines=None, +) +ds +# FeedbackDataset( +# fields=[ +# TextField(name="sentence-1", use_markdown=True), +# TextField(name="sentence-2", use_markdown=True) +# ], +# questions=[ +# RatingQuestion(name="similarity", values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +# ] +# guidelines="", +# ) +``` +::: + +:::{tab-item} Extractive Question Answering +```python +import argilla as rg + +ds = rg.FeedbackDataset.for_question_answering( + use_markdown=True, + guidelines=None, +) +ds +# FeedbackDataset( +# fields=[ +# TextField(name="question", use_markdown=True), +# TextField(name="context", use_markdown=True) +# ], +# questions=[ +# TextQuestion(name="answer", use_markdown=True) +# ] +# guidelines="", +# ) +``` +::: + +:::{tab-item} Supervised Fine-tuning (SFT) +```python +import argilla as rg + +ds = rg.FeedbackDataset.for_supervised_fine_tuning( + context=True, + use_markdown=True, + guidelines=None, +) +ds +# FeedbackDataset( +# fields=[ +# TextField(name="prompt", use_markdown=True), +# TextField(name="context", use_markdown=True) +# ], +# questions=[ +# TextQuestion(name="response", use_markdown=True) +# ] +# guidelines="", +# ) +``` + +:::{tab-item} Preference Modeling +```python +import argilla as rg + +ds = rg.FeedbackDataset.for_preference_modeling( + use_markdown=True, + guidelines=None, +) +ds +# FeedbackDataset( +# fields=[ +# TextField(name="prompt", use_markdown=True), +# TextField(name="context", use_markdown=True), +# TextField(name="response-1", use_markdown=True), +# TextField(name="response-2", use_markdown=True), +# ], +# questions=[ +# LabelQuestion(name="preference", values=["response-1", "response-2"]) +# ] +# guidelines="", +# ) +``` +::: + +:::{tab-item} Proximal Policy Optimization (PPO) +```python +import argilla as rg + +ds = rg.FeedbackDataset.for_proximal_policy_optimization( + context=True, + use_markdown=True, + guidelines=None, +) +ds +# FeedbackDataset( +# fields=[ +# TextField(name="prompt", use_markdown=True), +# TextField(name="context", use_markdown=True) +# ], +# questions=[ +# TextQuestion(name="response", use_markdown=True) +::: + +:::{tab-item} Direct Preference Optimization (DPO) +```python +import argilla as rg + +ds = rg.FeedbackDataset.for_direct_preference_optimization( + context=True, + use_markdown=True, + guidelines=None, +) +ds +# FeedbackDataset( +# fields=[ +# TextField(name="prompt", use_markdown=True), +# TextField(name="context", use_markdown=True) +# TextField(name="response-1", use_markdown=True), +# TextField(name="response-2", use_markdown=True), +# ], +# questions=[ +# LabelQuestion(name="preference", values=["response-1", "response-2"]) +# ] +# guidelines="", +# ) +``` +::: + +:::{tab-item} Retrieval-Augmented Generation (RAG) +```python +import argilla as rg + +ds = rg.FeedbackDataset.for_retrieval_augmented_generation( + number_of_retrievals=1, + rating_scale=10, + use_markdown=False, + guidelines=None, +) +ds +# FeedbackDataset( +# fields=[ +# TextField(name="prompt", use_markdown=True), +# TextField(name="retrieved_document_1", use_markdown=True), +# TextField(name="response", use_markdown=True), +# ], +# questions=[ +# RatingQuestion(name="retrieval_1_rating", values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +# ] +# guidelines="", +# ) +``` +::: + +:::: diff --git a/docs/_source/practical_guides/choose_dataset.md b/docs/_source/practical_guides/choose_dataset.md index 905f210f9c..aa834ea094 100644 --- a/docs/_source/practical_guides/choose_dataset.md +++ b/docs/_source/practical_guides/choose_dataset.md @@ -34,11 +34,16 @@ We are working on it! We will be adding new features to the `FeedbackDataset` ov |------------------------------- |----------------- |-------------------| | Text classification | ✔️ | ✔️ | | Token classificaiton | | ✔️ | -| Text2text | ✔️ | ✔️ | -| RLHF | ✔️ | | -| RAG | ✔️ | | +| Summarization | ✔️ | ✔️ | +| Translation | ✔️ | ✔️ | +| NLI | ✔️ | ✔️ | | Sentence Similarity | ✔️ | | | Question Answering | ✔️ | | +| RLHF (SFT) | ✔️ | | +| RLHF (RM) | ✔️ | | +| RLHF (PPO) | ✔️ | | +| RLHF (DPO) | ✔️ | | +| RAG | ✔️ | | | Image support | ✔️ | | | And many more | ✔️ | | diff --git a/docs/_source/practical_guides/create_dataset.md b/docs/_source/practical_guides/create_dataset.md index 36b62bb527..536cefc9c2 100644 --- a/docs/_source/practical_guides/create_dataset.md +++ b/docs/_source/practical_guides/create_dataset.md @@ -15,7 +15,20 @@ To follow the steps in this guide, you will first need to connect to Argilla. Ch ### Configure the dataset -#### Define `fields` +A record in Argilla refers to a data item that requires annotation and can consist of one or multiple `fields` i.e., the pieces of information that will be shown to the user in the UI in order to complete the annotation task. This can be, for example, a prompt and output pair in the case of instruction datasets. Additionally, the record will contain `questions` that the annotators will need to answer and guidelines to help them complete the task. + +The `FeedbackDataset` has a set of predefined task templates that you can use to quickly set up your dataset. These templates include the `fields` and `questions` needed for the task, as well as the `guidelines` to provide to the annotators. Additionally, you can customize the `fields`, `questions`, and `guidelines` to fit your specific needs using a [custom configuration](#custom-configuration). + +#### Task Templates + +```{include} /_common/tabs/task_templates.md +``` + +After having initialized the `FeedbackDataset` templates, we can still alter the `fields`, `questions`, and `guidelines` to fit our specific needs using a [custom configuration](#custom-configuration). + +#### Custom Configuration + +##### Define `fields` A record in Argilla refers to a data item that requires annotation and can consist of one or multiple fields i.e., the pieces of information that will be shown to the user in the UI in order to complete the annotation task. This can be, for example, a prompt and output pair in the case of instruction datasets. @@ -39,7 +52,7 @@ fields = [ The order of the fields in the UI follows the order in which these are added to the `fields` attribute in the Python SDK. ``` -#### Define `questions` +##### Define `questions` To collect feedback for your dataset, you need to formulate questions. The Feedback Task currently supports the following types of questions: @@ -68,7 +81,7 @@ Check out the following tabs to learn how to set up questions according to their ```{include} /_common/tabs/question_settings.md ``` -#### Define `guidelines` +##### Define `guidelines` Once you have decided on the data to show and the questions to ask, it's important to provide clear guidelines to the annotators. These guidelines help them understand the task and answer the questions consistently. You can provide guidelines in two ways: @@ -77,7 +90,7 @@ Once you have decided on the data to show and the questions to ask, it's importa It is good practice to use at least the dataset guidelines if not both methods. Question descriptions should be short and provide context to a specific question. They can be a summary of the guidelines to that question, but often times that is not sufficient to align the whole annotation team. In the guidelines, you can include a description of the project, details on how to answer each question with examples, instructions on when to discard a record, etc. -#### Create the dataset +##### Create the dataset Once the scope of the project is defined, which implies knowing the `fields`, `questions` and `guidelines` (if applicable), you can proceed to create the `FeedbackDataset`. To do so, you will need to define the following arguments: From 880fa974d9893d5563287610f70c2728a5303d59 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 17 Oct 2023 16:04:40 +0200 Subject: [PATCH 10/12] docs: aensured readbility integrations --- .../tutorials_and_integrations/integrations/integrations.md | 6 +++--- .../integrations/monitor_endpoints with_fastapi.ipynb | 2 +- .../integrations/process_documents_with_unstructured.ipynb | 2 +- .../integrations/use_argilla_callback_in_langchain.md | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/_source/tutorials_and_integrations/integrations/integrations.md b/docs/_source/tutorials_and_integrations/integrations/integrations.md index bb478be6f1..ca20df6f8b 100644 --- a/docs/_source/tutorials_and_integrations/integrations/integrations.md +++ b/docs/_source/tutorials_and_integrations/integrations/integrations.md @@ -4,18 +4,18 @@ Here you can find how to integrate Argilla with other libraries and frameworks. ````{grid} 1 1 3 3 :class-container: tuto-section-2 -```{grid-item-card} Monitoring LLMs in LangChain apps, chains, and agents and tools +```{grid-item-card} LangChain: Monitoring LLMs in apps, chains, and agents and tools :link: use_argilla_callback_in_langchain.html Learn how to use Argilla to monitor LLMs in LangChain apps, chains, and agents and tools. ``` -```{grid-item-card} Large scale document processing for LLMs with Unstructured.io +```{grid-item-card} Unstructured.io: Large scale document processing for LLMs :link: process_documents_with_unstructured.html Learn how to use Argilla to process large scale documents for LLMs with Unstructured.io. ``` -```{grid-item-card} Monitor NLP models with FastAPI and ArgillaLogHTTPMiddleware +```{grid-item-card} FastAPI: Monitor NLP models with ArgillaLogHTTPMiddleware :link: monitor_endpoints with_fastapi.html Learn how to use Argilla to monitor NLP models with FastAPI and ArgillaLogHTTPMiddleware. diff --git a/docs/_source/tutorials_and_integrations/integrations/monitor_endpoints with_fastapi.ipynb b/docs/_source/tutorials_and_integrations/integrations/monitor_endpoints with_fastapi.ipynb index fe8bd889c2..ed78e6a3a7 100644 --- a/docs/_source/tutorials_and_integrations/integrations/monitor_endpoints with_fastapi.ipynb +++ b/docs/_source/tutorials_and_integrations/integrations/monitor_endpoints with_fastapi.ipynb @@ -7,7 +7,7 @@ "id": "BbsELFGS7tQS" }, "source": [ - "# Monitor NLP models with FastAPI and ArgillaLogHTTPMiddleware\n", + "# `FastAPI`: Monitor NLP models with ArgillaLogHTTPMiddleware\n", "\n", "In this tutorial, you'll learn to monitor the predictions of a FastAPI inference endpoint\n", "and log model predictions in a Argilla dataset. It will walk you through 4 basic MLOps Steps:\n", diff --git a/docs/_source/tutorials_and_integrations/integrations/process_documents_with_unstructured.ipynb b/docs/_source/tutorials_and_integrations/integrations/process_documents_with_unstructured.ipynb index 8f6d08afb9..7ea9bd3dc7 100644 --- a/docs/_source/tutorials_and_integrations/integrations/process_documents_with_unstructured.ipynb +++ b/docs/_source/tutorials_and_integrations/integrations/process_documents_with_unstructured.ipynb @@ -6,7 +6,7 @@ "id": "6bb44a8e", "metadata": {}, "source": [ - "# Large scale document processing for LLMs with Unstructured.io\n", + "# `Unstructured.io`: Large scale document processing for LLMs\n", "\n", "In this notebook, we'll show you how you can use the amazing library [unstructured](https://github.com/Unstructured-IO/unstructured) together with [argilla](https://github.com/argilla-io/argilla), and HuggingFace [transformers](https://github.com/huggingface/transformers) to train a custom summarization model. In this case, we're going to build a summarization model targeted at summarizing the [Institute for the Study of War's](https://www.understandingwar.org/) daily reports on the war in Ukraine. You can see an example of one of the reports [here](https://www.understandingwar.org/backgrounder/russian-offensive-campaign-assessment-december-12), and a screen shot appears below.\n", "\n", diff --git a/docs/_source/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.md b/docs/_source/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.md index 5baa8e432a..c4c32bfd6d 100644 --- a/docs/_source/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.md +++ b/docs/_source/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.md @@ -1,4 +1,4 @@ -# Monitoring LLMs in LangChain apps, chains, and agents and tools +# `LangChain`: Monitoring LLMs in apps, chains, and agents and tools This guide explains how to use the `ArgillaCallbackHandler` to integrate Argilla with LangChain apps. With this integration, Argilla can be used evaluate and fine-tune LLMs. It works by collecting the interactions with LLMs and pushing them into a `FeedbackDataset` for continuous monitoring and human feedback. You just need to create a Langchain-compatible `FeedbackDataset` in Argilla and then instantiate the `ArgillaCallbackHandler` to be provided to `LangChain` LLMs, Chains, and/or Agents. From 80bd0367c7885ddef515f45a0592c1cbda493c01 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 17 Oct 2023 16:04:56 +0200 Subject: [PATCH 11/12] docs: added tutorial on spacy-llm --- .../feedback/labelling-spacy-llm.ipynb | 437 ++++++++++++++++++ .../tutorials/tutorials.md | 37 +- 2 files changed, 456 insertions(+), 18 deletions(-) create mode 100644 docs/_source/tutorials_and_integrations/tutorials/feedback/labelling-spacy-llm.ipynb diff --git a/docs/_source/tutorials_and_integrations/tutorials/feedback/labelling-spacy-llm.ipynb b/docs/_source/tutorials_and_integrations/tutorials/feedback/labelling-spacy-llm.ipynb new file mode 100644 index 0000000000..bfa12876e7 --- /dev/null +++ b/docs/_source/tutorials_and_integrations/tutorials/feedback/labelling-spacy-llm.ipynb @@ -0,0 +1,437 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 🧸 Using LLMs for Text Classification and Summarization Suggestions with `spacy-llm`\n", + "\n", + "In this tutorial, we'll implement a `spacy-llm` pipeline to obtain model suggestions with GPT3.5 and add them to our `FeedbackDataset` as `suggestions`. The flow of the tutorial will be:\n", + "\n", + "- Run Argilla and load `spacy-llm` along with other libraries\n", + "- Define config for your pipeline and initialize it\n", + "- Create your `FeedbackDataset` instance\n", + "- Generate predictions on data and add them to `records`\n", + "- Push to Argilla\n", + "\n", + "## Introduction\n", + "\n", + "[spacy-llm](https://spacy.io/usage/large-language-models) is a package that integrates the strength of LLMs into regular spaCy pipelines, thus allowing quick and practical prompting for various tasks. Besides, since it requires no training data, the models are ready to use in your pipeline. If you want to train your own model or create your custom task, `spacy-llm` also helps to implement any custom pipeline.\n", + "\n", + "It is quite easy to use this powerful package with Argilla Feedback datasets. We can make inferences with the pipeline we will create and add them to our `FeedbackDataset`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Running Argilla\n", + "\n", + "For this tutorial, you will need to have an Argilla server running. There are two main options for deploying and running Argilla:\n", + "\n", + "\n", + "**Deploy Argilla on Hugging Face Spaces**: If you want to run tutorials with external notebooks (e.g., Google Colab) and you have an account on Hugging Face, you can deploy Argilla on Spaces with a few clicks:\n", + "\n", + "[![deploy on spaces](https://huggingface.co/datasets/huggingface/badges/raw/main/deploy-to-spaces-lg.svg)](https://huggingface.co/new-space?template=argilla/argilla-template-space)\n", + "\n", + "For details about configuring your deployment, check the [official Hugging Face Hub guide](https://huggingface.co/docs/hub/spaces-sdks-docker-argilla).\n", + "\n", + "\n", + "**Launch Argilla using Argilla's quickstart Docker image**: This is the recommended option if you want [Argilla running on your local machine](../../../getting_started/quickstart.md). Note that this option will only let you run the tutorial locally and not with an external notebook service.\n", + "\n", + "For more information on deployment options, please check the Deployment section of the documentation.\n", + "\n", + "
\n", + "\n", + "Tip\n", + " \n", + "This tutorial is a Jupyter Notebook. There are two options to run it:\n", + "\n", + "- Use the Open in Colab button at the top of this page. This option allows you to run the notebook directly on Google Colab. Don't forget to change the runtime type to GPU for faster model training and inference.\n", + "- Download the .ipynb file by clicking on the View source link at the top of the page. This option allows you to download the notebook and run it on your local machine or on a Jupyter notebook tool of your choice.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup \n", + "\n", + "Let us first install the required libraries for our task," + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pip install \"spacy-llm[transformers]\" \"transformers[sentencepiece]\" argilla datasets -qqq" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "and import them as well." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import spacy\n", + "from spacy_llm.util import assemble\n", + "import argilla as rg\n", + "from datasets import load_dataset\n", + "import configparser\n", + "from collections import Counter\n", + "from heapq import nlargest" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You need to initialize the Argilla client with `API_URL` and `API_KEY`: " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Replace api_url with the url to your HF Spaces URL if using Spaces\n", + "# Replace api_key if you configured a custom API key\n", + "rg.init(\n", + " api_url=\"http://localhost:6900\",\n", + " api_key=\"admin.apikey\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `spacy-llm` pipeline\n", + "\n", + "To be able to use GPT3.5 and other models from OpenAI with spacy-llm, we'll need an API key from [openai.com](https://openai.com) and set it as an environment variable." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"OPENAI_API_KEY\"] = \"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There are two ways to implement a `spacy-llm` pipeline for your LLM task: running the pipeline in the source code or using a `config.cfg` file to define all settings and hyperparameters of your pipeline. In this tutorial, we'll work with a config file and you can have more info about running directly in Python [here](https://spacy.io/usage/large-language-models#example-3).\n", + "\n", + "Let us first define the settings of our pipeline as parameters in our config file. We'll implement two tasks: text classification and summarization, which we define them in the `pipeline` command. Then, we add our components to our pipeline to specify each task with their respective models and hypermeters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config_string = \"\"\"\n", + " [nlp]\n", + " lang = \"en\"\n", + " pipeline = [\"llm_textcat\",\"llm_summarization\",\"sentencizer\"]\n", + "\n", + " [components]\n", + "\n", + " [components.llm_textcat]\n", + " factory = \"llm\"\n", + "\n", + " [components.llm_textcat.task]\n", + " @llm_tasks = \"spacy.TextCat.v2\"\n", + " labels = [\"HISTORY\",\"MUSIC\",\"TECHNOLOGY\",\"SCIENCE\",\"SPORTS\",\"POLITICS\"]\n", + " \n", + " [components.llm_textcat.model]\n", + " @llm_models = \"spacy.GPT-3-5.v1\"\n", + "\n", + " [components.llm_summarization]\n", + " factory = \"llm\"\n", + "\n", + " [components.llm_summarization.task]\n", + " @llm_tasks = \"spacy.Summarization.v1\"\n", + "\n", + " [components.llm_summarization.model]\n", + " @llm_models = \"spacy.GPT-3-5.v1\"\n", + "\n", + " [components.sentencizer]\n", + " factory = \"sentencizer\"\n", + " punct_chars = null\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With these settings, we create an LLM pipeline for text classification and summarization in English with GPT3.5.\n", + "\n", + "`spacy-llm` offers various models to implement in your pipeline. You can have a look at the available [OpenAI models](https://spacy.io/api/large-language-models#models-rest) as well as check the [HuggingFace models](https://spacy.io/api/large-language-models#models-hf) offered if you want to work with open-source models.\n", + "\n", + "Now, with `ConfigParser`, we can create the config file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = configparser.ConfigParser()\n", + "config.read_string(config_string)\n", + "\n", + "with open(\"config.cfg\", 'w') as configfile:\n", + " config.write(configfile)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us assemble the config file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nlp = assemble(\"config.cfg\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We are ready to make inferences with the pipeline we have created." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "doc = nlp(\"No matter how hard they tried, Barcelona lost the match.\")\n", + "doc.cats" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inference\n", + "\n", + "We need two functions that will ease the inferencing process and give us the text category and summary that we want." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#returns the category with the highest score\n", + "def get_textcat_suggestion(doc):\n", + " model_prediction = doc.cats\n", + " return max(model_prediction, key=model_prediction.get)\n", + "\n", + "#selects the top N sentences with the highest scores and return combined string\n", + "def get_summarization_suggestion(doc):\n", + " sentence_scores = Counter()\n", + " for sentence in doc.sents:\n", + " for word in sentence:\n", + " sentence_scores[sentence] += 1\n", + " summary_sentences = nlargest(2, sentence_scores, key=sentence_scores.get)\n", + " summary = ' '.join(str(sentence) for sentence in summary_sentences)\n", + " return summary" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will use `squad_v2` from HuggingFace library in this tutorial. `squad_v2` is a dataset consisting of questions and their answers along with the context to search for the answer within. We'll use only the `context` column for our purposes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_hf = load_dataset(\"squad_v2\", split=\"train\").shard(num_shards=10000, index=235)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## FeedbackDataset\n", + "\n", + "Now that we have our pipeline for inference and the data, we can create our Argilla `FeedbackDataset` to make and store model suggestions. For this tutorial, we will create both a text classification task and a summarization task. Argilla Feedback lets us implement both tasks with `LabelQuestion` and `TextQuestion`, respectively." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = rg.FeedbackDataset(\n", + " fields=[\n", + " rg.TextField(name=\"text\")\n", + " ],\n", + " questions=[\n", + " rg.LabelQuestion(\n", + " name=\"label-question\",\n", + " title=\"Classify the text category.\",\n", + " #make sure that the labels are in line with the labels we have defined in config.cfg\n", + " labels=[\"HISTORY\",\"MUSIC\",\"TECHNOLOGY\",\"SCIENCE\",\"SPORTS\",\"POLITICS\"]\n", + " ),\n", + " rg.TextQuestion(\n", + " name=\"text-question\",\n", + " title=\"Provide a summary for the text.\"\n", + " )\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can create the records for our dataset by iterating over the dataset we loaded. While doing this, we will make inferences and save them in the `suggestions` with `get_textcat_suggestion()` and `get_summarization_suggestion()` functions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "records = [\n", + " rg.FeedbackRecord(\n", + " fields={\n", + " \"text\": doc.text\n", + " },\n", + " suggestions=[\n", + " {\"question_name\": \"label-question\",\n", + " \"value\": get_textcat_suggestion(doc)},\n", + " {\"question_name\":\"text-question\",\n", + " \"value\": get_summarization_suggestion(doc)}\n", + " ]\n", + " ) for doc in [nlp(item) for item in dataset_hf[\"context\"]]\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We have created the records, let us add them to the `FeedbackDataset`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset.add_records(records)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Push to Argilla\n", + "\n", + "We are now ready to push our dataset to Argilla and can start to collect annotations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "remote_dataset = dataset.push_to_argilla(name=\"squad_spacy-llm\", workspace=\"admin\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You should see the Argilla page ready to annotate as below.\n", + "\n", + "![Screenshot of Argilla UI](../../../_static/tutorials/labelling-spacy-llm/feedback-annotation.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutorial, we have implemented a spacy-llm pipeline for text classification and summarization tasks. By Argilla Feedback datasets, we have been able to add the model predictions as suggestions to our dataset so that our annotators can utilize them. For more info on spacy-llm, you can go to their LLM [page](https://spacy.io/usage/large-language-models), and for other uses of Argilla Feedback datasets, you can refer to our [guides](../../../practical_guides/practical_guides.md)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1 (default, Dec 17 2020, 03:56:09) \n[Clang 11.0.0 (clang-1100.0.33.17)]" + }, + "metadata": { + "interpreter": { + "hash": "0f338a8622467eba0ef87b9a79c52cc260cef0b0d60c3c739596fb787bf801dd" + } + }, + "vscode": { + "interpreter": { + "hash": "8874e298d2bce9702a08b32d5709c9f02d53b2045f1d246836c6e4c8123e6782" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/_source/tutorials_and_integrations/tutorials/tutorials.md b/docs/_source/tutorials_and_integrations/tutorials/tutorials.md index 4289392de8..7b0dd76c99 100644 --- a/docs/_source/tutorials_and_integrations/tutorials/tutorials.md +++ b/docs/_source/tutorials_and_integrations/tutorials/tutorials.md @@ -9,47 +9,48 @@ Here you can find end-to-end examples to help you get started with curanting dat ````{grid} 1 1 3 3 :class-container: tuto-section-2 -```{grid-item-card} 🪄 Fine-tuning and evaluating GPT-3.5 with human feedback for RAG -:link: feedback/fine-tuning-openai-rag-feedback.html - -Learn how to fine-tune and evaluate gpt3.5-turbo models with human feedback for RAG applications with LlamaIndex. - -``` -```{grid-item-card} Ⓜ️ Finetuning LLMs as chat assistants: Supervised Finetuning on Mistral 7B +```{grid-item-card} Ⓜ️ Fine-tuning LLMs as chat assistants: Supervised Finetuning on Mistral 7B :link: feedback/training-llm-mistral-sft.html Learn how to fine-tune Mistral 7B into a chat assistant using supervised finetuning with the ArgillaTrainer and TRL. - ``` -```{grid-item-card} 🖼️ Curate an instruction dataset for supervised fine-tuning -:link: feedback/curating-feedback-instructiondataset.html - -Learn how to set up a project to curate a public dataset that can be used to fine-tune an instruction-following model. +```{grid-item-card} 🪄 Fine-tuning and evaluating GPT-3.5 with human feedback for RAG +:link: feedback/fine-tuning-openai-rag-feedback.html +Learn how to fine-tune and evaluate gpt3.5-turbo models with human feedback for RAG applications with LlamaIndex. ``` -```{grid-item-card} 🏆 Train a Reward Model for RLHF +```{grid-item-card} 🏆 Fine-tuning a Reward Model for RLHF :link: feedback/train-reward-model-rlhf.html Learn how to collect comparison or human preference data and train a reward model with the trl library. ``` -```{grid-item-card} ✨ Add zero-shot suggestions using SetFit +```{grid-item-card} 🎛️ Fine-tuning a SetFit model using the ArgillaTrainer +:link: feedback/trainer-feedback-setfit.html + +Learn how to use the ArgillaTrainer to fine-tune your Feedback Dataset using Setfit. +``` +```{grid-item-card} ✨ Add zero-shot text classification suggestions using SetFit :link: feedback/labelling-feedback-setfit.html Learn how to add suggestions to your Feedback Dataset using SetFit. +``` +```{grid-item-card} 🧸 Using LLMs for text classification and summarization with spacy-llm +:link: feedback/labelling-spacy-llm.html +Learn how to add suggestions for text classification and summarization to your Feedback Dataset using spacy-llm. ``` ```{grid-item-card} 🎡 Create and annotate synthetic data with LLMs :link: feedback/labelling-feedback-langchain-syntethic.html Learn how to create synthetic data and annotations with OpenAI, LangChain, Transformers and Outlines. ``` -```{grid-item-card} 🎛️ Fine-tune a SetFit model using the ArgillaTrainer -:link: feedback/trainer-feedback-setfit.html - -Learn how to use the ArgillaTrainer to fine-tune your Feedback Dataset using Setfit. +```{grid-item-card} 🖼️ Curate an instruction dataset for supervised fine-tuning +:link: feedback/curating-feedback-instructiondataset.html +Learn how to set up a project to curate a public dataset that can be used to fine-tune an instruction-following model. ``` + ```` **Other datasets** From c6450c856d8c35a15eea32ae24484cb9443b587b Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 17 Oct 2023 16:30:41 +0200 Subject: [PATCH 12/12] chore: resolved #3764 --- CHANGELOG.md | 2 + .../client/feedback/dataset/local/dataset.py | 4 +- .../client/feedback/dataset/local/mixins.py | 412 +++++++++++++++++- .../local/test_mixin_task_templates.py | 341 +++++++++++++++ 4 files changed, 756 insertions(+), 3 deletions(-) create mode 100644 tests/integration/client/feedback/dataset/local/test_mixin_task_templates.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 902e8247f9..edd6bcc2ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,8 @@ These are the section headers that we use: - Added `delete_metadata_properties` and `Remote{Terms,Integer,Float}MetadataProperty.delete` methods to delete metadata properties ([#3932](https://github.com/argilla-io/argilla/pull/3932)). - New `PATCH /api/v1/metadata-properties/:metadata_property_id` endpoint allowing the update of a specific metadata property. ([#3952](https://github.com/argilla-io/argilla/pull/3952)). - Added automatic model card generation through `ArgillaTrainer.save` ([#3857](https://github.com/argilla-io/argilla/pull/3857)). +- Added `FeedbackDataset` `TaskTemplateMixin` for pre-defined task templates. ([#3969](https://github.com/argilla-io/argilla/pull/3969)). + ### Changed diff --git a/src/argilla/client/feedback/dataset/local/dataset.py b/src/argilla/client/feedback/dataset/local/dataset.py index 255ec5b847..944a4c74db 100644 --- a/src/argilla/client/feedback/dataset/local/dataset.py +++ b/src/argilla/client/feedback/dataset/local/dataset.py @@ -19,7 +19,7 @@ from argilla.client.feedback.constants import FETCHING_BATCH_SIZE from argilla.client.feedback.dataset.base import FeedbackDatasetBase -from argilla.client.feedback.dataset.local.mixins import ArgillaMixin +from argilla.client.feedback.dataset.local.mixins import ArgillaMixin, TaskTemplateMixin from argilla.client.feedback.schemas.enums import RecordSortField, SortOrder from argilla.client.feedback.schemas.questions import ( LabelQuestion, @@ -62,7 +62,7 @@ _LOGGER = logging.getLogger(__name__) -class FeedbackDataset(ArgillaMixin, FeedbackDatasetBase["FeedbackRecord"]): +class FeedbackDataset(ArgillaMixin, FeedbackDatasetBase["FeedbackRecord"], TaskTemplateMixin): def __init__( self, *, diff --git a/src/argilla/client/feedback/dataset/local/mixins.py b/src/argilla/client/feedback/dataset/local/mixins.py index bb4d8b5ece..bc02b7bad9 100644 --- a/src/argilla/client/feedback/dataset/local/mixins.py +++ b/src/argilla/client/feedback/dataset/local/mixins.py @@ -19,10 +19,10 @@ from argilla.client.feedback.constants import PUSHING_BATCH_SIZE from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset from argilla.client.feedback.schemas.enums import FieldTypes, MetadataPropertyTypes, QuestionTypes +from argilla.client.feedback.schemas.fields import TextField from argilla.client.feedback.schemas.questions import ( LabelQuestion, MultiLabelQuestion, - RankingQuestion, RatingQuestion, TextQuestion, ) @@ -440,3 +440,413 @@ def list(cls: Type["FeedbackDataset"], workspace: Optional[str] = None) -> List[ ) for dataset in datasets ] + + +class TaskTemplateMixin: + """ + Mixin to add task template functionality to a `FeedbackDataset`. + The NLP tasks covered are: + "text_classification" + "extractive_question_answering" + "summarization" + "translation" + "sentence_similarity" + "natural_language_inference" + "supervised_fine_tuning" + "preference_modeling/reward_modeling" + "proximal_policy_optimization" + "direct_preference_optimization" + "retrieval_augmented_generation" + """ + + @classmethod + def for_text_classification( + cls: Type["FeedbackDataset"], + labels: List[str], + multi_label: bool = False, + use_markdown: bool = False, + guidelines: str = None, + ) -> "FeedbackDataset": + """ + You can use this method to create a basic dataset for text classification tasks. + + Args: + labels: A list of labels for your dataset + multi_label: Set this parameter to True if you want to add multiple labels to your dataset + use_markdown: Set this parameter to True if you want to use markdown in your dataset + + Returns: + A `FeedbackDataset` object for text classification containing "text" field and LabelQuestion or MultiLabelQuestion named "label" + """ + default_guidelines = "This is a text classification dataset that contains texts and labels. Given a set of texts and a predefined set of labels, the goal of text classification is to assign one or more labels to each text based on its content. Please classify the texts by making the correct selection." + + description = "Classify the text by selecting the correct label from the given list of labels." + return cls( + fields=[TextField(name="text", use_markdown=use_markdown)], + questions=[ + LabelQuestion( + name="label", + labels=labels, + description=description, + ) + if not multi_label + else MultiLabelQuestion( + name="label", + labels=labels, + description=description, + ) + ], + guidelines=guidelines + if guidelines is not None + else default_guidelines + if multi_label + else default_guidelines.replace("one or more labels", "one label"), + ) + + @classmethod + def for_question_answering( + cls: Type["FeedbackDataset"], use_markdown: bool = False, guidelines: str = None + ) -> "FeedbackDataset": + """ + You can use this method to create a basic dataset for question answering tasks. + + Args: + use_markdown: Set this parameter to True if you want to use markdown in your dataset + + Returns: + A `FeedbackDataset` object for question answering containing "context" and "question" fields and a TextQuestion named "answer" + """ + default_guidelines = "This is a question answering dataset that contains questions and contexts. Please answer the question by using the context." + return cls( + fields=[ + TextField(name="question", use_markdown=use_markdown), + TextField(name="context", use_markdown=use_markdown), + ], + questions=[ + TextQuestion( + name="answer", + description="Answer the question. Note that the answer must exactly be in the context.", + use_markdown=use_markdown, + required=True, + ) + ], + guidelines=default_guidelines if guidelines is None else guidelines, + ) + + @classmethod + def for_summarization( + cls: Type["FeedbackDataset"], + use_markdown: bool = False, + guidelines: str = None, + ) -> "FeedbackDataset": + """ + You can use this method to create a basic dataset for summarization tasks. + + Args: + use_markdown: Set this parameter to True if you want to use markdown in your dataset + + Returns: + A `FeedbackDataset` object for summarization containing "text" field and a TextQuestion named "summary" + """ + default_guidelines = ( + "This is a summarization dataset that contains texts. Please summarize the text in the text field." + ) + return cls( + fields=[TextField(name="text", use_markdown=use_markdown)], + questions=[ + TextQuestion(name="summary", description="Write a summary of the text.", use_markdown=use_markdown) + ], + guidelines=default_guidelines if guidelines is None else guidelines, + ) + + @classmethod + def for_translation( + cls: Type["FeedbackDataset"], + use_markdown: bool = False, + guidelines: str = None, + ) -> "FeedbackDataset": + """ + You can use this method to create a basic dataset for translation tasks. + + Args: + use_markdown: Set this parameter to True if you want to use markdown in your dataset + + Returns: + A `FeedbackDataset` object for translation containing "source" field and a TextQuestion named "target" + """ + default_guidelines = ( + "This is a translation dataset that contains texts. Please translate the text in the text field." + ) + return cls( + fields=[TextField(name="source", use_markdown=use_markdown)], + questions=[TextQuestion(name="target", description="Translate the text.", use_markdown=use_markdown)], + guidelines=default_guidelines if guidelines is None else guidelines, + ) + + @classmethod + def for_sentence_similarity( + cls: Type["FeedbackDataset"], + rating_scale: int = 10, + use_markdown: bool = False, + guidelines: str = None, + ) -> "FeedbackDataset": + """ + You can use this method to create a basic dataset for sentence similarity tasks. + + Args: + rating_scale: Set this parameter to the number of similarity scale you want to add to your dataset + use_markdown: Set this parameter to True if you want to use markdown in your dataset + + Returns: + A `FeedbackDataset` object for sentence similarity containing "sentence1" and "sentence2" fields and a RatingQuestion named "similarity" + """ + default_guidelines = "This is a sentence similarity dataset that contains two sentences. Please rate the similarity between the two sentences." + return cls( + fields=[ + TextField(name="sentence1", use_markdown=use_markdown), + TextField(name="sentence2", use_markdown=use_markdown), + ], + questions=[ + RatingQuestion( + name="similarity", + values=list(range(1, rating_scale + 1)), + description="Rate the similarity between the two sentences.", + ) + ], + guidelines=default_guidelines if guidelines is None else guidelines, + ) + + @classmethod + def for_natural_language_inference( + cls: Type["FeedbackDataset"], + labels: Optional[List[str]] = None, + use_markdown: bool = False, + guidelines: str = None, + ) -> "FeedbackDataset": + """ + You can use this method to create a basic dataset for natural language inference tasks. + + Args: + labels: A list of labels for your dataset + use_markdown: Set this parameter to True if you want to use markdown in your dataset + + Returns: + A `FeedbackDataset` object for natural language inference containing "premise" and "hypothesis" fields and a LabelQuestion named "label" + """ + default_guidelines = "This is a natural language inference dataset that contains premises and hypotheses. Please choose the correct label for the given premise and hypothesis." + if labels is None: + labels = ["entailment", "neutral", "contradiction"] + return cls( + fields=[ + TextField(name="premise", use_markdown=use_markdown), + TextField(name="hypothesis", use_markdown=use_markdown), + ], + questions=[LabelQuestion(name="label", labels=labels, description="Choose one of the labels.")], + guidelines=default_guidelines if guidelines is None else guidelines, + ) + + @classmethod + def for_supervised_fine_tuning( + cls: Type["FeedbackDataset"], + context: bool = False, + use_markdown: bool = False, + guidelines: str = None, + ) -> "FeedbackDataset": + """ + You can use this method to create a basic dataset for supervised fine-tuning tasks. + + Args: + context: Set this parameter to True if you want to add context to your dataset + use_markdown: Set this parameter to True if you want to use markdown in your dataset + + Returns: + A `FeedbackDataset` object for supervised fine-tuning containing "instruction" and optional "context" field and a TextQuestion named "response" + """ + default_guidelines = "This is a supervised fine-tuning dataset that contains instructions. Please write the response to the instruction in the response field." + fields = [ + TextField(name="prompt", use_markdown=use_markdown), + ] + if context: + fields.append(TextField(name="context", use_markdown=use_markdown, required=False)) + return cls( + fields=fields, + questions=[ + TextQuestion( + name="response", description="Write the response to the instruction.", use_markdown=use_markdown + ) + ], + guidelines=guidelines + if guidelines is not None + else default_guidelines + " Take the context into account when writing the response." + if context + else default_guidelines, + ) + + @classmethod + def for_preference_modeling( + cls: Type["FeedbackDataset"], + context: bool = False, + use_markdown: bool = False, + guidelines: str = None, + ) -> "FeedbackDataset": + """ + You can use this method to create a basic dataset for preference tasks. + + Args: + use_markdown: Set this parameter to True if you want to use markdown in your dataset + + Returns: + A `FeedbackDataset` object for preference containing "prompt", "option1" and "option2" fields and a LabelQuestion named "preference" + """ + default_guidelines = "This is a preference dataset that contains contexts and options. Please choose the option that you would prefer in the given context." + fields = [ + TextField(name="prompt", use_markdown=use_markdown), + TextField(name="response1", title="Response 1", use_markdown=use_markdown), + TextField(name="response2", title="Response 2", use_markdown=use_markdown), + ] + if context: + fields.insert(1, TextField(name="context", use_markdown=use_markdown, required=False)) + return cls( + fields=fields, + questions=[ + LabelQuestion( + name="preference", labels=["Response 1", "Response 2"], description="Choose your preference." + ) + ], + guidelines=default_guidelines if guidelines is None else guidelines, + ) + + @classmethod + def for_reward_modeling( + cls: Type["FeedbackDataset"], + context: bool = False, + use_markdown: bool = False, + guidelines: str = None, + ) -> "FeedbackDataset": + return cls.for_preference_modeling(context=context, use_markdown=use_markdown, guidelines=guidelines) + + @classmethod + def for_proximal_policy_optimization( + cls: Type["FeedbackDataset"], + context: bool = False, + use_markdown: bool = False, + guidelines: str = None, + ) -> "FeedbackDataset": + """ + You can use this method to create a basic dataset for proximal policy optimization tasks. + + Args: + use_markdown: Set this parameter to True if you want to use markdown in your dataset + + Returns: + A `FeedbackDataset` object for proximal policy optimization containing "context" and "action" fields and a LabelQuestion named "label" + """ + default_guidelines = "This is a proximal policy optimization dataset that contains contexts and prompts. Please choose the label that best prompt." + fields = [TextField(name="prompt", use_markdown=use_markdown)] + if context: + fields.append(TextField(name="context", use_markdown=use_markdown, required=False)) + + return cls( + fields=fields, + questions=[ + LabelQuestion( + name="prompt", + labels=["good", "bad"], + description="Choose one of the labels that best describes the prompt.", + ) + ], + guidelines=default_guidelines if guidelines is None else guidelines, + ) + + @classmethod + def for_direct_preference_optimization( + cls: Type["FeedbackDataset"], + context: bool = False, + use_markdown: bool = False, + guidelines: str = None, + ) -> "FeedbackDataset": + """ + You can use this method to create a basic dataset for direct preference optimization tasks. + + Args: + context: Set this parameter to True if you want to add context to your dataset + use_markdown: Set this parameter to True if you want to use markdown in your dataset + + Returns: + A `FeedbackDataset` object for direct preference optimization containing "prompt", "response1", "response2" with the optional "context" fields and a LabelQuestion named "preference" + """ + default_guidelines = "This is a direct preference optimization dataset that contains contexts and options. Please choose the option that you would prefer in the given context." + fields = [ + TextField(name="prompt", use_markdown=use_markdown), + TextField(name="response1", title="Response 1", use_markdown=use_markdown), + TextField(name="response2", title="Response 2", use_markdown=use_markdown), + ] + if context: + fields.insert(1, TextField(name="context", use_markdown=use_markdown, required=False)) + return cls( + fields=fields, + questions=[ + LabelQuestion( + name="preference", + labels=["Response 1", "Response 2"], + description="Choose the label that is your preference.", + ) + ], + guidelines=default_guidelines if guidelines is None else guidelines, + ) + + @classmethod + def for_retrieval_augmented_generation( + cls: Type["FeedbackDataset"], + number_of_retrievals: int = 1, + rating_scale: int = 10, + use_markdown: bool = False, + guidelines: str = None, + ) -> "FeedbackDataset": + """ + You can use this method to create a basic dataset for retrieval augmented generation tasks. + + Args: + number_of_retrievals: Set this parameter to the number of documents you want to add to your dataset + use_markdown: Set this parameter to True if you want to use markdown in your dataset + + Returns: + A `FeedbackDataset` object for retrieval augmented generation containing "query" and "retrieved_document" fields and a TextQuestion named "response" + """ + default_guidelines = "This is a retrieval augmented generation dataset that contains queries and retrieved documents. Please rate the relevancy of retrieved document and write the response to the query in the response field." + document_fields = [ + TextField( + name="retrieved_document_" + str(doc + 1), + title="Retrieved Document " + str(doc + 1), + use_markdown=use_markdown, + required=True if doc == 0 else False, + ) + for doc in range(number_of_retrievals) + ] + + rating_questions = [ + RatingQuestion( + name="question_rating_" + str(doc + 1), + title="Rate the relevance of the user question" + str(doc + 1), + values=list(range(1, rating_scale + 1)), + description="Rate the relevance of the retrieved document.", + required=True if doc == 0 else False, + ) + for doc in range(number_of_retrievals) + ] + + total_questions = rating_questions + [ + TextQuestion( + name="response", + title="Write a helpful, harmless, accurate response to the query.", + description="Write the response to the query.", + use_markdown=use_markdown, + required=False, + ) + ] + + return cls( + fields=[TextField(name="query", use_markdown=use_markdown, required=True)] + document_fields, + questions=total_questions, + guidelines=default_guidelines if guidelines is None else guidelines, + ) diff --git a/tests/integration/client/feedback/dataset/local/test_mixin_task_templates.py b/tests/integration/client/feedback/dataset/local/test_mixin_task_templates.py new file mode 100644 index 0000000000..55e9db5a54 --- /dev/null +++ b/tests/integration/client/feedback/dataset/local/test_mixin_task_templates.py @@ -0,0 +1,341 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from argilla.client.feedback.dataset import FeedbackDataset +from argilla.client.feedback.schemas import ( + LabelQuestion, + MultiLabelQuestion, + RatingQuestion, + TextQuestion, +) + +if TYPE_CHECKING: + pass + + +def test_for_question_answering(): + dataset = FeedbackDataset.for_question_answering(use_markdown=True) + assert len(dataset.fields) == 2 + assert len(dataset.questions) == 1 + assert dataset.questions[0].name == "answer" + assert ( + dataset.questions[0].description == "Answer the question. Note that the answer must exactly be in the context." + ) + assert dataset.questions[0].required is True + assert dataset.fields[0].name == "question" + assert dataset.fields[0].use_markdown is True + assert dataset.fields[1].name == "context" + assert dataset.fields[1].use_markdown is True + assert ( + dataset.guidelines + == "This is a question answering dataset that contains questions and contexts. Please answer the question by using the context." + ) + + +def test_for_text_classification(): + # Test case 1: Single label classification + dataset = FeedbackDataset.for_text_classification(labels=["positive", "negative"]) + assert len(dataset) == 0 + assert dataset.questions[0].name == "label" + assert ( + dataset.questions[0].description + == "Classify the text by selecting the correct label from the given list of labels." + ) + assert isinstance(dataset.questions[0], LabelQuestion) + assert dataset.questions[0].labels == ["positive", "negative"] + assert dataset.fields[0].name == "text" + assert dataset.fields[0].use_markdown is False + assert ( + dataset.guidelines + == "This is a text classification dataset that contains texts and labels. Given a set of texts and a predefined set of labels, the goal of text classification is to assign one label to each text based on its content. Please classify the texts by making the correct selection." + ) + + # Test case 2: Multi-label classification + dataset = FeedbackDataset.for_text_classification(labels=["positive", "negative"], multi_label=True) + assert len(dataset) == 0 + assert dataset.questions[0].name == "label" + assert ( + dataset.questions[0].description + == "Classify the text by selecting the correct label from the given list of labels." + ) + assert isinstance(dataset.questions[0], MultiLabelQuestion) + assert dataset.questions[0].labels == ["positive", "negative"] + assert dataset.fields[0].name == "text" + assert dataset.fields[0].use_markdown is False + print(dataset.guidelines) + assert ( + dataset.guidelines + == "This is a text classification dataset that contains texts and labels. Given a set of texts and a predefined set of labels, the goal of text classification is to assign one or more labels to each text based on its content. Please classify the texts by making the correct selection." + ) + + +def test_for_summarization(): + dataset = FeedbackDataset.for_summarization(use_markdown=True) + assert len(dataset) == 0 + assert dataset.questions[0].name == "summary" + assert dataset.questions[0].description == "Write a summary of the text." + assert isinstance(dataset.questions[0], TextQuestion) + assert dataset.fields[0].name == "text" + assert dataset.fields[0].use_markdown is True + assert ( + dataset.guidelines + == "This is a summarization dataset that contains texts. Please summarize the text in the text field." + ) + + +def test_for_supervised_fine_tuning(): + # Test case 1: context=False, use_markdown=False, guidelines=None + dataset = FeedbackDataset.for_supervised_fine_tuning(context=False, use_markdown=False, guidelines=None) + assert len(dataset) == 0 + assert dataset.questions[0].name == "response" + assert dataset.questions[0].description == "Write the response to the instruction." + assert isinstance(dataset.questions[0], TextQuestion) + assert dataset.questions[0].use_markdown is False + assert dataset.fields[0].name == "prompt" + assert dataset.fields[0].use_markdown is False + assert ( + dataset.guidelines + == "This is a supervised fine-tuning dataset that contains instructions. Please write the response to the instruction in the response field." + ) + + # Test case 2: context=True, use_markdown=True, guidelines="Custom guidelines" + dataset = FeedbackDataset.for_supervised_fine_tuning( + context=True, use_markdown=True, guidelines="Custom guidelines" + ) + assert len(dataset) == 0 + assert dataset.questions[0].name == "response" + assert dataset.questions[0].description == "Write the response to the instruction." + assert isinstance(dataset.questions[0], TextQuestion) + assert dataset.questions[0].use_markdown is True + assert dataset.fields[0].name == "prompt" + assert dataset.fields[0].use_markdown is True + assert dataset.fields[1].name == "context" + assert dataset.fields[1].use_markdown is True + assert dataset.guidelines == "Custom guidelines" + + +def test_for_retrieval_augmented_generation(): + # Test case 1: Single document retrieval augmented generation + dataset = FeedbackDataset.for_retrieval_augmented_generation( + number_of_retrievals=1, rating_scale=5, use_markdown=True + ) + assert len(dataset) == 0 + assert dataset.questions[0].name == "question_rating_1" + assert dataset.questions[0].description == "Rate the relevance of the retrieved document." + assert isinstance(dataset.questions[0], RatingQuestion) + assert dataset.questions[0].values == [1, 2, 3, 4, 5] + assert dataset.questions[1].name == "response" + assert dataset.questions[1].description == "Write the response to the query." + assert isinstance(dataset.questions[1], TextQuestion) + assert dataset.fields[0].name == "query" + assert dataset.fields[0].use_markdown is True + assert dataset.fields[1].name == "retrieved_document_1" + assert dataset.fields[1].use_markdown is True + assert ( + dataset.guidelines + == "This is a retrieval augmented generation dataset that contains queries and retrieved documents. Please rate the relevancy of retrieved document and write the response to the query in the response field." + ) + + # Test case 2: Multiple document retrieval augmented generation + dataset = FeedbackDataset.for_retrieval_augmented_generation( + number_of_retrievals=3, rating_scale=10, use_markdown=False, guidelines="Custom guidelines" + ) + assert len(dataset) == 0 + assert dataset.questions[0].name == "question_rating_1" + assert dataset.questions[0].description == "Rate the relevance of the retrieved document." + assert isinstance(dataset.questions[0], RatingQuestion) + assert dataset.questions[0].values == list(range(1, 11)) + assert dataset.questions[1].name == "question_rating_2" + assert dataset.questions[1].description == "Rate the relevance of the retrieved document." + assert isinstance(dataset.questions[1], RatingQuestion) + assert dataset.questions[1].values == list(range(1, 11)) + assert dataset.questions[2].name == "question_rating_3" + assert dataset.questions[2].description == "Rate the relevance of the retrieved document." + assert isinstance(dataset.questions[2], RatingQuestion) + assert dataset.questions[2].values == list(range(1, 11)) + assert dataset.questions[3].name == "response" + assert dataset.questions[3].description == "Write the response to the query." + assert isinstance(dataset.questions[3], TextQuestion) + assert dataset.fields[0].name == "query" + assert dataset.fields[0].use_markdown is False + assert dataset.fields[1].name == "retrieved_document_1" + assert dataset.fields[1].use_markdown is False + assert dataset.fields[2].name == "retrieved_document_2" + assert dataset.fields[2].use_markdown is False + assert dataset.fields[3].name == "retrieved_document_3" + assert dataset.fields[3].use_markdown is False + assert dataset.guidelines == "Custom guidelines" + + +def test_for_sentence_similarity(): + # Test case 1: Default parameters + dataset = FeedbackDataset.for_sentence_similarity() + assert len(dataset) == 0 + assert dataset.questions[0].name == "similarity" + assert dataset.questions[0].description == "Rate the similarity between the two sentences." + assert isinstance(dataset.questions[0], RatingQuestion) + assert dataset.questions[0].values == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + assert dataset.fields[0].name == "sentence1" + assert dataset.fields[0].use_markdown is False + assert dataset.fields[1].name == "sentence2" + assert dataset.fields[1].use_markdown is False + assert ( + dataset.guidelines + == "This is a sentence similarity dataset that contains two sentences. Please rate the similarity between the two sentences." + ) + + # Test case 2: Custom parameters + dataset = FeedbackDataset.for_sentence_similarity(rating_scale=5, use_markdown=True, guidelines="Custom guidelines") + assert len(dataset) == 0 + assert dataset.questions[0].name == "similarity" + assert dataset.questions[0].description == "Rate the similarity between the two sentences." + assert isinstance(dataset.questions[0], RatingQuestion) + assert dataset.questions[0].values == [1, 2, 3, 4, 5] + assert dataset.fields[0].name == "sentence1" + assert dataset.fields[0].use_markdown is True + assert dataset.fields[1].name == "sentence2" + assert dataset.fields[1].use_markdown is True + assert dataset.guidelines == "Custom guidelines" + + +def test_for_preference_modeling(): + dataset = FeedbackDataset.for_preference_modeling(use_markdown=False, context=True) + assert len(dataset) == 0 + assert dataset.questions[0].name == "preference" + assert dataset.questions[0].description == "Choose your preference." + assert isinstance(dataset.questions[0], LabelQuestion) + assert dataset.questions[0].labels == ["Response 1", "Response 2"] + assert dataset.fields[0].name == "prompt" + assert dataset.fields[0].use_markdown is False + assert dataset.fields[1].name == "context" + assert dataset.fields[1].use_markdown is False + assert dataset.fields[1].required is False + assert dataset.fields[2].name == "response1" + assert dataset.fields[2].title == "Response 1" + assert dataset.fields[2].use_markdown is False + assert dataset.fields[3].name == "response2" + assert dataset.fields[3].title == "Response 2" + assert dataset.fields[3].use_markdown is False + assert ( + dataset.guidelines + == "This is a preference dataset that contains contexts and options. Please choose the option that you would prefer in the given context." + ) + + +def test_for_natural_language_inference(): + # Test case 1: Default labels and guidelines + dataset = FeedbackDataset.for_natural_language_inference() + assert len(dataset) == 0 + assert dataset.questions[0].name == "label" + assert dataset.questions[0].description == "Choose one of the labels." + assert isinstance(dataset.questions[0], LabelQuestion) + assert dataset.questions[0].labels == ["entailment", "neutral", "contradiction"] + assert dataset.fields[0].name == "premise" + assert dataset.fields[0].use_markdown is False + assert dataset.fields[1].name == "hypothesis" + assert dataset.fields[1].use_markdown is False + assert ( + dataset.guidelines + == "This is a natural language inference dataset that contains premises and hypotheses. Please choose the correct label for the given premise and hypothesis." + ) + # Test case 2: Custom labels and guidelines + dataset = FeedbackDataset.for_natural_language_inference(labels=["yes", "no"], guidelines="Custom guidelines") + assert len(dataset) == 0 + assert dataset.questions[0].name == "label" + assert dataset.questions[0].description == "Choose one of the labels." + assert isinstance(dataset.questions[0], LabelQuestion) + assert dataset.questions[0].labels == ["yes", "no"] + assert dataset.fields[0].name == "premise" + assert dataset.fields[0].use_markdown is False + assert dataset.fields[1].name == "hypothesis" + assert dataset.fields[1].use_markdown is False + assert dataset.guidelines == "Custom guidelines" + + +def test_for_proximal_policy_optimization(): + # Test case 1: Without context and without markdown + dataset = FeedbackDataset.for_proximal_policy_optimization() + assert len(dataset) == 0 + assert dataset.questions[0].name == "prompt" + assert dataset.questions[0].description == "Choose one of the labels that best describes the prompt." + assert isinstance(dataset.questions[0], LabelQuestion) + assert dataset.questions[0].labels == ["good", "bad"] + assert dataset.fields[0].name == "prompt" + assert dataset.fields[0].use_markdown is False + assert ( + dataset.guidelines + == "This is a proximal policy optimization dataset that contains contexts and prompts. Please choose the label that best prompt." + ) + + # Test case 2: With context and with markdown + dataset = FeedbackDataset.for_proximal_policy_optimization(context=True, use_markdown=True) + assert len(dataset) == 0 + assert dataset.questions[0].name == "prompt" + assert dataset.questions[0].description == "Choose one of the labels that best describes the prompt." + assert isinstance(dataset.questions[0], LabelQuestion) + assert dataset.questions[0].labels == ["good", "bad"] + assert dataset.fields[0].name == "prompt" + assert dataset.fields[0].use_markdown is True + assert dataset.fields[1].name == "context" + assert dataset.fields[1].use_markdown is True + assert ( + dataset.guidelines + == "This is a proximal policy optimization dataset that contains contexts and prompts. Please choose the label that best prompt." + ) + + +def test_for_direct_preference_optimization(): + # Test case 1: Without context and markdown + dataset = FeedbackDataset.for_direct_preference_optimization() + assert len(dataset) == 0 + assert dataset.questions[0].name == "preference" + assert dataset.questions[0].description == "Choose the label that is your preference." + assert isinstance(dataset.questions[0], LabelQuestion) + assert dataset.questions[0].labels == ["Response 1", "Response 2"] + assert dataset.fields[0].name == "prompt" + assert dataset.fields[0].use_markdown is False + assert dataset.fields[1].name == "response1" + assert dataset.fields[1].title == "Response 1" + assert dataset.fields[1].use_markdown is False + assert dataset.fields[2].name == "response2" + assert dataset.fields[2].title == "Response 2" + assert dataset.fields[2].use_markdown is False + assert ( + dataset.guidelines + == "This is a direct preference optimization dataset that contains contexts and options. Please choose the option that you would prefer in the given context." + ) + + # Test case 2: With context and markdown + dataset = FeedbackDataset.for_direct_preference_optimization(context=True, use_markdown=True) + assert len(dataset) == 0 + assert dataset.questions[0].name == "preference" + assert dataset.questions[0].description == "Choose the label that is your preference." + assert isinstance(dataset.questions[0], LabelQuestion) + assert dataset.questions[0].labels == ["Response 1", "Response 2"] + assert dataset.fields[0].name == "prompt" + assert dataset.fields[0].use_markdown is True + assert dataset.fields[1].name == "context" + assert dataset.fields[1].use_markdown is True + assert dataset.fields[2].name == "response1" + assert dataset.fields[2].title == "Response 1" + assert dataset.fields[2].use_markdown is True + assert dataset.fields[3].name == "response2" + assert dataset.fields[3].title == "Response 2" + assert dataset.fields[3].use_markdown is True + assert ( + dataset.guidelines + == "This is a direct preference optimization dataset that contains contexts and options. Please choose the option that you would prefer in the given context." + )