From 0c5577239e3e5205f67f5d0bd9f64bfad20bb538 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 25 Sep 2023 17:30:29 +0200 Subject: [PATCH] feat: add creating metadata properties --- src/argilla/client/feedback/dataset/mixins.py | 30 ++++++++++++++++++- .../client/feedback/dataset/remote/base.py | 15 +++++++++- .../client/feedback/dataset/remote/dataset.py | 10 ++++++- src/argilla/client/sdk/v1/datasets/api.py | 29 ++++++++++++++++++ src/argilla/client/sdk/v1/datasets/models.py | 9 ++++++ .../client/sdk/v1/test_datasets.py | 20 +++++++++++++ 6 files changed, 110 insertions(+), 3 deletions(-) diff --git a/src/argilla/client/feedback/dataset/mixins.py b/src/argilla/client/feedback/dataset/mixins.py index 1fb6f06443..09152c8351 100644 --- a/src/argilla/client/feedback/dataset/mixins.py +++ b/src/argilla/client/feedback/dataset/mixins.py @@ -52,7 +52,11 @@ from argilla.client.client import Argilla as ArgillaClient from argilla.client.feedback.dataset.local import FeedbackDataset - from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes + from argilla.client.feedback.schemas.types import ( + AllowedRemoteFieldTypes, + AllowedRemoteMetadataPropertyTypes, + AllowedRemoteQuestionTypes, + ) from argilla.client.sdk.v1.datasets.models import FeedbackDatasetModel @@ -97,6 +101,27 @@ def __add_questions( ) from e return questions + def __add_metadata_properties( + self: "FeedbackDataset", client: "httpx.Client", id: UUID + ) -> Union[List["AllowedRemoteMetadataPropertyTypes"], None]: + if not self._metadata_properties: + return None + + metadata_properties = [] + for metadata_property in self._metadata_properties: + try: + new_metadata_property = datasets_api_v1.add_metadata_property( + client=client, id=id, metadata_property=metadata_property.to_server_payload() + ).parsed + metadata_properties.append(new_metadata_property) + except Exception as e: + self.__delete_dataset(client=client, id=id) + raise Exception( + f"Failed while adding the metadata property '{metadata_property.name}' to the `FeedbackDataset` in" + f" Argilla with exception: {e}" + ) from e + return metadata_properties + def __publish_dataset(self: "FeedbackDataset", client: "httpx.Client", id: UUID) -> None: try: datasets_api_v1.publish_dataset(client=client, id=id) @@ -180,6 +205,8 @@ def push_to_argilla( questions = self.__add_questions(client=httpx_client, id=argilla_id) question_name_to_id = {question.name: question.id for question in questions} + metadata_properties = self.__add_metadata_properties(client=httpx_client, id=argilla_id) + self.__publish_dataset(client=httpx_client, id=argilla_id) self.__push_records( @@ -195,6 +222,7 @@ def push_to_argilla( updated_at=new_dataset.updated_at, fields=fields, questions=questions, + metadata_properties=metadata_properties, guidelines=self.guidelines, ) diff --git a/src/argilla/client/feedback/dataset/remote/base.py b/src/argilla/client/feedback/dataset/remote/base.py index 52d510d8aa..45ee4d33ae 100644 --- a/src/argilla/client/feedback/dataset/remote/base.py +++ b/src/argilla/client/feedback/dataset/remote/base.py @@ -30,7 +30,11 @@ from argilla.client.feedback.dataset.local import FeedbackDataset from argilla.client.feedback.schemas.records import FeedbackRecord - from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes + from argilla.client.feedback.schemas.types import ( + AllowedRemoteFieldTypes, + AllowedRemoteMetadataPropertyTypes, + AllowedRemoteQuestionTypes, + ) from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel from argilla.client.workspaces import Workspace @@ -105,7 +109,9 @@ def __init__( updated_at: datetime, fields: List["AllowedRemoteFieldTypes"], questions: List["AllowedRemoteQuestionTypes"], + metadata_properties: Optional[List["AllowedRemoteMetadataPropertyTypes"]] = None, guidelines: Optional[str] = None, + allow_extra_metadata: bool = False, **kwargs: Any, ) -> None: """Initializes a `RemoteFeedbackDataset` instance in Argilla. @@ -123,7 +129,12 @@ def __init__( updated_at: contains the datetime when the dataset was last updated in Argilla. fields: contains the fields that will define the schema of the records in the dataset. questions: contains the questions that will be used to annotate the dataset. + metadata_properties: contains the metadata properties that will be indexed + and could be used to filter the dataset. Defaults to `None`. guidelines: contains the guidelines for annotating the dataset. Defaults to `None`. + extra_metadata_allowed: whether to allow to include metadata properties that + have not been defined in the `metadata` argument, and thus will not be + indexed. Defaults to `True`. Raises: TypeError: if `fields` is not a list of `FieldSchema`. @@ -136,7 +147,9 @@ def __init__( """ self._fields = fields self._questions = questions + self._metadata_properties = metadata_properties self._guidelines = guidelines + self._allow_extra_metadata = allow_extra_metadata self._client = client # Required to be able to use `allowed_for_roles` decorator self._id = id diff --git a/src/argilla/client/feedback/dataset/remote/dataset.py b/src/argilla/client/feedback/dataset/remote/dataset.py index 8c0cde06ee..4210ef9042 100644 --- a/src/argilla/client/feedback/dataset/remote/dataset.py +++ b/src/argilla/client/feedback/dataset/remote/dataset.py @@ -32,7 +32,11 @@ import httpx - from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes + from argilla.client.feedback.schemas.types import ( + AllowedRemoteFieldTypes, + AllowedRemoteMetadataPropertyTypes, + AllowedRemoteQuestionTypes, + ) from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel from argilla.client.workspaces import Workspace @@ -130,7 +134,9 @@ def __init__( updated_at: datetime, fields: List["AllowedRemoteFieldTypes"], questions: List["AllowedRemoteQuestionTypes"], + metadata_properties: Optional[List["AllowedRemoteMetadataPropertyTypes"]] = None, guidelines: Optional[str] = None, + allow_extra_metadata: bool = False, ) -> None: super().__init__( client=client, @@ -141,7 +147,9 @@ def __init__( updated_at=updated_at, fields=fields, questions=questions, + metadata_properties=metadata_properties, guidelines=guidelines, + allow_extra_metadata=allow_extra_metadata, ) def filter_by( diff --git a/src/argilla/client/sdk/v1/datasets/api.py b/src/argilla/client/sdk/v1/datasets/api.py index e17a482ec3..9772f6bc74 100644 --- a/src/argilla/client/sdk/v1/datasets/api.py +++ b/src/argilla/client/sdk/v1/datasets/api.py @@ -23,6 +23,7 @@ from argilla.client.sdk.v1.datasets.models import ( FeedbackDatasetModel, FeedbackFieldModel, + FeedbackMetadataPropertyModel, FeedbackMetricsModel, FeedbackQuestionModel, FeedbackRecordsModel, @@ -354,6 +355,8 @@ def add_question( Args: client: the authenticated Argilla client to be used to send the request to the API. + id: the id of the dataset to add the question to. + question: the question to be added to the dataset. Returns: A `Response` object containing a `parsed` attribute with the parsed response if the @@ -370,6 +373,32 @@ def add_question( return handle_response_error(response) +def add_metadata_property( + client: httpx.Client, id: UUID, metadata_property: Dict[str, Any] +) -> Response[Union[FeedbackMetadataPropertyModel, ErrorMessage, HTTPValidationError]]: + """Sends a POST request to `/api/v1/datasets/{id}/metadata-properties` endpoint to + add a metadata property to the `FeedbackDataset`. + + Args: + client: the authenticated Argilla client to be used to send the request to the API. + id: the id of the dataset to add the metadata property to. + metadata_property: the metadata property to be added to the dataset. + + Returns: + A `Response` object containing a `parsed` attribute with the parsed response if the + request was successful, which is a `FeedbackMetadataPropertyModel`. + """ + url = f"/api/v1/datasets/{id}/metadata-properties" + + response = client.post(url=url, json=metadata_property) + + if response.status_code == 201: + response_obj = Response.from_httpx_response(response) + response_obj.parsed = FeedbackMetadataPropertyModel(**response.json()) + return response_obj + return handle_response_error(response) + + def set_suggestion( client: httpx.Client, record_id: UUID, diff --git a/src/argilla/client/sdk/v1/datasets/models.py b/src/argilla/client/sdk/v1/datasets/models.py index d6b3b6127b..ed8acd57c6 100644 --- a/src/argilla/client/sdk/v1/datasets/models.py +++ b/src/argilla/client/sdk/v1/datasets/models.py @@ -107,6 +107,15 @@ class FeedbackQuestionModel(BaseModel): updated_at: datetime +class FeedbackMetadataPropertyModel(BaseModel): + id: UUID + name: str + description: Optional[str] = None + settings: Dict[str, Any] + inserted_at: datetime + updated_at: datetime + + class FeedbackRecordsMetricsModel(BaseModel): count: int diff --git a/tests/integration/client/sdk/v1/test_datasets.py b/tests/integration/client/sdk/v1/test_datasets.py index b2abde784b..6aa5b84db9 100644 --- a/tests/integration/client/sdk/v1/test_datasets.py +++ b/tests/integration/client/sdk/v1/test_datasets.py @@ -16,6 +16,7 @@ from argilla.client.client import Argilla from argilla.client.sdk.v1.datasets.api import ( add_field, + add_metadata_property, add_question, add_records, create_dataset, @@ -217,6 +218,25 @@ async def test_add_question(role: UserRole) -> None: assert response.status_code == 201 +@pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner]) +@pytest.mark.asyncio +async def test_add_metadata_property(role: UserRole) -> None: + dataset = await DatasetFactory.create() + user = await UserFactory.create(role=role, workspaces=[dataset.workspace]) + + api = Argilla(api_key=user.api_key, workspace=dataset.workspace.name) + response = add_metadata_property( + client=api.client.httpx, + id=dataset.id, + metadata_property={ + "name": "test_metadata_property", + "description": "test_description", + "settings": {"type": "terms", "values": ["a", "b", "c"]}, + }, + ) + assert response.status_code == 201 + + @pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner, UserRole.annotator]) @pytest.mark.asyncio async def test_get_questions(role: UserRole) -> None: