-
Notifications
You must be signed in to change notification settings - Fork 377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support for update records from SDK #3946
Changes from all commits
1ce3feb
86a3ea1
ce580db
83bac6f
b10ea26
bbd4236
bcdbe9c
1a464f3
875f88a
1daff38
8e24a36
b28c45f
c42b963
2a94c0e
ab77ecb
4b0ccbd
256f797
547b641
5d416d7
a85279a
2787be3
99e5d00
db1f466
f219d7b
04c3829
02b5c29
7537bd7
15d36df
2e74cd0
c24c65e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,6 +105,14 @@ def _client(self) -> "httpx.Client": | |
"""Returns the `httpx.Client` instance that will be used to send requests to Argilla.""" | ||
return self.dataset._client | ||
|
||
@property | ||
def _question_id_to_name(self) -> Dict["UUID", str]: | ||
return self.dataset._question_id_to_name_id | ||
|
||
@property | ||
def _question_name_to_id(self) -> Dict[str, "UUID"]: | ||
return self.dataset._question_name_to_id | ||
|
||
@allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) | ||
def __len__(self) -> int: | ||
"""Returns the number of records in the current `FeedbackDataset` in Argilla.""" | ||
|
@@ -240,7 +248,7 @@ def _create_from_dataset( | |
) | ||
|
||
|
||
class RemoteFeedbackDataset(FeedbackDatasetBase): | ||
class RemoteFeedbackDataset(FeedbackDatasetBase[RemoteFeedbackRecord]): | ||
# TODO: Call super method once the base init contains only commons init attributes | ||
def __init__( | ||
self, | ||
|
@@ -311,6 +319,14 @@ def records(self) -> RemoteFeedbackRecords: | |
""" | ||
return self._records | ||
|
||
def update_records(self, records: Union[RemoteFeedbackRecord, List[RemoteFeedbackRecord]]) -> None: | ||
if not isinstance(records, list): | ||
records = [records] | ||
|
||
# TODO: Use the batch version of endpoint once is implemented | ||
for record in records: | ||
record.update() | ||
|
||
Comment on lines
+322
to
+329
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which is the scenario where someone modifies a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once we have the batch version of records update, the updates should be done using the batch version, since it has a better performance than the per-record update. The |
||
@property | ||
def id(self) -> "UUID": | ||
"""Returns the ID of the dataset in Argilla.""" | ||
|
@@ -341,6 +357,14 @@ def updated_at(self) -> datetime: | |
"""Returns the datetime when the dataset was last updated in Argilla.""" | ||
return self._updated_at | ||
|
||
@property | ||
def _question_id_to_name_id(self) -> Dict["UUID", str]: | ||
return {question.id: question.name for question in self._questions} | ||
|
||
@property | ||
def _question_name_to_id(self) -> Dict[str, "UUID"]: | ||
return {question.name: question.id for question in self._questions} | ||
|
||
def __repr__(self) -> str: | ||
"""Returns a string representation of the dataset.""" | ||
return ( | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -88,6 +88,7 @@ def to_server_payload(self) -> Dict[str, Any]: | |||
"""Method that will be used to create the payload that will be sent to Argilla | ||||
to create a `ResponseSchema` for a `FeedbackRecord`.""" | ||||
return { | ||||
# UUID is not json serializable!!! | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Yes, it's not, for the moment we're checking the |
||||
"user_id": self.user_id, | ||||
Comment on lines
+91
to
92
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be, but a lot of tests must be changed to support this. I put the comment to don't forget and tackle in a separate PR |
||||
"values": {question_name: value.dict() for question_name, value in self.values.items()} | ||||
if self.values is not None | ||||
|
@@ -194,9 +195,7 @@ class FeedbackRecord(BaseModel): | |||
fields: Dict[str, Union[str, None]] | ||||
metadata: Dict[str, Any] = Field(default_factory=dict) | ||||
responses: List[ResponseSchema] = Field(default_factory=list) | ||||
suggestions: Union[Tuple[SuggestionSchema], List[SuggestionSchema]] = Field( | ||||
default_factory=tuple, allow_mutation=False | ||||
) | ||||
suggestions: Union[Tuple[SuggestionSchema], List[SuggestionSchema]] = Field(default_factory=tuple) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I already mentioned this, but the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, but we need to do this taking into account that tuples have been used in current releases. My change here just removes the |
||||
external_id: Optional[str] = None | ||||
|
||||
_unified_responses: Optional[Dict[str, List["UnifiedValueSchema"]]] = PrivateAttr(default_factory=dict) | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,6 @@ | |
from argilla.client.feedback.schemas.records import FeedbackRecord, ResponseSchema, SuggestionSchema | ||
from argilla.client.feedback.schemas.remote.shared import RemoteSchema | ||
from argilla.client.sdk.users.models import UserRole | ||
from argilla.client.sdk.v1.datasets import api as datasets_api_v1 | ||
from argilla.client.sdk.v1.records import api as records_api_v1 | ||
from argilla.client.sdk.v1.suggestions import api as suggestions_api_v1 | ||
from argilla.client.utils import allowed_for_roles | ||
|
@@ -31,7 +30,7 @@ | |
import httpx | ||
|
||
from argilla.client.sdk.v1.datasets.models import FeedbackResponseModel, FeedbackSuggestionModel | ||
from argilla.client.sdk.v1.records.models import FeedbackItemModel | ||
from argilla.client.sdk.v1.records.models import FeedbackRecordModel | ||
|
||
|
||
class RemoteSuggestionSchema(SuggestionSchema, RemoteSchema): | ||
|
@@ -95,12 +94,16 @@ def from_api(cls, payload: "FeedbackResponseModel") -> "RemoteResponseSchema": | |
return RemoteResponseSchema( | ||
user_id=payload.user_id, | ||
values=payload.values, | ||
# TODO: Review type mismatch between API and SDK | ||
status=payload.status, | ||
inserted_at=payload.inserted_at, | ||
updated_at=payload.updated_at, | ||
) | ||
|
||
|
||
AllowedSuggestionSchema = Union[RemoteSuggestionSchema, SuggestionSchema] | ||
|
||
|
||
class RemoteFeedbackRecord(FeedbackRecord, RemoteSchema): | ||
"""Schema for the records of a `RemoteFeedbackDataset`. | ||
|
||
|
@@ -117,37 +120,29 @@ class RemoteFeedbackRecord(FeedbackRecord, RemoteSchema): | |
question. Defaults to an empty list. | ||
""" | ||
|
||
# TODO: remote record should receive a dataset instead of this | ||
question_name_to_id: Optional[Dict[str, UUID]] = Field(..., exclude=True, repr=False) | ||
|
||
responses: List[RemoteResponseSchema] = Field(default_factory=list) | ||
suggestions: Union[Tuple[RemoteSuggestionSchema], List[RemoteSuggestionSchema]] = Field( | ||
default_factory=tuple, allow_mutation=False | ||
) | ||
suggestions: Union[Tuple[AllowedSuggestionSchema], List[AllowedSuggestionSchema]] = Field(default_factory=tuple) | ||
|
||
class Config: | ||
allow_mutation = True | ||
validate_assignment = True | ||
|
||
def __update_suggestions( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, the old |
||
def __normalize_suggestions_to_update( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should carefully review the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, but separate PR. I didn't change the internal logic, just separated the method in 2 different ones. |
||
self, | ||
suggestions: Union[ | ||
RemoteSuggestionSchema, | ||
List[RemoteSuggestionSchema], | ||
SuggestionSchema, | ||
List[SuggestionSchema], | ||
Dict[str, Any], | ||
List[Dict[str, Any]], | ||
Dict[str, Any], List[Dict[str, Any]], AllowedSuggestionSchema, List[AllowedSuggestionSchema] | ||
], | ||
) -> None: | ||
"""Updates the suggestions for the record in Argilla. Note that the suggestions | ||
must exist in Argilla to be updated. | ||
|
||
Note that this method will update the record in Argilla directly. | ||
) -> List[AllowedSuggestionSchema]: | ||
""" | ||
Normalizes the suggestions to update. | ||
|
||
Args: | ||
suggestions: can be a single `RemoteSuggestionSchema` or `SuggestionSchema`, | ||
a list of `RemoteSuggestionSchema` or `SuggestionSchema`, a single | ||
dictionary, or a list of dictionaries. If a dictionary is provided, | ||
it will be converted to a `RemoteSuggestionSchema` internally. | ||
suggestions: can be a single `RemoteSuggestionSchema` or `SuggestionSchema`, a dictionary, a list of | ||
`RemoteSuggestionSchema` or `SuggestionSchema`, or a list of dictionaries. If a dictionary is provided, | ||
it will be converted to a `SuggestionSchema` internally. | ||
""" | ||
if isinstance(suggestions, (dict, SuggestionSchema)): | ||
suggestions = [suggestions] | ||
|
@@ -203,33 +198,49 @@ def __update_suggestions( | |
else: | ||
new_suggestions[suggestion.question_name] = suggestion | ||
|
||
for suggestion in new_suggestions.values(): | ||
return list(new_suggestions.values()) | ||
|
||
def __update_suggestions(self, suggestions: List[AllowedSuggestionSchema]) -> None: | ||
"""Updates the suggestions for the record in Argilla. | ||
|
||
Note that this method will update the record in Argilla directly. | ||
|
||
Args: | ||
suggestions: can be a list of `RemoteSuggestionSchema` or `SuggestionSchema`. | ||
""" | ||
|
||
pushed_suggestions = [] | ||
|
||
for suggestion in suggestions: | ||
if isinstance(suggestion, RemoteSuggestionSchema): | ||
suggestion = suggestion.to_local() | ||
pushed_suggestion = datasets_api_v1.set_suggestion( | ||
# TODO: review the existence of bulk endpoint for record suggestions | ||
pushed_suggestion = records_api_v1.set_suggestion( | ||
client=self.client, | ||
record_id=self.id, | ||
**suggestion.to_server_payload(question_name_to_id=self.question_name_to_id), | ||
) | ||
existing_suggestions[suggestion.question_name] = RemoteSuggestionSchema.from_api( | ||
payload=pushed_suggestion.parsed, | ||
question_id_to_name={value: key for key, value in self.question_name_to_id.items()}, | ||
client=self.client, | ||
pushed_suggestions.append( | ||
RemoteSuggestionSchema.from_api( | ||
payload=pushed_suggestion.parsed, | ||
question_id_to_name={value: key for key, value in self.question_name_to_id.items()}, | ||
client=self.client, | ||
) | ||
) | ||
|
||
self.__dict__["suggestions"] = tuple(existing_suggestions.values()) | ||
self.__dict__["suggestions"] = tuple(pushed_suggestions) | ||
|
||
@allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) | ||
def update( | ||
self, | ||
suggestions: Union[ | ||
RemoteSuggestionSchema, | ||
List[RemoteSuggestionSchema], | ||
SuggestionSchema, | ||
List[SuggestionSchema], | ||
Dict[str, Any], | ||
List[Dict[str, Any]], | ||
], | ||
suggestions: Optional[ | ||
Union[ | ||
AllowedSuggestionSchema, | ||
Dict[str, Any], | ||
List[AllowedSuggestionSchema], | ||
List[Dict[str, Any]], | ||
] | ||
] = None, | ||
) -> None: | ||
"""Update a `RemoteFeedbackRecord`. Currently just `suggestions` are supported. | ||
|
||
|
@@ -244,7 +255,27 @@ def update( | |
Raises: | ||
PermissionError: if the user does not have either `owner` or `admin` role. | ||
""" | ||
self.__update_suggestions(suggestions=suggestions) | ||
if suggestions: | ||
suggestions = self.__normalize_suggestions_to_update(suggestions) | ||
else: | ||
suggestions = suggestions or [s for s in self.suggestions] | ||
|
||
self.__updated_record_data() | ||
if suggestions: | ||
self.__update_suggestions(suggestions=suggestions) | ||
|
||
def __updated_record_data(self) -> None: | ||
response = records_api_v1.update_record(self.client, self.id, self.to_server_payload()) | ||
|
||
updated_record = self.from_api( | ||
payload=response.parsed, | ||
question_id_to_name={value: key for key, value in self.question_name_to_id.items()} | ||
if self.question_name_to_id | ||
else None, | ||
client=self.client, | ||
) | ||
|
||
self.__dict__.update(updated_record.__dict__) | ||
|
||
@allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) | ||
def delete_suggestions(self, suggestions: Union[RemoteSuggestionSchema, List[RemoteSuggestionSchema]]) -> None: | ||
|
@@ -282,7 +313,8 @@ def delete_suggestions(self, suggestions: Union[RemoteSuggestionSchema, List[Rem | |
self.__dict__["suggestions"] = tuple(existing_suggestions.values()) | ||
except Exception as e: | ||
raise RuntimeError( | ||
f"Failed to delete suggestions with IDs `{[suggestion.id for suggestion in delete_suggestions]}` from record with ID `{self.id}` from Argilla." | ||
f"Failed to delete suggestions with IDs `{[suggestion.id for suggestion in delete_suggestions]}` from " | ||
f"record with ID `{self.id}` from Argilla." | ||
) from e | ||
|
||
@allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) | ||
|
@@ -316,7 +348,7 @@ def to_local(self) -> "FeedbackRecord": | |
@classmethod | ||
def from_api( | ||
cls, | ||
payload: "FeedbackItemModel", | ||
payload: "FeedbackRecordModel", | ||
question_id_to_name: Optional[Dict[UUID, str]] = None, | ||
client: Optional["httpx.Client"] = None, | ||
) -> "RemoteFeedbackRecord": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually need to wrap those properties under the same name? IMO we can just re-use those from
self.dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are several places using this variable. The idea should be to remove them and start using the dataset ones, but for this, we need to refactor some remote schemas first. I would like to keep this PR with minimal changes