From 4d04835012c9289d772233f5a39de3de36f6f7b6 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Wed, 26 Jul 2023 14:33:09 +0200 Subject: [PATCH 1/6] docs: Resolve typos, missing import (#3443) Closes #3429 Hello! # Description Resolves typos and a missing import as described by #3429 **Type of change** - [x] Documentation update **How Has This Been Tested** I ran the snippet that I modified - it worked. **Checklist** - [x] I added relevant documentation - [ ] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- - Tom Aarsen --- docs/_source/guides/train_a_model.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/_source/guides/train_a_model.md b/docs/_source/guides/train_a_model.md index a7ca2ca680..ed177605df 100644 --- a/docs/_source/guides/train_a_model.md +++ b/docs/_source/guides/train_a_model.md @@ -89,6 +89,7 @@ Options: ```python import argilla as rg +from argilla.training import ArgillaTrainer from datasets import load_dataset dataset_rg = rg.DatasetForTokenClassification.from_datasets( @@ -126,18 +127,18 @@ It is possible to directly include train-test splits to the `prepare_for_trainin *TextClassification* For text classification tasks, it flattens the inputs into separate columns of the returned dataset and converts the annotations of your records into integers and writes them in a label column: -By passing the `framework` variable as `setfit`, `transformers`, `spark-nlp` or `spacy`. This task requires a `DatastForTextClassification`. +By passing the `framework` variable as `setfit`, `transformers`, `spark-nlp` or `spacy`. This task requires a `DatasetForTextClassification`. *TokenClassification* For token classification tasks, it converts the annotations of a record into integers representing BIO tags and writes them in a `ner_tags` column: -By passing the `framework` variable as `transformers`, `spark-nlp` or `spacy`. This task requires a `DatastForTokenClassification`. +By passing the `framework` variable as `transformers`, `spark-nlp` or `spacy`. This task requires a `DatasetForTokenClassification`. *Text2Text* For text generation tasks like `summarization` and translation tasks, it converts the annotations of a record `text` and `target` columns. -By passing the `framework` variable as `transformers` and `spark-nlp`. This task requires a `DatastForText2Text`. +By passing the `framework` variable as `transformers` and `spark-nlp`. This task requires a `DatasetForText2Text`. *Feedback* For feedback-oriented datasets, we currently rely on a fully customizable workflow, which means automation is limited and yet to be thought out. From e77f4166afa689e3b5a338384cb73fd2c43b8d00 Mon Sep 17 00:00:00 2001 From: Agus <56895847+plaguss@users.noreply.github.com> Date: Thu, 27 Jul 2023 13:45:30 +0200 Subject: [PATCH 2/6] docs: update example os listing users with python client (#3454) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. Closes #3453 **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Refactor (change restructuring the codebase without changing functionality) - [ ] Improvement (change adding some improvement to an existing functionality) - [x] Documentation update **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [ ] Test A - [ ] Test B **Checklist** - [ ] I added relevant documentation - [ ] follows the style guidelines of this project - [ ] I did a self-review of my code - [x] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: Natalia Elvira <126158523+nataliaElv@users.noreply.github.com> --- ...labelling-tokenclassification-basics.ipynb | 51 +++++++++++-------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/docs/_source/tutorials/notebooks/labelling-tokenclassification-basics.ipynb b/docs/_source/tutorials/notebooks/labelling-tokenclassification-basics.ipynb index 096606f6d3..f9b0dd634f 100644 --- a/docs/_source/tutorials/notebooks/labelling-tokenclassification-basics.ipynb +++ b/docs/_source/tutorials/notebooks/labelling-tokenclassification-basics.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -16,7 +15,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -38,7 +36,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -70,7 +67,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -89,7 +85,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -106,7 +101,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -128,7 +122,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -138,7 +131,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -159,7 +151,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -181,7 +172,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -189,7 +179,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -198,6 +187,32 @@ "As a first step, we want to get the list of the users that will be annotating our dataset." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# get workspace where the dataset is (or will be) located\n", + "ws = rg.Workspace.from_name(\"my_workspace\")\n", + "# get the list of users with access to the workspace\n", + "# make sure that all users that will work on the dataset have access to the workspace\n", + "# optional: filter users to get only those with annotator role\n", + "users = [u for u in rg.User.list() if u.role == \"annotator\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "**Note**\n", + "\n", + "If you are using a version earlier than 1.11.0 you will need to call the API directly to get the list of users as is done in the following cell. Note that, in that case, users will be returned as dictionaries and so `users.username` will be `users['username']` instead.\n", + "
" + ] + }, { "cell_type": "code", "execution_count": null, @@ -218,7 +233,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -238,7 +252,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -261,11 +274,10 @@ "chunked_records = [ds[i:i + n] for i in range(0, len(ds), n)]\n", "for chunk in chunked_records:\n", " for idx, record in enumerate(chunk):\n", - " assignments[users[idx]['username']].append(record)" + " assignments[users[idx].username].append(record)" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -273,7 +285,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -306,7 +317,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -334,7 +344,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -342,7 +351,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -358,7 +366,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -372,7 +379,7 @@ ], "metadata": { "kernelspec": { - "display_name": "argilla", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -386,7 +393,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.12" + "version": "3.8.10" }, "vscode": { "interpreter": { From 2d0029a0d090cdef42a6cbdb450ec2510aeefde9 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Thu, 27 Jul 2023 15:47:11 +0200 Subject: [PATCH 3/6] feat: bump version to `0.13.3` --- src/argilla/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/argilla/_version.py b/src/argilla/_version.py index df0e41ce30..1509e218a4 100644 --- a/src/argilla/_version.py +++ b/src/argilla/_version.py @@ -13,4 +13,4 @@ # limitations under the License. # coding: utf-8 -version = "1.13.2" +version = "1.13.3" From d37ea7ee1e44d51d526f9153876f01cb0a2fa486 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 27 Jul 2023 16:48:09 +0200 Subject: [PATCH 4/6] fix: import errors when importing from `argilla.feedback` (#3471) # Description This PRs fixes the `ModuleNotFoundError` and `ImportError` that occurred when trying to import something from `argilla.feedback` module. The first error was caused because in #3336 the telemetry was included in the `ArgillaTrainer`, but in the `argilla.utils.telemetry` module some optional dependencies used by the server were being imported. The second one was caused because the module in which `HuggingFaceDatasetMixin` (and from which `FeedbackDataset` is inheriting) class lives was importing classes from the `argilla.client.feedback.config` module, which was importing `pyyaml` in its root causing the `ImportError`. Closes #3468 **Type of change** - [x] Bug fix (non-breaking change which fixes an issue) **How Has This Been Tested** I've created a wheel of this branch, installed in a new virtual environment and I was able to import something `argilla.feedback` module without errors. **Checklist** - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [x] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: Francisco Aranda --- CHANGELOG.md | 7 ++ .../integrations/huggingface/dataset.py | 10 ++- src/argilla/server/errors/api_errors.py | 24 ++++-- src/argilla/utils/telemetry.py | 27 ++----- tests/server/commons/test_telemetry.py | 54 ------------- tests/server/errors/__init__.py | 13 ++++ tests/server/errors/test_api_errors.py | 75 +++++++++++++++++++ 7 files changed, 129 insertions(+), 81 deletions(-) create mode 100644 tests/server/errors/__init__.py create mode 100644 tests/server/errors/test_api_errors.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 018ffd1618..32cdc12b91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,13 @@ These are the section headers that we use: ## [Unreleased] +## [1.13.3](https://github.com/argilla-io/argilla/compare/v1.13.2...v1.13.3) + +### Fixed + +- Fixed `ModuleNotFoundError` caused because the `argilla.utils.telemetry` module used in the `ArgillaTrainer` was importing an optional dependency not installed by default ([#3471](https://github.com/argilla-io/argilla/pull/3471)). +- Fixed `ImportError` caused because the `argilla.client.feedback.config` module was importing `pyyaml` optional dependency not installed by default ([#3471](https://github.com/argilla-io/argilla/pull/3471)). + ## [1.13.2](https://github.com/argilla-io/argilla/compare/v1.13.1...v1.13.2) ### Fixed diff --git a/src/argilla/client/feedback/integrations/huggingface/dataset.py b/src/argilla/client/feedback/integrations/huggingface/dataset.py index a612db6af5..319cf8dc6d 100644 --- a/src/argilla/client/feedback/integrations/huggingface/dataset.py +++ b/src/argilla/client/feedback/integrations/huggingface/dataset.py @@ -20,7 +20,6 @@ from packaging.version import parse as parse_version -from argilla.client.feedback.config import DatasetConfig, DeprecatedDatasetConfig from argilla.client.feedback.constants import FIELD_TYPE_TO_PYTHON_TYPE from argilla.client.feedback.schemas import FeedbackRecord from argilla.client.feedback.types import AllowedQuestionTypes @@ -188,6 +187,9 @@ def push_to_huggingface( import huggingface_hub from huggingface_hub import DatasetCardData, HfApi + # https://github.com/argilla-io/argilla/issues/3468 + from argilla.client.feedback.config import DatasetConfig + if parse_version(huggingface_hub.__version__) < parse_version("0.14.0"): _LOGGER.warning( "Recommended `huggingface_hub` version is 0.14.0 or higher, and you have" @@ -261,6 +263,12 @@ def from_huggingface(cls: Type["FeedbackDataset"], repo_id: str, *args: Any, **k from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError + # https://github.com/argilla-io/argilla/issues/3468 + from argilla.client.feedback.config import ( + DatasetConfig, + DeprecatedDatasetConfig, + ) + if parse_version(huggingface_hub.__version__) < parse_version("0.14.0"): _LOGGER.warning( "Recommended `huggingface_hub` version is 0.14.0 or higher, and you have" diff --git a/src/argilla/server/errors/api_errors.py b/src/argilla/server/errors/api_errors.py index 005b3680ca..c63a8aac66 100644 --- a/src/argilla/server/errors/api_errors.py +++ b/src/argilla/server/errors/api_errors.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from typing import Any, Dict from fastapi import HTTPException, Request @@ -20,11 +19,14 @@ from pydantic import BaseModel from argilla.server.errors.adapter import exception_to_argilla_error -from argilla.server.errors.base_errors import ServerError +from argilla.server.errors.base_errors import ( + EntityAlreadyExistsError, + EntityNotFoundError, + GenericServerError, + ServerError, +) from argilla.utils import telemetry -_LOGGER = logging.getLogger("argilla") - class ErrorDetail(BaseModel): code: str @@ -41,10 +43,22 @@ def __init__(self, error: ServerError): class APIErrorHandler: + @staticmethod + async def track_error(error: ServerError, request: Request): + data = { + "code": error.code, + "user-agent": request.headers.get("user-agent"), + "accept-language": request.headers.get("accept-language"), + } + if isinstance(error, (GenericServerError, EntityNotFoundError, EntityAlreadyExistsError)): + data["type"] = error.type + + telemetry.get_telemetry_client().track_data(action="ServerErrorFound", data=data) + @staticmethod async def common_exception_handler(request: Request, error: Exception): """Wraps errors as custom generic error""" argilla_error = exception_to_argilla_error(error) - await telemetry.track_error(argilla_error, request=request) + await APIErrorHandler.track_error(argilla_error, request=request) return await http_exception_handler(request, ServerHTTPException(argilla_error)) diff --git a/src/argilla/utils/telemetry.py b/src/argilla/utils/telemetry.py index b31538def4..1dfc317e14 100644 --- a/src/argilla/utils/telemetry.py +++ b/src/argilla/utils/telemetry.py @@ -16,19 +16,14 @@ import logging import platform import uuid -from typing import Any, Dict, Optional - -from fastapi import Request +from typing import TYPE_CHECKING, Any, Dict, Optional from argilla.server.commons.models import TaskType -from argilla.server.errors.base_errors import ( - EntityAlreadyExistsError, - EntityNotFoundError, - GenericServerError, - ServerError, -) from argilla.server.settings import settings +if TYPE_CHECKING: + from fastapi import Request + try: from analytics import Client # This import works only for version 2.2.0 except (ImportError, ModuleNotFoundError): @@ -89,25 +84,15 @@ def track_data(self, action: str, data: Dict[str, Any], include_system_info: boo _CLIENT = TelemetryClient() -def _process_request_info(request: Request): +def _process_request_info(request: "Request"): return {header: request.headers.get(header) for header in ["user-agent", "accept-language"]} -async def track_error(error: ServerError, request: Request): - data = {"code": error.code} - if isinstance(error, (GenericServerError, EntityNotFoundError, EntityAlreadyExistsError)): - data["type"] = error.type - - data.update(_process_request_info(request)) - - _CLIENT.track_data(action="ServerErrorFound", data=data) - - async def track_bulk(task: TaskType, records: int): _CLIENT.track_data(action="LogRecordsRequested", data={"task": task, "records": records}) -async def track_login(request: Request, username: str): +async def track_login(request: "Request", username: str): _CLIENT.track_data( action="UserInfoRequested", data={ diff --git a/tests/server/commons/test_telemetry.py b/tests/server/commons/test_telemetry.py index 1bdbc70e0f..5d41cfb2e2 100644 --- a/tests/server/commons/test_telemetry.py +++ b/tests/server/commons/test_telemetry.py @@ -17,13 +17,6 @@ import pytest from argilla.server.commons.models import TaskType -from argilla.server.errors import ( - EntityAlreadyExistsError, - EntityNotFoundError, - GenericServerError, - ServerError, -) -from argilla.server.schemas.datasets import Dataset from argilla.utils import telemetry from argilla.utils.telemetry import TelemetryClient, get_telemetry_client from fastapi import Request @@ -57,50 +50,3 @@ async def test_track_bulk(test_telemetry): await telemetry.track_bulk(task=task, records=records) test_telemetry.assert_called_once_with("LogRecordsRequested", {"task": task, "records": records}) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ["error", "expected_event"], - [ - ( - EntityNotFoundError(name="mock-name", type="MockType"), - { - "accept-language": None, - "code": "argilla.api.errors::EntityNotFoundError", - "type": "MockType", - "user-agent": None, - }, - ), - ( - EntityAlreadyExistsError(name="mock-name", type=Dataset, workspace="mock-workspace"), - { - "accept-language": None, - "code": "argilla.api.errors::EntityAlreadyExistsError", - "type": "Dataset", - "user-agent": None, - }, - ), - ( - GenericServerError(RuntimeError("This is a mock error")), - { - "accept-language": None, - "code": "argilla.api.errors::GenericServerError", - "type": "builtins.RuntimeError", - "user-agent": None, - }, - ), - ( - ServerError(), - { - "accept-language": None, - "code": "argilla.api.errors::ServerError", - "user-agent": None, - }, - ), - ], -) -async def test_track_error(test_telemetry, error, expected_event): - await telemetry.track_error(error, request=mock_request) - - test_telemetry.assert_called_once_with("ServerErrorFound", expected_event) diff --git a/tests/server/errors/__init__.py b/tests/server/errors/__init__.py new file mode 100644 index 0000000000..55be41799b --- /dev/null +++ b/tests/server/errors/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/server/errors/test_api_errors.py b/tests/server/errors/test_api_errors.py new file mode 100644 index 0000000000..20f88a516d --- /dev/null +++ b/tests/server/errors/test_api_errors.py @@ -0,0 +1,75 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla.server.errors.api_errors import APIErrorHandler +from argilla.server.errors.base_errors import ( + EntityAlreadyExistsError, + EntityNotFoundError, + GenericServerError, + ServerError, +) +from argilla.server.schemas.datasets import Dataset +from fastapi import Request + +mock_request = Request(scope={"type": "http", "headers": {}}) + + +@pytest.mark.asyncio +class TestAPIErrorHandler: + @pytest.mark.asyncio + @pytest.mark.parametrize( + ["error", "expected_event"], + [ + ( + EntityNotFoundError(name="mock-name", type="MockType"), + { + "accept-language": None, + "code": "argilla.api.errors::EntityNotFoundError", + "type": "MockType", + "user-agent": None, + }, + ), + ( + EntityAlreadyExistsError(name="mock-name", type=Dataset, workspace="mock-workspace"), + { + "accept-language": None, + "code": "argilla.api.errors::EntityAlreadyExistsError", + "type": "Dataset", + "user-agent": None, + }, + ), + ( + GenericServerError(RuntimeError("This is a mock error")), + { + "accept-language": None, + "code": "argilla.api.errors::GenericServerError", + "type": "builtins.RuntimeError", + "user-agent": None, + }, + ), + ( + ServerError(), + { + "accept-language": None, + "code": "argilla.api.errors::ServerError", + "user-agent": None, + }, + ), + ], + ) + async def test_track_error(self, test_telemetry, error, expected_event): + await APIErrorHandler.track_error(error, request=mock_request) + + test_telemetry.assert_called_once_with("ServerErrorFound", expected_event) From 656155aa0db322a9f0217e4085962e0b029a9d4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Fri, 28 Jul 2023 12:41:43 +0200 Subject: [PATCH 5/6] feat: add dataset `PATCH` endpoint (#3402) # Description This PR adds a new endpoint `PATCH /api/v1/datasets/{dataset_id}` that allows to partially update a dataset. The attributes that can be updated are `name` and `guidelines`. Closes #3396 **Type of change** - [x] New feature (non-breaking change which adds functionality) **How Has This Been Tested** Manually in a local deployment and I've added unit tests covering this new endpoint. **Checklist** - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [x] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: Paco Aranda --- CHANGELOG.md | 4 + .../server/apis/v1/handlers/datasets.py | 18 +- src/argilla/server/contexts/datasets.py | 5 + src/argilla/server/policies.py | 9 + src/argilla/server/schemas/v1/datasets.py | 42 ++++- tests/server/api/v1/test_datasets.py | 175 +++++++++++++++++- 6 files changed, 234 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 211cdfec0f..4904e64098 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ These are the section headers that we use: ## [Unreleased] +### Added + +- Added `PATCH /api/v1/datasets/{dataset_id}` endpoint to update dataset name and guidelines (See [#3402](https://github.com/argilla-io/argilla/pull/3402)). + ### Changed - Improved efficiency of weak labeling when dataset contains vectors ([#3444](https://github.com/argilla-io/argilla/pull/3444)). diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index 82e46f8a61..6a548ea887 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -28,6 +28,7 @@ Dataset, DatasetCreate, Datasets, + DatasetUpdate, Field, FieldCreate, Fields, @@ -378,7 +379,7 @@ async def publish_dataset( async def delete_dataset( *, db: AsyncSession = Depends(get_async_db), - search_engine=Depends(get_search_engine), + search_engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, current_user: User = Security(auth.get_current_user), ): @@ -389,3 +390,18 @@ async def delete_dataset( await datasets.delete_dataset(db, search_engine, dataset=dataset) return dataset + + +@router.patch("/datasets/{dataset_id}", response_model=Dataset) +async def update_dataset( + *, + db: AsyncSession = Depends(get_async_db), + dataset_id: UUID, + dataset_update: DatasetUpdate, + current_user: User = Security(auth.get_current_user), +): + dataset = await _get_dataset(db, dataset_id) + + await authorize(current_user, DatasetPolicyV1.update(dataset)) + + return await datasets.update_dataset(db, dataset=dataset, dataset_update=dataset_update) diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index 5768d299b6..2835bc1ea1 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -132,6 +132,11 @@ async def delete_dataset(db: "AsyncSession", search_engine: SearchEngine, datase return dataset +async def update_dataset(db: "AsyncSession", dataset: Dataset, dataset_update: DatasetCreate) -> Dataset: + params = dataset_update.dict(exclude_unset=True, exclude_none=True) + return await dataset.update(db, **params) + + async def get_field_by_id(db: "AsyncSession", field_id: UUID) -> Union[Field, None]: result = await db.execute(select(Field).filter_by(id=field_id).options(selectinload(Field.dataset))) return result.scalar_one_or_none() diff --git a/src/argilla/server/policies.py b/src/argilla/server/policies.py index 4875f2fa84..451f44d850 100644 --- a/src/argilla/server/policies.py +++ b/src/argilla/server/policies.py @@ -279,6 +279,15 @@ async def is_allowed(actor: User) -> bool: return is_allowed + @classmethod + def update(cls, dataset: Dataset) -> PolicyAction: + async def is_allowed(actor: User) -> bool: + return actor.is_owner or ( + actor.is_admin and await _exists_workspace_user_by_user_and_workspace_id(actor, dataset.workspace_id) + ) + + return is_allowed + class FieldPolicyV1: @classmethod diff --git a/src/argilla/server/schemas/v1/datasets.py b/src/argilla/server/schemas/v1/datasets.py index a730fc05de..0f59de546f 100644 --- a/src/argilla/server/schemas/v1/datasets.py +++ b/src/argilla/server/schemas/v1/datasets.py @@ -17,7 +17,15 @@ from typing import Any, Dict, List, Literal, Optional, Union from uuid import UUID -from pydantic import BaseModel, PositiveInt, conlist, constr, root_validator, validator +from pydantic import ( + BaseModel, + Field, + PositiveInt, + conlist, + constr, + root_validator, + validator, +) from pydantic import Field as PydanticField from pydantic.utils import GetterDict @@ -31,8 +39,11 @@ from argilla.server.models import DatasetStatus, FieldType, QuestionSettings, QuestionType, ResponseStatus -DATASET_CREATE_GUIDELINES_MIN_LENGTH = 1 -DATASET_CREATE_GUIDELINES_MAX_LENGTH = 10000 +DATASET_NAME_REGEX = r"^(?!-|_)[a-zA-Z0-9-_ ]+$" +DATASET_NAME_MIN_LENGTH = 1 +DATASET_NAME_MAX_LENGTH = 200 +DATASET_GUIDELINES_MIN_LENGTH = 1 +DATASET_GUIDELINES_MAX_LENGTH = 10000 FIELD_CREATE_NAME_REGEX = r"^(?=.*[a-z0-9])[a-z0-9_-]+$" FIELD_CREATE_NAME_MIN_LENGTH = 1 @@ -89,17 +100,28 @@ class Datasets(BaseModel): items: List[Dataset] +DatasetName = Annotated[ + constr(regex=DATASET_NAME_REGEX, min_length=DATASET_NAME_MIN_LENGTH, max_length=DATASET_NAME_MAX_LENGTH), + PydanticField(..., description="Dataset name"), +] + +DatasetGuidelines = Annotated[ + constr(min_length=DATASET_GUIDELINES_MIN_LENGTH, max_length=DATASET_GUIDELINES_MAX_LENGTH), + PydanticField(..., description="Dataset guidelines"), +] + + class DatasetCreate(BaseModel): - name: str - guidelines: Optional[ - constr( - min_length=DATASET_CREATE_GUIDELINES_MIN_LENGTH, - max_length=DATASET_CREATE_GUIDELINES_MAX_LENGTH, - ) - ] + name: DatasetName + guidelines: Optional[DatasetGuidelines] workspace_id: UUID +class DatasetUpdate(BaseModel): + name: Optional[DatasetName] + guidelines: Optional[DatasetGuidelines] + + class RecordMetrics(BaseModel): count: int diff --git a/tests/server/api/v1/test_datasets.py b/tests/server/api/v1/test_datasets.py index f20f6f5876..d87eb72d55 100644 --- a/tests/server/api/v1/test_datasets.py +++ b/tests/server/api/v1/test_datasets.py @@ -34,7 +34,8 @@ Workspace, ) from argilla.server.schemas.v1.datasets import ( - DATASET_CREATE_GUIDELINES_MAX_LENGTH, + DATASET_GUIDELINES_MAX_LENGTH, + DATASET_NAME_MAX_LENGTH, FIELD_CREATE_NAME_MAX_LENGTH, FIELD_CREATE_TITLE_MAX_LENGTH, QUESTION_CREATE_DESCRIPTION_MAX_LENGTH, @@ -1245,16 +1246,25 @@ async def test_create_dataset(client: TestClient, db: "AsyncSession", owner_auth } +@pytest.mark.parametrize( + "dataset_json", + [ + {"name": ""}, + {"name": "123$abc"}, + {"name": "unit@test"}, + {"name": "-test-dataset"}, + {"name": "_test-dataset"}, + {"name": "a" * (DATASET_NAME_MAX_LENGTH + 1)}, + {"name": "test-dataset", "guidelines": ""}, + {"name": "test-dataset", "guidelines": "a" * (DATASET_GUIDELINES_MAX_LENGTH + 1)}, + ], +) @pytest.mark.asyncio -async def test_create_dataset_with_invalid_length_guidelines( - client: TestClient, db: "AsyncSession", owner_auth_header: dict +async def test_create_dataset_with_invalid_settings( + client: TestClient, db: "AsyncSession", owner_auth_header: dict, dataset_json: dict ): workspace = await WorkspaceFactory.create() - dataset_json = { - "name": "name", - "guidelines": "a" * (DATASET_CREATE_GUIDELINES_MAX_LENGTH + 1), - "workspace_id": str(workspace.id), - } + dataset_json.update({"workspace_id": str(workspace.id)}) response = client.post("/api/v1/datasets", headers=owner_auth_header, json=dataset_json) @@ -3218,6 +3228,155 @@ async def test_publish_dataset_with_nonexistent_dataset_id( assert (await db.execute(select(func.count(Record.id)))).scalar() == 0 +@pytest.mark.parametrize( + "payload", + [ + {"name": "New Name", "guidelines": "New Guidelines"}, + {"name": "New Name"}, + {"guidelines": "New Guidelines"}, + {}, + {"name": None, "guidelines": None}, + {"status": DatasetStatus.draft, "workspace_id": str(uuid4())}, + ], +) +@pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner]) +@pytest.mark.asyncio +async def test_update_dataset(client: TestClient, role: UserRole, payload: dict): + dataset = await DatasetFactory.create( + name="Current Name", guidelines="Current Guidelines", status=DatasetStatus.ready + ) + user = await UserFactory.create(role=role, workspaces=[dataset.workspace]) + + response = client.patch( + f"/api/v1/datasets/{dataset.id}", + headers={API_KEY_HEADER_NAME: user.api_key}, + json=payload, + ) + + assert response.status_code == 200 + assert response.json() == { + "id": str(dataset.id), + "name": payload.get("name") or dataset.name, + "guidelines": payload.get("guidelines") or dataset.guidelines, + "status": "ready", + "workspace_id": str(dataset.workspace_id), + "inserted_at": dataset.inserted_at.isoformat(), + "updated_at": dataset.updated_at.isoformat(), + } + + +@pytest.mark.parametrize( + "dataset_json", + [ + {"name": ""}, + {"name": "123$abc"}, + {"name": "unit@test"}, + {"name": "-test-dataset"}, + {"name": "_test-dataset"}, + {"name": "a" * (DATASET_NAME_MAX_LENGTH + 1)}, + {"name": "test-dataset", "guidelines": ""}, + {"name": "test-dataset", "guidelines": "a" * (DATASET_GUIDELINES_MAX_LENGTH + 1)}, + ], +) +@pytest.mark.asyncio +async def test_update_dataset_with_invalid_settings( + client: TestClient, db: "AsyncSession", owner_auth_header: dict, dataset_json: dict +): + dataset = await DatasetFactory.create( + name="Current Name", guidelines="Current Guidelines", status=DatasetStatus.ready + ) + + response = client.patch(f"/api/v1/datasets/{dataset.id}", headers=owner_auth_header, json=dataset_json) + + assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_update_dataset_with_invalid_payload(client: TestClient, owner_auth_header: dict): + dataset = await DatasetFactory.create() + + response = client.patch( + f"/api/v1/datasets/{dataset.id}", + headers=owner_auth_header, + json={"name": {"this": {"is": "invalid"}}, "guidelines": {"this": {"is": "invalid"}}}, + ) + + assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_update_dataset_with_none_values(client: TestClient, owner_auth_header: dict): + dataset = await DatasetFactory.create() + + response = client.patch( + f"/api/v1/datasets/{dataset.id}", + headers=owner_auth_header, + json={"name": None, "guidelines": None}, + ) + + assert response.status_code == 200 + assert response.json() == { + "id": str(dataset.id), + "name": dataset.name, + "guidelines": dataset.guidelines, + "status": dataset.status, + "workspace_id": str(dataset.workspace_id), + "inserted_at": dataset.inserted_at.isoformat(), + "updated_at": dataset.updated_at.isoformat(), + } + + +@pytest.mark.asyncio +async def test_update_dataset_non_existent(client: TestClient, owner_auth_header: dict): + response = client.patch( + f"/api/v1/datasets/{uuid4()}", + headers=owner_auth_header, + json={"name": "New Name", "guidelines": "New Guidelines"}, + ) + + assert response.status_code == 404 + + +@pytest.mark.asyncio +async def test_update_dataset_as_admin_from_different_workspace(client: TestClient): + dataset = await DatasetFactory.create() + user = await UserFactory.create(role=UserRole.admin) + + response = client.patch( + f"/api/v1/datasets/{dataset.id}", + headers={API_KEY_HEADER_NAME: user.api_key}, + json={"name": "New Name", "guidelines": "New Guidelines"}, + ) + + assert response.status_code == 403 + + +@pytest.mark.asyncio +async def test_update_dataset_as_annotator(client: TestClient): + dataset = await DatasetFactory.create() + user = await UserFactory.create(role=UserRole.annotator, workspaces=[dataset.workspace]) + + response = client.patch( + f"/api/v1/datasets/{dataset.id}", + headers={API_KEY_HEADER_NAME: user.api_key}, + json={"name": "New Name", "guidelines": "New Guidelines"}, + ) + + assert response.status_code == 403 + + +@pytest.mark.asyncio +async def test_update_dataset_without_authentication(client: TestClient): + dataset = await DatasetFactory.create() + + response = client.patch( + f"/api/v1/datasets/{dataset.id}", + json={"name": "New Name", "guidelines": "New Guidelines"}, + ) + + assert response.status_code == 401 + + @pytest.mark.asyncio async def test_delete_dataset( client: TestClient, db: "AsyncSession", mock_search_engine: SearchEngine, owner: User, owner_auth_header: dict From 642f32219a26f5210b03dd08339a616982736d3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Fri, 28 Jul 2023 12:43:58 +0200 Subject: [PATCH 6/6] feat: field `PATCH` endpoint (#3421) # Description This PRs adds a new `PATCH /api/v1/fields/{field_id}` endpoint that allows to partially update a field of a `FeedbackDataset` in the API. The attributes that can be updated using this new endpoint are: - `title` - `settings.use_markdown` To be able to update the `dict` column, it was needed to update the `CRUDMixin.fill` method to detect if the column to be filled is a `dict` column. In that case, it will iterate the keys and values of the received value, and it will set the key and value of the `dict` col one by one to avoid overriding the whole dict. Closes #3397 **Type of change** - [x] New feature (non-breaking change which adds functionality) **How Has This Been Tested** I've tested it manually in a local environment and I've added unit tests. **Checklist** - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [x] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: Paco Aranda --- CHANGELOG.md | 1 + src/argilla/server/apis/v1/handlers/fields.py | 34 ++++-- src/argilla/server/contexts/datasets.py | 6 + src/argilla/server/models/mixins.py | 6 + src/argilla/server/models/models.py | 3 +- src/argilla/server/policies.py | 10 ++ src/argilla/server/schemas/v1/datasets.py | 24 ++-- src/argilla/server/schemas/v1/fields.py | 12 +- tests/factories.py | 2 +- tests/server/api/v1/test_fields.py | 110 +++++++++++++++++- 10 files changed, 187 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4904e64098..ebdaea2c70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ These are the section headers that we use: ### Added +- Added `PATCH /api/v1/fields/{field_id}` endpoint to update the field title and markdown settings (See [#3421](https://github.com/argilla-io/argilla/pull/3421)). - Added `PATCH /api/v1/datasets/{dataset_id}` endpoint to update dataset name and guidelines (See [#3402](https://github.com/argilla-io/argilla/pull/3402)). ### Changed diff --git a/src/argilla/server/apis/v1/handlers/fields.py b/src/argilla/server/apis/v1/handlers/fields.py index 92307f9394..f5848a8a9e 100644 --- a/src/argilla/server/apis/v1/handlers/fields.py +++ b/src/argilla/server/apis/v1/handlers/fields.py @@ -20,25 +20,45 @@ from argilla.server.contexts import datasets from argilla.server.database import get_async_db from argilla.server.policies import FieldPolicyV1, authorize -from argilla.server.schemas.v1.fields import Field +from argilla.server.schemas.v1.fields import Field, FieldUpdate from argilla.server.security import auth from argilla.server.security.model import User router = APIRouter(tags=["fields"]) -@router.delete("/fields/{field_id}", response_model=Field) -async def delete_field( - *, db: AsyncSession = Depends(get_async_db), field_id: UUID, current_user: User = Security(auth.get_current_user) -): +async def _get_field(db: "AsyncSession", field_id: UUID) -> Field: field = await datasets.get_field_by_id(db, field_id) - - await authorize(current_user, FieldPolicyV1.delete(field)) if not field: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Field with id `{field_id}` not found", ) + return field + + +@router.patch("/fields/{field_id}", response_model=Field) +async def update_field( + *, + db: AsyncSession = Depends(get_async_db), + field_id: UUID, + field_update: FieldUpdate, + current_user: User = Security(auth.get_current_user), +): + field = await _get_field(db, field_id) + + await authorize(current_user, FieldPolicyV1.update(field)) + + return await datasets.update_field(db, field, field_update) + + +@router.delete("/fields/{field_id}", response_model=Field) +async def delete_field( + *, db: AsyncSession = Depends(get_async_db), field_id: UUID, current_user: User = Security(auth.get_current_user) +): + field = await _get_field(db, field_id) + + await authorize(current_user, FieldPolicyV1.delete(field)) # TODO: We should split API v1 into different FastAPI apps so we can customize error management. # After mapping ValueError to 422 errors for API v1 then we can remove this try except. diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index 2835bc1ea1..176e2c2d3f 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession + from argilla.server.schemas.v1.fields import FieldUpdate from argilla.server.schemas.v1.suggestions import SuggestionCreate LIST_RECORDS_LIMIT = 20 @@ -161,6 +162,11 @@ async def create_field(db: "AsyncSession", dataset: Dataset, field_create: Field ) +async def update_field(db: "AsyncSession", field: Field, field_update: "FieldUpdate") -> Field: + params = field_update.dict(exclude_unset=True, exclude_none=True) + return await field.update(db, **params) + + async def delete_field(db: "AsyncSession", field: Field) -> Field: if field.dataset.is_ready: raise ValueError("Fields cannot be deleted for a published dataset") diff --git a/src/argilla/server/models/mixins.py b/src/argilla/server/models/mixins.py index 013438c186..82c0ed6b2a 100644 --- a/src/argilla/server/models/mixins.py +++ b/src/argilla/server/models/mixins.py @@ -50,6 +50,12 @@ def fill(self, **kwargs: Any) -> Self: for key, value in kwargs.items(): if not hasattr(self, key): raise AttributeError(f"Model `{self.__class__.__name__}` has no attribute `{key}`") + # If the value is a dict, set value for each key one by one, as we want to update only the keys that are in + # `value` and not override the whole dict. + if isinstance(value, dict): + dict_col = getattr(self, key) or {} + dict_col.update(value) + value = dict_col setattr(self, key, value) return self diff --git a/src/argilla/server/models/models.py b/src/argilla/server/models/models.py index a5e54b6719..c1ffcf7921 100644 --- a/src/argilla/server/models/models.py +++ b/src/argilla/server/models/models.py @@ -20,6 +20,7 @@ from pydantic import parse_obj_as from sqlalchemy import JSON, ForeignKey, Text, UniqueConstraint, and_ from sqlalchemy import Enum as SAEnum +from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import Mapped, mapped_column, relationship from argilla.server.models.base import DatabaseModel @@ -60,7 +61,7 @@ class Field(DatabaseModel): name: Mapped[str] = mapped_column(Text, index=True) title: Mapped[str] = mapped_column(Text) required: Mapped[bool] = mapped_column(default=False) - settings: Mapped[dict] = mapped_column(JSON, default={}) + settings: Mapped[dict] = mapped_column(MutableDict.as_mutable(JSON), default={}) dataset_id: Mapped[UUID] = mapped_column(ForeignKey("datasets.id", ondelete="CASCADE"), index=True) dataset: Mapped["Dataset"] = relationship(back_populates="fields") diff --git a/src/argilla/server/policies.py b/src/argilla/server/policies.py index 451f44d850..031e21bd86 100644 --- a/src/argilla/server/policies.py +++ b/src/argilla/server/policies.py @@ -290,6 +290,16 @@ async def is_allowed(actor: User) -> bool: class FieldPolicyV1: + @classmethod + def update(cls, field: Field) -> PolicyAction: + async def is_allowed(actor: User) -> bool: + return actor.is_owner or ( + actor.is_admin + and await _exists_workspace_user_by_user_and_workspace_id(actor, field.dataset.workspace_id) + ) + + return is_allowed + @classmethod def delete(cls, field: Field) -> PolicyAction: async def is_allowed(actor: User) -> bool: diff --git a/src/argilla/server/schemas/v1/datasets.py b/src/argilla/server/schemas/v1/datasets.py index 0f59de546f..c807dea20f 100644 --- a/src/argilla/server/schemas/v1/datasets.py +++ b/src/argilla/server/schemas/v1/datasets.py @@ -160,16 +160,22 @@ class Fields(BaseModel): items: List[Field] +FieldName = Annotated[ + constr( + regex=FIELD_CREATE_NAME_REGEX, min_length=FIELD_CREATE_NAME_MIN_LENGTH, max_length=FIELD_CREATE_NAME_MAX_LENGTH + ), + PydanticField(..., description="The name of the field"), +] + +FieldTitle = Annotated[ + constr(min_length=FIELD_CREATE_TITLE_MIN_LENGTH, max_length=FIELD_CREATE_TITLE_MAX_LENGTH), + PydanticField(..., description="The title of the field"), +] + + class FieldCreate(BaseModel): - name: constr( - regex=FIELD_CREATE_NAME_REGEX, - min_length=FIELD_CREATE_NAME_MIN_LENGTH, - max_length=FIELD_CREATE_NAME_MAX_LENGTH, - ) - title: constr( - min_length=FIELD_CREATE_TITLE_MIN_LENGTH, - max_length=FIELD_CREATE_TITLE_MAX_LENGTH, - ) + name: FieldName + title: FieldTitle required: Optional[bool] settings: TextFieldSettings diff --git a/src/argilla/server/schemas/v1/fields.py b/src/argilla/server/schemas/v1/fields.py index 39489bb5a4..a902e82054 100644 --- a/src/argilla/server/schemas/v1/fields.py +++ b/src/argilla/server/schemas/v1/fields.py @@ -13,12 +13,13 @@ # limitations under the License. from datetime import datetime -from typing import Literal +from typing import Literal, Optional from uuid import UUID from pydantic import BaseModel from argilla.server.models import FieldType +from argilla.server.schemas.v1.datasets import FieldTitle class TextFieldSettings(BaseModel): @@ -38,3 +39,12 @@ class Field(BaseModel): class Config: orm_mode = True + + +class TextFieldSettingsUpdate(BaseModel): + use_markdown: Optional[bool] + + +class FieldUpdate(BaseModel): + title: Optional[FieldTitle] + settings: Optional[TextFieldSettingsUpdate] diff --git a/tests/factories.py b/tests/factories.py index 6b6f1a15bc..2a895af9dc 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -232,7 +232,7 @@ class Meta: class TextFieldFactory(FieldFactory): - settings = {"type": FieldType.text.value} + settings = {"type": FieldType.text.value, "use_markdown": False} class QuestionFactory(BaseFactory): diff --git a/tests/server/api/v1/test_fields.py b/tests/server/api/v1/test_fields.py index 25d3ab1f1c..5674229f64 100644 --- a/tests/server/api/v1/test_fields.py +++ b/tests/server/api/v1/test_fields.py @@ -18,16 +18,122 @@ import pytest from argilla._constants import API_KEY_HEADER_NAME -from argilla.server.models import DatasetStatus, Field +from argilla.server.models import DatasetStatus, Field, UserRole from fastapi.testclient import TestClient from sqlalchemy import func, select -from tests.factories import AnnotatorFactory, DatasetFactory, TextFieldFactory +from tests.factories import ( + AnnotatorFactory, + DatasetFactory, + TextFieldFactory, + UserFactory, +) if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession +@pytest.mark.parametrize( + "payload", + [ + {"title": "New Title", "settings": {"use_markdown": True}}, + {"title": "New Title"}, + {}, + {"title": None, "settings": None}, + {"name": "New Name", "required": True, "settings": {"type": "unit-test"}, "dataset_id": str(uuid4())}, + ], +) +@pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner]) +@pytest.mark.asyncio +async def test_update_field(client: TestClient, role: UserRole, payload: dict): + field = await TextFieldFactory.create() + user = await UserFactory.create(role=role, workspaces=[field.dataset.workspace]) + + response = client.patch(f"/api/v1/fields/{field.id}", headers={API_KEY_HEADER_NAME: user.api_key}, json=payload) + + settings = payload.get("settings") + if settings is None: + use_markdown = field.settings["use_markdown"] + else: + use_markdown = settings.get("use_markdown") or field.settings["use_markdown"] + + assert response.status_code == 200 + assert response.json() == { + "id": str(field.id), + "name": field.name, + "title": payload.get("title") or field.title, + "required": field.required, + "settings": {"type": field.settings["type"], "use_markdown": use_markdown}, + "dataset_id": str(field.dataset.id), + "inserted_at": field.inserted_at.isoformat(), + "updated_at": field.updated_at.isoformat(), + } + + +@pytest.mark.asyncio +async def test_update_field_with_invalid_payload(client: TestClient, owner_auth_header: dict): + field = await TextFieldFactory.create() + + response = client.patch( + f"/api/v1/fields/{field.id}", + headers=owner_auth_header, + json={"title": {"this": "is", "not": "valid"}, "settings": {"use_markdown": "no"}}, + ) + + assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_update_field_non_existent(client: TestClient, owner_auth_header: dict): + response = client.patch( + f"/api/v1/fields/{uuid4()}", + headers=owner_auth_header, + json={"title": "New Title", "settings": {"use_markdown": True}}, + ) + + assert response.status_code == 404 + + +@pytest.mark.asyncio +async def test_update_field_as_admin_from_different_workspace(client: TestClient): + field = await TextFieldFactory.create() + user = await UserFactory.create(role=UserRole.admin) + + response = client.patch( + f"/api/v1/fields/{field.id}", + headers={API_KEY_HEADER_NAME: user.api_key}, + json={"title": "New Title", "settings": {"use_markdown": True}}, + ) + + assert response.status_code == 403 + + +@pytest.mark.asyncio +async def test_update_field_as_annotator(client: TestClient): + field = await TextFieldFactory.create() + user = await UserFactory.create(role=UserRole.annotator, workspaces=[field.dataset.workspace]) + + response = client.patch( + f"/api/v1/fields/{field.id}", + headers={API_KEY_HEADER_NAME: user.api_key}, + json={"title": "New Title", "settings": {"use_markdown": True}}, + ) + + assert response.status_code == 403 + + +@pytest.mark.asyncio +async def test_update_field_without_authentication(client: TestClient): + field = await TextFieldFactory.create() + + response = client.patch( + f"/api/v1/fields/{field.id}", + json={"title": "New Title", "settings": {"use_markdown": True}}, + ) + + assert response.status_code == 401 + + @pytest.mark.asyncio async def test_delete_field(client: TestClient, db: "AsyncSession", owner_auth_header: dict): field = await TextFieldFactory.create(name="name", title="title")