Skip to content

Commit

Permalink
feat: add creating metadata properties
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Sep 25, 2023
1 parent ac5c987 commit 0c55772
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 3 deletions.
30 changes: 29 additions & 1 deletion src/argilla/client/feedback/dataset/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@

from argilla.client.client import Argilla as ArgillaClient
from argilla.client.feedback.dataset.local import FeedbackDataset
from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes
from argilla.client.feedback.schemas.types import (
AllowedRemoteFieldTypes,
AllowedRemoteMetadataPropertyTypes,
AllowedRemoteQuestionTypes,
)
from argilla.client.sdk.v1.datasets.models import FeedbackDatasetModel


Expand Down Expand Up @@ -97,6 +101,27 @@ def __add_questions(
) from e
return questions

def __add_metadata_properties(
self: "FeedbackDataset", client: "httpx.Client", id: UUID
) -> Union[List["AllowedRemoteMetadataPropertyTypes"], None]:
if not self._metadata_properties:
return None

metadata_properties = []
for metadata_property in self._metadata_properties:
try:
new_metadata_property = datasets_api_v1.add_metadata_property(
client=client, id=id, metadata_property=metadata_property.to_server_payload()
).parsed
metadata_properties.append(new_metadata_property)
except Exception as e:
self.__delete_dataset(client=client, id=id)
raise Exception(
f"Failed while adding the metadata property '{metadata_property.name}' to the `FeedbackDataset` in"
f" Argilla with exception: {e}"
) from e
return metadata_properties

