From 1ce3feba264ddf8932b5c8f355f9c0fe0eeee2af Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Fri, 13 Oct 2023 14:41:24 +0200 Subject: [PATCH 01/26] fix: Using utcnow datetime --- src/argilla/server/models/mixins.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/argilla/server/models/mixins.py b/src/argilla/server/models/mixins.py index 539c441c40..92eba5f240 100644 --- a/src/argilla/server/models/mixins.py +++ b/src/argilla/server/models/mixins.py @@ -20,6 +20,7 @@ from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as postgres_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert +from sqlalchemy.engine.default import DefaultExecutionContext from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import Self @@ -158,6 +159,10 @@ async def save(self, db: "AsyncSession", autocommit: bool = True) -> Self: return self +def _default_inserted_at(context: DefaultExecutionContext) -> datetime: + return context.get_current_parameters()["inserted_at"] + + class TimestampMixin: - inserted_at: Mapped[datetime] = mapped_column(default=func.now()) - updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now()) + inserted_at: Mapped[datetime] = mapped_column(default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column(default=_default_inserted_at, onupdate=datetime.utcnow) From 86a3ea14fee4e67b78698a1d37c83110ffe0ebc7 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Fri, 13 Oct 2023 14:41:40 +0200 Subject: [PATCH 02/26] tests: Remove date creation --- tests/factories.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/factories.py b/tests/factories.py index 103ed7cc95..4cae6a064e 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -215,8 +215,6 @@ class Meta: external_id = factory.Sequence(lambda n: f"external-id-{n}") dataset = factory.SubFactory(DatasetFactory) - inserted_at = factory.Sequence(lambda n: datetime.datetime.utcnow() + datetime.timedelta(seconds=n)) - class ResponseFactory(BaseFactory): class Meta: From 83bac6ffc0ef6b02e9a34dc68162ac10bb0b1cbf Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sat, 14 Oct 2023 22:45:04 +0200 Subject: [PATCH 03/26] feat: Define `update_records`for base feeedback dataset class Also, this class defines an generic type `R` for records. --- src/argilla/client/feedback/dataset/base.py | 25 ++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/argilla/client/feedback/dataset/base.py b/src/argilla/client/feedback/dataset/base.py index dc668c2473..477c9e47de 100644 --- a/src/argilla/client/feedback/dataset/base.py +++ b/src/argilla/client/feedback/dataset/base.py @@ -14,7 +14,7 @@ import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +from typing import Generic, Iterable, TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypeVar, Union from pydantic import ValidationError @@ -55,8 +55,10 @@ _LOGGER = logging.getLogger(__name__) +R = TypeVar("R", bound=FeedbackRecord) -class FeedbackDatasetBase(ABC, HuggingFaceDatasetMixin): + +class FeedbackDatasetBase(ABC, HuggingFaceDatasetMixin, Generic[R]): """Base class with shared functionality for `FeedbackDataset` and `RemoteFeedbackDataset`.""" def __init__( @@ -166,10 +168,22 @@ def __init__( @property @abstractmethod - def records(self) -> Any: + def records(self) -> Iterable[R]: """Returns the records of the dataset.""" pass + @abstractmethod + def update_records(self, records: Union[R, List[R]]) -> None: + """Updates the records of the dataset. + + Args: + records: the records to update the dataset with. + + Raises: + ValueError: if the provided `records` are invalid. + """ + pass + @property def guidelines(self) -> str: """Returns the guidelines for annotating the dataset.""" @@ -364,11 +378,12 @@ def _validate_records(self, records: List[FeedbackRecord]) -> None: def _parse_and_validate_records( self, - records: Union[FeedbackRecord, Dict[str, Any], List[Union[FeedbackRecord, Dict[str, Any]]]], - ) -> List[FeedbackRecord]: + records: Union[R, Dict[str, Any], List[Union[R, Dict[str, Any]]]], + ) -> List[R]: """Convenient method for calling `_parse_records` and `_validate_records` in sequence.""" records = self._parse_records(records) self._validate_records(records) + return records @requires_dependencies("datasets") From b10ea26c59a695e66812b4f1b1792c98d210850e Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sat, 14 Oct 2023 22:45:48 +0200 Subject: [PATCH 04/26] refactor: Implement `update_records` method for local datasets The implementation will show a warning with an explicit message --- src/argilla/client/feedback/dataset/local.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/argilla/client/feedback/dataset/local.py b/src/argilla/client/feedback/dataset/local.py index 48c51c8abf..ebe0e61c9f 100644 --- a/src/argilla/client/feedback/dataset/local.py +++ b/src/argilla/client/feedback/dataset/local.py @@ -30,7 +30,7 @@ ) -class FeedbackDataset(FeedbackDatasetBase, ArgillaMixin, UnificationMixin): +class FeedbackDataset(FeedbackDatasetBase["FeedbackRecord"], ArgillaMixin, UnificationMixin): def __init__( self, *, @@ -127,6 +127,12 @@ def records(self) -> List["FeedbackRecord"]: """Returns the records in the dataset.""" return self._records + def update_records(self, records: Union["FeedbackRecord", List["FeedbackRecord"]]) -> None: + warnings.warn( + "`update_records` method only works for `FeedbackDataset` pushed to Argilla. " + "If your are working with local data, you can just iterate over the records and update them." + ) + def __repr__(self) -> str: """Returns a string representation of the dataset.""" return f"" From bbd4236d4c4f2137ed59c7fd75e46a8619fd3318 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sat, 14 Oct 2023 22:50:08 +0200 Subject: [PATCH 05/26] feat: Implement `update_records` method based on `record.update` Also, the question (id -> name) and question (name -> id) maps are computed from the original dataset --- .../client/feedback/dataset/remote/dataset.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/argilla/client/feedback/dataset/remote/dataset.py b/src/argilla/client/feedback/dataset/remote/dataset.py index f1bc4b0093..49f7ab5c3e 100644 --- a/src/argilla/client/feedback/dataset/remote/dataset.py +++ b/src/argilla/client/feedback/dataset/remote/dataset.py @@ -66,10 +66,6 @@ def __init__( and/or attributes. """ self._dataset = dataset - # TODO: review why this is here ! - self._question_id_to_name = {question.id: question.name for question in self._dataset.questions} - self._question_name_to_id = {value: key for key, value in self._question_id_to_name.items()} - # TODO END if response_status and not isinstance(response_status, list): response_status = [response_status] @@ -106,6 +102,14 @@ def _client(self) -> "httpx.Client": """Returns the `httpx.Client` instance that will be used to send requests to Argilla.""" return self.dataset._client + @property + def _question_id_to_name(self) -> Dict["UUID", str]: + return self.dataset._question_id_to_name_id + + @property + def _question_name_to_id(self) -> Dict[str, "UUID"]: + return self.dataset._question_name_to_id + @allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) def __len__(self) -> int: """Returns the number of records in the current `FeedbackDataset` in Argilla.""" @@ -233,7 +237,7 @@ def _create_from_dataset( ) -class RemoteFeedbackDataset(FeedbackDatasetBase): +class RemoteFeedbackDataset(FeedbackDatasetBase[RemoteFeedbackRecord]): # TODO: Call super method once the base init contains only commons init attributes def __init__( self, @@ -304,6 +308,14 @@ def records(self) -> RemoteFeedbackRecords: """ return self._records + def update_records(self, records: Union[RemoteFeedbackRecord, List[RemoteFeedbackRecord]]) -> None: + if not isinstance(records, list): + records = [records] + + # TODO: Use the batch version of endpoint once is implemented + for record in records: + record.update() + @property def id(self) -> "UUID": """Returns the ID of the dataset in Argilla.""" @@ -334,6 +346,14 @@ def updated_at(self) -> datetime: """Returns the datetime when the dataset was last updated in Argilla.""" return self._updated_at + @property + def _question_id_to_name_id(self) -> Dict["UUID", str]: + return {question.id: question.name for question in self._questions} + + @property + def _question_name_to_id(self) -> Dict[str, "UUID"]: + return {question.name: question.id for question in self._questions} + def __repr__(self) -> str: """Returns a string representation of the dataset.""" return ( From bcdbe9c248393d5665c493ef1107e277b632a6c2 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sat, 14 Oct 2023 23:03:31 +0200 Subject: [PATCH 06/26] chore: Fix `ArgillaRecordsMixin` method signatures --- src/argilla/client/feedback/dataset/remote/mixins.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/argilla/client/feedback/dataset/remote/mixins.py b/src/argilla/client/feedback/dataset/remote/mixins.py index bb74bf2819..4de22d99b0 100644 --- a/src/argilla/client/feedback/dataset/remote/mixins.py +++ b/src/argilla/client/feedback/dataset/remote/mixins.py @@ -21,13 +21,13 @@ from argilla.client.utils import allowed_for_roles if TYPE_CHECKING: - from argilla.client.feedback.dataset.remote.base import RemoteFeedbackRecordsBase + from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackRecords class ArgillaRecordsMixin: @allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) def __getitem__( - self: "RemoteFeedbackRecordsBase", key: Union[slice, int] + self: "RemoteFeedbackRecords", key: Union[slice, int] ) -> Union["RemoteFeedbackRecord", List["RemoteFeedbackRecord"]]: """Returns the record(s) at the given index(es) from Argilla. @@ -103,7 +103,7 @@ def __getitem__( @allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) def __iter__( - self: "RemoteFeedbackRecordsBase", + self: "RemoteFeedbackRecords", ) -> Iterator["RemoteFeedbackRecord"]: """Iterates over the `FeedbackRecord`s of the current `FeedbackDataset` in Argilla.""" current_batch = 0 From 1a464f3869262918a8c7cbcfe6e38d10d6ecee5f Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sat, 14 Oct 2023 23:08:40 +0200 Subject: [PATCH 07/26] chore: Add some TODO reminders --- src/argilla/client/feedback/schemas/records.py | 1 + src/argilla/client/feedback/schemas/remote/shared.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/argilla/client/feedback/schemas/records.py b/src/argilla/client/feedback/schemas/records.py index f0f5910b1a..1f58002879 100644 --- a/src/argilla/client/feedback/schemas/records.py +++ b/src/argilla/client/feedback/schemas/records.py @@ -88,6 +88,7 @@ def to_server_payload(self) -> Dict[str, Any]: """Method that will be used to create the payload that will be sent to Argilla to create a `ResponseSchema` for a `FeedbackRecord`.""" return { + # UUID is not json serializable!!! "user_id": self.user_id, "values": {question_name: value.dict() for question_name, value in self.values.items()} if self.values is not None diff --git a/src/argilla/client/feedback/schemas/remote/shared.py b/src/argilla/client/feedback/schemas/remote/shared.py index 9622127d0e..0eba971988 100644 --- a/src/argilla/client/feedback/schemas/remote/shared.py +++ b/src/argilla/client/feedback/schemas/remote/shared.py @@ -21,6 +21,7 @@ class RemoteSchema(BaseModel, ABC): + # TODO(@alvarobartt): Review optional id configuration for remote schemas id: Optional[UUID] = None client: Optional[httpx.Client] = None @@ -30,6 +31,7 @@ def _client(self) -> Optional[httpx.Client]: return self.client class Config: + # TODO(@alvarobart) Not sure if we need this at this level allow_mutation = False arbitrary_types_allowed = True From 875f88a2d9e87fa74164d744f2a222e4c62e31a0 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sat, 14 Oct 2023 23:14:15 +0200 Subject: [PATCH 08/26] feat: `record.update` support record level update Records can be updated by assigning content and then call the `record.update` method. Suggestions are still supported, so users can update a record by passing the suggestions. But a more general way should be: ```python record.metadata.update({"new": "metadata"}) record.suggestions = (Suggestion....) record.update() ``` --- .../client/feedback/schemas/remote/records.py | 61 +++++++++++-------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/src/argilla/client/feedback/schemas/remote/records.py b/src/argilla/client/feedback/schemas/remote/records.py index fdf84a113e..c74cba565d 100644 --- a/src/argilla/client/feedback/schemas/remote/records.py +++ b/src/argilla/client/feedback/schemas/remote/records.py @@ -31,7 +31,7 @@ import httpx from argilla.client.sdk.v1.datasets.models import FeedbackResponseModel, FeedbackSuggestionModel - from argilla.client.sdk.v1.records.models import FeedbackItemModel + from argilla.client.sdk.v1.records.models import FeedbackRecordModel class RemoteSuggestionSchema(SuggestionSchema, RemoteSchema): @@ -95,12 +95,23 @@ def from_api(cls, payload: "FeedbackResponseModel") -> "RemoteResponseSchema": return RemoteResponseSchema( user_id=payload.user_id, values=payload.values, + # TODO: Review type mismatch between API and SDK status=payload.status, inserted_at=payload.inserted_at, updated_at=payload.updated_at, ) +AllowedSuggestionTypes = Union[ + RemoteSuggestionSchema, + SuggestionSchema, + Dict[str, Any], + List[RemoteSuggestionSchema], + List[SuggestionSchema], + List[Dict[str, Any]], +] + + class RemoteFeedbackRecord(FeedbackRecord, RemoteSchema): """Schema for the records of a `RemoteFeedbackDataset`. @@ -117,6 +128,7 @@ class RemoteFeedbackRecord(FeedbackRecord, RemoteSchema): question. Defaults to an empty list. """ + # TODO: remote record should receive a dataset instead of this question_name_to_id: Optional[Dict[str, UUID]] = Field(..., exclude=True, repr=False) responses: List[RemoteResponseSchema] = Field(default_factory=list) @@ -125,19 +137,10 @@ class RemoteFeedbackRecord(FeedbackRecord, RemoteSchema): ) class Config: + allow_mutation = True validate_assignment = True - def __update_suggestions( - self, - suggestions: Union[ - RemoteSuggestionSchema, - List[RemoteSuggestionSchema], - SuggestionSchema, - List[SuggestionSchema], - Dict[str, Any], - List[Dict[str, Any]], - ], - ) -> None: + def __update_suggestions(self, suggestions: AllowedSuggestionTypes) -> None: """Updates the suggestions for the record in Argilla. Note that the suggestions must exist in Argilla to be updated. @@ -220,17 +223,7 @@ def __update_suggestions( self.__dict__["suggestions"] = tuple(existing_suggestions.values()) @allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) - def update( - self, - suggestions: Union[ - RemoteSuggestionSchema, - List[RemoteSuggestionSchema], - SuggestionSchema, - List[SuggestionSchema], - Dict[str, Any], - List[Dict[str, Any]], - ], - ) -> None: + def update(self, suggestions: Optional[AllowedSuggestionTypes] = None) -> None: """Update a `RemoteFeedbackRecord`. Currently just `suggestions` are supported. Note that this method will update the record in Argilla directly. @@ -244,7 +237,25 @@ def update( Raises: PermissionError: if the user does not have either `owner` or `admin` role. """ - self.__update_suggestions(suggestions=suggestions) + + self.__updated_record_data() + + suggestions = suggestions or self.suggestions + if suggestions: + self.__update_suggestions(suggestions=suggestions) + + def __updated_record_data(self) -> None: + response = records_api_v1.update_record(self.client, self.id, self.to_server_payload()) + + updated_record = self.from_api( + payload=response.parsed, + question_id_to_name={value: key for key, value in self.question_name_to_id.items()} + if self.question_name_to_id + else None, + client=self.client, + ) + + self.__dict__.update(updated_record.__dict__) @allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) def delete_suggestions(self, suggestions: Union[RemoteSuggestionSchema, List[RemoteSuggestionSchema]]) -> None: @@ -316,7 +327,7 @@ def to_local(self) -> "FeedbackRecord": @classmethod def from_api( cls, - payload: "FeedbackItemModel", + payload: "FeedbackRecordModel", question_id_to_name: Optional[Dict[UUID, str]] = None, client: Optional["httpx.Client"] = None, ) -> "RemoteFeedbackRecord": From 1daff38d6eb627c271dd340a831f7cb74c6490d3 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sat, 14 Oct 2023 23:14:54 +0200 Subject: [PATCH 09/26] feat: call record update endpoint --- src/argilla/client/sdk/v1/records/api.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/argilla/client/sdk/v1/records/api.py b/src/argilla/client/sdk/v1/records/api.py index 6fe13d14e8..855b3e0ead 100644 --- a/src/argilla/client/sdk/v1/records/api.py +++ b/src/argilla/client/sdk/v1/records/api.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Any, Dict, List, Union from uuid import UUID import httpx @@ -22,6 +22,28 @@ from argilla.client.sdk.v1.records.models import FeedbackItemModel +def update_record( + # TODO: Use the proper sdk API Model instead of the dict + client: httpx.Client, id: Union[str, UUID], data: Dict[str, Any] +) -> Response[Union[FeedbackItemModel, ErrorMessage, HTTPValidationError]]: + url = f"/api/v1/records/{id}" + + body = {} + if "metadata" in data: + body["metadata"] = data["metadata"] + if "external_id" in data: + body["external_id"] = data["external_id"] + + response = client.patch(url=url, json=body) + + if response.status_code == 200: + response_obj = Response.from_httpx_response(response) + response_obj.parsed = FeedbackItemModel.parse_raw(response.content) + return response_obj + + return handle_response_error(response) + + def delete_record( client: httpx.Client, id: UUID ) -> Response[Union[FeedbackItemModel, ErrorMessage, HTTPValidationError]]: From 8e24a368f8449d2c2c9d0f7b20211740640de400 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sat, 14 Oct 2023 23:33:47 +0200 Subject: [PATCH 10/26] refactor: Support updating record suggestions through `update_records` workflow Record suggestions can be modified locally to prepare changes and then call the `ds.updated_records` with modified suggestions. The `record.update` still support suggestions ```python records = ds.records[:10] for record in records: record.suggestions = [SuggestionSchema(...)] record.metadata.update({"new": "metadata"}) # Apply all local changes to remote records ds.update_records(records) ``` --- .../client/feedback/schemas/records.py | 4 +-- .../client/feedback/schemas/remote/records.py | 25 +++++++++---------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/src/argilla/client/feedback/schemas/records.py b/src/argilla/client/feedback/schemas/records.py index 1f58002879..33fa6a57d4 100644 --- a/src/argilla/client/feedback/schemas/records.py +++ b/src/argilla/client/feedback/schemas/records.py @@ -195,9 +195,7 @@ class FeedbackRecord(BaseModel): fields: Dict[str, Union[str, None]] metadata: Dict[str, Any] = Field(default_factory=dict) responses: List[ResponseSchema] = Field(default_factory=list) - suggestions: Union[Tuple[SuggestionSchema], List[SuggestionSchema]] = Field( - default_factory=tuple, allow_mutation=False - ) + suggestions: Union[Tuple[SuggestionSchema], List[SuggestionSchema]] = Field(default_factory=tuple) external_id: Optional[str] = None _unified_responses: Optional[Dict[str, List["UnifiedValueSchema"]]] = PrivateAttr(default_factory=dict) diff --git a/src/argilla/client/feedback/schemas/remote/records.py b/src/argilla/client/feedback/schemas/remote/records.py index c74cba565d..67bb7802a7 100644 --- a/src/argilla/client/feedback/schemas/remote/records.py +++ b/src/argilla/client/feedback/schemas/remote/records.py @@ -102,14 +102,7 @@ def from_api(cls, payload: "FeedbackResponseModel") -> "RemoteResponseSchema": ) -AllowedSuggestionTypes = Union[ - RemoteSuggestionSchema, - SuggestionSchema, - Dict[str, Any], - List[RemoteSuggestionSchema], - List[SuggestionSchema], - List[Dict[str, Any]], -] +AllowedSuggestionSchema = Union[RemoteSuggestionSchema, SuggestionSchema] class RemoteFeedbackRecord(FeedbackRecord, RemoteSchema): @@ -132,15 +125,21 @@ class RemoteFeedbackRecord(FeedbackRecord, RemoteSchema): question_name_to_id: Optional[Dict[str, UUID]] = Field(..., exclude=True, repr=False) responses: List[RemoteResponseSchema] = Field(default_factory=list) - suggestions: Union[Tuple[RemoteSuggestionSchema], List[RemoteSuggestionSchema]] = Field( - default_factory=tuple, allow_mutation=False - ) + suggestions: Union[Tuple[AllowedSuggestionSchema], List[AllowedSuggestionSchema]] = Field(default_factory=tuple) class Config: allow_mutation = True validate_assignment = True - def __update_suggestions(self, suggestions: AllowedSuggestionTypes) -> None: + def __update_suggestions( + self, + suggestions: Union[ + Dict[str, Any], + List[Dict[str, Any]], + AllowedSuggestionSchema, + List[AllowedSuggestionSchema], + ], + ) -> None: """Updates the suggestions for the record in Argilla. Note that the suggestions must exist in Argilla to be updated. @@ -223,7 +222,7 @@ def __update_suggestions(self, suggestions: AllowedSuggestionTypes) -> None: self.__dict__["suggestions"] = tuple(existing_suggestions.values()) @allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) - def update(self, suggestions: Optional[AllowedSuggestionTypes] = None) -> None: + def update(self, suggestions: Optional[AllowedSuggestionSchema] = None) -> None: """Update a `RemoteFeedbackRecord`. Currently just `suggestions` are supported. Note that this method will update the record in Argilla directly. From b28c45f2cec5f377cbd829f7039b41c5032b6437 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sat, 14 Oct 2023 23:34:29 +0200 Subject: [PATCH 11/26] tests: Adapt Test base dataset including missing abstract methods to test dataset class --- tests/unit/client/feedback/dataset/test_base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unit/client/feedback/dataset/test_base.py b/tests/unit/client/feedback/dataset/test_base.py index c9ca77fbe9..2da96232b3 100644 --- a/tests/unit/client/feedback/dataset/test_base.py +++ b/tests/unit/client/feedback/dataset/test_base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Iterable, List, Optional, Union +from typing import Iterable, List, Optional, TYPE_CHECKING, Union import pytest from argilla.client.feedback.dataset.base import FeedbackDatasetBase @@ -26,7 +26,6 @@ ) from argilla.client.feedback.schemas.questions import RatingQuestion, TextQuestion from argilla.client.feedback.schemas.records import FeedbackRecord -from argilla.client.sdk.v1.datasets.models import FeedbackResponseStatusFilter if TYPE_CHECKING: from argilla.client.feedback.schemas.types import ( @@ -37,6 +36,9 @@ class TestFeedbackDataset(FeedbackDatasetBase): + def update_records(self, records: Union[FeedbackRecord, List[FeedbackRecord]]) -> None: + pass + def filter_by( self, *, From c42b963599e1332256543f75c6411f22401388d7 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sat, 14 Oct 2023 23:35:08 +0200 Subject: [PATCH 12/26] tests: Add unit test for local.update_records workflow --- tests/unit/client/feedback/dataset/test_local.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/unit/client/feedback/dataset/test_local.py b/tests/unit/client/feedback/dataset/test_local.py index 2732657f67..d669548c5b 100644 --- a/tests/unit/client/feedback/dataset/test_local.py +++ b/tests/unit/client/feedback/dataset/test_local.py @@ -204,3 +204,19 @@ def test_add_metadata_property_errors(metadata_property: "AllowedMetadataPropert ): _ = dataset.add_metadata_property(metadata_property) assert len(dataset.metadata_properties) == 3 + + +def test_update_records_with_warning() -> None: + dataset = FeedbackDataset( + fields=[TextField(name="required-field")], + questions=[TextQuestion(name="question")], + ) + + with pytest.warns( + UserWarning, + match="`update_records` method only works for `FeedbackDataset` pushed to Argilla." + " If your are working with local data, you can just iterate over the records and update them.", + ): + dataset.update_records( + FeedbackRecord(fields={"required-field": "text"}, metadata={"nested-metadata": {"a": 1}}) + ) From 2a94c0e4a07c6fdf14d9f92a3db2749d4e5a6ef0 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sun, 15 Oct 2023 00:49:36 +0200 Subject: [PATCH 13/26] chore: Move `set_suggestions` fuction to records.py API module --- src/argilla/client/sdk/v1/datasets/api.py | 50 +---------------------- src/argilla/client/sdk/v1/records/api.py | 50 ++++++++++++++++++++++- 2 files changed, 50 insertions(+), 50 deletions(-) diff --git a/src/argilla/client/sdk/v1/datasets/api.py b/src/argilla/client/sdk/v1/datasets/api.py index 47c8ad23f0..90a3d13345 100644 --- a/src/argilla/client/sdk/v1/datasets/api.py +++ b/src/argilla/client/sdk/v1/datasets/api.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Optional, Union from uuid import UUID import httpx @@ -28,7 +28,6 @@ FeedbackQuestionModel, FeedbackRecordsModel, FeedbackResponseStatusFilter, - FeedbackSuggestionModel, ) @@ -440,53 +439,6 @@ def add_metadata_property( return handle_response_error(response) -def set_suggestion( - client: httpx.Client, - record_id: UUID, - question_id: UUID, - value: Any, - type: Optional[Literal["model", "human"]] = None, - score: Optional[float] = None, - agent: Optional[str] = None, -) -> Response[Union[FeedbackSuggestionModel, ErrorMessage, HTTPValidationError]]: - """Sends a PUT request to `/api/v1/records/{id}/suggestions` endpoint to add or update - a suggestion for a question in the `FeedbackDataset`. - - Args: - client: the authenticated Argilla client to be used to send the request to the API. - record_id: the id of the record to add the suggestion to. - question_id: the id of the question to add the suggestion to. - value: the value of the suggestion. - type: the type of the suggestion. It can be either `model` or `human`. Defaults to None. - score: the score of the suggestion. Defaults to None. - agent: the agent used to obtain the suggestion. Defaults to None. - - Returns: - A `Response` object containing a `parsed` attribute with the parsed response if the - request was successful, which is a `FeedbackSuggestionModel`. - """ - url = f"/api/v1/records/{record_id}/suggestions" - - suggestion = { - "question_id": str(question_id), - "value": value, - } - if type is not None: - suggestion["type"] = type - if score is not None: - suggestion["score"] = score - if agent is not None: - suggestion["agent"] = agent - - response = client.put(url=url, json=suggestion) - - if response.status_code in [200, 201]: - response_obj = Response.from_httpx_response(response) - response_obj.parsed = FeedbackSuggestionModel(**response.json()) - return response_obj - return handle_response_error(response) - - def get_metrics( client: httpx.Client, id: UUID, diff --git a/src/argilla/client/sdk/v1/records/api.py b/src/argilla/client/sdk/v1/records/api.py index 855b3e0ead..9d331d1fa8 100644 --- a/src/argilla/client/sdk/v1/records/api.py +++ b/src/argilla/client/sdk/v1/records/api.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Literal, Optional, Union from uuid import UUID import httpx from argilla.client.sdk.commons.errors_handler import handle_response_error from argilla.client.sdk.commons.models import ErrorMessage, HTTPValidationError, Response +from argilla.client.sdk.v1.datasets.models import FeedbackSuggestionModel from argilla.client.sdk.v1.records.models import FeedbackItemModel @@ -91,3 +92,50 @@ def delete_suggestions( if response.status_code == 204: return Response.from_httpx_response(response) return handle_response_error(response) + + +def set_suggestion( + client: httpx.Client, + record_id: UUID, + question_id: UUID, + value: Any, + type: Optional[Literal["model", "human"]] = None, + score: Optional[float] = None, + agent: Optional[str] = None, +) -> Response[Union[FeedbackSuggestionModel, ErrorMessage, HTTPValidationError]]: + """Sends a PUT request to `/api/v1/records/{id}/suggestions` endpoint to add or update + a suggestion for a question in the `FeedbackDataset`. + + Args: + client: the authenticated Argilla client to be used to send the request to the API. + record_id: the id of the record to add the suggestion to. + question_id: the id of the question to add the suggestion to. + value: the value of the suggestion. + type: the type of the suggestion. It can be either `model` or `human`. Defaults to None. + score: the score of the suggestion. Defaults to None. + agent: the agent used to obtain the suggestion. Defaults to None. + + Returns: + A `Response` object containing a `parsed` attribute with the parsed response if the + request was successful, which is a `FeedbackSuggestionModel`. + """ + url = f"/api/v1/records/{record_id}/suggestions" + + suggestion = { + "question_id": str(question_id), + "value": value, + } + if type is not None: + suggestion["type"] = type + if score is not None: + suggestion["score"] = score + if agent is not None: + suggestion["agent"] = agent + + response = client.put(url=url, json=suggestion) + + if response.status_code in [200, 201]: + response_obj = Response.from_httpx_response(response) + response_obj.parsed = FeedbackSuggestionModel(**response.json()) + return response_obj + return handle_response_error(response) From ab77ecb61cb6d2a64733df88e62c000b9d17d434 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sun, 15 Oct 2023 00:52:33 +0200 Subject: [PATCH 14/26] refactor: Control record suggestions updates from `update(suggestions=...) and `record.update()` The suggestions will be filtered before update them if suggestions where provided in the `record.update` method. Otherwise, the record suggestions will be sent as new suggestions --- .../client/feedback/schemas/remote/records.py | 64 +++++++++++-------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/src/argilla/client/feedback/schemas/remote/records.py b/src/argilla/client/feedback/schemas/remote/records.py index 67bb7802a7..2ff6ea9315 100644 --- a/src/argilla/client/feedback/schemas/remote/records.py +++ b/src/argilla/client/feedback/schemas/remote/records.py @@ -14,7 +14,7 @@ import warnings from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union from uuid import UUID from pydantic import Field @@ -22,7 +22,6 @@ from argilla.client.feedback.schemas.records import FeedbackRecord, ResponseSchema, SuggestionSchema from argilla.client.feedback.schemas.remote.shared import RemoteSchema from argilla.client.sdk.users.models import UserRole -from argilla.client.sdk.v1.datasets import api as datasets_api_v1 from argilla.client.sdk.v1.records import api as records_api_v1 from argilla.client.sdk.v1.suggestions import api as suggestions_api_v1 from argilla.client.utils import allowed_for_roles @@ -131,25 +130,19 @@ class Config: allow_mutation = True validate_assignment = True - def __update_suggestions( + def __normalize_suggestions_to_update( self, suggestions: Union[ - Dict[str, Any], - List[Dict[str, Any]], - AllowedSuggestionSchema, - List[AllowedSuggestionSchema], + Dict[str, Any], List[Dict[str, Any]], AllowedSuggestionSchema, List[AllowedSuggestionSchema] ], - ) -> None: - """Updates the suggestions for the record in Argilla. Note that the suggestions - must exist in Argilla to be updated. - - Note that this method will update the record in Argilla directly. + ) -> List[AllowedSuggestionSchema]: + """ + Normalizes the suggestions to update. Args: - suggestions: can be a single `RemoteSuggestionSchema` or `SuggestionSchema`, - a list of `RemoteSuggestionSchema` or `SuggestionSchema`, a single - dictionary, or a list of dictionaries. If a dictionary is provided, - it will be converted to a `RemoteSuggestionSchema` internally. + suggestions: can be a single `RemoteSuggestionSchema` or `SuggestionSchema`, a dictionary, a list of + `RemoteSuggestionSchema` or `SuggestionSchema`, or a list of dictionaries. If a dictionary is provided, + it will be converted to a `SuggestionSchema` internally. """ if isinstance(suggestions, (dict, SuggestionSchema)): suggestions = [suggestions] @@ -205,21 +198,37 @@ def __update_suggestions( else: new_suggestions[suggestion.question_name] = suggestion - for suggestion in new_suggestions.values(): + return list(new_suggestions.values()) + + def __update_suggestions(self, suggestions: List[AllowedSuggestionSchema]) -> None: + """Updates the suggestions for the record in Argilla. + + Note that this method will update the record in Argilla directly. + + Args: + suggestions: can be a list of `RemoteSuggestionSchema` or `SuggestionSchema`. + """ + + pushed_suggestions = [] + + for suggestion in suggestions: if isinstance(suggestion, RemoteSuggestionSchema): suggestion = suggestion.to_local() - pushed_suggestion = datasets_api_v1.set_suggestion( + # TODO: review the existence of bulk endpoint for record suggestions + pushed_suggestion = records_api_v1.set_suggestion( client=self.client, record_id=self.id, **suggestion.to_server_payload(question_name_to_id=self.question_name_to_id), ) - existing_suggestions[suggestion.question_name] = RemoteSuggestionSchema.from_api( - payload=pushed_suggestion.parsed, - question_id_to_name={value: key for key, value in self.question_name_to_id.items()}, - client=self.client, + pushed_suggestions.append( + RemoteSuggestionSchema.from_api( + payload=pushed_suggestion.parsed, + question_id_to_name={value: key for key, value in self.question_name_to_id.items()}, + client=self.client, + ) ) - self.__dict__["suggestions"] = tuple(existing_suggestions.values()) + self.__dict__["suggestions"] = tuple(pushed_suggestions) @allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) def update(self, suggestions: Optional[AllowedSuggestionSchema] = None) -> None: @@ -236,10 +245,12 @@ def update(self, suggestions: Optional[AllowedSuggestionSchema] = None) -> None: Raises: PermissionError: if the user does not have either `owner` or `admin` role. """ + if suggestions: + suggestions = self.__normalize_suggestions_to_update(suggestions) + else: + suggestions = suggestions or [s for s in self.suggestions] self.__updated_record_data() - - suggestions = suggestions or self.suggestions if suggestions: self.__update_suggestions(suggestions=suggestions) @@ -292,7 +303,8 @@ def delete_suggestions(self, suggestions: Union[RemoteSuggestionSchema, List[Rem self.__dict__["suggestions"] = tuple(existing_suggestions.values()) except Exception as e: raise RuntimeError( - f"Failed to delete suggestions with IDs `{[suggestion.id for suggestion in delete_suggestions]}` from record with ID `{self.id}` from Argilla." + f"Failed to delete suggestions with IDs `{[suggestion.id for suggestion in delete_suggestions]}` from " + f"record with ID `{self.id}` from Argilla." ) from e @allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) From 4b0ccbdfd11219efa460baa65996e31717063b05 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sun, 15 Oct 2023 00:55:12 +0200 Subject: [PATCH 15/26] refactor: Define the workspace instance creation method private for better integration with unit tests (A code review must be taken in order to not modify a class because the tests) --- src/argilla/client/workspaces.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/argilla/client/workspaces.py b/src/argilla/client/workspaces.py index c8647ba5a6..350c17cd96 100644 --- a/src/argilla/client/workspaces.py +++ b/src/argilla/client/workspaces.py @@ -262,7 +262,7 @@ def __active_client() -> "httpx.Client": raise RuntimeError(f"The `rg.active_client()` is not available or not respoding.") from e @classmethod - def __new_instance( + def _new_instance( cls, client: Optional["httpx.Client"] = None, ws: Optional[Union[WorkspaceModelV0, WorkspaceModelV1]] = None ) -> "Workspace": """Returns a new `Workspace` instance.""" @@ -293,7 +293,7 @@ def create(cls, name: str) -> "Workspace": client = cls.__active_client() try: ws = workspaces_api.create_workspace(client, name).parsed - return cls.__new_instance(client, ws) + return cls._new_instance(client, ws) except AlreadyExistsApiError as e: raise ValueError(f"Workspace with name=`{name}` already exists, so please use a different name.") from e except (ValidationApiError, BaseClientError) as e: @@ -321,7 +321,7 @@ def from_id(cls, id: UUID) -> "Workspace": client = cls.__active_client() try: ws = workspaces_api_v1.get_workspace(client, id).parsed - return cls.__new_instance(client, ws) + return cls._new_instance(client, ws) except NotFoundApiError as e: raise ValueError( f"Workspace with id=`{id}` doesn't exist in Argilla, so please" @@ -362,7 +362,7 @@ def from_name(cls, name: str) -> "Workspace": for ws in workspaces: if ws.name == name: - return cls.__new_instance(client, ws) + return cls._new_instance(client, ws) raise ValueError( f"Workspace with name=`{name}` doesn't exist in Argilla, so please" @@ -388,6 +388,6 @@ def list(cls) -> Iterator["Workspace"]: try: workspaces = workspaces_api_v1.list_workspaces_me(client).parsed for ws in workspaces: - yield cls.__new_instance(client, ws) + yield cls._new_instance(client, ws) except Exception as e: raise RuntimeError("Error while retrieving the list of workspaces from Argilla.") from e From 256f797ddc939eadd59faedda717e2e009d2d586 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sun, 15 Oct 2023 01:40:24 +0200 Subject: [PATCH 16/26] fix: Indentation return --- src/argilla/client/feedback/schemas/remote/records.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/argilla/client/feedback/schemas/remote/records.py b/src/argilla/client/feedback/schemas/remote/records.py index 2ff6ea9315..2a3bbf3d74 100644 --- a/src/argilla/client/feedback/schemas/remote/records.py +++ b/src/argilla/client/feedback/schemas/remote/records.py @@ -198,7 +198,7 @@ def __normalize_suggestions_to_update( else: new_suggestions[suggestion.question_name] = suggestion - return list(new_suggestions.values()) + return list(new_suggestions.values()) def __update_suggestions(self, suggestions: List[AllowedSuggestionSchema]) -> None: """Updates the suggestions for the record in Argilla. From 547b641c7676b4103d6ba13f7f04e4f6a3443b03 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sun, 15 Oct 2023 01:41:29 +0200 Subject: [PATCH 17/26] tests: Remove raise check for suggestions immutability --- tests/integration/client/feedback/dataset/test_dataset.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/integration/client/feedback/dataset/test_dataset.py b/tests/integration/client/feedback/dataset/test_dataset.py index 2b87d92ae3..f5946bd0df 100644 --- a/tests/integration/client/feedback/dataset/test_dataset.py +++ b/tests/integration/client/feedback/dataset/test_dataset.py @@ -476,13 +476,7 @@ async def test_update_dataset_records_in_argilla( }, ] ) - with pytest.raises(TypeError, match='"RemoteFeedbackRecord" is immutable and does not support item assignment'): - record.suggestions = [ - { - "question_name": "question-1", - "value": "This is a suggestion to question 1", - }, - ] + def test_push_to_huggingface_and_from_huggingface( From 5d416d72dd3a3526829f001f1cfb4f6c3f7c99d8 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sun, 15 Oct 2023 01:41:57 +0200 Subject: [PATCH 18/26] chore: Adapt imports --- tests/integration/client/sdk/v1/test_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/client/sdk/v1/test_datasets.py b/tests/integration/client/sdk/v1/test_datasets.py index 7d669135f7..4f6ac18b2c 100644 --- a/tests/integration/client/sdk/v1/test_datasets.py +++ b/tests/integration/client/sdk/v1/test_datasets.py @@ -39,8 +39,8 @@ get_records, list_datasets, publish_dataset, - set_suggestion, ) +from argilla.client.sdk.v1.records.api import set_suggestion from argilla.client.sdk.v1.datasets.models import ( FeedbackDatasetModel, FeedbackFieldModel, From a85279a173ea4d5ca639f708486390957b52fba5 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sun, 15 Oct 2023 01:42:14 +0200 Subject: [PATCH 19/26] tests: fixture for mock httpx client --- tests/unit/conftest.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index c01502746a..3ef030f37e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -15,6 +15,7 @@ import asyncio from typing import TYPE_CHECKING, AsyncGenerator, Generator +import httpx import pytest import pytest_asyncio from argilla.cli.server.database.migrate import migrate_db @@ -37,6 +38,11 @@ def event_loop() -> Generator["asyncio.AbstractEventLoop", None, None]: loop.close() +@pytest.fixture(scope="function") +def mock_httpx_client(mocker) -> Generator[httpx.Client, None, None]: + return mocker.Mock(httpx.Client) + + @pytest_asyncio.fixture(scope="session") async def connection() -> AsyncGenerator["AsyncConnection", None]: set_task(asyncio.current_task()) From 2787be30557fc6bdd1b7f33ddc20091fac5cc213 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sun, 15 Oct 2023 01:42:44 +0200 Subject: [PATCH 20/26] tests: Unit tests for update records with and without suggestions --- .../feedback/dataset/remote/__init__.py | 0 .../feedback/dataset/remote/test_dataset.py | 238 ++++++++++++++++++ 2 files changed, 238 insertions(+) create mode 100644 tests/unit/client/feedback/dataset/remote/__init__.py create mode 100644 tests/unit/client/feedback/dataset/remote/test_dataset.py diff --git a/tests/unit/client/feedback/dataset/remote/__init__.py b/tests/unit/client/feedback/dataset/remote/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/client/feedback/dataset/remote/test_dataset.py b/tests/unit/client/feedback/dataset/remote/test_dataset.py new file mode 100644 index 0000000000..108277fb8c --- /dev/null +++ b/tests/unit/client/feedback/dataset/remote/test_dataset.py @@ -0,0 +1,238 @@ +from datetime import datetime +from uuid import uuid4 + +import httpx +import pytest + +from argilla import Workspace +from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset +from argilla.client.feedback.schemas import SuggestionSchema +from argilla.client.feedback.schemas.remote.fields import RemoteTextField +from argilla.client.feedback.schemas.remote.questions import RemoteTextQuestion +from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord +from argilla.client.sdk.users.models import UserModel, UserRole +from argilla.client.sdk.v1.datasets.models import FeedbackItemModel, FeedbackSuggestionModel +from argilla.client.sdk.v1.workspaces.models import WorkspaceModel + + +@pytest.fixture() +def test_remote_dataset(mock_httpx_client: httpx.Client) -> RemoteFeedbackDataset: + return RemoteFeedbackDataset( + client=mock_httpx_client, + id=uuid4(), + name="test-remote-dataset", + workspace=Workspace._new_instance( + client=mock_httpx_client, + ws=WorkspaceModel( + id=uuid4(), + name="test-remote-workspace", + inserted_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ), + ), + fields=[RemoteTextField(id=uuid4(), name="text")], + questions=[RemoteTextQuestion(id=uuid4(), name="text")], + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ) + + +@pytest.fixture() +def test_remote_record( + mock_httpx_client: httpx.Client, test_remote_dataset: RemoteFeedbackDataset +) -> RemoteFeedbackRecord: + return RemoteFeedbackRecord( + id=uuid4(), + client=mock_httpx_client, + fields={"text": "test"}, + metadata={"new": "metadata"}, + question_name_to_id=test_remote_dataset._question_name_to_id, + ) + + +class TestSuiteRemoteDataset: + def test_update_records( + self, + mock_httpx_client: httpx.Client, + test_remote_dataset: RemoteFeedbackDataset, + test_remote_record: RemoteFeedbackRecord, + ) -> None: + """Test updating records.""" + + mock_httpx_client.get.return_value = httpx.Response( + status_code=200, + content=UserModel( + id=uuid4(), + first_name="test", + username="test", + role=UserRole.owner, + api_key="api.key", + inserted_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ).json(), + ) + + mock_httpx_client.patch.return_value = httpx.Response( + status_code=200, + content=FeedbackItemModel( + id=test_remote_record.id, + fields=test_remote_record.fields, + inserted_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ).json(), + ) + + test_remote_dataset.update_records(records=[test_remote_record]) + + mock_httpx_client.patch.assert_called_once_with( + url=f"/api/v1/records/{test_remote_record.id}", + json={"metadata": {"new": "metadata"}}, + ) + + def test_update_multiple_records( + self, + mock_httpx_client: httpx.Client, + test_remote_dataset: RemoteFeedbackDataset, + test_remote_record: RemoteFeedbackRecord, + ) -> None: + """Test updating records.""" + + mock_httpx_client.get.return_value = httpx.Response( + status_code=200, + content=UserModel( + id=uuid4(), + first_name="test", + username="test", + role=UserRole.owner, + api_key="api.key", + inserted_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ).json(), + ) + + mock_httpx_client.patch.return_value = httpx.Response( + status_code=200, + content=FeedbackItemModel( + id=test_remote_record.id, + fields=test_remote_record.fields, + inserted_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ).json(), + ) + + test_remote_dataset.update_records(records=[test_remote_record] * 10) + + assert mock_httpx_client.patch.call_count == 10 + + def test_update_records_with_multiple_suggestions( + self, + mock_httpx_client: httpx.Client, + test_remote_dataset: RemoteFeedbackDataset, + test_remote_record: RemoteFeedbackRecord, + ) -> None: + """Test updating records.""" + + mock_httpx_client.get.return_value = httpx.Response( + status_code=200, + content=UserModel( + id=uuid4(), + first_name="test", + username="test", + role=UserRole.owner, + api_key="api.key", + inserted_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ).json(), + ) + + mock_httpx_client.patch.return_value = httpx.Response( + status_code=200, + content=FeedbackItemModel( + id=test_remote_record.id, + fields=test_remote_record.fields, + inserted_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ).json(), + ) + + expected_suggestion = FeedbackSuggestionModel( + id=uuid4(), + question_id=test_remote_dataset.question_by_name("text").id, + value="Test value", + score=0.5, + agent="test", + ) + mock_httpx_client.put.return_value = httpx.Response(status_code=200, content=expected_suggestion.json()) + + test_remote_record.suggestions = [ + SuggestionSchema(question_name="text", value="Test value", score=0.5, agent="test") + ] * 10 + + test_remote_dataset.update_records(records=[test_remote_record] * 10) + + # TODO: Reduce the number of call -> bulk endpoint at least for suggestions + assert mock_httpx_client.put.call_count == 100 + + def test_update_records_suggestions( + self, + mock_httpx_client: httpx.Client, + test_remote_dataset: RemoteFeedbackDataset, + test_remote_record: RemoteFeedbackRecord, + ) -> None: + mock_httpx_client.get.return_value = httpx.Response( + status_code=200, + content=UserModel( + id=uuid4(), + first_name="test", + username="test", + role=UserRole.owner, + api_key="api.key", + inserted_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ).json(), + ) + + expected_suggestion = FeedbackSuggestionModel( + id=uuid4(), + question_id=test_remote_dataset.question_by_name("text").id, + value="Test value", + score=0.5, + agent="test", + ) + mock_httpx_client.patch.return_value = httpx.Response( + status_code=200, + content=FeedbackItemModel( + id=test_remote_record.id, + fields=test_remote_record.fields, + inserted_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ).json(), + ) + mock_httpx_client.put.return_value = httpx.Response(status_code=200, content=expected_suggestion.json()) + + test_remote_record.suggestions = [ + SuggestionSchema(question_name="text", value="Test value", score=0.5, agent="test") + ] + + test_remote_dataset.update_records(records=test_remote_record) + + mock_httpx_client.patch.assert_called + mock_httpx_client.put.assert_called_with( + url=f"/api/v1/records/{test_remote_record.id}/suggestions", + # TODO: This should be a list of suggestions + json={ + "agent": expected_suggestion.agent, + "question_id": str(expected_suggestion.question_id), + "score": expected_suggestion.score, + "value": expected_suggestion.value, + }, + ) + + def test_update_records_suggestions_with_already_suggestion( + self, + mock_httpx_client: httpx.Client, + test_remote_dataset: RemoteFeedbackDataset, + test_remote_record: RemoteFeedbackRecord, + ) -> None: + # TODO: Implement + pass From 99e5d00fb6969c9f527b76cb8bdb3162a599df8b Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sun, 15 Oct 2023 01:44:32 +0200 Subject: [PATCH 21/26] tests: Integration tests for updating records --- .../feedback/dataset/remote/test_dataset.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/integration/client/feedback/dataset/remote/test_dataset.py b/tests/integration/client/feedback/dataset/remote/test_dataset.py index 09780dbe18..95df67a950 100644 --- a/tests/integration/client/feedback/dataset/remote/test_dataset.py +++ b/tests/integration/client/feedback/dataset/remote/test_dataset.py @@ -23,6 +23,7 @@ from argilla.client import api from argilla.client.feedback.dataset import FeedbackDataset from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset +from argilla.client.feedback.schemas import SuggestionSchema from argilla.client.feedback.schemas.fields import TextField from argilla.client.feedback.schemas.metadata import ( FloatMetadataProperty, @@ -119,6 +120,62 @@ async def test_add_records(self, owner: "User", test_dataset: FeedbackDataset, r assert len(remote_dataset.records) == 1 + async def test_update_records(self, owner: "User", test_dataset: FeedbackDataset): + import argilla as rg + + rg.init(api_key=owner.api_key) + ws = rg.Workspace.create(name="test-workspace") + + test_dataset.add_records( + [ + FeedbackRecord(fields={"text": "Hello world!"}), + FeedbackRecord(fields={"text": "Another record"}), + ] + ) + + remote = test_dataset.push_to_argilla(name="test_dataset", workspace=ws) + + first_record = remote[0] + first_record.external_id = "new-external-id" + first_record.metadata.update({"terms-metadata": "a"}) + + remote.update_records(first_record) + + assert first_record == remote[0] + + first_record = remote[0] + assert first_record.external_id == "new-external-id" + assert first_record.metadata["terms-metadata"] == "a" + + async def test_update_records_with_suggestions(self, owner: "User", test_dataset: FeedbackDataset): + import argilla as rg + + rg.init(api_key=owner.api_key) + ws = rg.Workspace.create(name="test-workspace") + + test_dataset.add_records( + [ + FeedbackRecord(fields={"text": "Hello world!"}), + FeedbackRecord(fields={"text": "Another record"}), + ] + ) + + remote = test_dataset.push_to_argilla(name="test_dataset", workspace=ws) + + records = [] + for record in remote: + record.suggestions = [ + SuggestionSchema(question_name="question", value=f"Hello world! for {record.fields['text']}") + ] + records.append(record) + + remote.update_records(records) + + for record in records: + for suggestion in record.suggestions: + assert suggestion.question_name == "question" + assert suggestion.value == f"Hello world! for {record.fields['text']}" + async def test_from_argilla(self, feedback_dataset: FeedbackDataset, owner: "User") -> None: api.init(api_key=owner.api_key) workspace = Workspace.create(name="unit-test") From db1f466966690b48cd162a4c8b04cbfe3c8c4ff8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 14 Oct 2023 23:51:58 +0000 Subject: [PATCH 22/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/argilla/client/feedback/dataset/base.py | 2 +- .../client/feedback/schemas/remote/records.py | 2 +- src/argilla/client/sdk/v1/records/api.py | 4 +++- .../client/feedback/dataset/test_dataset.py | 1 - tests/integration/client/sdk/v1/test_datasets.py | 2 +- .../client/feedback/dataset/remote/__init__.py | 13 +++++++++++++ .../feedback/dataset/remote/test_dataset.py | 15 ++++++++++++++- tests/unit/client/feedback/dataset/test_base.py | 2 +- 8 files changed, 34 insertions(+), 7 deletions(-) diff --git a/src/argilla/client/feedback/dataset/base.py b/src/argilla/client/feedback/dataset/base.py index 477c9e47de..740270ed54 100644 --- a/src/argilla/client/feedback/dataset/base.py +++ b/src/argilla/client/feedback/dataset/base.py @@ -14,7 +14,7 @@ import logging from abc import ABC, abstractmethod -from typing import Generic, Iterable, TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal, Optional, TypeVar, Union from pydantic import ValidationError diff --git a/src/argilla/client/feedback/schemas/remote/records.py b/src/argilla/client/feedback/schemas/remote/records.py index 2a3bbf3d74..1259604645 100644 --- a/src/argilla/client/feedback/schemas/remote/records.py +++ b/src/argilla/client/feedback/schemas/remote/records.py @@ -14,7 +14,7 @@ import warnings from datetime import datetime -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from uuid import UUID from pydantic import Field diff --git a/src/argilla/client/sdk/v1/records/api.py b/src/argilla/client/sdk/v1/records/api.py index 9d331d1fa8..6809c547e1 100644 --- a/src/argilla/client/sdk/v1/records/api.py +++ b/src/argilla/client/sdk/v1/records/api.py @@ -25,7 +25,9 @@ def update_record( # TODO: Use the proper sdk API Model instead of the dict - client: httpx.Client, id: Union[str, UUID], data: Dict[str, Any] + client: httpx.Client, + id: Union[str, UUID], + data: Dict[str, Any], ) -> Response[Union[FeedbackItemModel, ErrorMessage, HTTPValidationError]]: url = f"/api/v1/records/{id}" diff --git a/tests/integration/client/feedback/dataset/test_dataset.py b/tests/integration/client/feedback/dataset/test_dataset.py index f5946bd0df..d45e047680 100644 --- a/tests/integration/client/feedback/dataset/test_dataset.py +++ b/tests/integration/client/feedback/dataset/test_dataset.py @@ -478,7 +478,6 @@ async def test_update_dataset_records_in_argilla( ) - def test_push_to_huggingface_and_from_huggingface( mocked_client: "SecuredClient", monkeypatch: pytest.MonkeyPatch, diff --git a/tests/integration/client/sdk/v1/test_datasets.py b/tests/integration/client/sdk/v1/test_datasets.py index 4f6ac18b2c..52064a0345 100644 --- a/tests/integration/client/sdk/v1/test_datasets.py +++ b/tests/integration/client/sdk/v1/test_datasets.py @@ -40,7 +40,6 @@ list_datasets, publish_dataset, ) -from argilla.client.sdk.v1.records.api import set_suggestion from argilla.client.sdk.v1.datasets.models import ( FeedbackDatasetModel, FeedbackFieldModel, @@ -51,6 +50,7 @@ FeedbackRecordsModel, FeedbackSuggestionModel, ) +from argilla.client.sdk.v1.records.api import set_suggestion from argilla.server.models import DatasetStatus, User, UserRole from tests.factories import ( diff --git a/tests/unit/client/feedback/dataset/remote/__init__.py b/tests/unit/client/feedback/dataset/remote/__init__.py index e69de29bb2..55be41799b 100644 --- a/tests/unit/client/feedback/dataset/remote/__init__.py +++ b/tests/unit/client/feedback/dataset/remote/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/client/feedback/dataset/remote/test_dataset.py b/tests/unit/client/feedback/dataset/remote/test_dataset.py index 108277fb8c..ba69f63b47 100644 --- a/tests/unit/client/feedback/dataset/remote/test_dataset.py +++ b/tests/unit/client/feedback/dataset/remote/test_dataset.py @@ -1,9 +1,22 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from datetime import datetime from uuid import uuid4 import httpx import pytest - from argilla import Workspace from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset from argilla.client.feedback.schemas import SuggestionSchema diff --git a/tests/unit/client/feedback/dataset/test_base.py b/tests/unit/client/feedback/dataset/test_base.py index 2da96232b3..70c02da234 100644 --- a/tests/unit/client/feedback/dataset/test_base.py +++ b/tests/unit/client/feedback/dataset/test_base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Iterable, List, Optional, Union import pytest from argilla.client.feedback.dataset.base import FeedbackDatasetBase From f219d7b498d4f88d134db55e01a4562a5ef5f331 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sun, 15 Oct 2023 02:00:48 +0200 Subject: [PATCH 23/26] chore: Fix method signature --- .../client/feedback/schemas/remote/records.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/argilla/client/feedback/schemas/remote/records.py b/src/argilla/client/feedback/schemas/remote/records.py index 2a3bbf3d74..a0059f2448 100644 --- a/src/argilla/client/feedback/schemas/remote/records.py +++ b/src/argilla/client/feedback/schemas/remote/records.py @@ -231,7 +231,17 @@ def __update_suggestions(self, suggestions: List[AllowedSuggestionSchema]) -> No self.__dict__["suggestions"] = tuple(pushed_suggestions) @allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) - def update(self, suggestions: Optional[AllowedSuggestionSchema] = None) -> None: + def update( + self, + suggestions: Optional[ + Union[ + AllowedSuggestionSchema, + Dict[str, Any], + List[AllowedSuggestionSchema], + List[Dict[str, Any]], + ] + ] = None, + ) -> None: """Update a `RemoteFeedbackRecord`. Currently just `suggestions` are supported. Note that this method will update the record in Argilla directly. From 02b5c29b7bf1231933486cfa77f3daf994a6104a Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Sun, 15 Oct 2023 02:02:58 +0200 Subject: [PATCH 24/26] chore: Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 70acfc8a2e..97a9efa90d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ These are the section headers that we use: - Added fields `inserted_at` and `updated_at` in `RemoteResponseSchema` ([#3822](https://github.com/argilla-io/argilla/pull/3822)). - New `DELETE /api/v1/metadata-properties/:metadata_property_id` endpoint allowing the deletion of a specific metadata property. ([#3911](https://github.com/argilla-io/argilla/pull/3911)). - Add support for `sort_by` for Argilla feedback datasets ([#3925](https://github.com/argilla-io/argilla/pull/3925)) +- Add support for update records (`metadata` and `external_id`) from Python SDK ([#3946](https://github.com/argilla-io/argilla/pull/3946)). ### Changed From 2e74cd0841c3d6f31d1254d9c46bb88b5d750b20 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Mon, 16 Oct 2023 14:38:29 +0200 Subject: [PATCH 25/26] ci: Show file system description --- .github/workflows/run-python-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/run-python-tests.yml b/.github/workflows/run-python-tests.yml index f783274051..3538046789 100644 --- a/.github/workflows/run-python-tests.yml +++ b/.github/workflows/run-python-tests.yml @@ -82,6 +82,7 @@ jobs: env: ARGILLA_ENABLE_TELEMETRY: 0 run: | + df -h # ulimit to avoid segmentation fault ulimit -c unlimited pip install -e ".[server,listeners]" From c24c65ee7adc808144120d2a019b22d5e24e4273 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 16 Oct 2023 17:03:59 +0200 Subject: [PATCH 26/26] Apply suggestions from code review Co-authored-by: Alvaro Bartolome --- tests/unit/client/feedback/dataset/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/client/feedback/dataset/test_base.py b/tests/unit/client/feedback/dataset/test_base.py index 70c02da234..0aafbdb3e8 100644 --- a/tests/unit/client/feedback/dataset/test_base.py +++ b/tests/unit/client/feedback/dataset/test_base.py @@ -36,7 +36,7 @@ class TestFeedbackDataset(FeedbackDatasetBase): - def update_records(self, records: Union[FeedbackRecord, List[FeedbackRecord]]) -> None: + def update_records(self, **kwargs: Dict[str, Any]) -> None: pass def filter_by(