Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support for update records from SDK #3946

Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1ce3feb
fix: Using utcnow datetime
frascuchon Oct 13, 2023
86a3ea1
tests: Remove date creation
frascuchon Oct 13, 2023
ce580db
Merge branch 'feature/support-for-metadata-filtering-and-sorting' int…
frascuchon Oct 14, 2023
83bac6f
feat: Define `update_records`for base feeedback dataset class
frascuchon Oct 14, 2023
b10ea26
refactor: Implement `update_records` method for local datasets
frascuchon Oct 14, 2023
bbd4236
feat: Implement `update_records` method based on `record.update`
frascuchon Oct 14, 2023
bcdbe9c
chore: Fix `ArgillaRecordsMixin` method signatures
frascuchon Oct 14, 2023
1a464f3
chore: Add some TODO reminders
frascuchon Oct 14, 2023
875f88a
feat: `record.update` support record level update
frascuchon Oct 14, 2023
1daff38
feat: call record update endpoint
frascuchon Oct 14, 2023
8e24a36
refactor: Support updating record suggestions through `update_records…
frascuchon Oct 14, 2023
b28c45f
tests: Adapt Test base dataset including missing abstract methods to …
frascuchon Oct 14, 2023
c42b963
tests: Add unit test for local.update_records workflow
frascuchon Oct 14, 2023
2a94c0e
chore: Move `set_suggestions` fuction to records.py API module
frascuchon Oct 14, 2023
ab77ecb
refactor: Control record suggestions updates from `update(suggestions…
frascuchon Oct 14, 2023
4b0ccbd
refactor: Define the workspace instance creation method private for b…
frascuchon Oct 14, 2023
256f797
fix: Indentation return
frascuchon Oct 14, 2023
547b641
tests: Remove raise check for suggestions immutability
frascuchon Oct 14, 2023
5d416d7
chore: Adapt imports
frascuchon Oct 14, 2023
a85279a
tests: fixture for mock httpx client
frascuchon Oct 14, 2023
2787be3
tests: Unit tests for update records with and without suggestions
frascuchon Oct 14, 2023
99e5d00
tests: Integration tests for updating records
frascuchon Oct 14, 2023
db1f466
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2023
f219d7b
chore: Fix method signature
frascuchon Oct 15, 2023
04c3829
Merge branch 'feat/support-for-update-records-from-SDK' of github.com…
frascuchon Oct 15, 2023
02b5c29
chore: Update changelog
frascuchon Oct 15, 2023
7537bd7
Merge branch 'feature/support-for-metadata-filtering-and-sorting' int…
frascuchon Oct 16, 2023
15d36df
Merge branch 'feature/support-for-metadata-filtering-and-sorting' int…
frascuchon Oct 16, 2023
2e74cd0
ci: Show file system description
frascuchon Oct 16, 2023
c24c65e
Apply suggestions from code review
frascuchon Oct 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/argilla/client/feedback/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal, Optional, TypeVar, Union

from pydantic import ValidationError

Expand Down Expand Up @@ -55,8 +55,10 @@

_LOGGER = logging.getLogger(__name__)

R = TypeVar("R", bound=FeedbackRecord)

class FeedbackDatasetBase(ABC, HuggingFaceDatasetMixin):

class FeedbackDatasetBase(ABC, HuggingFaceDatasetMixin, Generic[R]):
"""Base class with shared functionality for `FeedbackDataset` and `RemoteFeedbackDataset`."""

def __init__(
Expand Down Expand Up @@ -166,10 +168,22 @@ def __init__(

@property
@abstractmethod
def records(self) -> Any:
def records(self) -> Iterable[R]:
"""Returns the records of the dataset."""
pass

@abstractmethod
def update_records(self, records: Union[R, List[R]]) -> None:
"""Updates the records of the dataset.

Args:
records: the records to update the dataset with.

Raises:
ValueError: if the provided `records` are invalid.
"""
pass

@property
def guidelines(self) -> str:
"""Returns the guidelines for annotating the dataset."""
Expand Down Expand Up @@ -364,11 +378,12 @@ def _validate_records(self, records: List[FeedbackRecord]) -> None:

def _parse_and_validate_records(
self,
records: Union[FeedbackRecord, Dict[str, Any], List[Union[FeedbackRecord, Dict[str, Any]]]],
) -> List[FeedbackRecord]:
records: Union[R, Dict[str, Any], List[Union[R, Dict[str, Any]]]],
) -> List[R]:
"""Convenient method for calling `_parse_records` and `_validate_records` in sequence."""
records = self._parse_records(records)
self._validate_records(records)

return records

@requires_dependencies("datasets")
Expand Down
8 changes: 7 additions & 1 deletion src/argilla/client/feedback/dataset/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)


class FeedbackDataset(FeedbackDatasetBase, ArgillaMixin, UnificationMixin):
class FeedbackDataset(FeedbackDatasetBase["FeedbackRecord"], ArgillaMixin, UnificationMixin):
def __init__(
self,
*,
Expand Down Expand Up @@ -127,6 +127,12 @@ def records(self) -> List["FeedbackRecord"]:
"""Returns the records in the dataset."""
return self._records

def update_records(self, records: Union["FeedbackRecord", List["FeedbackRecord"]]) -> None:
warnings.warn(
"`update_records` method only works for `FeedbackDataset` pushed to Argilla. "
"If your are working with local data, you can just iterate over the records and update them."
)

def __repr__(self) -> str:
"""Returns a string representation of the dataset."""
return f"<FeedbackDataset fields={self.fields} questions={self.questions} guidelines={self.guidelines}>"
Expand Down
30 changes: 25 additions & 5 deletions src/argilla/client/feedback/dataset/remote/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ def __init__(
and/or attributes.
"""
self._dataset = dataset
# TODO: review why this is here !
self._question_id_to_name = {question.id: question.name for question in self._dataset.questions}
self._question_name_to_id = {value: key for key, value in self._question_id_to_name.items()}
# TODO END

if response_status and not isinstance(response_status, list):
response_status = [response_status]
Expand Down Expand Up @@ -106,6 +102,14 @@ def _client(self) -> "httpx.Client":
"""Returns the `httpx.Client` instance that will be used to send requests to Argilla."""
return self.dataset._client

@property
def _question_id_to_name(self) -> Dict["UUID", str]:
return self.dataset._question_id_to_name_id

@property
def _question_name_to_id(self) -> Dict[str, "UUID"]:
return self.dataset._question_name_to_id

Comment on lines +108 to +115
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need to wrap those properties under the same name? IMO we can just re-use those from self.dataset

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are several places using this variable. The idea should be to remove them and start using the dataset ones, but for this, we need to refactor some remote schemas first. I would like to keep this PR with minimal changes

@allowed_for_roles(roles=[UserRole.owner, UserRole.admin])
def __len__(self) -> int:
"""Returns the number of records in the current `FeedbackDataset` in Argilla."""
Expand Down Expand Up @@ -233,7 +237,7 @@ def _create_from_dataset(
)


class RemoteFeedbackDataset(FeedbackDatasetBase):
class RemoteFeedbackDataset(FeedbackDatasetBase[RemoteFeedbackRecord]):
# TODO: Call super method once the base init contains only commons init attributes
def __init__(
self,
Expand Down Expand Up @@ -304,6 +308,14 @@ def records(self) -> RemoteFeedbackRecords:
"""
return self._records

def update_records(self, records: Union[RemoteFeedbackRecord, List[RemoteFeedbackRecord]]) -> None:
if not isinstance(records, list):
records = [records]

# TODO: Use the batch version of endpoint once is implemented
for record in records:
record.update()

Comment on lines +322 to +329
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which is the scenario where someone modifies a RemoteFeedbackRecord and then pushes it if not via RemoteFeedbackRecord.update? Are we allowing the assignment there? e.g. record.metadata = {"a": 1}, if so, won't this be conflictive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we have the batch version of records update, the updates should be done using the batch version, since it has a better performance than the per-record update. The record.update is a way to support the current behaviour but i think it should be deprecated and removed

@property
def id(self) -> "UUID":
"""Returns the ID of the dataset in Argilla."""
Expand Down Expand Up @@ -334,6 +346,14 @@ def updated_at(self) -> datetime:
"""Returns the datetime when the dataset was last updated in Argilla."""
return self._updated_at

@property
def _question_id_to_name_id(self) -> Dict["UUID", str]:
return {question.id: question.name for question in self._questions}

@property
def _question_name_to_id(self) -> Dict[str, "UUID"]:
return {question.name: question.id for question in self._questions}

def __repr__(self) -> str:
"""Returns a string representation of the dataset."""
return (
Expand Down
6 changes: 3 additions & 3 deletions src/argilla/client/feedback/dataset/remote/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from argilla.client.utils import allowed_for_roles

if TYPE_CHECKING:
from argilla.client.feedback.dataset.remote.base import RemoteFeedbackRecordsBase
from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackRecords


class ArgillaRecordsMixin:
@allowed_for_roles(roles=[UserRole.owner, UserRole.admin])
def __getitem__(
self: "RemoteFeedbackRecordsBase", key: Union[slice, int]
self: "RemoteFeedbackRecords", key: Union[slice, int]
) -> Union["RemoteFeedbackRecord", List["RemoteFeedbackRecord"]]:
"""Returns the record(s) at the given index(es) from Argilla.

Expand Down Expand Up @@ -103,7 +103,7 @@ def __getitem__(

@allowed_for_roles(roles=[UserRole.owner, UserRole.admin])
def __iter__(
self: "RemoteFeedbackRecordsBase",
self: "RemoteFeedbackRecords",
) -> Iterator["RemoteFeedbackRecord"]:
"""Iterates over the `FeedbackRecord`s of the current `FeedbackDataset` in Argilla."""
current_batch = 0
Expand Down
5 changes: 2 additions & 3 deletions src/argilla/client/feedback/schemas/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def to_server_payload(self) -> Dict[str, Any]:
"""Method that will be used to create the payload that will be sent to Argilla
to create a `ResponseSchema` for a `FeedbackRecord`."""
return {
# UUID is not json serializable!!!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# UUID is not json serializable!!!

Yes, it's not, for the moment we're checking the user_id in the Python SDK in add_records, but we can review this later to just parse it as a str in the to_server_payload method

"user_id": self.user_id,
Comment on lines +91 to 92
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be str(self.user_id), right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be, but a lot of tests must be changed to support this. I put the comment to don't forget and tackle in a separate PR

"values": {question_name: value.dict() for question_name, value in self.values.items()}
if self.values is not None
Expand Down Expand Up @@ -194,9 +195,7 @@ class FeedbackRecord(BaseModel):
fields: Dict[str, Union[str, None]]
metadata: Dict[str, Any] = Field(default_factory=dict)
responses: List[ResponseSchema] = Field(default_factory=list)
suggestions: Union[Tuple[SuggestionSchema], List[SuggestionSchema]] = Field(
default_factory=tuple, allow_mutation=False
)
suggestions: Union[Tuple[SuggestionSchema], List[SuggestionSchema]] = Field(default_factory=tuple)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I already mentioned this, but the tuple may be confusing, I'm more comfortable with the list, in any case, responses is still a list, so we should align that at some point

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but we need to do this taking into account that tuples have been used in current releases. My change here just removes the allow_mutation=True. Other extra things should be tackled in separate PRs. Otherwise, a lot of changes could be included here without a need.

external_id: Optional[str] = None

_unified_responses: Optional[Dict[str, List["UnifiedValueSchema"]]] = PrivateAttr(default_factory=dict)
Expand Down
106 changes: 64 additions & 42 deletions src/argilla/client/feedback/schemas/remote/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from argilla.client.feedback.schemas.records import FeedbackRecord, ResponseSchema, SuggestionSchema
from argilla.client.feedback.schemas.remote.shared import RemoteSchema
from argilla.client.sdk.users.models import UserRole
from argilla.client.sdk.v1.datasets import api as datasets_api_v1
from argilla.client.sdk.v1.records import api as records_api_v1
from argilla.client.sdk.v1.suggestions import api as suggestions_api_v1
from argilla.client.utils import allowed_for_roles
Expand All @@ -31,7 +30,7 @@
import httpx

from argilla.client.sdk.v1.datasets.models import FeedbackResponseModel, FeedbackSuggestionModel
from argilla.client.sdk.v1.records.models import FeedbackItemModel
from argilla.client.sdk.v1.records.models import FeedbackRecordModel


class RemoteSuggestionSchema(SuggestionSchema, RemoteSchema):
Expand Down Expand Up @@ -95,12 +94,16 @@ def from_api(cls, payload: "FeedbackResponseModel") -> "RemoteResponseSchema":
return RemoteResponseSchema(
user_id=payload.user_id,
values=payload.values,
# TODO: Review type mismatch between API and SDK
status=payload.status,
inserted_at=payload.inserted_at,
updated_at=payload.updated_at,
)


AllowedSuggestionSchema = Union[RemoteSuggestionSchema, SuggestionSchema]


class RemoteFeedbackRecord(FeedbackRecord, RemoteSchema):
"""Schema for the records of a `RemoteFeedbackDataset`.

Expand All @@ -117,37 +120,29 @@ class RemoteFeedbackRecord(FeedbackRecord, RemoteSchema):
question. Defaults to an empty list.
"""

# TODO: remote record should receive a dataset instead of this
question_name_to_id: Optional[Dict[str, UUID]] = Field(..., exclude=True, repr=False)

responses: List[RemoteResponseSchema] = Field(default_factory=list)
suggestions: Union[Tuple[RemoteSuggestionSchema], List[RemoteSuggestionSchema]] = Field(
default_factory=tuple, allow_mutation=False
)
suggestions: Union[Tuple[AllowedSuggestionSchema], List[AllowedSuggestionSchema]] = Field(default_factory=tuple)

class Config:
allow_mutation = True
validate_assignment = True

def __update_suggestions(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, the old __update_suggestions method has been split into 2 steps: 1. validate and normalize/filter suggestions and 2. prepare and call the update endpoints.

def __normalize_suggestions_to_update(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should carefully review the suggestions update/addition workflow, because I think we're adding too much complexity here that can probably be simplified

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but separate PR. I didn't change the internal logic, just separated the method in 2 different ones.

self,
suggestions: Union[
RemoteSuggestionSchema,
List[RemoteSuggestionSchema],
SuggestionSchema,
List[SuggestionSchema],
Dict[str, Any],
List[Dict[str, Any]],
Dict[str, Any], List[Dict[str, Any]], AllowedSuggestionSchema, List[AllowedSuggestionSchema]
],
) -> None:
"""Updates the suggestions for the record in Argilla. Note that the suggestions
must exist in Argilla to be updated.

Note that this method will update the record in Argilla directly.
) -> List[AllowedSuggestionSchema]:
"""
Normalizes the suggestions to update.

Args:
suggestions: can be a single `RemoteSuggestionSchema` or `SuggestionSchema`,
a list of `RemoteSuggestionSchema` or `SuggestionSchema`, a single
dictionary, or a list of dictionaries. If a dictionary is provided,
it will be converted to a `RemoteSuggestionSchema` internally.
suggestions: can be a single `RemoteSuggestionSchema` or `SuggestionSchema`, a dictionary, a list of
`RemoteSuggestionSchema` or `SuggestionSchema`, or a list of dictionaries. If a dictionary is provided,
it will be converted to a `SuggestionSchema` internally.
"""
if isinstance(suggestions, (dict, SuggestionSchema)):
suggestions = [suggestions]
Expand Down Expand Up @@ -203,34 +198,40 @@ def __update_suggestions(
else:
new_suggestions[suggestion.question_name] = suggestion

for suggestion in new_suggestions.values():
return list(new_suggestions.values())

def __update_suggestions(self, suggestions: List[AllowedSuggestionSchema]) -> None:
"""Updates the suggestions for the record in Argilla.

Note that this method will update the record in Argilla directly.

Args:
suggestions: can be a list of `RemoteSuggestionSchema` or `SuggestionSchema`.
"""

pushed_suggestions = []

for suggestion in suggestions:
if isinstance(suggestion, RemoteSuggestionSchema):
suggestion = suggestion.to_local()
pushed_suggestion = datasets_api_v1.set_suggestion(
# TODO: review the existence of bulk endpoint for record suggestions
pushed_suggestion = records_api_v1.set_suggestion(
client=self.client,
record_id=self.id,
**suggestion.to_server_payload(question_name_to_id=self.question_name_to_id),
)
existing_suggestions[suggestion.question_name] = RemoteSuggestionSchema.from_api(
payload=pushed_suggestion.parsed,
question_id_to_name={value: key for key, value in self.question_name_to_id.items()},
client=self.client,
pushed_suggestions.append(
RemoteSuggestionSchema.from_api(
payload=pushed_suggestion.parsed,
question_id_to_name={value: key for key, value in self.question_name_to_id.items()},
client=self.client,
)
)

self.__dict__["suggestions"] = tuple(existing_suggestions.values())
self.__dict__["suggestions"] = tuple(pushed_suggestions)

@allowed_for_roles(roles=[UserRole.owner, UserRole.admin])
def update(
self,
suggestions: Union[
RemoteSuggestionSchema,
List[RemoteSuggestionSchema],
SuggestionSchema,
List[SuggestionSchema],
Dict[str, Any],
List[Dict[str, Any]],
],
) -> None:
def update(self, suggestions: Optional[AllowedSuggestionSchema] = None) -> None:
"""Update a `RemoteFeedbackRecord`. Currently just `suggestions` are supported.

Note that this method will update the record in Argilla directly.
Expand All @@ -244,7 +245,27 @@ def update(
Raises:
PermissionError: if the user does not have either `owner` or `admin` role.
"""
self.__update_suggestions(suggestions=suggestions)
if suggestions:
suggestions = self.__normalize_suggestions_to_update(suggestions)
else:
suggestions = suggestions or [s for s in self.suggestions]

self.__updated_record_data()
if suggestions:
self.__update_suggestions(suggestions=suggestions)

def __updated_record_data(self) -> None:
response = records_api_v1.update_record(self.client, self.id, self.to_server_payload())

updated_record = self.from_api(
payload=response.parsed,
question_id_to_name={value: key for key, value in self.question_name_to_id.items()}
if self.question_name_to_id
else None,
client=self.client,
)

self.__dict__.update(updated_record.__dict__)

@allowed_for_roles(roles=[UserRole.owner, UserRole.admin])
def delete_suggestions(self, suggestions: Union[RemoteSuggestionSchema, List[RemoteSuggestionSchema]]) -> None:
Expand Down Expand Up @@ -282,7 +303,8 @@ def delete_suggestions(self, suggestions: Union[RemoteSuggestionSchema, List[Rem
self.__dict__["suggestions"] = tuple(existing_suggestions.values())
except Exception as e:
raise RuntimeError(
f"Failed to delete suggestions with IDs `{[suggestion.id for suggestion in delete_suggestions]}` from record with ID `{self.id}` from Argilla."
f"Failed to delete suggestions with IDs `{[suggestion.id for suggestion in delete_suggestions]}` from "
f"record with ID `{self.id}` from Argilla."
) from e

@allowed_for_roles(roles=[UserRole.owner, UserRole.admin])
Expand Down Expand Up @@ -316,7 +338,7 @@ def to_local(self) -> "FeedbackRecord":
@classmethod
def from_api(
cls,
payload: "FeedbackItemModel",
payload: "FeedbackRecordModel",
question_id_to_name: Optional[Dict[UUID, str]] = None,
client: Optional["httpx.Client"] = None,
) -> "RemoteFeedbackRecord":
Expand Down
2 changes: 2 additions & 0 deletions src/argilla/client/feedback/schemas/remote/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@


class RemoteSchema(BaseModel, ABC):
# TODO(@alvarobartt): Review optional id configuration for remote schemas
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's optional for the moment, as we're not using those neither for the fields nor for the questions, but sure, we should make it mandatory once it's fully covered, as well as the client

id: Optional[UUID] = None
client: Optional[httpx.Client] = None

Expand All @@ -30,6 +31,7 @@ def _client(self) -> Optional[httpx.Client]:
return self.client

class Config:
# TODO(@alvarobart) Not sure if we need this at this level
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Config is inherited AFAIK, but we do need to arbitrary_types_allowed = True here because otherwise we cannot add the httpx.Client type-hint

allow_mutation = False
arbitrary_types_allowed = True

Expand Down
Loading