def __publish_dataset(self: "FeedbackDataset", client: "httpx.Client", id: UUID) -> None:
try:
datasets_api_v1.publish_dataset(client=client, id=id)
Expand Down Expand Up @@ -180,6 +205,8 @@ def push_to_argilla(
questions = self.__add_questions(client=httpx_client, id=argilla_id)
question_name_to_id = {question.name: question.id for question in questions}

metadata_properties = self.__add_metadata_properties(client=httpx_client, id=argilla_id)

self.__publish_dataset(client=httpx_client, id=argilla_id)

self.__push_records(
Expand All @@ -195,6 +222,7 @@ def push_to_argilla(
updated_at=new_dataset.updated_at,
fields=fields,
questions=questions,
metadata_properties=metadata_properties,
guidelines=self.guidelines,
)

Expand Down
15 changes: 14 additions & 1 deletion src/argilla/client/feedback/dataset/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@

from argilla.client.feedback.dataset.local import FeedbackDataset
from argilla.client.feedback.schemas.records import FeedbackRecord
from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes
from argilla.client.feedback.schemas.types import (
AllowedRemoteFieldTypes,
AllowedRemoteMetadataPropertyTypes,
AllowedRemoteQuestionTypes,
)
from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel
from argilla.client.workspaces import Workspace

Expand Down Expand Up @@ -105,7 +109,9 @@ def __init__(
updated_at: datetime,
fields: List["AllowedRemoteFieldTypes"],
questions: List["AllowedRemoteQuestionTypes"],
metadata_properties: Optional[List["AllowedRemoteMetadataPropertyTypes"]] = None,
guidelines: Optional[str] = None,
allow_extra_metadata: bool = False,
**kwargs: Any,
) -> None:
"""Initializes a `RemoteFeedbackDataset` instance in Argilla.
Expand All @@ -123,7 +129,12 @@ def __init__(
updated_at: contains the datetime when the dataset was last updated in Argilla.
fields: contains the fields that will define the schema of the records in the dataset.
questions: contains the questions that will be used to annotate the dataset.
metadata_properties: contains the metadata properties that will be indexed
and could be used to filter the dataset. Defaults to `None`.
guidelines: contains the guidelines for annotating the dataset. Defaults to `None`.
extra_metadata_allowed: whether to allow to include metadata properties that
have not been defined in the `metadata` argument, and thus will not be
indexed. Defaults to `True`.
Raises:
TypeError: if `fields` is not a list of `FieldSchema`.
Expand All @@ -136,7 +147,9 @@ def __init__(
"""
self._fields = fields
self._questions = questions
self._metadata_properties = metadata_properties
self._guidelines = guidelines
self._allow_extra_metadata = allow_extra_metadata

self._client = client # Required to be able to use `allowed_for_roles` decorator
self._id = id
Expand Down
10 changes: 9 additions & 1 deletion src/argilla/client/feedback/dataset/remote/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@

import httpx

from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes
from argilla.client.feedback.schemas.types import (
AllowedRemoteFieldTypes,
AllowedRemoteMetadataPropertyTypes,
AllowedRemoteQuestionTypes,
)
from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel
from argilla.client.workspaces import Workspace

Expand Down Expand Up @@ -130,7 +134,9 @@ def __init__(
updated_at: datetime,
fields: List["AllowedRemoteFieldTypes"],
questions: List["AllowedRemoteQuestionTypes"],
metadata_properties: Optional[List["AllowedRemoteMetadataPropertyTypes"]] = None,
guidelines: Optional[str] = None,
allow_extra_metadata: bool = False,
) -> None:
super().__init__(
client=client,
Expand All @@ -141,7 +147,9 @@ def __init__(
updated_at=updated_at,
fields=fields,
questions=questions,
metadata_properties=metadata_properties,
guidelines=guidelines,
allow_extra_metadata=allow_extra_metadata,
)

def filter_by(
Expand Down
29 changes: 29 additions & 0 deletions src/argilla/client/sdk/v1/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from argilla.client.sdk.v1.datasets.models import (
FeedbackDatasetModel,
FeedbackFieldModel,
FeedbackMetadataPropertyModel,
FeedbackMetricsModel,
FeedbackQuestionModel,
FeedbackRecordsModel,
Expand Down Expand Up @@ -354,6 +355,8 @@ def add_question(
Args:
client: the authenticated Argilla client to be used to send the request to the API.
id: the id of the dataset to add the question to.
question: the question to be added to the dataset.
Returns:
A `Response` object containing a `parsed` attribute with the parsed response if the
Expand All @@ -370,6 +373,32 @@ def add_question(
return handle_response_error(response)


def add_metadata_property(
client: httpx.Client, id: UUID, metadata_property: Dict[str, Any]
) -> Response[Union[FeedbackMetadataPropertyModel, ErrorMessage, HTTPValidationError]]:
"""Sends a POST request to `/api/v1/datasets/{id}/metadata-properties` endpoint to
add a metadata property to the `FeedbackDataset`.
Args:
client: the authenticated Argilla client to be used to send the request to the API.
id: the id of the dataset to add the metadata property to.
metadata_property: the metadata property to be added to the dataset.
Returns:
A `Response` object containing a `parsed` attribute with the parsed response if the
request was successful, which is a `FeedbackMetadataPropertyModel`.
"""
url = f"/api/v1/datasets/{id}/metadata-properties"

response = client.post(url=url, json=metadata_property)

if response.status_code == 201:
response_obj = Response.from_httpx_response(response)
response_obj.parsed = FeedbackMetadataPropertyModel(**response.json())
return response_obj
return handle_response_error(response)


def set_suggestion(
client: httpx.Client,
record_id: UUID,
Expand Down
9 changes: 9 additions & 0 deletions src/argilla/client/sdk/v1/datasets/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ class FeedbackQuestionModel(BaseModel):
updated_at: datetime


class FeedbackMetadataPropertyModel(BaseModel):
id: UUID
name: str
description: Optional[str] = None
settings: Dict[str, Any]
inserted_at: datetime
updated_at: datetime


class FeedbackRecordsMetricsModel(BaseModel):
count: int

Expand Down
20 changes: 20 additions & 0 deletions tests/integration/client/sdk/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from argilla.client.client import Argilla
from argilla.client.sdk.v1.datasets.api import (
add_field,
add_metadata_property,
add_question,
add_records,
create_dataset,
Expand Down Expand Up @@ -217,6 +218,25 @@ async def test_add_question(role: UserRole) -> None:
assert response.status_code == 201


@pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner])
@pytest.mark.asyncio
async def test_add_metadata_property(role: UserRole) -> None:
dataset = await DatasetFactory.create()
user = await UserFactory.create(role=role, workspaces=[dataset.workspace])

api = Argilla(api_key=user.api_key, workspace=dataset.workspace.name)
response = add_metadata_property(
client=api.client.httpx,
id=dataset.id,
metadata_property={
"name": "test_metadata_property",
"description": "test_description",
"settings": {"type": "terms", "values": ["a", "b", "c"]},
},
)
assert response.status_code == 201


@pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner, UserRole.annotator])
@pytest.mark.asyncio
async def test_get_questions(role: UserRole) -> None:
Expand Down

0 comments on commit 0c55772

Please sign in to comment.