From 3df90a35e46e0aa5d8615b8806563d2fc404a6d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Thu, 11 Jan 2024 16:13:48 +0100 Subject: [PATCH 1/5] refactor: move Record related schemas to his own file and add set dataset_id attribute always to Record schema --- CHANGELOG.md | 1 + .../apis/v1/handlers/datasets/datasets.py | 4 +- .../apis/v1/handlers/datasets/records.py | 16 +- .../server/apis/v1/handlers/records.py | 12 +- src/argilla/server/contexts/datasets.py | 20 +- src/argilla/server/contexts/search.py | 4 +- src/argilla/server/schemas/v1/datasets.py | 232 +----------------- src/argilla/server/schemas/v1/questions.py | 45 +++- src/argilla/server/schemas/v1/records.py | 142 +++++++++-- src/argilla/server/schemas/v1/responses.py | 43 +++- .../datasets/test_search_dataset_records.py | 5 + tests/unit/server/api/v1/test_datasets.py | 23 +- tests/unit/server/api/v1/test_questions.py | 2 +- tests/unit/server/api/v1/test_records.py | 14 ++ 14 files changed, 288 insertions(+), 275 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c20275b34..07802cca10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ These are the section headers that we use: - Restore filters from feedback dataset settings ([#4461])(https://github.com/argilla-io/argilla/pull/4461) - Warning on feedback dataset settings when leaving page with unsaved changes ([#4461])(https://github.com/argilla-io/argilla/pull/4461) - Added pydantic v2 support using the python SDK ([#4459](https://github.com/argilla-io/argilla/pull/4459)) +- API v1 responses returning `Record` schema now always include `dataset_id` as attribute. ([]()) ## Changed diff --git a/src/argilla/server/apis/v1/handlers/datasets/datasets.py b/src/argilla/server/apis/v1/handlers/datasets/datasets.py index c76640dd55..9e587700f2 100644 --- a/src/argilla/server/apis/v1/handlers/datasets/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets/datasets.py @@ -27,6 +27,7 @@ from argilla.server.schemas.v1.datasets import ( Dataset, DatasetCreate, + DatasetMetrics, Datasets, DatasetUpdate, Field, @@ -35,7 +36,6 @@ MetadataProperties, MetadataProperty, MetadataPropertyCreate, - Metrics, Question, QuestionCreate, Questions, @@ -175,7 +175,7 @@ async def get_dataset( return dataset -@router.get("/me/datasets/{dataset_id}/metrics", response_model=Metrics) +@router.get("/me/datasets/{dataset_id}/metrics", response_model=DatasetMetrics) async def get_current_user_dataset_metrics( *, db: AsyncSession = Depends(get_async_db), dataset_id: UUID, current_user: User = Security(auth.get_current_user) ): diff --git a/src/argilla/server/apis/v1/handlers/datasets/records.py b/src/argilla/server/apis/v1/handlers/datasets/records.py index d3eed8d33f..4f42c4394b 100644 --- a/src/argilla/server/apis/v1/handlers/datasets/records.py +++ b/src/argilla/server/apis/v1/handlers/datasets/records.py @@ -37,12 +37,6 @@ MetadataQueryParams, Order, RangeFilter, - RecordFilterScope, - RecordIncludeParam, - Records, - RecordsCreate, - RecordsUpdate, - ResponseFilterScope, SearchRecord, SearchRecordsQuery, SearchRecordsResult, @@ -53,9 +47,17 @@ TermsFilter, VectorSettings, ) -from argilla.server.schemas.v1.datasets import ( +from argilla.server.schemas.v1.records import ( Record as RecordSchema, ) +from argilla.server.schemas.v1.records import ( + RecordFilterScope, + RecordIncludeParam, + Records, + RecordsCreate, + RecordsUpdate, +) +from argilla.server.schemas.v1.responses import ResponseFilterScope from argilla.server.search_engine import ( AndFilter, FloatMetadataFilter, diff --git a/src/argilla/server/apis/v1/handlers/records.py b/src/argilla/server/apis/v1/handlers/records.py index 3c78d55c42..2e6fc36065 100644 --- a/src/argilla/server/apis/v1/handlers/records.py +++ b/src/argilla/server/apis/v1/handlers/records.py @@ -21,18 +21,16 @@ from argilla.server.contexts import datasets from argilla.server.database import get_async_db -from argilla.server.models import User +from argilla.server.models import Record, User from argilla.server.policies import RecordPolicyV1, authorize -from argilla.server.schemas.v1.datasets import Record as RecordSchema -from argilla.server.schemas.v1.records import RecordUpdate, Response, ResponseCreate +from argilla.server.schemas.v1.records import Record as RecordSchema +from argilla.server.schemas.v1.records import RecordUpdate +from argilla.server.schemas.v1.responses import Response, ResponseCreate from argilla.server.schemas.v1.suggestions import Suggestion, SuggestionCreate, Suggestions from argilla.server.search_engine import SearchEngine, get_search_engine from argilla.server.security import auth from argilla.server.utils import parse_uuids -if TYPE_CHECKING: - from argilla.server.models import Record - DELETE_RECORD_SUGGESTIONS_LIMIT = 100 router = APIRouter(tags=["records"]) @@ -44,7 +42,7 @@ async def _get_record( with_dataset: bool = False, with_suggestions: bool = False, with_vectors: bool = False, -) -> "Record": +) -> Record: record = await datasets.get_record_by_id(db, record_id, with_dataset, with_suggestions, with_vectors) if not record: raise HTTPException( diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index 44aef78ac9..14cbb8f7eb 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -56,18 +56,25 @@ FieldCreate, MetadataPropertyCreate, QuestionCreate, +) +from argilla.server.schemas.v1.datasets import ( + VectorSettings as VectorSettingsSchema, +) +from argilla.server.schemas.v1.metadata_properties import MetadataPropertyUpdate +from argilla.server.schemas.v1.records import ( RecordCreate, RecordIncludeParam, RecordsCreate, + RecordsUpdate, RecordUpdateWithId, - ResponseValueCreate, ) -from argilla.server.schemas.v1.datasets import ( - VectorSettings as VectorSettingsSchema, +from argilla.server.schemas.v1.responses import ( + ResponseCreate, + ResponseUpdate, + ResponseUpsert, + ResponseValueCreate, + ResponseValueUpdate, ) -from argilla.server.schemas.v1.metadata_properties import MetadataPropertyUpdate -from argilla.server.schemas.v1.records import ResponseCreate -from argilla.server.schemas.v1.responses import ResponseUpdate, ResponseUpsert, ResponseValueUpdate from argilla.server.schemas.v1.vectors import Vector as VectorSchema from argilla.server.search_engine import SearchEngine from argilla.server.security.model import User @@ -77,7 +84,6 @@ from argilla.server.schemas.v1.datasets import ( DatasetUpdate, - RecordsUpdate, VectorSettingsCreate, ) from argilla.server.schemas.v1.fields import FieldUpdate diff --git a/src/argilla/server/contexts/search.py b/src/argilla/server/contexts/search.py index 3d647fa0b7..72cb8c8948 100644 --- a/src/argilla/server/contexts/search.py +++ b/src/argilla/server/contexts/search.py @@ -26,11 +26,11 @@ from argilla.server.schemas.v1.datasets import ( FilterScope, MetadataFilterScope, - RecordFilterScope, - ResponseFilterScope, SearchRecordsQuery, SuggestionFilterScope, ) +from argilla.server.schemas.v1.records import RecordFilterScope +from argilla.server.schemas.v1.responses import ResponseFilterScope class SearchRecordsQueryValidator: diff --git a/src/argilla/server/schemas/v1/datasets.py b/src/argilla/server/schemas/v1/datasets.py index 4b39a88fe7..7bad1da5f4 100644 --- a/src/argilla/server/schemas/v1/datasets.py +++ b/src/argilla/server/schemas/v1/datasets.py @@ -16,16 +16,16 @@ from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union from uuid import UUID -from fastapi import HTTPException, Query +from fastapi import Query -from argilla.server.enums import RecordInclude, RecordSortField, SimilarityOrder, SortOrder +from argilla.server.enums import SimilarityOrder, SortOrder from argilla.server.pydantic_v1 import BaseModel, PositiveInt, conlist, constr, root_validator, validator from argilla.server.pydantic_v1 import Field as PydanticField from argilla.server.pydantic_v1.generics import GenericModel -from argilla.server.pydantic_v1.utils import GetterDict from argilla.server.schemas.base import UpdateSchema -from argilla.server.schemas.v1.records import RecordUpdate -from argilla.server.schemas.v1.suggestions import Suggestion, SuggestionCreate +from argilla.server.schemas.v1.questions import QuestionDescription, QuestionName, QuestionTitle +from argilla.server.schemas.v1.records import Record, RecordFilterScope +from argilla.server.schemas.v1.responses import ResponseFilterScope from argilla.server.search_engine import TextQuery try: @@ -34,7 +34,7 @@ from typing_extensions import Annotated from argilla.server.enums import DatasetStatus, FieldType, MetadataPropertyType -from argilla.server.models import QuestionSettings, QuestionType, ResponseStatus +from argilla.server.models import QuestionSettings, QuestionType DATASET_NAME_REGEX = r"^(?!-|_)[a-zA-Z0-9-_ ]+$" DATASET_NAME_MIN_LENGTH = 1 @@ -48,14 +48,6 @@ FIELD_CREATE_TITLE_MIN_LENGTH = 1 FIELD_CREATE_TITLE_MAX_LENGTH = 500 -QUESTION_CREATE_NAME_REGEX = r"^(?=.*[a-z0-9])[a-z0-9_-]+$" -QUESTION_CREATE_NAME_MIN_LENGTH = 1 -QUESTION_CREATE_NAME_MAX_LENGTH = 200 -QUESTION_CREATE_TITLE_MIN_LENGTH = 1 -QUESTION_CREATE_TITLE_MAX_LENGTH = 500 -QUESTION_CREATE_DESCRIPTION_MIN_LENGTH = 1 -QUESTION_CREATE_DESCRIPTION_MAX_LENGTH = 1000 - METADATA_PROPERTY_CREATE_NAME_REGEX = r"^(?=.*[a-z0-9])[a-z0-9_-]+$" METADATA_PROPERTY_CREATE_NAME_MIN_LENGTH = 1 METADATA_PROPERTY_CREATE_NAME_MAX_LENGTH = 200 @@ -91,12 +83,6 @@ TERMS_METADATA_PROPERTY_VALUES_MIN_ITEMS = 1 TERMS_METADATA_PROPERTY_VALUES_MAX_ITEMS = 250 -RECORDS_CREATE_MIN_ITEMS = 1 -RECORDS_CREATE_MAX_ITEMS = 1000 - -RECORDS_UPDATE_MIN_ITEMS = 1 -RECORDS_UPDATE_MAX_ITEMS = 1000 - TERMS_FILTER_VALUES_MIN_ITEMS = 1 TERMS_FILTER_VALUES_MAX_ITEMS = 250 @@ -163,7 +149,7 @@ class ResponseMetrics(BaseModel): draft: int -class Metrics(BaseModel): +class DatasetMetrics(BaseModel): records: RecordMetrics responses: ResponseMetrics @@ -339,32 +325,6 @@ class Questions(BaseModel): items: List[Question] -QuestionName = Annotated[ - constr( - regex=QUESTION_CREATE_NAME_REGEX, - min_length=QUESTION_CREATE_NAME_MIN_LENGTH, - max_length=QUESTION_CREATE_NAME_MAX_LENGTH, - ), - PydanticField(..., description="The name of the question"), -] - -QuestionTitle = Annotated[ - constr( - min_length=QUESTION_CREATE_TITLE_MIN_LENGTH, - max_length=QUESTION_CREATE_TITLE_MAX_LENGTH, - ), - PydanticField(..., description="The title of the question"), -] - -QuestionDescription = Annotated[ - constr( - min_length=QUESTION_CREATE_DESCRIPTION_MIN_LENGTH, - max_length=QUESTION_CREATE_DESCRIPTION_MAX_LENGTH, - ), - PydanticField(..., description="The description of the question"), -] - - class QuestionCreate(BaseModel): name: QuestionName title: QuestionTitle @@ -415,173 +375,6 @@ class VectorSettingsCreate(BaseModel): dimensions: PositiveInt -class ResponseValue(BaseModel): - value: Any - - -class ResponseValueCreate(BaseModel): - value: Any - - -class Response(BaseModel): - id: UUID - values: Optional[Dict[str, ResponseValue]] - status: ResponseStatus - user_id: UUID - inserted_at: datetime - updated_at: datetime - - class Config: - orm_mode = True - - -class RecordGetterDict(GetterDict): - def get(self, key: str, default: Any) -> Any: - if key == "metadata": - return getattr(self._obj, "metadata_", None) - - if key == "responses" and not self._obj.is_relationship_loaded("responses"): - return default - - if key == "suggestions" and not self._obj.is_relationship_loaded("suggestions"): - return default - - if key == "vectors": - if self._obj.is_relationship_loaded("vectors"): - return {vector.vector_settings.name: vector.value for vector in self._obj.vectors} - else: - return default - - return super().get(key, default) - - -class Record(BaseModel): - id: UUID - fields: Dict[str, Any] - metadata: Optional[Dict[str, Any]] - external_id: Optional[str] - # TODO: move `responses` to `response` since contextualized endpoint will contains only the user response - # response: Optional[Response] - responses: Optional[List[Response]] - suggestions: Optional[List[Suggestion]] - vectors: Optional[Dict[str, List[float]]] - inserted_at: datetime - updated_at: datetime - - class Config: - orm_mode = True - getter_dict = RecordGetterDict - - -class Records(BaseModel): - items: List[Record] - # TODO(@frascuchon): Make it required once fetch records without metadata filter computes also the total - total: Optional[int] = None - - -class UserSubmittedResponseCreate(BaseModel): - user_id: UUID - values: Dict[str, ResponseValueCreate] - status: Literal[ResponseStatus.submitted] - - -class UserDiscardedResponseCreate(BaseModel): - user_id: UUID - values: Optional[Dict[str, ResponseValueCreate]] - status: Literal[ResponseStatus.discarded] - - -class UserDraftResponseCreate(BaseModel): - user_id: UUID - values: Dict[str, ResponseValueCreate] - status: Literal[ResponseStatus.draft] - - -UserResponseCreate = Annotated[ - Union[UserSubmittedResponseCreate, UserDraftResponseCreate, UserDiscardedResponseCreate], - PydanticField(discriminator="status"), -] - - -class RecordCreate(BaseModel): - fields: Dict[str, Any] - metadata: Optional[Dict[str, Any]] - external_id: Optional[str] - responses: Optional[List[UserResponseCreate]] - suggestions: Optional[List[SuggestionCreate]] - vectors: Optional[Dict[str, List[float]]] - - @validator("responses") - def check_user_id_is_unique(cls, values: Optional[List[UserResponseCreate]]) -> Optional[List[UserResponseCreate]]: - if values is None: - return values - - user_ids = [] - for value in values: - if value.user_id in user_ids: - raise ValueError(f"'responses' contains several responses for the same user_id={str(value.user_id)!r}") - user_ids.append(value.user_id) - - return values - - -class RecordsCreate(BaseModel): - items: conlist(item_type=RecordCreate, min_items=RECORDS_CREATE_MIN_ITEMS, max_items=RECORDS_CREATE_MAX_ITEMS) - - -class RecordUpdateWithId(RecordUpdate): - id: UUID - - -class RecordsUpdate(BaseModel): - # TODO: review this definition and align to create model - items: List[RecordUpdateWithId] = PydanticField( - ..., min_items=RECORDS_UPDATE_MIN_ITEMS, max_items=RECORDS_UPDATE_MAX_ITEMS - ) - - -class RecordIncludeParam(BaseModel): - relationships: Optional[List[RecordInclude]] = PydanticField(None, alias="keys") - vectors: Optional[List[str]] = PydanticField(None, alias="vectors") - - @root_validator(skip_on_failure=True) - def check(cls, values: Dict[str, Any]) -> Dict[str, Any]: - relationships = values.get("relationships") - if not relationships: - return values - - vectors = values.get("vectors") - if vectors is not None and len(vectors) > 0 and RecordInclude.vectors in relationships: - # TODO: once we have a exception handler for ValueError in v1, remove HTTPException - # raise ValueError("Cannot include both 'vectors' and 'relationships' in the same request") - raise HTTPException( - status_code=422, - detail="'include' query param cannot have both 'vectors' and 'vectors:vector_settings_name_1,vectors_settings_name_2,...'", - ) - - return values - - @property - def with_responses(self) -> bool: - return self._has_relationships and RecordInclude.responses in self.relationships - - @property - def with_suggestions(self) -> bool: - return self._has_relationships and RecordInclude.suggestions in self.relationships - - @property - def with_all_vectors(self) -> bool: - return self._has_relationships and not self.vectors and RecordInclude.vectors in self.relationships - - @property - def with_some_vector(self) -> bool: - return self.vectors is not None and len(self.vectors) > 0 - - @property - def _has_relationships(self): - return self.relationships is not None - - NT = TypeVar("NT", int, float) @@ -723,17 +516,6 @@ class Query(BaseModel): vector: Optional[VectorQuery] = None -class RecordFilterScope(BaseModel): - entity: Literal["record"] - property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at]] - - -class ResponseFilterScope(BaseModel): - entity: Literal["response"] - question: Optional[QuestionName] - property: Optional[Literal["status"]] - - class SuggestionFilterScope(BaseModel): entity: Literal["suggestion"] question: QuestionName diff --git a/src/argilla/server/schemas/v1/questions.py b/src/argilla/server/schemas/v1/questions.py index aea7cea443..0e6d7312f8 100644 --- a/src/argilla/server/schemas/v1/questions.py +++ b/src/argilla/server/schemas/v1/questions.py @@ -13,12 +13,13 @@ # limitations under the License. from datetime import datetime -from typing import Literal, Optional, Union +from typing import Annotated, Literal, Optional, Union from uuid import UUID -from argilla.server.pydantic_v1 import BaseModel, Field, PositiveInt, conlist +from typing_extensions import Annotated + +from argilla.server.pydantic_v1 import BaseModel, Field, PositiveInt, conlist, constr from argilla.server.schemas.base import UpdateSchema -from argilla.server.schemas.v1.datasets import QuestionDescription, QuestionTitle try: from typing import Annotated @@ -27,6 +28,16 @@ from argilla.server.models import QuestionType +QUESTION_CREATE_NAME_REGEX = r"^(?=.*[a-z0-9])[a-z0-9_-]+$" +QUESTION_CREATE_NAME_MIN_LENGTH = 1 +QUESTION_CREATE_NAME_MAX_LENGTH = 200 + +QUESTION_CREATE_TITLE_MIN_LENGTH = 1 +QUESTION_CREATE_TITLE_MAX_LENGTH = 500 + +QUESTION_CREATE_DESCRIPTION_MIN_LENGTH = 1 +QUESTION_CREATE_DESCRIPTION_MAX_LENGTH = 1000 + class TextQuestionSettings(BaseModel): type: Literal[QuestionType.text] @@ -126,6 +137,34 @@ class RankingQuestionSettingsUpdate(UpdateSchema): ] +QuestionName = Annotated[ + constr( + regex=QUESTION_CREATE_NAME_REGEX, + min_length=QUESTION_CREATE_NAME_MIN_LENGTH, + max_length=QUESTION_CREATE_NAME_MAX_LENGTH, + ), + Field(..., description="The name of the question"), +] + + +QuestionTitle = Annotated[ + constr( + min_length=QUESTION_CREATE_TITLE_MIN_LENGTH, + max_length=QUESTION_CREATE_TITLE_MAX_LENGTH, + ), + Field(..., description="The title of the question"), +] + + +QuestionDescription = Annotated[ + constr( + min_length=QUESTION_CREATE_DESCRIPTION_MIN_LENGTH, + max_length=QUESTION_CREATE_DESCRIPTION_MAX_LENGTH, + ), + Field(..., description="The description of the question"), +] + + class QuestionUpdate(UpdateSchema): title: Optional[QuestionTitle] description: Optional[QuestionDescription] diff --git a/src/argilla/server/schemas/v1/records.py b/src/argilla/server/schemas/v1/records.py index d0fc87937b..fbf3676d27 100644 --- a/src/argilla/server/schemas/v1/records.py +++ b/src/argilla/server/schemas/v1/records.py @@ -13,41 +13,153 @@ # limitations under the License. from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional, Union from uuid import UUID -from argilla.server.models import ResponseStatus -from argilla.server.pydantic_v1 import BaseModel, Field +from fastapi import HTTPException + +from argilla.server.enums import RecordInclude, RecordSortField +from argilla.server.pydantic_v1 import BaseModel, Field, root_validator, validator +from argilla.server.pydantic_v1.utils import GetterDict from argilla.server.schemas.base import UpdateSchema -from argilla.server.schemas.v1.suggestions import SuggestionCreate +from argilla.server.schemas.v1.responses import Response, UserResponseCreate +from argilla.server.schemas.v1.suggestions import Suggestion, SuggestionCreate + +RECORDS_CREATE_MIN_ITEMS = 1 +RECORDS_CREATE_MAX_ITEMS = 1000 + +RECORDS_UPDATE_MIN_ITEMS = 1 +RECORDS_UPDATE_MAX_ITEMS = 1000 + +class RecordGetterDict(GetterDict): + def get(self, key: str, default: Any) -> Any: + if key == "metadata": + return getattr(self._obj, "metadata_", None) -class ResponseValue(BaseModel): - value: Any + if key == "responses" and not self._obj.is_relationship_loaded("responses"): + return default + if key == "suggestions" and not self._obj.is_relationship_loaded("suggestions"): + return default -class ResponseValueCreate(BaseModel): - value: Any + if key == "vectors": + if self._obj.is_relationship_loaded("vectors"): + return {vector.vector_settings.name: vector.value for vector in self._obj.vectors} + else: + return default + return super().get(key, default) -class Response(BaseModel): + +class Record(BaseModel): id: UUID - values: Optional[Dict[str, ResponseValue]] - status: ResponseStatus - user_id: UUID + fields: Dict[str, Any] + metadata: Optional[Dict[str, Any]] + external_id: Optional[str] + # TODO: move `responses` to `response` since contextualized endpoint will contains only the user response + # response: Optional[Response] + responses: Optional[List[Response]] + suggestions: Optional[List[Suggestion]] + vectors: Optional[Dict[str, List[float]]] + dataset_id: UUID inserted_at: datetime updated_at: datetime class Config: orm_mode = True + getter_dict = RecordGetterDict + + +class RecordCreate(BaseModel): + fields: Dict[str, Any] + metadata: Optional[Dict[str, Any]] + external_id: Optional[str] + responses: Optional[List[UserResponseCreate]] + suggestions: Optional[List[SuggestionCreate]] + vectors: Optional[Dict[str, List[float]]] + + @validator("responses") + def check_user_id_is_unique(cls, values: Optional[List[UserResponseCreate]]) -> Optional[List[UserResponseCreate]]: + if values is None: + return values + user_ids = [] + for value in values: + if value.user_id in user_ids: + raise ValueError(f"'responses' contains several responses for the same user_id={str(value.user_id)!r}") + user_ids.append(value.user_id) -class ResponseCreate(BaseModel): - values: Optional[Dict[str, ResponseValueCreate]] - status: ResponseStatus + return values class RecordUpdate(UpdateSchema): metadata_: Optional[Dict[str, Any]] = Field(None, alias="metadata") suggestions: Optional[List[SuggestionCreate]] = None vectors: Optional[Dict[str, List[float]]] + + +class RecordUpdateWithId(RecordUpdate): + id: UUID + + +class RecordIncludeParam(BaseModel): + relationships: Optional[List[RecordInclude]] = Field(None, alias="keys") + vectors: Optional[List[str]] = Field(None, alias="vectors") + + @root_validator(skip_on_failure=True) + def check(cls, values: Dict[str, Any]) -> Dict[str, Any]: + relationships = values.get("relationships") + if not relationships: + return values + + vectors = values.get("vectors") + if vectors is not None and len(vectors) > 0 and RecordInclude.vectors in relationships: + # TODO: once we have a exception handler for ValueError in v1, remove HTTPException + # raise ValueError("Cannot include both 'vectors' and 'relationships' in the same request") + raise HTTPException( + status_code=422, + detail="'include' query param cannot have both 'vectors' and 'vectors:vector_settings_name_1,vectors_settings_name_2,...'", + ) + + return values + + @property + def with_responses(self) -> bool: + return self._has_relationships and RecordInclude.responses in self.relationships + + @property + def with_suggestions(self) -> bool: + return self._has_relationships and RecordInclude.suggestions in self.relationships + + @property + def with_all_vectors(self) -> bool: + return self._has_relationships and not self.vectors and RecordInclude.vectors in self.relationships + + @property + def with_some_vector(self) -> bool: + return self.vectors is not None and len(self.vectors) > 0 + + @property + def _has_relationships(self): + return self.relationships is not None + + +class RecordFilterScope(BaseModel): + entity: Literal["record"] + property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at]] + + +class Records(BaseModel): + items: List[Record] + # TODO(@frascuchon): Make it required once fetch records without metadata filter computes also the total + total: Optional[int] = None + + +class RecordsCreate(BaseModel): + items: List[RecordCreate] = Field(..., min_items=RECORDS_CREATE_MIN_ITEMS, max_items=RECORDS_CREATE_MAX_ITEMS) + + +class RecordsUpdate(BaseModel): + # TODO: review this definition and align to create model + items: List[RecordUpdateWithId] = Field(..., min_items=RECORDS_UPDATE_MIN_ITEMS, max_items=RECORDS_UPDATE_MAX_ITEMS) diff --git a/src/argilla/server/schemas/v1/responses.py b/src/argilla/server/schemas/v1/responses.py index 214f39f41a..b64ca46074 100644 --- a/src/argilla/server/schemas/v1/responses.py +++ b/src/argilla/server/schemas/v1/responses.py @@ -13,13 +13,15 @@ # limitations under the License. from datetime import datetime -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Annotated, Any, Dict, List, Literal, Optional, Union from uuid import UUID from fastapi import Body +from typing_extensions import Annotated from argilla.server.models import ResponseStatus from argilla.server.pydantic_v1 import BaseModel, Field +from argilla.server.schemas.v1.questions import QuestionName try: from typing import Annotated @@ -34,6 +36,10 @@ class ResponseValue(BaseModel): value: Any +class ResponseValueCreate(BaseModel): + value: Any + + class ResponseValueUpdate(BaseModel): value: Any @@ -51,6 +57,17 @@ class Config: orm_mode = True +class ResponseCreate(BaseModel): + values: Optional[Dict[str, ResponseValueCreate]] + status: ResponseStatus + + +class ResponseFilterScope(BaseModel): + entity: Literal["response"] + question: Optional[QuestionName] + property: Optional[Literal["status"]] + + class SubmittedResponseUpdate(BaseModel): values: Dict[str, ResponseValueUpdate] status: Literal[ResponseStatus.submitted] @@ -115,3 +132,27 @@ class ResponseBulk(BaseModel): class ResponsesBulk(BaseModel): items: List[ResponseBulk] + + +class UserDraftResponseCreate(BaseModel): + user_id: UUID + values: Dict[str, ResponseValueCreate] + status: Literal[ResponseStatus.draft] + + +class UserDiscardedResponseCreate(BaseModel): + user_id: UUID + values: Optional[Dict[str, ResponseValueCreate]] + status: Literal[ResponseStatus.discarded] + + +class UserSubmittedResponseCreate(BaseModel): + user_id: UUID + values: Dict[str, ResponseValueCreate] + status: Literal[ResponseStatus.submitted] + + +UserResponseCreate = Annotated[ + Union[UserSubmittedResponseCreate, UserDraftResponseCreate, UserDiscardedResponseCreate], + Field(discriminator="status"), +] diff --git a/tests/unit/server/api/v1/datasets/test_search_dataset_records.py b/tests/unit/server/api/v1/datasets/test_search_dataset_records.py index 8efb72708e..e9a98ba215 100644 --- a/tests/unit/server/api/v1/datasets/test_search_dataset_records.py +++ b/tests/unit/server/api/v1/datasets/test_search_dataset_records.py @@ -128,6 +128,7 @@ async def test_with_include_responses( "id": str(response_a.id), "status": "submitted", "values": {"input_ok": {"value": "yes"}}, + "record_id": str(record_a.id), "user_id": str(response_a.user_id), "inserted_at": response_a.inserted_at.isoformat(), "updated_at": response_a.updated_at.isoformat(), @@ -136,12 +137,14 @@ async def test_with_include_responses( "id": str(response_b.id), "status": "submitted", "values": {"input_ok": {"value": "no"}}, + "record_id": str(record_a.id), "user_id": str(response_b.user_id), "inserted_at": response_b.inserted_at.isoformat(), "updated_at": response_b.updated_at.isoformat(), }, ], "external_id": record_a.external_id, + "dataset_id": str(record_a.dataset_id), "inserted_at": record_a.inserted_at.isoformat(), "updated_at": record_a.updated_at.isoformat(), }, @@ -160,12 +163,14 @@ async def test_with_include_responses( "id": str(response_c.id), "status": "submitted", "values": {"input_ok": {"value": "yes"}}, + "record_id": str(record_b.id), "user_id": str(response_c.user_id), "inserted_at": response_c.inserted_at.isoformat(), "updated_at": response_c.updated_at.isoformat(), }, ], "external_id": record_b.external_id, + "dataset_id": str(record_b.dataset_id), "inserted_at": record_b.inserted_at.isoformat(), "updated_at": record_b.updated_at.isoformat(), }, diff --git a/tests/unit/server/api/v1/test_datasets.py b/tests/unit/server/api/v1/test_datasets.py index c3bf220040..d0dd1b8cc5 100644 --- a/tests/unit/server/api/v1/test_datasets.py +++ b/tests/unit/server/api/v1/test_datasets.py @@ -47,14 +47,9 @@ FIELD_CREATE_TITLE_MAX_LENGTH, METADATA_PROPERTY_CREATE_NAME_MAX_LENGTH, METADATA_PROPERTY_CREATE_TITLE_MAX_LENGTH, - QUESTION_CREATE_DESCRIPTION_MAX_LENGTH, - QUESTION_CREATE_NAME_MAX_LENGTH, - QUESTION_CREATE_TITLE_MAX_LENGTH, RANKING_OPTIONS_MAX_ITEMS, RATING_OPTIONS_MAX_ITEMS, RATING_OPTIONS_MIN_ITEMS, - RECORDS_CREATE_MAX_ITEMS, - RECORDS_CREATE_MIN_ITEMS, TERMS_METADATA_PROPERTY_VALUES_MAX_ITEMS, VALUE_TEXT_OPTION_DESCRIPTION_MAX_LENGTH, VALUE_TEXT_OPTION_TEXT_MAX_LENGTH, @@ -62,6 +57,12 @@ VECTOR_SETTINGS_CREATE_NAME_MAX_LENGTH, VECTOR_SETTINGS_CREATE_TITLE_MAX_LENGTH, ) +from argilla.server.schemas.v1.questions import ( + QUESTION_CREATE_DESCRIPTION_MAX_LENGTH, + QUESTION_CREATE_NAME_MAX_LENGTH, + QUESTION_CREATE_TITLE_MAX_LENGTH, +) +from argilla.server.schemas.v1.records import RECORDS_CREATE_MAX_ITEMS, RECORDS_CREATE_MIN_ITEMS from argilla.server.search_engine import ( FloatMetadataFilter, IntegerMetadataFilter, @@ -3982,6 +3983,7 @@ async def test_search_current_user_dataset_records( "fields": {"input": "input_a", "output": "output_a"}, "metadata": None, "external_id": records[0].external_id, + "dataset_id": str(records[0].dataset_id), "inserted_at": records[0].inserted_at.isoformat(), "updated_at": records[0].updated_at.isoformat(), }, @@ -3993,6 +3995,7 @@ async def test_search_current_user_dataset_records( "fields": {"input": "input_b", "output": "output_b"}, "metadata": {"unit": "test"}, "external_id": records[1].external_id, + "dataset_id": str(records[1].dataset_id), "inserted_at": records[1].inserted_at.isoformat(), "updated_at": records[1].updated_at.isoformat(), }, @@ -4335,6 +4338,7 @@ async def test_search_current_user_dataset_records_with_include( }, "metadata": None, "external_id": records[0].external_id, + "dataset_id": str(records[0].dataset_id), "inserted_at": records[0].inserted_at.isoformat(), "updated_at": records[0].updated_at.isoformat(), }, @@ -4349,6 +4353,7 @@ async def test_search_current_user_dataset_records_with_include( }, "metadata": {"unit": "test"}, "external_id": records[1].external_id, + "dataset_id": str(records[1].dataset_id), "inserted_at": records[1].inserted_at.isoformat(), "updated_at": records[1].updated_at.isoformat(), }, @@ -4367,6 +4372,7 @@ async def test_search_current_user_dataset_records_with_include( "id": str(first_owner_response.id), "values": None, "status": "discarded", + "record_id": str(records[0].id), "user_id": str(owner.id), "inserted_at": first_owner_response.inserted_at.isoformat(), "updated_at": first_owner_response.updated_at.isoformat(), @@ -4380,6 +4386,7 @@ async def test_search_current_user_dataset_records_with_include( "output_ok": {"value": "no"}, }, "status": "submitted", + "record_id": str(records[1].id), "user_id": str(owner.id), "inserted_at": second_owner_response.inserted_at.isoformat(), "updated_at": second_owner_response.updated_at.isoformat(), @@ -4481,6 +4488,7 @@ async def test_search_current_user_dataset_records_with_include_vectors( "vector-a": [1.0, 2.0, 3.0], "vector-b": [4.0, 5.0], }, + "dataset_id": str(record_a.dataset_id), "inserted_at": record_a.inserted_at.isoformat(), "updated_at": record_a.updated_at.isoformat(), }, @@ -4495,6 +4503,7 @@ async def test_search_current_user_dataset_records_with_include_vectors( "vectors": { "vector-b": [1.0, 2.0], }, + "dataset_id": str(record_b.dataset_id), "inserted_at": record_b.inserted_at.isoformat(), "updated_at": record_b.updated_at.isoformat(), }, @@ -4507,6 +4516,7 @@ async def test_search_current_user_dataset_records_with_include_vectors( "metadata": None, "external_id": record_c.external_id, "vectors": {}, + "dataset_id": str(record_c.dataset_id), "inserted_at": record_c.inserted_at.isoformat(), "updated_at": record_c.updated_at.isoformat(), }, @@ -4572,6 +4582,7 @@ async def test_search_current_user_dataset_records_with_include_specific_vectors "vector-a": [1.0, 2.0, 3.0], "vector-b": [4.0, 5.0], }, + "dataset_id": str(record_a.dataset_id), "inserted_at": record_a.inserted_at.isoformat(), "updated_at": record_a.updated_at.isoformat(), }, @@ -4586,6 +4597,7 @@ async def test_search_current_user_dataset_records_with_include_specific_vectors "vectors": { "vector-b": [1.0, 2.0], }, + "dataset_id": str(record_b.dataset_id), "inserted_at": record_b.inserted_at.isoformat(), "updated_at": record_b.updated_at.isoformat(), }, @@ -4598,6 +4610,7 @@ async def test_search_current_user_dataset_records_with_include_specific_vectors "metadata": None, "external_id": record_c.external_id, "vectors": {}, + "dataset_id": str(record_c.dataset_id), "inserted_at": record_c.inserted_at.isoformat(), "updated_at": record_c.updated_at.isoformat(), }, diff --git a/tests/unit/server/api/v1/test_questions.py b/tests/unit/server/api/v1/test_questions.py index e7ba614c0f..5bbc788005 100644 --- a/tests/unit/server/api/v1/test_questions.py +++ b/tests/unit/server/api/v1/test_questions.py @@ -18,7 +18,7 @@ import pytest from argilla._constants import API_KEY_HEADER_NAME from argilla.server.models import DatasetStatus, Question, UserRole -from argilla.server.schemas.v1.datasets import QUESTION_CREATE_DESCRIPTION_MAX_LENGTH, QUESTION_CREATE_TITLE_MAX_LENGTH +from argilla.server.schemas.v1.questions import QUESTION_CREATE_DESCRIPTION_MAX_LENGTH, QUESTION_CREATE_TITLE_MAX_LENGTH from sqlalchemy import func, select from tests.factories import ( diff --git a/tests/unit/server/api/v1/test_records.py b/tests/unit/server/api/v1/test_records.py index 2a031afbcc..cb681f6f45 100644 --- a/tests/unit/server/api/v1/test_records.py +++ b/tests/unit/server/api/v1/test_records.py @@ -98,6 +98,7 @@ async def test_get_record(self, async_client: "AsyncClient", role: UserRole): "responses": None, "suggestions": [], "vectors": None, + "dataset_id": str(dataset.id), "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), } @@ -213,6 +214,7 @@ async def test_update_record(self, async_client: "AsyncClient", mock_search_engi vector_settings_1.name: [2.0, 2.0, 2.0, 2.0, 2.0], vector_settings_2.name: [3.0, 3.0, 3.0, 3.0, 3.0], }, + "dataset_id": str(dataset.id), "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), } @@ -245,6 +247,7 @@ async def test_update_record_with_null_metadata( "responses": [], "suggestions": [], "vectors": {}, + "dataset_id": str(dataset.id), "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), } @@ -271,6 +274,7 @@ async def test_update_record_with_no_metadata( "responses": None, "suggestions": [], "vectors": {}, + "dataset_id": str(dataset.id), "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), } @@ -304,6 +308,7 @@ async def test_update_record_with_list_terms_metadata( "responses": [], "suggestions": [], "vectors": {}, + "dataset_id": str(dataset.id), "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), } @@ -330,6 +335,7 @@ async def test_update_record_with_no_suggestions( "responses": None, "suggestions": [], "vectors": {}, + "dataset_id": str(record.dataset_id), "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), } @@ -567,6 +573,7 @@ async def test_create_record_response_with_required_questions( "id": str(UUID(response_body["id"])), "values": responses["values"], "status": response_status, + "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_body["inserted_at"]).isoformat(), "updated_at": datetime.fromisoformat(response_body["updated_at"]).isoformat(), @@ -622,6 +629,7 @@ async def test_create_record_response_with_missing_required_questions( "id": str(UUID(response_body["id"])), "values": responses["values"], "status": response_status, + "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_body["inserted_at"]).isoformat(), "updated_at": datetime.fromisoformat(response_body["updated_at"]).isoformat(), @@ -672,6 +680,7 @@ async def test_create_record_response_with_submitted_status( "id": str(UUID(response_body["id"])), "values": {question.name: {"value": response_value}}, "status": ResponseStatus.submitted, + "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_body["inserted_at"]).isoformat(), "updated_at": datetime.fromisoformat(response_body["updated_at"]).isoformat(), @@ -727,6 +736,7 @@ async def test_create_record_response_with_non_submitted_status( "id": str(UUID(response_body["id"])), "values": {question.name: {"value": response_value}}, "status": response_status.value, + "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_body["inserted_at"]).isoformat(), "updated_at": datetime.fromisoformat(response_body["updated_at"]).isoformat(), @@ -982,6 +992,7 @@ async def test_create_record_response( "output_ok": {"value": "yes"}, }, "status": status, + "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_body["inserted_at"]).isoformat(), "updated_at": datetime.fromisoformat(response_body["updated_at"]).isoformat(), @@ -1018,6 +1029,7 @@ async def test_create_record_response_without_values( "id": str(UUID(response_body["id"])), "values": None, "status": status, + "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_body["inserted_at"]).isoformat(), "updated_at": datetime.fromisoformat(response_body["updated_at"]).isoformat(), @@ -1069,6 +1081,7 @@ async def test_create_record_response_for_user_role(self, async_client: "AsyncCl "output_ok": {"value": "yes"}, }, "status": "submitted", + "record_id": str(record.id), "user_id": str(user.id), "inserted_at": datetime.fromisoformat(response_body["inserted_at"]).isoformat(), "updated_at": datetime.fromisoformat(response_body["updated_at"]).isoformat(), @@ -1342,6 +1355,7 @@ async def test_delete_record( "fields": record.fields, "metadata": None, "external_id": record.external_id, + "dataset_id": str(record.dataset_id), "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), } From 4042a6e61b1adf3138cd3a0546ac28c24bfb239d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Thu, 11 Jan 2024 16:27:14 +0100 Subject: [PATCH 2/5] chore: Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 07802cca10..bbb242c527 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ These are the section headers that we use: - Warning on feedback dataset settings when leaving page with unsaved changes ([#4461])(https://github.com/argilla-io/argilla/pull/4461) - Added pydantic v2 support using the python SDK ([#4459](https://github.com/argilla-io/argilla/pull/4459)) - API v1 responses returning `Record` schema now always include `dataset_id` as attribute. ([]()) +- API v1 responses returning `Response` schema now always include `record_id` as attribute. ([]()) ## Changed From 0e2812654aca6cd8a3f1d3b7750ec458faa0cfac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Thu, 11 Jan 2024 16:36:02 +0100 Subject: [PATCH 3/5] chore: update CHANGELOG.md --- CHANGELOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bbb242c527..fb040293c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,10 +21,10 @@ These are the section headers that we use: - Restore filters from feedback dataset settings ([#4461])(https://github.com/argilla-io/argilla/pull/4461) - Warning on feedback dataset settings when leaving page with unsaved changes ([#4461])(https://github.com/argilla-io/argilla/pull/4461) - Added pydantic v2 support using the python SDK ([#4459](https://github.com/argilla-io/argilla/pull/4459)) -- API v1 responses returning `Record` schema now always include `dataset_id` as attribute. ([]()) -- API v1 responses returning `Response` schema now always include `record_id` as attribute. ([]()) +- API v1 responses returning `Record` schema now always include `dataset_id` as attribute. ([#4482](https://github.com/argilla-io/argilla/pull/4482)) +- API v1 responses returning `Response` schema now always include `record_id` as attribute. ([#4482](https://github.com/argilla-io/argilla/pull/4482)) -## Changed +### Changed - Module `argilla.cli.server` definitions have been moved to `argilla.server.cli` module. ([#4472](https://github.com/argilla-io/argilla/pull/4472)) From b98cb9d50fa87f9316aa4ffc212a327a7c8d5091 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Thu, 11 Jan 2024 17:56:58 +0100 Subject: [PATCH 4/5] Use ValueError instead of HTTPException Co-authored-by: Francisco Aranda --- src/argilla/server/schemas/v1/records.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/argilla/server/schemas/v1/records.py b/src/argilla/server/schemas/v1/records.py index fbf3676d27..477edcf3b3 100644 --- a/src/argilla/server/schemas/v1/records.py +++ b/src/argilla/server/schemas/v1/records.py @@ -117,9 +117,7 @@ def check(cls, values: Dict[str, Any]) -> Dict[str, Any]: if vectors is not None and len(vectors) > 0 and RecordInclude.vectors in relationships: # TODO: once we have a exception handler for ValueError in v1, remove HTTPException # raise ValueError("Cannot include both 'vectors' and 'relationships' in the same request") - raise HTTPException( - status_code=422, - detail="'include' query param cannot have both 'vectors' and 'vectors:vector_settings_name_1,vectors_settings_name_2,...'", + raise ValueError("'include' query param cannot have both 'vectors' and 'vectors:vector_settings_name_1,vectors_settings_name_2,...'", ) return values From da5b48d587c22a346da9c6eef74f9f8548b83e55 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Jan 2024 16:57:27 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/argilla/server/schemas/v1/records.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/argilla/server/schemas/v1/records.py b/src/argilla/server/schemas/v1/records.py index 477edcf3b3..8215c8e153 100644 --- a/src/argilla/server/schemas/v1/records.py +++ b/src/argilla/server/schemas/v1/records.py @@ -117,7 +117,8 @@ def check(cls, values: Dict[str, Any]) -> Dict[str, Any]: if vectors is not None and len(vectors) > 0 and RecordInclude.vectors in relationships: # TODO: once we have a exception handler for ValueError in v1, remove HTTPException # raise ValueError("Cannot include both 'vectors' and 'relationships' in the same request") - raise ValueError("'include' query param cannot have both 'vectors' and 'vectors:vector_settings_name_1,vectors_settings_name_2,...'", + raise ValueError( + "'include' query param cannot have both 'vectors' and 'vectors:vector_settings_name_1,vectors_settings_name_2,...'", ) return values