From 83a11e92c73d226b92f387573b4629e0d88588e7 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 13 Mar 2023 10:04:45 +0100 Subject: [PATCH] Integrate datasets crud endpoints (#2510) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR integrates new DB logic and policies with the current implementation. - The main changes are applied to the dataset service class - The `auth.get_current_user` returns a `models.User` instance - The ONLY change in endpoints is using this new `auth.get_current_user` method In a separate PR the pydantic User class will be cleaned and normalized --------- Co-authored-by: José Francisco Calvo Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: José Francisco Calvo --- .../server/apis/v0/handlers/datasets.py | 59 +++---- .../server/apis/v0/handlers/metrics.py | 4 +- .../server/apis/v0/handlers/records.py | 4 +- .../server/apis/v0/handlers/records_search.py | 2 +- .../server/apis/v0/handlers/records_update.py | 2 +- .../server/apis/v0/handlers/text2text.py | 6 +- .../apis/v0/handlers/text_classification.py | 20 +-- .../text_classification_dataset_settings.py | 16 +- .../apis/v0/handlers/token_classification.py | 6 +- .../token_classification_dataset_settings.py | 20 +-- src/argilla/server/apis/v0/handlers/users.py | 13 +- .../server/apis/v0/handlers/workspaces.py | 12 +- src/argilla/server/daos/datasets.py | 12 +- src/argilla/server/policies.py | 34 ++++- .../security/auth_provider/local/provider.py | 29 +--- src/argilla/server/services/datasets.py | 144 +++++++++++------- .../server/services/storage/service.py | 4 +- tests/server/datasets/test_dao.py | 16 +- tests/server/security/test_provider.py | 10 +- 19 files changed, 220 insertions(+), 193 deletions(-) diff --git a/src/argilla/server/apis/v0/handlers/datasets.py b/src/argilla/server/apis/v0/handlers/datasets.py index 33a2ca1331..cced02c4da 100644 --- a/src/argilla/server/apis/v0/handlers/datasets.py +++ b/src/argilla/server/apis/v0/handlers/datasets.py @@ -17,15 +17,11 @@ from fastapi import APIRouter, Body, Depends, Security from pydantic import parse_obj_as -from sqlalchemy.orm import Session -from argilla.server import database from argilla.server.apis.v0.helpers import deprecate_endpoint from argilla.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies -from argilla.server.contexts import accounts -from argilla.server.daos.datasets import DatasetsDAO -from argilla.server.errors import EntityNotFoundError, ForbiddenOperationError -from argilla.server.policies import DatasetPolicy, is_authorized +from argilla.server.errors import EntityNotFoundError +from argilla.server.models import User from argilla.server.schemas.datasets import ( CopyDatasetRequest, CreateDatasetRequest, @@ -33,7 +29,6 @@ UpdateDatasetRequest, ) from argilla.server.security import auth -from argilla.server.security.model import User, Workspace from argilla.server.services.datasets import DatasetsService router = APIRouter(tags=["datasets"], prefix="/datasets") @@ -50,7 +45,7 @@ async def list_datasets( request_deps: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> List[Dataset]: datasets = service.list( user=current_user, @@ -72,10 +67,10 @@ async def create_dataset( request: CreateDatasetRequest = Body(..., description="The request dataset info"), ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), - user: User = Security(auth.get_user, scopes=["create:datasets"]), + current_user: User = Security(auth.get_current_user), ) -> Dataset: request.workspace = request.workspace or ws_params.workspace - dataset = datasets.create_dataset(user=user, dataset=request) + dataset = datasets.create_dataset(user=current_user, dataset=request) return Dataset.from_orm(dataset) @@ -90,7 +85,7 @@ def get_dataset( name: str, ds_params: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> Dataset: return Dataset.from_orm( service.find_by_name( @@ -112,7 +107,7 @@ def update_dataset( request: UpdateDatasetRequest, ds_params: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> Dataset: found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) @@ -132,28 +127,22 @@ def update_dataset( ) def delete_dataset( name: str, - request_params: CommonTaskHandlerDependencies = Depends(), - db: Session = Depends(database.get_db), - datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), - user: User = Security(auth.get_current_user, scopes=[]), + ds_params: CommonTaskHandlerDependencies = Depends(), + service: DatasetsService = Depends(DatasetsService.get_instance), + current_user: User = Security(auth.get_current_user), ): - workspace_name = request_params.workspace - - workspace = accounts.get_workspace_by_name(db, workspace_name=workspace_name) - if not workspace: - raise EntityNotFoundError(name=workspace_name, type=Workspace) - - dataset = datasets.find_by_name_and_workspace(name=name, workspace=workspace.name) - if not dataset: - return - - if not is_authorized(user, DatasetPolicy.delete(dataset)): - raise ForbiddenOperationError( - "You don't have the necessary permissions to delete this dataset. " - "Only dataset creators or administrators can delete datasets" + try: + found_ds = service.find_by_name( + user=current_user, + name=name, + workspace=ds_params.workspace, ) - - datasets.delete_dataset(dataset) + service.delete( + user=current_user, + dataset=found_ds, + ) + except EntityNotFoundError: + pass @router.put( @@ -164,7 +153,7 @@ def close_dataset( name: str, ds_params: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ): found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) service.close(user=current_user, dataset=found_ds) @@ -178,7 +167,7 @@ def open_dataset( name: str, ds_params: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ): found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) service.open(user=current_user, dataset=found_ds) @@ -195,7 +184,7 @@ def copy_dataset( copy_request: CopyDatasetRequest, ds_params: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> Dataset: found = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) dataset = service.copy_dataset( diff --git a/src/argilla/server/apis/v0/handlers/metrics.py b/src/argilla/server/apis/v0/handlers/metrics.py index eec6397c7d..df2914a368 100644 --- a/src/argilla/server/apis/v0/handlers/metrics.py +++ b/src/argilla/server/apis/v0/handlers/metrics.py @@ -73,7 +73,7 @@ def configure_router(router: APIRouter, cfg: TaskConfig): def get_dataset_metrics( name: str, request_deps: CommonTaskHandlerDependencies = Depends(), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), datasets: DatasetsService = Depends(DatasetsService.get_instance), ) -> List[MetricInfo]: dataset = datasets.find_by_name( @@ -101,7 +101,7 @@ def metric_summary( query: cfg.query, metric_params: MetricSummaryParams = Depends(), request_deps: CommonTaskHandlerDependencies = Depends(), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user, scopes=[]), datasets: DatasetsService = Depends(DatasetsService.get_instance), metrics: MetricsService = Depends(MetricsService.get_instance), ): diff --git a/src/argilla/server/apis/v0/handlers/records.py b/src/argilla/server/apis/v0/handlers/records.py index 108afb1adf..4e29f1afaf 100644 --- a/src/argilla/server/apis/v0/handlers/records.py +++ b/src/argilla/server/apis/v0/handlers/records.py @@ -53,7 +53,7 @@ async def get_dataset_record( request_deps: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), search: SearchRecordsService = Depends(SearchRecordsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> RecordType: found = service.find_by_name( user=current_user, @@ -85,7 +85,7 @@ async def delete_dataset_records( request_deps: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), storage: RecordsStorageService = Depends(RecordsStorageService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ): found = service.find_by_name( user=current_user, diff --git a/src/argilla/server/apis/v0/handlers/records_search.py b/src/argilla/server/apis/v0/handlers/records_search.py index 9043db3f48..103780c586 100644 --- a/src/argilla/server/apis/v0/handlers/records_search.py +++ b/src/argilla/server/apis/v0/handlers/records_search.py @@ -72,7 +72,7 @@ async def scan_dataset_records( request_deps: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), engine: GenericElasticEngineBackend = Depends(GenericElasticEngineBackend.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ): found = service.find_by_name(user=current_user, name=name, workspace=request_deps.workspace) diff --git a/src/argilla/server/apis/v0/handlers/records_update.py b/src/argilla/server/apis/v0/handlers/records_update.py index b2034265ad..6fa67eb677 100644 --- a/src/argilla/server/apis/v0/handlers/records_update.py +++ b/src/argilla/server/apis/v0/handlers/records_update.py @@ -53,7 +53,7 @@ async def partial_update_dataset_record( service: DatasetsService = Depends(DatasetsService.get_instance), search: SearchRecordsService = Depends(SearchRecordsService.get_instance), storage: RecordsStorageService = Depends(RecordsStorageService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user, scopes=[]), ) -> RecordType: dataset = service.find_by_name( user=current_user, diff --git a/src/argilla/server/apis/v0/handlers/text2text.py b/src/argilla/server/apis/v0/handlers/text2text.py index 11efcca104..8354a024f4 100644 --- a/src/argilla/server/apis/v0/handlers/text2text.py +++ b/src/argilla/server/apis/v0/handlers/text2text.py @@ -73,10 +73,10 @@ async def bulk_records( common_params: CommonTaskHandlerDependencies = Depends(), service: Text2TextService = Depends(Text2TextService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> BulkResponse: task = task_type - workspace = current_user.check_workspace(common_params.workspace) + workspace = common_params.workspace try: dataset = datasets.find_by_name( current_user, @@ -118,7 +118,7 @@ def search_records( pagination: RequestPagination = Depends(), service: Text2TextService = Depends(Text2TextService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> Text2TextSearchResults: search = search or Text2TextSearchRequest() query = search.query or Text2TextQuery() diff --git a/src/argilla/server/apis/v0/handlers/text_classification.py b/src/argilla/server/apis/v0/handlers/text_classification.py index 8f4144847d..d8e76d863c 100644 --- a/src/argilla/server/apis/v0/handlers/text_classification.py +++ b/src/argilla/server/apis/v0/handlers/text_classification.py @@ -91,10 +91,10 @@ async def bulk_records( service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> BulkResponse: task = task_type - workspace = current_user.check_workspace(common_params.workspace) + workspace = common_params.workspace try: dataset = datasets.find_by_name( current_user, @@ -140,7 +140,7 @@ def search_records( pagination: RequestPagination = Depends(), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> TextClassificationSearchResults: """ Searches data from dataset @@ -208,7 +208,7 @@ async def list_labeling_rules( common_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), service: TextClassificationService = Depends(TextClassificationService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> List[LabelingRule]: dataset = datasets.find_by_name( user=current_user, @@ -235,7 +235,7 @@ async def create_rule( common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> LabelingRule: dataset = datasets.find_by_name( user=current_user, @@ -271,7 +271,7 @@ async def compute_rule_metrics( common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> LabelingRuleMetricsSummary: dataset = datasets.find_by_name( user=current_user, @@ -297,7 +297,7 @@ async def compute_dataset_rules_metrics( common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> DatasetLabelingRulesMetricsSummary: dataset = datasets.find_by_name( user=current_user, @@ -322,7 +322,7 @@ async def delete_labeling_rule( common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> None: dataset = datasets.find_by_name( user=current_user, @@ -349,7 +349,7 @@ async def get_rule( common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> LabelingRule: dataset = datasets.find_by_name( user=current_user, @@ -377,7 +377,7 @@ async def update_rule( common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> LabelingRule: dataset = datasets.find_by_name( user=current_user, diff --git a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py index 70c6b327d5..b3a78f2b23 100644 --- a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py +++ b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py @@ -54,16 +54,16 @@ async def get_dataset_settings( name: str = DATASET_NAME_PATH_PARAM, ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), - user: User = Security(auth.get_user, scopes=["read:dataset.settings"]), + current_user: User = Security(auth.get_current_user), ) -> TextClassificationSettings: found_ds = datasets.find_by_name( - user=user, + user=current_user, name=name, workspace=ws_params.workspace, task=task, ) - settings = await datasets.get_settings(user=user, dataset=found_ds, class_type=__svc_settings_class__) + settings = await datasets.get_settings(user=current_user, dataset=found_ds, class_type=__svc_settings_class__) return TextClassificationSettings.parse_obj(settings) @deprecate_endpoint( @@ -82,17 +82,17 @@ async def save_settings( ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), - user: User = Security(auth.get_user, scopes=["write:dataset.settings"]), + current_user: User = Security(auth.get_current_user), ) -> TextClassificationSettings: found_ds = datasets.find_by_name( - user=user, + user=current_user, name=name, task=task, workspace=ws_params.workspace, ) - await validator.validate_dataset_settings(user=user, dataset=found_ds, settings=request) + await validator.validate_dataset_settings(user=current_user, dataset=found_ds, settings=request) settings = await datasets.save_settings( - user=user, + user=current_user, dataset=found_ds, settings=__svc_settings_class__.parse_obj(request.dict()), ) @@ -110,7 +110,7 @@ async def delete_settings( name: str = DATASET_NAME_PATH_PARAM, ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), - user: User = Security(auth.get_user, scopes=["delete:dataset.settings"]), + user: User = Security(auth.get_current_user), ) -> None: found_ds = datasets.find_by_name( user=user, diff --git a/src/argilla/server/apis/v0/handlers/token_classification.py b/src/argilla/server/apis/v0/handlers/token_classification.py index e4e9e82eef..c62fe77189 100644 --- a/src/argilla/server/apis/v0/handlers/token_classification.py +++ b/src/argilla/server/apis/v0/handlers/token_classification.py @@ -83,10 +83,10 @@ async def bulk_records( service: TokenClassificationService = Depends(TokenClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user, scopes=[]), ) -> BulkResponse: task = task_type - workspace = current_user.check_workspace(common_params.workspace) + workspace = common_params.workspace try: dataset = datasets.find_by_name( current_user, @@ -139,7 +139,7 @@ def search_records( pagination: RequestPagination = Depends(), service: TokenClassificationService = Depends(TokenClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user, scopes=[]), ) -> TokenClassificationSearchResults: search = search or TokenClassificationSearchRequest() query = search.query or TokenClassificationQuery() diff --git a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py index a5ac2411f3..cc9ed8500d 100644 --- a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py +++ b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py @@ -54,16 +54,16 @@ async def get_dataset_settings( name: str = DATASET_NAME_PATH_PARAM, ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), - user: User = Security(auth.get_user, scopes=["read:dataset.settings"]), + current_user: User = Security(auth.get_current_user), ) -> TokenClassificationSettings: found_ds = datasets.find_by_name( - user=user, + user=current_user, name=name, workspace=ws_params.workspace, task=task, ) - settings = await datasets.get_settings(user=user, dataset=found_ds, class_type=__svc_settings_class__) + settings = await datasets.get_settings(user=current_user, dataset=found_ds, class_type=__svc_settings_class__) return TokenClassificationSettings.parse_obj(settings) @deprecate_endpoint( @@ -82,17 +82,17 @@ async def save_settings( ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), - user: User = Security(auth.get_user, scopes=["write:dataset.settings"]), + current_user: User = Security(auth.get_current_user), ) -> TokenClassificationSettings: found_ds = datasets.find_by_name( - user=user, + user=current_user, name=name, task=task, workspace=ws_params.workspace, ) - await validator.validate_dataset_settings(user=user, dataset=found_ds, settings=request) + await validator.validate_dataset_settings(user=current_user, dataset=found_ds, settings=request) settings = await datasets.save_settings( - user=user, + user=current_user, dataset=found_ds, settings=__svc_settings_class__.parse_obj(request.dict()), ) @@ -110,16 +110,16 @@ async def delete_settings( name: str = DATASET_NAME_PATH_PARAM, ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), - user: User = Security(auth.get_user, scopes=["delete:dataset.settings"]), + current_user: User = Security(auth.get_current_user), ) -> None: found_ds = datasets.find_by_name( - user=user, + user=current_user, name=name, task=task, workspace=ws_params.workspace, ) await datasets.delete_settings( - user=user, + user=current_user, dataset=found_ds, ) diff --git a/src/argilla/server/apis/v0/handlers/users.py b/src/argilla/server/apis/v0/handlers/users.py index 713cad43c7..4540b01ec0 100644 --- a/src/argilla/server/apis/v0/handlers/users.py +++ b/src/argilla/server/apis/v0/handlers/users.py @@ -20,6 +20,7 @@ from pydantic import parse_obj_as from sqlalchemy.orm import Session +from argilla.server import models from argilla.server.commons import telemetry from argilla.server.contexts import accounts from argilla.server.database import get_db @@ -32,7 +33,7 @@ @router.get("/me", response_model=User, response_model_exclude_none=True, operation_id="whoami") -async def whoami(request: Request, current_user: User = Security(auth.get_user, scopes=[])): +async def whoami(request: Request, current_user: models.User = Security(auth.get_current_user)): """ User info endpoint @@ -51,11 +52,11 @@ async def whoami(request: Request, current_user: User = Security(auth.get_user, await telemetry.track_login(request, username=current_user.username) - return current_user + return User.from_orm(current_user) @router.get("/users", response_model=List[User], response_model_exclude_none=True) -def list_users(*, db: Session = Depends(get_db), current_user: User = Security(auth.get_user, scopes=[])): +def list_users(*, db: Session = Depends(get_db), current_user: User = Security(auth.get_current_user)): authorize(current_user, UserPolicy.list) users = accounts.list_users(db) @@ -65,7 +66,7 @@ def list_users(*, db: Session = Depends(get_db), current_user: User = Security(a @router.post("/users", response_model=User, response_model_exclude_none=True) def create_user( - *, db: Session = Depends(get_db), user_create: UserCreate, current_user: User = Security(auth.get_user, scopes=[]) + *, db: Session = Depends(get_db), user_create: UserCreate, current_user: User = Security(auth.get_current_user) ): authorize(current_user, UserPolicy.create) @@ -75,9 +76,7 @@ def create_user( @router.delete("/users/{user_id}", response_model=User, response_model_exclude_none=True) -def delete_user( - *, db: Session = Depends(get_db), user_id: UUID, current_user: User = Security(auth.get_user, scopes=[]) -): +def delete_user(*, db: Session = Depends(get_db), user_id: UUID, current_user: User = Security(auth.get_current_user)): user = accounts.get_user_by_id(db, user_id) if not user: # TODO: Forcing here user_id to be an string. diff --git a/src/argilla/server/apis/v0/handlers/workspaces.py b/src/argilla/server/apis/v0/handlers/workspaces.py index 616a452b12..54bd3f7921 100644 --- a/src/argilla/server/apis/v0/handlers/workspaces.py +++ b/src/argilla/server/apis/v0/handlers/workspaces.py @@ -35,7 +35,7 @@ @router.get("/workspaces", response_model=List[Workspace], response_model_exclude_none=True) -def list_workspaces(*, db: Session = Depends(get_db), current_user: User = Security(auth.get_user, scopes=[])): +def list_workspaces(*, db: Session = Depends(get_db), current_user: User = Security(auth.get_current_user)): authorize(current_user, WorkspacePolicy.list) workspaces = accounts.list_workspaces(db) @@ -48,7 +48,7 @@ def create_workspace( *, db: Session = Depends(get_db), workspace_create: WorkspaceCreate, - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ): authorize(current_user, WorkspacePolicy.create) @@ -62,7 +62,7 @@ def create_workspace( # any dataset then we can delete them. # @router.delete("/workspaces/{workspace_id}", response_model=Workspace, response_model_exclude_none=True) def delete_workspace( - *, db: Session = Depends(get_db), workspace_id: UUID, current_user: User = Security(auth.get_user, scopes=[]) + *, db: Session = Depends(get_db), workspace_id: UUID, current_user: User = Security(auth.get_current_user) ): workspace = accounts.get_workspace_by_id(db, workspace_id) if not workspace: @@ -77,7 +77,7 @@ def delete_workspace( @router.get("/workspaces/{workspace_id}/users", response_model=List[User], response_model_exclude_none=True) def list_workspace_users( - *, db: Session = Depends(get_db), workspace_id: UUID, current_user: User = Security(auth.get_user, scopes=[]) + *, db: Session = Depends(get_db), workspace_id: UUID, current_user: User = Security(auth.get_current_user) ): authorize(current_user, WorkspaceUserPolicy.list) @@ -94,7 +94,7 @@ def create_workspace_user( db: Session = Depends(get_db), workspace_id: UUID, user_id: UUID, - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ): authorize(current_user, WorkspaceUserPolicy.create) @@ -117,7 +117,7 @@ def delete_workspace_user( db: Session = Depends(get_db), workspace_id: UUID, user_id: UUID, - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ): workspace_user = accounts.get_workspace_user_by_workspace_id_and_user_id(db, workspace_id, user_id) if not workspace_user: diff --git a/src/argilla/server/daos/datasets.py b/src/argilla/server/daos/datasets.py index 204abc8ad4..1926fdbb67 100644 --- a/src/argilla/server/daos/datasets.py +++ b/src/argilla/server/daos/datasets.py @@ -79,7 +79,9 @@ def list_datasets( task2dataset_map: Dict[str, Type[DatasetDB]] = None, name: Optional[str] = None, ) -> List[DatasetDB]: - workspaces = workspaces or [] + if not workspaces: + return [] + query = BaseDatasetsQuery( workspaces=workspaces, tasks=[task for task in task2dataset_map] if task2dataset_map else None, @@ -181,14 +183,14 @@ def copy(self, source: DatasetDB, target: DatasetDB): ) self._es.copy(id_from=source.id, id_to=target.id) - def close(self, dataset: DatasetDB): - """Close a dataset. It's mean that release all related resources, like elasticsearch index""" - self._es.close(dataset.id) - def open(self, dataset: DatasetDB): """Make available a dataset""" self._es.open(dataset.id) + def close(self, dataset: DatasetDB): + """Close a dataset. It's mean that release all related resources, like elasticsearch index""" + self._es.close(dataset.id) + def save_settings( self, dataset: DatasetDB, diff --git a/src/argilla/server/policies.py b/src/argilla/server/policies.py index 83dac986b4..625b39dbe6 100644 --- a/src/argilla/server/policies.py +++ b/src/argilla/server/policies.py @@ -66,9 +66,41 @@ def delete(cls, user: User) -> PolicyAction: class DatasetPolicy: + @classmethod + def list(cls, user: User) -> bool: + return True + + @classmethod + def get(cls, dataset: Dataset) -> PolicyAction: + return lambda actor: actor.is_admin or dataset.workspace in [ws.name for ws in actor.workspaces] + + @classmethod + def create(cls, user: User) -> bool: + return user.is_admin + + @classmethod + def update(cls, dataset: Dataset) -> PolicyAction: + is_get_allowed = cls.get(dataset) + return lambda actor: actor.is_admin or (is_get_allowed(actor) and actor.username == dataset.created_by) + @classmethod def delete(cls, dataset: Dataset) -> PolicyAction: - return lambda actor: actor.is_admin or actor.username == dataset.created_by + is_get_allowed = cls.get(dataset) + return lambda actor: actor.is_admin or (is_get_allowed(actor) and actor.username == dataset.created_by) + + @classmethod + def open(cls, dataset: Dataset) -> PolicyAction: + is_get_allowed = cls.get(dataset) + return lambda actor: actor.is_admin or (is_get_allowed(actor) and actor.username == dataset.created_by) + + @classmethod + def close(cls, dataset: Dataset) -> PolicyAction: + return lambda actor: actor.is_admin + + @classmethod + def copy(cls, dataset: Dataset) -> PolicyAction: + is_get_allowed = cls.get(dataset) + return lambda actor: actor.is_admin or is_get_allowed(actor) and cls.create(actor) def authorize(actor: User, policy_action: PolicyAction) -> None: diff --git a/src/argilla/server/security/auth_provider/local/provider.py b/src/argilla/server/security/auth_provider/local/provider.py index 84ef0669a4..e00161cabe 100644 --- a/src/argilla/server/security/auth_provider/local/provider.py +++ b/src/argilla/server/security/auth_provider/local/provider.py @@ -28,6 +28,7 @@ from argilla.server.contexts import accounts from argilla.server.database import get_db from argilla.server.errors import UnauthorizedError +from argilla.server.models import User from argilla.server.security.auth_provider.base import ( AuthProvider, api_key_header, @@ -143,7 +144,6 @@ def fetch_token_user(self, db: Session, token: str) -> Optional[User]: ) username: str = payload.get("sub") if username: - # return self.users.get_user(username=username) return accounts.get_user_by_username(db, username) except JWTError: return None @@ -169,33 +169,6 @@ def get_current_user( return user - async def get_user( - self, - security_scopes: SecurityScopes, - db: Session = Depends(get_db), - api_key: Optional[str] = Depends(api_key_header), - old_api_key: Optional[str] = Depends(old_api_key_header), - token: Optional[str] = Depends(_oauth2_scheme), - ) -> User: - """ - Fetches the user for a given token - - Parameters - ---------- - api_key: - The apikey header info if provided - old_api_key: - Same as api key but for old clients - token: - The login token. - fastapi injects this param from request - Returns - ------- - - """ - user = self.get_current_user(security_scopes, db, api_key, old_api_key, token) - return User.from_orm(user) - def create_local_auth_provider(): settings = Settings() diff --git a/src/argilla/server/services/datasets.py b/src/argilla/server/services/datasets.py index da48e21d0d..e863fd45ce 100644 --- a/src/argilla/server/services/datasets.py +++ b/src/argilla/server/services/datasets.py @@ -17,7 +17,10 @@ from typing import Any, Dict, List, Optional, Type, TypeVar, cast from fastapi import Depends +from sqlalchemy.orm import Session +from argilla.server import database +from argilla.server.contexts import accounts from argilla.server.daos.datasets import BaseDatasetSettingsDB, DatasetsDAO from argilla.server.daos.models.datasets import BaseDatasetDB from argilla.server.errors import ( @@ -26,8 +29,9 @@ ForbiddenOperationError, WrongTaskError, ) +from argilla.server.models import User, Workspace +from argilla.server.policies import DatasetPolicy, is_authorized from argilla.server.schemas.datasets import CreateDatasetRequest, Dataset -from argilla.server.security.model import User class ServiceBaseDataset(BaseDatasetDB): @@ -46,16 +50,24 @@ class DatasetsService: _INSTANCE: "DatasetsService" = None @classmethod - def get_instance(cls, dao: DatasetsDAO = Depends(DatasetsDAO.get_instance)) -> "DatasetsService": - if not cls._INSTANCE: - cls._INSTANCE = cls(dao) - return cls._INSTANCE + def get_instance( + cls, db: Session = Depends(database.get_db), dao: DatasetsDAO = Depends(DatasetsDAO.get_instance) + ) -> "DatasetsService": + return cls(db, dao) - def __init__(self, dao: DatasetsDAO): + def __init__(self, db: Session, dao: DatasetsDAO): + self._db = db self.__dao__ = dao def create_dataset(self, user: User, dataset: CreateDatasetRequest) -> BaseDatasetDB: - dataset.workspace = user.check_workspace(dataset.workspace) + if not accounts.get_workspace_by_name(self._db, workspace_name=dataset.workspace): + raise EntityNotFoundError(name=dataset.workspace, type=Workspace) + + if not is_authorized(user, DatasetPolicy.create): + raise ForbiddenOperationError( + "You don't have the necessary permissions to create datasets. Only administrators can create datasets" + ) + try: self.find_by_name(user=user, name=dataset.name, task=dataset.task, workspace=dataset.workspace) raise EntityAlreadyExistsError(name=dataset.name, type=ServiceDataset, workspace=dataset.workspace) @@ -80,25 +92,26 @@ def find_by_name( as_dataset_class: Type[ServiceDataset] = ServiceBaseDataset, task: Optional[str] = None, ) -> ServiceDataset: - workspace = user.check_workspace(workspace) - found_ds = self.__dao__.find_by_name(name=name, workspace=workspace, as_dataset_class=as_dataset_class) - if found_ds is None: + found_dataset = self.__dao__.find_by_name(name=name, workspace=workspace, as_dataset_class=as_dataset_class) + + if found_dataset is None: raise EntityNotFoundError(name=name, type=ServiceDataset) - elif task and found_ds.task != task: + + elif task and found_dataset.task != task: raise WrongTaskError(detail=f"Provided task {task} cannot be applied to dataset") - else: - return cast(ServiceDataset, found_ds) - def delete(self, user: User, dataset: ServiceDataset): - dataset = self.find_by_name(user=user, name=dataset.name, workspace=dataset.workspace, task=dataset.task) + if not is_authorized(user, DatasetPolicy.get(found_dataset)): + raise ForbiddenOperationError("You don't have the necessary permissions to get this dataset.") - if user.is_superuser() or user.username == dataset.created_by: - self.__dao__.delete_dataset(dataset) - else: + return cast(ServiceDataset, found_dataset) + + def delete(self, user: User, dataset: ServiceDataset): + if not is_authorized(user, DatasetPolicy.delete(dataset)): raise ForbiddenOperationError( "You don't have the necessary permissions to delete this dataset. " "Only dataset creators or administrators can delete datasets" ) + self.__dao__.delete_dataset(dataset) def update( self, @@ -109,9 +122,16 @@ def update( ) -> Dataset: found = self.find_by_name(user=user, name=dataset.name, task=dataset.task, workspace=dataset.workspace) + if not is_authorized(user, DatasetPolicy.update(found)): + raise ForbiddenOperationError( + "You don't have the necessary permissions to update this dataset. " + "Only dataset creators or administrators can update datasets" + ) + dataset.tags = {**found.tags, **(tags or {})} dataset.metadata = {**found.metadata, **(metadata or {})} updated = found.copy(update={**dataset.dict(by_alias=True), "last_updated": datetime.utcnow()}) + return self.__dao__.update_dataset(updated) def list( @@ -120,16 +140,38 @@ def list( workspaces: Optional[List[str]], task2dataset_map: Dict[str, Type[ServiceDataset]] = None, ) -> List[ServiceDataset]: - workspaces = user.check_workspaces(workspaces) - return self.__dao__.list_datasets(workspaces=workspaces, task2dataset_map=task2dataset_map) + if not is_authorized(user, DatasetPolicy.list): + raise ForbiddenOperationError("You don't have the necessary permissions to list datasets.") + + accessible_workspace_names = [ + ws.name for ws in (accounts.list_workspaces(self._db) if user.is_admin else user.workspaces) + ] + + if workspaces: + for ws in workspaces: + if ws not in accessible_workspace_names: + raise EntityNotFoundError(name=ws, type=Workspace) + workspace_names = workspaces + else: # no workspaces + workspace_names = accessible_workspace_names + + return self.__dao__.list_datasets(workspaces=workspace_names, task2dataset_map=task2dataset_map) def close(self, user: User, dataset: ServiceDataset): - found = self.find_by_name(user=user, name=dataset.name, workspace=dataset.workspace) - self.__dao__.close(found) + if not is_authorized(user, DatasetPolicy.close(dataset)): + raise ForbiddenOperationError( + "You don't have the necessary permissions to close this dataset. " + "Only dataset creators or administrators can close datasets" + ) + self.__dao__.close(dataset) def open(self, user: User, dataset: ServiceDataset): - found = self.find_by_name(user=user, name=dataset.name, workspace=dataset.workspace) - self.__dao__.open(found) + if not is_authorized(user, DatasetPolicy.open(dataset)): + raise ForbiddenOperationError( + "You don't have the necessary permissions to open this dataset. " + "Only dataset creators or administrators can open datasets" + ) + self.__dao__.open(dataset) def copy_dataset( self, @@ -140,48 +182,40 @@ def copy_dataset( copy_tags: Dict[str, Any] = None, copy_metadata: Dict[str, Any] = None, ) -> ServiceDataset: - dataset_workspace = copy_workspace or dataset.workspace - dataset_workspace = user.check_workspace(dataset_workspace) + target_workspace_name = copy_workspace or dataset.workspace - self._validate_create_dataset( - name=copy_name, - workspace=dataset_workspace, - user=user, - ) + target_workspace = accounts.get_workspace_by_name(self._db, target_workspace_name) + if not target_workspace: + raise EntityNotFoundError(name=target_workspace_name, type=Workspace) - copy_dataset = dataset.copy() - copy_dataset.name = copy_name - copy_dataset.workspace = dataset_workspace + if self.__dao__.find_by_name_and_workspace(name=copy_name, workspace=target_workspace_name): + raise EntityAlreadyExistsError(name=copy_name, workspace=target_workspace_name, type=Dataset) + + if not is_authorized(user, DatasetPolicy.copy(dataset)): + raise ForbiddenOperationError( + "You don't have the necessary permissions to copy this dataset. " + "Only dataset creators or administrators can copy datasets" + ) + + dataset_copy = dataset.copy() + dataset_copy.name = copy_name + dataset_copy.workspace = target_workspace_name date_now = datetime.utcnow() - copy_dataset.created_at = date_now - copy_dataset.last_updated = date_now - copy_dataset.tags = {**copy_dataset.tags, **(copy_tags or {})} - copy_dataset.metadata = { - **copy_dataset.metadata, + dataset_copy.created_at = date_now + dataset_copy.last_updated = date_now + dataset_copy.tags = {**dataset_copy.tags, **(copy_tags or {})} + dataset_copy.metadata = { + **dataset_copy.metadata, **(copy_metadata or {}), "source_workspace": dataset.workspace, "copied_from": dataset.name, } - self.__dao__.copy( - source=dataset, - target=copy_dataset, - ) + self.__dao__.copy(source=dataset, target=dataset_copy) - return copy_dataset - - def _validate_create_dataset(self, name: str, workspace: str, user: User): - try: - found = self.find_by_name(user=user, name=name, workspace=workspace) - raise EntityAlreadyExistsError( - name=found.name, - type=found.__class__, - workspace=workspace, - ) - except (EntityNotFoundError, ForbiddenOperationError): - pass + return dataset_copy async def get_settings( self, diff --git a/src/argilla/server/services/storage/service.py b/src/argilla/server/services/storage/service.py index 9fecae36d6..07c0dbaf97 100644 --- a/src/argilla/server/services/storage/service.py +++ b/src/argilla/server/services/storage/service.py @@ -23,7 +23,7 @@ from argilla.server.daos.backend.base import WrongLogDataError from argilla.server.daos.records import DatasetRecordsDAO from argilla.server.errors import BulkDataError, ForbiddenOperationError -from argilla.server.security.model import User +from argilla.server.models import User from argilla.server.services.datasets import ServiceDataset from argilla.server.services.search.model import ServiceBaseRecordsQuery from argilla.server.services.tasks.commons import ServiceRecord @@ -92,7 +92,7 @@ async def delete_records( status=TaskStatus.discarded, ) else: - if not user.is_superuser() and user.username != dataset.created_by: + if not user.is_admin and user.username != dataset.created_by: raise ForbiddenOperationError( "You don't have the necessary permissions to delete records on this dataset. " "Only dataset creators or administrators can delete datasets" diff --git a/tests/server/datasets/test_dao.py b/tests/server/datasets/test_dao.py index 6f2a734009..4565f17dd8 100644 --- a/tests/server/datasets/test_dao.py +++ b/tests/server/datasets/test_dao.py @@ -37,15 +37,16 @@ def test_retrieve_ownered_dataset_for_no_owner_user(): def test_list_datasets_by_task(): dataset = "test_list_datasets_by_task" + workspace_name = "other" - all_datasets = dao.list_datasets() + all_datasets = dao.list_datasets(workspaces=[workspace_name]) for ds in all_datasets: dao.delete_dataset(ds) created_text = dao.create_dataset( BaseDatasetDB( name=dataset + "_text", - workspace="other", + workspace=workspace_name, task=TaskType.text_classification, ), ) @@ -53,19 +54,20 @@ def test_list_datasets_by_task(): created_token = dao.create_dataset( BaseDatasetDB( name=dataset + "_token", - workspace="other", + workspace=workspace_name, task=TaskType.token_classification, ), ) - datasets = dao.list_datasets( - task2dataset_map={created_text.task: BaseDatasetDB}, - ) + assert len(dao.list_datasets()) == 0 + assert len(dao.list_datasets(workspaces=[workspace_name])) == 2 + + datasets = dao.list_datasets(workspaces=[workspace_name], task2dataset_map={created_text.task: BaseDatasetDB}) assert len(datasets) == 1 assert datasets[0].name == created_text.name - datasets = dao.list_datasets(task2dataset_map={created_token.task: BaseDatasetDB}) + datasets = dao.list_datasets(workspaces=[workspace_name], task2dataset_map={created_token.task: BaseDatasetDB}) assert len(datasets) == 1 assert datasets[0].name == created_token.name diff --git a/tests/server/security/test_provider.py b/tests/server/security/test_provider.py index 69b944861a..8832520b8a 100644 --- a/tests/server/security/test_provider.py +++ b/tests/server/security/test_provider.py @@ -28,19 +28,15 @@ @pytest.mark.asyncio async def test_get_user_via_token(db: Session, argilla_user): access_token = localAuth._create_access_token(username="argilla") - user = await localAuth.get_user( - security_scopes=security_Scopes, - db=db, - token=access_token, - api_key=None, - old_api_key=None, + user = localAuth.get_current_user( + security_scopes=security_Scopes, db=db, token=access_token, api_key=None, old_api_key=None ) assert user.username == "argilla" @pytest.mark.asyncio async def test_get_user_via_api_key(db: Session, argilla_user): - user = await localAuth.get_user(security_scopes=security_Scopes, db=db, api_key=DEFAULT_API_KEY, token=None) + user = localAuth.get_current_user(security_scopes=security_Scopes, db=db, api_key=DEFAULT_API_KEY, token=None) assert user.username == "argilla"