Skip to content

Commit

Permalink
fix: follow-up fixes for detached pydantic.BaseModel schemas (#3829)
Browse files Browse the repository at this point in the history
# Description

This PR contains some bug-fixes and improvements as a follow up for
#3784.

* `FeedbackDataset` from the Hugging Face Hub fails if exported with an
outdated version as it still contained the ID for some columns
* Adding records via `add_records` over an existing `FeedbackDataset` in
Argilla i.e. a `RemoteFeedbackDataset` fails because the
`_fields_schema` is not defined
* Shared attribute initialization in `FeedbackDatasetBase`, but move
local validation to `FeedbackDataset`
* Extend supported type-hints in `generate_pydantic_schema` function
* Fix `from_huggingface` to be backwards compatible with previously
uploaded `FeedbackDataset` datasets to the Hugging Face Hub (from
Argilla v1.8.0)

Kudos @frascuchon for detecting and reporting the bugs tackled in this
PR!

**Type of change**

- [X] Bug fix (non-breaking change which fixes an issue)
- [X] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

- [X] Run
`rg.FeedbackDataset.from_huggingface("argilla/oasst_response_quality",
split="train")`
- [X] Add outdated/deprecated configuration files to check that
`DatasetConfig.from_yaml` works as intended
- [x] Add integration tests to ensure that adding a record to an
existing `FeedbackDataset` in Argilla works as expected

**Checklist**

- [ ] I added relevant documentation
- [X] follows the style guidelines of this project
- [X] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [X] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [X] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Paco Aranda <[email protected]>
  • Loading branch information
alvarobartt and frascuchon authored Oct 11, 2023
1 parent eeaaddc commit 7243b71
Show file tree
Hide file tree
Showing 17 changed files with 871 additions and 76 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ These are the section headers that we use:

### Fixed

- Fixed saving of models trained with `ArgillaTrainer` with a `peft_config` parameter. ([#3795](https://github.com/argilla-io/argilla/pull/3795))
- Fixed saving of models trained with `ArgillaTrainer` with a `peft_config` parameter ([#3795](https://github.com/argilla-io/argilla/pull/3795)).
- Fixed backwards compatibility on `from_huggingface` when loading a `FeedbackDataset` from the Hugging Face Hub that was previously dumped using another version of Argilla, starting at 1.8.0, when it was first introduced ([#3829](https://github.com/argilla-io/argilla/pull/3829)).

## [1.16.0](https://github.com/argilla-io/argilla/compare/v1.15.1...v1.16.0)

Expand Down
52 changes: 44 additions & 8 deletions src/argilla/client/feedback/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import re
import warnings
from typing import List, Optional
Expand Down Expand Up @@ -43,15 +44,17 @@ def to_yaml(self) -> str:
return dump(self.dict())

@classmethod
def from_yaml(cls, yaml: str) -> "DatasetConfig":
yaml = re.sub(r"(\n\s*|)id: !!python/object:uuid\.UUID\s+int: \d+", "", yaml)
yaml = load(yaml, Loader=SafeLoader)
def from_yaml(cls, yaml_str: str) -> "DatasetConfig":
yaml_str = re.sub(r"(\n\s*|)id: !!python/object:uuid\.UUID\s+int: \d+", "", yaml_str)
yaml_dict = load(yaml_str, Loader=SafeLoader)
# Here for backwards compatibility
for field in yaml["fields"]:
for field in yaml_dict["fields"]:
field.pop("id", None)
field.pop("settings", None)
for question in yaml["questions"]:
for question in yaml_dict["questions"]:
question.pop("id", None)
question.pop("settings", None)
return cls(**yaml)
return cls(**yaml_dict)


# TODO(alvarobartt): here for backwards compatibility, remove in 1.14.0
Expand All @@ -70,11 +73,44 @@ def to_json(self) -> str:
return self.json()

@classmethod
def from_json(cls, json: str) -> "DeprecatedDatasetConfig":
def from_json(cls, json_str: str) -> "DeprecatedDatasetConfig":
warnings.warn(
"`DatasetConfig` can just be loaded from YAML, so make sure that you are"
" loading a YAML file instead of a JSON file. `DatasetConfig` will be dumped"
" as YAML from now on, instead of JSON.",
DeprecationWarning,
)
return cls.parse_raw(json)
parsed_json = json.loads(json_str)
# Here for backwards compatibility
for field in parsed_json["fields"]:
# for 1.10.0, 1.9.0, and 1.8.0
field.pop("id", None)
field.pop("inserted_at", None)
field.pop("updated_at", None)
if "settings" not in field:
continue
field["type"] = field["settings"]["type"]
if "use_markdown" in field["settings"]:
field["use_markdown"] = field["settings"]["use_markdown"]
# for 1.12.0 and 1.11.0
field.pop("settings", None)
for question in parsed_json["questions"]:
# for 1.10.0, 1.9.0, and 1.8.0
question.pop("id", None)
question.pop("inserted_at", None)
question.pop("updated_at", None)
if "settings" not in question:
continue
question.update({"type": question["settings"]["type"]})
if question["type"] in ["rating", "ranking"]:
question["values"] = [option["value"] for option in question["settings"]["options"]]
elif question["type"] in ["label_selection", "multi_label_selection"]:
if all(option["value"] == option["text"] for option in question["settings"]["options"]):
question["labels"] = [option["value"] for option in question["settings"]["options"]]
else:
question["labels"] = {option["value"]: option["text"] for option in question["settings"]["options"]}
if "visible_labels" in question["settings"]:
question["visible_labels"] = question["settings"]["visible_labels"]
# for 1.12.0 and 1.11.0
question.pop("settings", None)
return cls(**parsed_json)
26 changes: 16 additions & 10 deletions src/argilla/client/feedback/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
FeedbackRecord,
FieldSchema,
)
from argilla.client.feedback.schemas.types import AllowedQuestionTypes
from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes
from argilla.client.feedback.training.schemas import (
TrainingTaskForChatCompletion,
TrainingTaskForDPO,
Expand All @@ -43,7 +43,6 @@
from datasets import Dataset

from argilla.client.feedback.schemas.types import (
AllowedFieldTypes,
AllowedRemoteFieldTypes,
AllowedRemoteQuestionTypes,
)
Expand All @@ -58,8 +57,8 @@ class FeedbackDatasetBase(ABC, HuggingFaceDatasetMixin):
def __init__(
self,
*,
fields: Union[List["AllowedFieldTypes"], List["AllowedRemoteFieldTypes"]],
questions: Union[List["AllowedQuestionTypes"], List["AllowedRemoteQuestionTypes"]],
fields: Union[List[AllowedFieldTypes], List["AllowedRemoteFieldTypes"]],
questions: Union[List[AllowedQuestionTypes], List["AllowedRemoteQuestionTypes"]],
guidelines: Optional[str] = None,
) -> None:
"""Initializes a `FeedbackDatasetBase` instance locally.
Expand All @@ -84,17 +83,21 @@ def __init__(
any_required = False
unique_names = set()
for field in fields:
if not isinstance(field, FieldSchema):
raise TypeError(f"Expected `fields` to be a list of `FieldSchema`, got {type(field)} instead.")
if not isinstance(field, AllowedFieldTypes):
raise TypeError(
f"Expected `fields` to be a list of `{AllowedFieldTypes.__name__}`, got {type(field)} instead."
)
if field.name in unique_names:
raise ValueError(f"Expected `fields` to have unique names, got {field.name} twice instead.")
unique_names.add(field.name)
if not any_required and field.required:
any_required = True

if not any_required:
raise ValueError("At least one `FieldSchema` in `fields` must be required (`required=True`).")
raise ValueError("At least one field in `fields` must be required (`required=True`).")

self._fields = fields
self._fields_schema = None
self._fields_schema = generate_pydantic_schema(self.fields)

if not isinstance(questions, list):
raise TypeError(f"Expected `questions` to be a list, got {type(questions)} instead.")
Expand All @@ -113,8 +116,10 @@ def __init__(
unique_names.add(question.name)
if not any_required and question.required:
any_required = True

if not any_required:
raise ValueError("At least one question in `questions` must be required (`required=True`).")

self._questions = questions

if guidelines is not None:
Expand All @@ -126,6 +131,7 @@ def __init__(
raise ValueError(
"Expected `guidelines` to be either None (default) or a non-empty string, minimum length is 1."
)

self._guidelines = guidelines

@property
Expand All @@ -140,11 +146,11 @@ def guidelines(self) -> str:
return self._guidelines

@property
def fields(self) -> Union[List["AllowedFieldTypes"], List["AllowedRemoteFieldTypes"]]:
def fields(self) -> Union[List[AllowedFieldTypes], List["AllowedRemoteFieldTypes"]]:
"""Returns the fields that define the schema of the records in the dataset."""
return self._fields

def field_by_name(self, name: str) -> Union["AllowedFieldTypes", "AllowedRemoteFieldTypes"]:
def field_by_name(self, name: str) -> Union[AllowedFieldTypes, "AllowedRemoteFieldTypes"]:
"""Returns the field by name if it exists. Othewise a `ValueError` is raised.
Args:
Expand Down
6 changes: 4 additions & 2 deletions src/argilla/client/feedback/dataset/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@
from argilla.client.feedback.constants import FETCHING_BATCH_SIZE
from argilla.client.feedback.dataset.base import FeedbackDatasetBase
from argilla.client.feedback.dataset.mixins import ArgillaMixin, UnificationMixin
from argilla.client.feedback.schemas.fields import TextField
from argilla.client.feedback.schemas.types import AllowedQuestionTypes

if TYPE_CHECKING:
from argilla.client.feedback.schemas.records import FeedbackRecord
from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes
from argilla.client.feedback.schemas.types import AllowedFieldTypes


class FeedbackDataset(FeedbackDatasetBase, ArgillaMixin, UnificationMixin):
def __init__(
self,
*,
fields: List["AllowedFieldTypes"],
questions: List["AllowedQuestionTypes"],
questions: List[AllowedQuestionTypes],
guidelines: Optional[str] = None,
) -> None:
"""Initializes a `FeedbackDataset` instance locally.
Expand Down
8 changes: 5 additions & 3 deletions src/argilla/client/feedback/dataset/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,15 @@ def push_to_argilla(
except Exception as e:
raise Exception(f"Failed while creating the `FeedbackDataset` in Argilla with exception: {e}") from e

fields = self.__add_fields(client=httpx_client, id=argilla_id)
# TODO(alvarobartt): re-use ArgillaMixin components when applicable
self.__add_fields(client=httpx_client, id=argilla_id)
fields = self.__get_fields(client=httpx_client, id=argilla_id)

questions = self.__add_questions(client=httpx_client, id=argilla_id)
self.__add_questions(client=httpx_client, id=argilla_id)
questions = self.__get_questions(client=httpx_client, id=argilla_id)
question_name_to_id = {question.name: question.id for question in questions}

self.__publish_dataset(client=httpx_client, id=argilla_id)

self.__push_records(
client=httpx_client, id=argilla_id, show_progress=show_progress, question_name_to_id=question_name_to_id
)
Expand Down
4 changes: 1 addition & 3 deletions src/argilla/client/feedback/dataset/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,7 @@ def __init__(
TypeError: if `guidelines` is not None and not a string.
ValueError: if `guidelines` is an empty string.
"""
self._fields = fields
self._questions = questions
self._guidelines = guidelines
super().__init__(fields=fields, questions=questions, guidelines=guidelines)

self._client = client # Required to be able to use `allowed_for_roles` decorator
self._id = id
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/client/feedback/dataset/remote/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord
from argilla.client.sdk.users.models import UserRole
from argilla.client.sdk.v1.datasets import api as datasets_api_v1
from argilla.client.sdk.v1.datasets.models import FeedbackResponseStatusFilter
from argilla.client.utils import allowed_for_roles

if TYPE_CHECKING:
from uuid import UUID

import httpx

from argilla.client.feedback.schemas.enums import ResponseStatusFilter
from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes
from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel
from argilla.client.workspaces import Workspace
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(
)

def filter_by(
self, response_status: Union[FeedbackResponseStatusFilter, List[FeedbackResponseStatusFilter]]
self, response_status: Union["ResponseStatusFilter", List["ResponseStatusFilter"]]
) -> FilteredRemoteFeedbackDataset:
"""Filters the current `RemoteFeedbackDataset` based on the `response_status` of
the responses of the records in Argilla. This method creates a new class instance
Expand Down
13 changes: 13 additions & 0 deletions src/argilla/client/feedback/schemas/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,16 @@ class QuestionTypes(str, Enum):
label_selection = "label_selection"
multi_label_selection = "multi_label_selection"
ranking = "ranking"


class ResponseStatus(str, Enum):
draft = "draft"
submitted = "submitted"
discarded = "discarded"


class ResponseStatusFilter(str, Enum):
draft = "draft"
submitted = "submitted"
discarded = "discarded"
missing = "missing"
9 changes: 5 additions & 4 deletions src/argilla/client/feedback/schemas/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractproperty
from abc import ABC, abstractmethod
from typing import Any, Dict, Literal, Optional

from pydantic import BaseModel, Extra, Field, validator
Expand All @@ -21,7 +21,7 @@
from argilla.client.feedback.schemas.validators import title_must_have_value


class FieldSchema(BaseModel):
class FieldSchema(BaseModel, ABC):
"""Base schema for the `FeedbackDataset` fields.
Args:
Expand Down Expand Up @@ -52,12 +52,13 @@ class Config:
extra = Extra.forbid
exclude = {"type"}

@abstractproperty
@property
@abstractmethod
def server_settings(self) -> Dict[str, Any]:
"""Abstract property that should be implemented by the classes that inherit from
this one, and that will be used to create the `FeedbackDataset` in Argilla.
"""
raise NotImplementedError
...

def to_server_payload(self) -> Dict[str, Any]:
"""Method that will be used to create the payload that will be sent to Argilla
Expand Down
9 changes: 5 additions & 4 deletions src/argilla/client/feedback/schemas/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import warnings
from abc import abstractproperty
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Union

from pydantic import BaseModel, Extra, Field, conint, conlist, root_validator, validator
Expand All @@ -23,7 +23,7 @@
from argilla.client.feedback.schemas.validators import title_must_have_value


class QuestionSchema(BaseModel):
class QuestionSchema(BaseModel, ABC):
"""Base schema for the `FeedbackDataset` questions. Which means that all the questions
in the dataset will have at least these fields.
Expand Down Expand Up @@ -58,12 +58,13 @@ class Config:
extra = Extra.forbid
exclude = {"type"}

@abstractproperty
@property
@abstractmethod
def server_settings(self) -> Dict[str, Any]:
"""Abstract property that should be implemented by the classes that inherit from
this one, and that will be used to create the `FeedbackDataset` in Argilla.
"""
raise NotImplementedError
...

def to_server_payload(self) -> Dict[str, Any]:
"""Method that will be used to create the payload that will be sent to Argilla
Expand Down
13 changes: 3 additions & 10 deletions src/argilla/client/feedback/schemas/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.

import warnings
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
from uuid import UUID

from pydantic import BaseModel, Extra, Field, PrivateAttr, StrictInt, StrictStr, conint, validator

from argilla.client.feedback.schemas.enums import ResponseStatus

if TYPE_CHECKING:
from argilla.client.feedback.unification import UnifiedValueSchema

Expand Down Expand Up @@ -46,12 +47,6 @@ class ValueSchema(BaseModel):
value: Union[StrictStr, StrictInt, List[str], List[RankingValueSchema]]


class ResponseStatus(str, Enum):
draft = "draft"
submitted = "submitted"
discarded = "discarded"


class ResponseSchema(BaseModel):
"""Schema for the `FeedbackRecord` response.
Expand Down Expand Up @@ -103,9 +98,7 @@ class SuggestionSchema(BaseModel):
"""Schema for the suggestions for the questions related to the record.
Args:
question_id: ID of the question in Argilla. Defaults to None, and is automatically
fulfilled internally once the question is pushed to Argilla.
question_name: name of the question.
question_name: name of the question in the `FeedbackDataset`.
type: type of the question. Defaults to None. Possible values are `model` or `human`.
score: score of the suggestion. Defaults to None.
value: value of the suggestion, which should match the type of the question.
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/client/feedback/schemas/remote/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import Optional, Type
from uuid import UUID

import httpx
from pydantic import BaseModel


class RemoteSchema(BaseModel):
class RemoteSchema(BaseModel, ABC):
id: Optional[UUID] = None
client: Optional[httpx.Client] = None

Expand Down
Loading

0 comments on commit 7243b71

Please sign in to comment.