From a0ae0f9236e84823fe617d0d801daef157cefa6d Mon Sep 17 00:00:00 2001 From: frascuchon Date: Thu, 9 Mar 2023 09:25:04 +0100 Subject: [PATCH 01/13] wip --- .../server/apis/v0/handlers/datasets.py | 39 ++++++++++++------- src/argilla/server/policies.py | 4 ++ 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/src/argilla/server/apis/v0/handlers/datasets.py b/src/argilla/server/apis/v0/handlers/datasets.py index 33a2ca1331..e323fb13f1 100644 --- a/src/argilla/server/apis/v0/handlers/datasets.py +++ b/src/argilla/server/apis/v0/handlers/datasets.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from datetime import datetime from typing import List from fastapi import APIRouter, Body, Depends, Security @@ -110,20 +110,31 @@ def get_dataset( def update_dataset( name: str, request: UpdateDatasetRequest, - ds_params: CommonTaskHandlerDependencies = Depends(), - service: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + workspace_request_params: CommonTaskHandlerDependencies = Depends(), + db: Session = Depends(database.get_db), + datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), + user: User = Security(auth.get_current_user, scopes=[]), ) -> Dataset: - found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) + workspace_name = workspace_request_params.workspace - dataset = service.update( - user=current_user, - dataset=found_ds, - tags=request.tags, - metadata=request.metadata, - ) + workspace = accounts.get_workspace_by_name(db, workspace_name=workspace_name) + if not workspace: + raise EntityNotFoundError(name=workspace_name, type=Workspace) - return Dataset.from_orm(dataset) + dataset = datasets.find_by_name_and_workspace(name=name, workspace=workspace.name) + if not dataset: + raise EntityNotFoundError(name=dataset, type=Dataset) + + if not is_authorized(user, DatasetPolicy.update(dataset)): + raise ForbiddenOperationError( + "You don't have the necessary permissions to update this dataset. " + "Only dataset creators or administrators can delete datasets" + ) + + dataset_update = dataset.copy(update={**request.dict(), "last_updated": datetime.utcnow()}) + updated_dataset = datasets.update_dataset(dataset_update) + + return Dataset.from_orm(updated_dataset) @router.delete( @@ -132,12 +143,12 @@ def update_dataset( ) def delete_dataset( name: str, - request_params: CommonTaskHandlerDependencies = Depends(), + workspace_request_params: CommonTaskHandlerDependencies = Depends(), db: Session = Depends(database.get_db), datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), user: User = Security(auth.get_current_user, scopes=[]), ): - workspace_name = request_params.workspace + workspace_name = workspace_request_params.workspace workspace = accounts.get_workspace_by_name(db, workspace_name=workspace_name) if not workspace: diff --git a/src/argilla/server/policies.py b/src/argilla/server/policies.py index 83dac986b4..9913743026 100644 --- a/src/argilla/server/policies.py +++ b/src/argilla/server/policies.py @@ -70,6 +70,10 @@ class DatasetPolicy: def delete(cls, dataset: Dataset) -> PolicyAction: return lambda actor: actor.is_admin or actor.username == dataset.created_by + @classmethod + def update(cls, dataset: Dataset) -> PolicyAction: + return lambda actor: actor.is_admin or actor.username == dataset.created_by + def authorize(actor: User, policy_action: PolicyAction) -> None: if not is_authorized(actor, policy_action): From c074d3f0725f5448b0c2aa9f707b5b0de1031555 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Thu, 9 Mar 2023 17:50:28 +0100 Subject: [PATCH 02/13] Refactor the rest of dataset endpoints --- .../server/apis/v0/handlers/datasets.py | 243 ++++++++++-------- src/argilla/server/daos/datasets.py | 8 +- src/argilla/server/policies.py | 26 +- 3 files changed, 165 insertions(+), 112 deletions(-) diff --git a/src/argilla/server/apis/v0/handlers/datasets.py b/src/argilla/server/apis/v0/handlers/datasets.py index e323fb13f1..2fd3d1a3ac 100644 --- a/src/argilla/server/apis/v0/handlers/datasets.py +++ b/src/argilla/server/apis/v0/handlers/datasets.py @@ -24,7 +24,11 @@ from argilla.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies from argilla.server.contexts import accounts from argilla.server.daos.datasets import DatasetsDAO -from argilla.server.errors import EntityNotFoundError, ForbiddenOperationError +from argilla.server.errors import ( + EntityAlreadyExistsError, + EntityNotFoundError, + ForbiddenOperationError, +) from argilla.server.policies import DatasetPolicy, is_authorized from argilla.server.schemas.datasets import ( CopyDatasetRequest, @@ -34,7 +38,6 @@ ) from argilla.server.security import auth from argilla.server.security.model import User, Workspace -from argilla.server.services.datasets import DatasetsService router = APIRouter(tags=["datasets"], prefix="/datasets") @@ -48,16 +51,33 @@ operation_id="list_datasets", ) async def list_datasets( - request_deps: CommonTaskHandlerDependencies = Depends(), - service: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + workspace_params: CommonTaskHandlerDependencies = Depends(), + datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), + current_user: User = Security(auth.get_current_user, scopes=[]), ) -> List[Dataset]: - datasets = service.list( - user=current_user, - workspaces=[request_deps.workspace] if request_deps.workspace is not None else None, - ) + if not is_authorized(current_user, DatasetPolicy.list): + raise ForbiddenOperationError("You don't have the necessary permissions to list datasets.") + + workspaces = [workspace_params.workspace] if workspace_params.workspace is not None else None - return parse_obj_as(List[Dataset], datasets) + return parse_obj_as(List[Dataset], datasets.list_datasets(workspaces=workspaces)) + + +@router.get("/{dataset_name}", response_model=Dataset, response_model_exclude_none=True, operation_id="get_dataset") +def get_dataset( + dataset_name: str, + workspace_params: CommonTaskHandlerDependencies = Depends(), + datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), + current_user: User = Security(auth.get_current_user, scopes=[]), +) -> Dataset: + dataset = datasets.find_by_name_and_workspace(dataset_name, workspace_params.workspace) + if not dataset: + raise EntityNotFoundError(name=dataset, type=Dataset) + + if not is_authorized(current_user, DatasetPolicy.get): + raise ForbiddenOperationError("You don't have the necessary permissions to get datasets.") + + return Dataset.from_orm(dataset) @router.post( @@ -69,63 +89,42 @@ async def list_datasets( description="Create a new dataset", ) async def create_dataset( + db: Session = Depends(database.get_db), request: CreateDatasetRequest = Body(..., description="The request dataset info"), - ws_params: CommonTaskHandlerDependencies = Depends(), - datasets: DatasetsService = Depends(DatasetsService.get_instance), - user: User = Security(auth.get_user, scopes=["create:datasets"]), + workspace_params: CommonTaskHandlerDependencies = Depends(), + datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), + current_user: User = Security(auth.get_current_user, scopes=[]), ) -> Dataset: - request.workspace = request.workspace or ws_params.workspace - dataset = datasets.create_dataset(user=user, dataset=request) - - return Dataset.from_orm(dataset) + request.workspace = request.workspace or workspace_params.workspace + if not accounts.get_workspace_by_name(db, workspace_name=request.workspace): + raise EntityNotFoundError(name=request.workspace, type=Workspace) -@router.get( - "/{name}", - response_model=Dataset, - response_model_exclude_none=True, - operation_id="get_dataset", -) -def get_dataset( - name: str, - ds_params: CommonTaskHandlerDependencies = Depends(), - service: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), -) -> Dataset: - return Dataset.from_orm( - service.find_by_name( - user=current_user, - name=name, - workspace=ds_params.workspace, + if not is_authorized(current_user, DatasetPolicy.create): + raise ForbiddenOperationError( + "You don't have the necessary permissions to create datasets. Only administrators can create datasets" ) - ) + + dataset = datasets.create_dataset(user=current_user, dataset=request) + + return Dataset.from_orm(dataset) @router.patch( - "/{name}", - operation_id="update_dataset", - response_model=Dataset, - response_model_exclude_none=True, + "/{dataset_name}", operation_id="update_dataset", response_model=Dataset, response_model_exclude_none=True ) def update_dataset( - name: str, + dataset_name: str, request: UpdateDatasetRequest, - workspace_request_params: CommonTaskHandlerDependencies = Depends(), - db: Session = Depends(database.get_db), + workspace_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), - user: User = Security(auth.get_current_user, scopes=[]), + current_user: User = Security(auth.get_current_user, scopes=[]), ) -> Dataset: - workspace_name = workspace_request_params.workspace - - workspace = accounts.get_workspace_by_name(db, workspace_name=workspace_name) - if not workspace: - raise EntityNotFoundError(name=workspace_name, type=Workspace) - - dataset = datasets.find_by_name_and_workspace(name=name, workspace=workspace.name) + dataset = datasets.find_by_name_and_workspace(name=dataset_name, workspace=workspace_params.workspace) if not dataset: raise EntityNotFoundError(name=dataset, type=Dataset) - if not is_authorized(user, DatasetPolicy.update(dataset)): + if not is_authorized(current_user, DatasetPolicy.update(dataset)): raise ForbiddenOperationError( "You don't have the necessary permissions to update this dataset. " "Only dataset creators or administrators can delete datasets" @@ -137,28 +136,20 @@ def update_dataset( return Dataset.from_orm(updated_dataset) -@router.delete( - "/{name}", - operation_id="delete_dataset", -) +@router.delete("/{dataset_name}", operation_id="delete_dataset") def delete_dataset( - name: str, - workspace_request_params: CommonTaskHandlerDependencies = Depends(), - db: Session = Depends(database.get_db), + dataset_name: str, + workspace_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), - user: User = Security(auth.get_current_user, scopes=[]), + current_user: User = Security(auth.get_current_user, scopes=[]), ): - workspace_name = workspace_request_params.workspace - - workspace = accounts.get_workspace_by_name(db, workspace_name=workspace_name) - if not workspace: - raise EntityNotFoundError(name=workspace_name, type=Workspace) - - dataset = datasets.find_by_name_and_workspace(name=name, workspace=workspace.name) + dataset = datasets.find_by_name_and_workspace(name=dataset_name, workspace=workspace_params.workspace) if not dataset: + # We are not raising an EntityNotFoundError because this endpoint + # was not doing it originally so we want to continue doing the same. return - if not is_authorized(user, DatasetPolicy.delete(dataset)): + if not is_authorized(current_user, DatasetPolicy.delete(dataset)): raise ForbiddenOperationError( "You don't have the necessary permissions to delete this dataset. " "Only dataset creators or administrators can delete datasets" @@ -167,55 +158,93 @@ def delete_dataset( datasets.delete_dataset(dataset) -@router.put( - "/{name}:close", - operation_id="close_dataset", -) -def close_dataset( - name: str, - ds_params: CommonTaskHandlerDependencies = Depends(), - service: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), +@router.put("/{dataset_name}:open", operation_id="open_dataset") +def open_dataset( + dataset_name: str, + workspace_params: CommonTaskHandlerDependencies = Depends(), + datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), + current_user: User = Security(auth.get_current_user, scopes=[]), ): - found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) - service.close(user=current_user, dataset=found_ds) + dataset = datasets.find_by_name_and_workspace(dataset_name, workspace_params.workspace) + if not dataset: + raise EntityNotFoundError(name=dataset_name, type=Dataset) + if not is_authorized(current_user, DatasetPolicy.open(dataset)): + raise ForbiddenOperationError( + "You don't have the necessary permissions to open this dataset. " + "Only dataset creators or administrators can open datasets" + ) -@router.put( - "/{name}:open", - operation_id="open_dataset", -) -def open_dataset( - name: str, - ds_params: CommonTaskHandlerDependencies = Depends(), - service: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + datasets.open(dataset) + + +@router.put("/{dataset_name}:close", operation_id="close_dataset") +def close_dataset( + dataset_name: str, + workspace_params: CommonTaskHandlerDependencies = Depends(), + datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), + current_user: User = Security(auth.get_current_user, scopes=[]), ): - found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) - service.open(user=current_user, dataset=found_ds) + dataset = datasets.find_by_name_and_workspace(dataset_name, workspace_params.workspace) + if not dataset: + raise EntityNotFoundError(name=dataset_name, type=Dataset) + + if not is_authorized(current_user, DatasetPolicy.close(dataset)): + raise ForbiddenOperationError( + "You don't have the necessary permissions to close this dataset. " + "Only dataset creators or administrators can close datasets" + ) + + datasets.close(dataset) @router.put( - "/{name}:copy", - operation_id="copy_dataset", - response_model=Dataset, - response_model_exclude_none=True, + "/{dataset_name}:copy", operation_id="copy_dataset", response_model=Dataset, response_model_exclude_none=True ) def copy_dataset( - name: str, + *, + db: Session = Depends(database.get_db), + dataset_name: str, + workspace_params: CommonTaskHandlerDependencies = Depends(), copy_request: CopyDatasetRequest, - ds_params: CommonTaskHandlerDependencies = Depends(), - service: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), + current_user: User = Security(auth.get_current_user, scopes=[]), ) -> Dataset: - found = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) - dataset = service.copy_dataset( - user=current_user, - dataset=found, - copy_name=copy_request.name, - copy_workspace=copy_request.target_workspace, - copy_tags=copy_request.tags, - copy_metadata=copy_request.metadata, - ) + source_dataset_name = dataset_name + source_workspace_name = workspace_params.workspace - return Dataset.from_orm(dataset) + target_dataset_name = copy_request.name + target_workspace_name = copy_request.target_workspace or source_workspace_name + + source_dataset = datasets.find_by_name_and_workspace(source_dataset_name, source_workspace_name) + if not source_dataset: + raise EntityNotFoundError(name=source_dataset_name, type=Dataset) + + target_workspace = accounts.get_workspace_by_name(db, target_workspace_name) + if not target_workspace: + raise EntityNotFoundError(name=target_workspace_name, type=Workspace) + + if datasets.find_by_name_and_workspace(target_dataset_name, target_workspace_name): + raise EntityAlreadyExistsError(name=target_dataset_name, workspace=target_workspace_name, type=Dataset) + + if not is_authorized(current_user, DatasetPolicy.copy(source_dataset)): + raise ForbiddenOperationError( + "You don't have the necessary permissions to copy this dataset. " + "Only dataset creators or administrators can copy datasets" + ) + + target_dataset = source_dataset.copy() + target_dataset.name = target_dataset_name + target_dataset.workspace = target_workspace_name + target_dataset.created_at = target_dataset.last_updated = datetime.utcnow() + target_dataset.tags = {**target_dataset.tags, **(copy_request.tags or {})} + target_dataset.metadata = { + **target_dataset.metadata, + **(copy_request.metadata or {}), + "source_workspace": source_workspace_name, + "copied_from": source_dataset_name, + } + + datasets.copy(source_dataset, target_dataset) + + return Dataset.from_orm(target_dataset) diff --git a/src/argilla/server/daos/datasets.py b/src/argilla/server/daos/datasets.py index 204abc8ad4..33a2f4b494 100644 --- a/src/argilla/server/daos/datasets.py +++ b/src/argilla/server/daos/datasets.py @@ -181,14 +181,14 @@ def copy(self, source: DatasetDB, target: DatasetDB): ) self._es.copy(id_from=source.id, id_to=target.id) - def close(self, dataset: DatasetDB): - """Close a dataset. It's mean that release all related resources, like elasticsearch index""" - self._es.close(dataset.id) - def open(self, dataset: DatasetDB): """Make available a dataset""" self._es.open(dataset.id) + def close(self, dataset: DatasetDB): + """Close a dataset. It's mean that release all related resources, like elasticsearch index""" + self._es.close(dataset.id) + def save_settings( self, dataset: DatasetDB, diff --git a/src/argilla/server/policies.py b/src/argilla/server/policies.py index 9913743026..85889cff17 100644 --- a/src/argilla/server/policies.py +++ b/src/argilla/server/policies.py @@ -66,14 +66,38 @@ def delete(cls, user: User) -> PolicyAction: class DatasetPolicy: + @classmethod + def list(cls, user: User) -> bool: + return True + + @classmethod + def get(cls, user: User) -> bool: + return True + + @classmethod + def create(cls, user: User) -> bool: + return user.is_admin + + @classmethod + def update(cls, dataset: Dataset) -> PolicyAction: + return lambda actor: actor.is_admin or actor.username == dataset.created_by + @classmethod def delete(cls, dataset: Dataset) -> PolicyAction: return lambda actor: actor.is_admin or actor.username == dataset.created_by @classmethod - def update(cls, dataset: Dataset) -> PolicyAction: + def open(cls, dataset: Dataset) -> PolicyAction: return lambda actor: actor.is_admin or actor.username == dataset.created_by + @classmethod + def close(cls, dataset: Dataset) -> PolicyAction: + return lambda actor: actor.is_admin or actor.username == dataset.created_by + + @classmethod + def copy(cls, dataset: Dataset) -> PolicyAction: + return lambda actor: cls.get(actor) and cls.create(actor) + def authorize(actor: User, policy_action: PolicyAction) -> None: if not is_authorized(actor, policy_action): From d58f71e6cf10def7adc0f50b9dcd76dc3f819228 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Fri, 10 Mar 2023 11:18:24 +0100 Subject: [PATCH 03/13] Add some more refactors --- .../server/apis/v0/handlers/datasets.py | 36 +++++++++++-------- src/argilla/server/daos/datasets.py | 4 ++- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/argilla/server/apis/v0/handlers/datasets.py b/src/argilla/server/apis/v0/handlers/datasets.py index 2fd3d1a3ac..cd43ee4f04 100644 --- a/src/argilla/server/apis/v0/handlers/datasets.py +++ b/src/argilla/server/apis/v0/handlers/datasets.py @@ -51,16 +51,22 @@ operation_id="list_datasets", ) async def list_datasets( - workspace_params: CommonTaskHandlerDependencies = Depends(), + db: Session = Depends(database.get_db), datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), current_user: User = Security(auth.get_current_user, scopes=[]), ) -> List[Dataset]: if not is_authorized(current_user, DatasetPolicy.list): raise ForbiddenOperationError("You don't have the necessary permissions to list datasets.") - workspaces = [workspace_params.workspace] if workspace_params.workspace is not None else None + workspaces = [] + if current_user.is_admin: + workspaces = accounts.list_workspaces(db) + else: + workspaces = current_user.workspaces + + workspace_names = [workspace.name for workspace in workspaces] - return parse_obj_as(List[Dataset], datasets.list_datasets(workspaces=workspaces)) + return parse_obj_as(List[Dataset], datasets.list_datasets(workspaces=workspace_names)) @router.get("/{dataset_name}", response_model=Dataset, response_model_exclude_none=True, operation_id="get_dataset") @@ -90,22 +96,22 @@ def get_dataset( ) async def create_dataset( db: Session = Depends(database.get_db), - request: CreateDatasetRequest = Body(..., description="The request dataset info"), + create_dataset: CreateDatasetRequest = Body(..., description="The request dataset info"), workspace_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), current_user: User = Security(auth.get_current_user, scopes=[]), ) -> Dataset: - request.workspace = request.workspace or workspace_params.workspace + create_dataset.workspace = create_dataset.workspace or workspace_params.workspace - if not accounts.get_workspace_by_name(db, workspace_name=request.workspace): - raise EntityNotFoundError(name=request.workspace, type=Workspace) + if not accounts.get_workspace_by_name(db, workspace_name=create_dataset.workspace): + raise EntityNotFoundError(name=create_dataset.workspace, type=Workspace) if not is_authorized(current_user, DatasetPolicy.create): raise ForbiddenOperationError( "You don't have the necessary permissions to create datasets. Only administrators can create datasets" ) - dataset = datasets.create_dataset(user=current_user, dataset=request) + dataset = datasets.create_dataset(user=current_user, dataset=create_dataset) return Dataset.from_orm(dataset) @@ -115,7 +121,7 @@ async def create_dataset( ) def update_dataset( dataset_name: str, - request: UpdateDatasetRequest, + update_dataset: UpdateDatasetRequest, workspace_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), current_user: User = Security(auth.get_current_user, scopes=[]), @@ -130,7 +136,7 @@ def update_dataset( "Only dataset creators or administrators can delete datasets" ) - dataset_update = dataset.copy(update={**request.dict(), "last_updated": datetime.utcnow()}) + dataset_update = dataset.copy(update={**update_dataset.dict(), "last_updated": datetime.utcnow()}) updated_dataset = datasets.update_dataset(dataset_update) return Dataset.from_orm(updated_dataset) @@ -206,15 +212,15 @@ def copy_dataset( db: Session = Depends(database.get_db), dataset_name: str, workspace_params: CommonTaskHandlerDependencies = Depends(), - copy_request: CopyDatasetRequest, + copy_dataset: CopyDatasetRequest, datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), current_user: User = Security(auth.get_current_user, scopes=[]), ) -> Dataset: source_dataset_name = dataset_name source_workspace_name = workspace_params.workspace - target_dataset_name = copy_request.name - target_workspace_name = copy_request.target_workspace or source_workspace_name + target_dataset_name = copy_dataset.name + target_workspace_name = copy_dataset.target_workspace or source_workspace_name source_dataset = datasets.find_by_name_and_workspace(source_dataset_name, source_workspace_name) if not source_dataset: @@ -237,10 +243,10 @@ def copy_dataset( target_dataset.name = target_dataset_name target_dataset.workspace = target_workspace_name target_dataset.created_at = target_dataset.last_updated = datetime.utcnow() - target_dataset.tags = {**target_dataset.tags, **(copy_request.tags or {})} + target_dataset.tags = {**target_dataset.tags, **(copy_dataset.tags or {})} target_dataset.metadata = { **target_dataset.metadata, - **(copy_request.metadata or {}), + **(copy_dataset.metadata or {}), "source_workspace": source_workspace_name, "copied_from": source_dataset_name, } diff --git a/src/argilla/server/daos/datasets.py b/src/argilla/server/daos/datasets.py index 33a2f4b494..1926fdbb67 100644 --- a/src/argilla/server/daos/datasets.py +++ b/src/argilla/server/daos/datasets.py @@ -79,7 +79,9 @@ def list_datasets( task2dataset_map: Dict[str, Type[DatasetDB]] = None, name: Optional[str] = None, ) -> List[DatasetDB]: - workspaces = workspaces or [] + if not workspaces: + return [] + query = BaseDatasetsQuery( workspaces=workspaces, tasks=[task for task in task2dataset_map] if task2dataset_map else None, From d6bc191b086c86575c8cd9bd02958e84bc3a6cd7 Mon Sep 17 00:00:00 2001 From: frascuchon Date: Fri, 10 Mar 2023 16:44:40 +0100 Subject: [PATCH 04/13] Revert datasets handlers --- .../server/apis/v0/handlers/datasets.py | 285 +++++++----------- 1 file changed, 114 insertions(+), 171 deletions(-) diff --git a/src/argilla/server/apis/v0/handlers/datasets.py b/src/argilla/server/apis/v0/handlers/datasets.py index cd43ee4f04..9f5be790ba 100644 --- a/src/argilla/server/apis/v0/handlers/datasets.py +++ b/src/argilla/server/apis/v0/handlers/datasets.py @@ -12,24 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime + from typing import List from fastapi import APIRouter, Body, Depends, Security from pydantic import parse_obj_as -from sqlalchemy.orm import Session -from argilla.server import database from argilla.server.apis.v0.helpers import deprecate_endpoint from argilla.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies -from argilla.server.contexts import accounts -from argilla.server.daos.datasets import DatasetsDAO -from argilla.server.errors import ( - EntityAlreadyExistsError, - EntityNotFoundError, - ForbiddenOperationError, -) -from argilla.server.policies import DatasetPolicy, is_authorized +from argilla.server.errors import EntityNotFoundError +from argilla.server.models import User from argilla.server.schemas.datasets import ( CopyDatasetRequest, CreateDatasetRequest, @@ -37,7 +29,7 @@ UpdateDatasetRequest, ) from argilla.server.security import auth -from argilla.server.security.model import User, Workspace +from argilla.server.services.datasets import DatasetsService router = APIRouter(tags=["datasets"], prefix="/datasets") @@ -51,39 +43,16 @@ operation_id="list_datasets", ) async def list_datasets( - db: Session = Depends(database.get_db), - datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), - current_user: User = Security(auth.get_current_user, scopes=[]), + request_deps: CommonTaskHandlerDependencies = Depends(), + service: DatasetsService = Depends(DatasetsService.get_instance), + current_user: User = Security(auth.get_current_user), ) -> List[Dataset]: - if not is_authorized(current_user, DatasetPolicy.list): - raise ForbiddenOperationError("You don't have the necessary permissions to list datasets.") - - workspaces = [] - if current_user.is_admin: - workspaces = accounts.list_workspaces(db) - else: - workspaces = current_user.workspaces - - workspace_names = [workspace.name for workspace in workspaces] - - return parse_obj_as(List[Dataset], datasets.list_datasets(workspaces=workspace_names)) - - -@router.get("/{dataset_name}", response_model=Dataset, response_model_exclude_none=True, operation_id="get_dataset") -def get_dataset( - dataset_name: str, - workspace_params: CommonTaskHandlerDependencies = Depends(), - datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), - current_user: User = Security(auth.get_current_user, scopes=[]), -) -> Dataset: - dataset = datasets.find_by_name_and_workspace(dataset_name, workspace_params.workspace) - if not dataset: - raise EntityNotFoundError(name=dataset, type=Dataset) + datasets = service.list( + user=current_user, + workspaces=[request_deps.workspace] if request_deps.workspace is not None else None, + ) - if not is_authorized(current_user, DatasetPolicy.get): - raise ForbiddenOperationError("You don't have the necessary permissions to get datasets.") - - return Dataset.from_orm(dataset) + return parse_obj_as(List[Dataset], datasets) @router.post( @@ -95,162 +64,136 @@ def get_dataset( description="Create a new dataset", ) async def create_dataset( - db: Session = Depends(database.get_db), - create_dataset: CreateDatasetRequest = Body(..., description="The request dataset info"), - workspace_params: CommonTaskHandlerDependencies = Depends(), - datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), - current_user: User = Security(auth.get_current_user, scopes=[]), + request: CreateDatasetRequest = Body(..., description="The request dataset info"), + ws_params: CommonTaskHandlerDependencies = Depends(), + datasets: DatasetsService = Depends(DatasetsService.get_instance), + user: User = Security(auth.get_current_user), ) -> Dataset: - create_dataset.workspace = create_dataset.workspace or workspace_params.workspace + request.workspace = request.workspace or ws_params.workspace + dataset = datasets.create_dataset(user=user, dataset=request) - if not accounts.get_workspace_by_name(db, workspace_name=create_dataset.workspace): - raise EntityNotFoundError(name=create_dataset.workspace, type=Workspace) - - if not is_authorized(current_user, DatasetPolicy.create): - raise ForbiddenOperationError( - "You don't have the necessary permissions to create datasets. Only administrators can create datasets" - ) + return Dataset.from_orm(dataset) - dataset = datasets.create_dataset(user=current_user, dataset=create_dataset) - return Dataset.from_orm(dataset) +@router.get( + "/{name}", + response_model=Dataset, + response_model_exclude_none=True, + operation_id="get_dataset", +) +def get_dataset( + name: str, + ds_params: CommonTaskHandlerDependencies = Depends(), + service: DatasetsService = Depends(DatasetsService.get_instance), + current_user: User = Security(auth.get_current_user), +) -> Dataset: + return Dataset.from_orm( + service.find_by_name( + user=current_user, + name=name, + workspace=ds_params.workspace, + ) + ) @router.patch( - "/{dataset_name}", operation_id="update_dataset", response_model=Dataset, response_model_exclude_none=True + "/{name}", + operation_id="update_dataset", + response_model=Dataset, + response_model_exclude_none=True, ) def update_dataset( - dataset_name: str, - update_dataset: UpdateDatasetRequest, - workspace_params: CommonTaskHandlerDependencies = Depends(), - datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), - current_user: User = Security(auth.get_current_user, scopes=[]), + name: str, + request: UpdateDatasetRequest, + ds_params: CommonTaskHandlerDependencies = Depends(), + service: DatasetsService = Depends(DatasetsService.get_instance), + current_user: User = Security(auth.get_current_user), ) -> Dataset: - dataset = datasets.find_by_name_and_workspace(name=dataset_name, workspace=workspace_params.workspace) - if not dataset: - raise EntityNotFoundError(name=dataset, type=Dataset) - - if not is_authorized(current_user, DatasetPolicy.update(dataset)): - raise ForbiddenOperationError( - "You don't have the necessary permissions to update this dataset. " - "Only dataset creators or administrators can delete datasets" - ) + found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) - dataset_update = dataset.copy(update={**update_dataset.dict(), "last_updated": datetime.utcnow()}) - updated_dataset = datasets.update_dataset(dataset_update) + dataset = service.update( + user=current_user, + dataset=found_ds, + tags=request.tags, + metadata=request.metadata, + ) - return Dataset.from_orm(updated_dataset) + return Dataset.from_orm(dataset) -@router.delete("/{dataset_name}", operation_id="delete_dataset") +@router.delete( + "/{name}", + operation_id="delete_dataset", +) def delete_dataset( - dataset_name: str, - workspace_params: CommonTaskHandlerDependencies = Depends(), - datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), - current_user: User = Security(auth.get_current_user, scopes=[]), + name: str, + ds_params: CommonTaskHandlerDependencies = Depends(), + service: DatasetsService = Depends(DatasetsService.get_instance), + current_user: User = Security(auth.get_current_user), ): - dataset = datasets.find_by_name_and_workspace(name=dataset_name, workspace=workspace_params.workspace) - if not dataset: - # We are not raising an EntityNotFoundError because this endpoint - # was not doing it originally so we want to continue doing the same. - return - - if not is_authorized(current_user, DatasetPolicy.delete(dataset)): - raise ForbiddenOperationError( - "You don't have the necessary permissions to delete this dataset. " - "Only dataset creators or administrators can delete datasets" + try: + found_ds = service.find_by_name( + user=current_user, + name=name, + workspace=ds_params.workspace, ) - - datasets.delete_dataset(dataset) - - -@router.put("/{dataset_name}:open", operation_id="open_dataset") -def open_dataset( - dataset_name: str, - workspace_params: CommonTaskHandlerDependencies = Depends(), - datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), - current_user: User = Security(auth.get_current_user, scopes=[]), -): - dataset = datasets.find_by_name_and_workspace(dataset_name, workspace_params.workspace) - if not dataset: - raise EntityNotFoundError(name=dataset_name, type=Dataset) - - if not is_authorized(current_user, DatasetPolicy.open(dataset)): - raise ForbiddenOperationError( - "You don't have the necessary permissions to open this dataset. " - "Only dataset creators or administrators can open datasets" + service.delete( + user=current_user, + dataset=found_ds, ) + except EntityNotFoundError: + pass - datasets.open(dataset) - -@router.put("/{dataset_name}:close", operation_id="close_dataset") +@router.put( + "/{name}:close", + operation_id="close_dataset", +) def close_dataset( - dataset_name: str, - workspace_params: CommonTaskHandlerDependencies = Depends(), - datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), - current_user: User = Security(auth.get_current_user, scopes=[]), + name: str, + ds_params: CommonTaskHandlerDependencies = Depends(), + service: DatasetsService = Depends(DatasetsService.get_instance), + current_user: User = Security(auth.get_current_user), ): - dataset = datasets.find_by_name_and_workspace(dataset_name, workspace_params.workspace) - if not dataset: - raise EntityNotFoundError(name=dataset_name, type=Dataset) + found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) + service.close(user=current_user, dataset=found_ds) - if not is_authorized(current_user, DatasetPolicy.close(dataset)): - raise ForbiddenOperationError( - "You don't have the necessary permissions to close this dataset. " - "Only dataset creators or administrators can close datasets" - ) - datasets.close(dataset) +@router.put( + "/{name}:open", + operation_id="open_dataset", +) +def open_dataset( + name: str, + ds_params: CommonTaskHandlerDependencies = Depends(), + service: DatasetsService = Depends(DatasetsService.get_instance), + current_user: User = Security(auth.get_current_user), +): + found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) + service.open(user=current_user, dataset=found_ds) @router.put( - "/{dataset_name}:copy", operation_id="copy_dataset", response_model=Dataset, response_model_exclude_none=True + "/{name}:copy", + operation_id="copy_dataset", + response_model=Dataset, + response_model_exclude_none=True, ) def copy_dataset( - *, - db: Session = Depends(database.get_db), - dataset_name: str, - workspace_params: CommonTaskHandlerDependencies = Depends(), - copy_dataset: CopyDatasetRequest, - datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), - current_user: User = Security(auth.get_current_user, scopes=[]), + name: str, + copy_request: CopyDatasetRequest, + ds_params: CommonTaskHandlerDependencies = Depends(), + service: DatasetsService = Depends(DatasetsService.get_instance), + current_user: User = Security(auth.get_current_user), ) -> Dataset: - source_dataset_name = dataset_name - source_workspace_name = workspace_params.workspace - - target_dataset_name = copy_dataset.name - target_workspace_name = copy_dataset.target_workspace or source_workspace_name - - source_dataset = datasets.find_by_name_and_workspace(source_dataset_name, source_workspace_name) - if not source_dataset: - raise EntityNotFoundError(name=source_dataset_name, type=Dataset) - - target_workspace = accounts.get_workspace_by_name(db, target_workspace_name) - if not target_workspace: - raise EntityNotFoundError(name=target_workspace_name, type=Workspace) + found = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) + dataset = service.copy_dataset( + user=current_user, + dataset=found, + copy_name=copy_request.name, + copy_workspace=copy_request.target_workspace, + copy_tags=copy_request.tags, + copy_metadata=copy_request.metadata, + ) - if datasets.find_by_name_and_workspace(target_dataset_name, target_workspace_name): - raise EntityAlreadyExistsError(name=target_dataset_name, workspace=target_workspace_name, type=Dataset) - - if not is_authorized(current_user, DatasetPolicy.copy(source_dataset)): - raise ForbiddenOperationError( - "You don't have the necessary permissions to copy this dataset. " - "Only dataset creators or administrators can copy datasets" - ) - - target_dataset = source_dataset.copy() - target_dataset.name = target_dataset_name - target_dataset.workspace = target_workspace_name - target_dataset.created_at = target_dataset.last_updated = datetime.utcnow() - target_dataset.tags = {**target_dataset.tags, **(copy_dataset.tags or {})} - target_dataset.metadata = { - **target_dataset.metadata, - **(copy_dataset.metadata or {}), - "source_workspace": source_workspace_name, - "copied_from": source_dataset_name, - } - - datasets.copy(source_dataset, target_dataset) - - return Dataset.from_orm(target_dataset) + return Dataset.from_orm(dataset) From 7ed3c787be375f2423a4b0b1de17af448f2abdc6 Mon Sep 17 00:00:00 2001 From: frascuchon Date: Fri, 10 Mar 2023 16:45:16 +0100 Subject: [PATCH 05/13] Update dataset policies --- src/argilla/server/policies.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/argilla/server/policies.py b/src/argilla/server/policies.py index 85889cff17..625b39dbe6 100644 --- a/src/argilla/server/policies.py +++ b/src/argilla/server/policies.py @@ -71,8 +71,8 @@ def list(cls, user: User) -> bool: return True @classmethod - def get(cls, user: User) -> bool: - return True + def get(cls, dataset: Dataset) -> PolicyAction: + return lambda actor: actor.is_admin or dataset.workspace in [ws.name for ws in actor.workspaces] @classmethod def create(cls, user: User) -> bool: @@ -80,23 +80,27 @@ def create(cls, user: User) -> bool: @classmethod def update(cls, dataset: Dataset) -> PolicyAction: - return lambda actor: actor.is_admin or actor.username == dataset.created_by + is_get_allowed = cls.get(dataset) + return lambda actor: actor.is_admin or (is_get_allowed(actor) and actor.username == dataset.created_by) @classmethod def delete(cls, dataset: Dataset) -> PolicyAction: - return lambda actor: actor.is_admin or actor.username == dataset.created_by + is_get_allowed = cls.get(dataset) + return lambda actor: actor.is_admin or (is_get_allowed(actor) and actor.username == dataset.created_by) @classmethod def open(cls, dataset: Dataset) -> PolicyAction: - return lambda actor: actor.is_admin or actor.username == dataset.created_by + is_get_allowed = cls.get(dataset) + return lambda actor: actor.is_admin or (is_get_allowed(actor) and actor.username == dataset.created_by) @classmethod def close(cls, dataset: Dataset) -> PolicyAction: - return lambda actor: actor.is_admin or actor.username == dataset.created_by + return lambda actor: actor.is_admin @classmethod def copy(cls, dataset: Dataset) -> PolicyAction: - return lambda actor: cls.get(actor) and cls.create(actor) + is_get_allowed = cls.get(dataset) + return lambda actor: actor.is_admin or is_get_allowed(actor) and cls.create(actor) def authorize(actor: User, policy_action: PolicyAction) -> None: From 73513cfe4378d2503bb19c177bf721bd81c5066c Mon Sep 17 00:00:00 2001 From: frascuchon Date: Fri, 10 Mar 2023 16:47:40 +0100 Subject: [PATCH 06/13] Using models.User --- .../server/apis/v0/handlers/metrics.py | 4 ++-- .../server/apis/v0/handlers/records.py | 4 ++-- .../server/apis/v0/handlers/records_search.py | 2 +- .../server/apis/v0/handlers/records_update.py | 2 +- .../server/apis/v0/handlers/text2text.py | 8 +++---- .../apis/v0/handlers/text_classification.py | 22 +++++++++---------- .../text_classification_dataset_settings.py | 6 ++--- .../apis/v0/handlers/token_classification.py | 8 +++---- .../token_classification_dataset_settings.py | 6 ++--- src/argilla/server/apis/v0/handlers/users.py | 13 +++++------ .../server/apis/v0/handlers/workspaces.py | 12 +++++----- .../server/services/storage/service.py | 4 ++-- 12 files changed, 45 insertions(+), 46 deletions(-) diff --git a/src/argilla/server/apis/v0/handlers/metrics.py b/src/argilla/server/apis/v0/handlers/metrics.py index eec6397c7d..df2914a368 100644 --- a/src/argilla/server/apis/v0/handlers/metrics.py +++ b/src/argilla/server/apis/v0/handlers/metrics.py @@ -73,7 +73,7 @@ def configure_router(router: APIRouter, cfg: TaskConfig): def get_dataset_metrics( name: str, request_deps: CommonTaskHandlerDependencies = Depends(), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), datasets: DatasetsService = Depends(DatasetsService.get_instance), ) -> List[MetricInfo]: dataset = datasets.find_by_name( @@ -101,7 +101,7 @@ def metric_summary( query: cfg.query, metric_params: MetricSummaryParams = Depends(), request_deps: CommonTaskHandlerDependencies = Depends(), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user, scopes=[]), datasets: DatasetsService = Depends(DatasetsService.get_instance), metrics: MetricsService = Depends(MetricsService.get_instance), ): diff --git a/src/argilla/server/apis/v0/handlers/records.py b/src/argilla/server/apis/v0/handlers/records.py index 108afb1adf..4e29f1afaf 100644 --- a/src/argilla/server/apis/v0/handlers/records.py +++ b/src/argilla/server/apis/v0/handlers/records.py @@ -53,7 +53,7 @@ async def get_dataset_record( request_deps: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), search: SearchRecordsService = Depends(SearchRecordsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> RecordType: found = service.find_by_name( user=current_user, @@ -85,7 +85,7 @@ async def delete_dataset_records( request_deps: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), storage: RecordsStorageService = Depends(RecordsStorageService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ): found = service.find_by_name( user=current_user, diff --git a/src/argilla/server/apis/v0/handlers/records_search.py b/src/argilla/server/apis/v0/handlers/records_search.py index 9043db3f48..103780c586 100644 --- a/src/argilla/server/apis/v0/handlers/records_search.py +++ b/src/argilla/server/apis/v0/handlers/records_search.py @@ -72,7 +72,7 @@ async def scan_dataset_records( request_deps: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), engine: GenericElasticEngineBackend = Depends(GenericElasticEngineBackend.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ): found = service.find_by_name(user=current_user, name=name, workspace=request_deps.workspace) diff --git a/src/argilla/server/apis/v0/handlers/records_update.py b/src/argilla/server/apis/v0/handlers/records_update.py index b2034265ad..6fa67eb677 100644 --- a/src/argilla/server/apis/v0/handlers/records_update.py +++ b/src/argilla/server/apis/v0/handlers/records_update.py @@ -53,7 +53,7 @@ async def partial_update_dataset_record( service: DatasetsService = Depends(DatasetsService.get_instance), search: SearchRecordsService = Depends(SearchRecordsService.get_instance), storage: RecordsStorageService = Depends(RecordsStorageService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user, scopes=[]), ) -> RecordType: dataset = service.find_by_name( user=current_user, diff --git a/src/argilla/server/apis/v0/handlers/text2text.py b/src/argilla/server/apis/v0/handlers/text2text.py index 24e7d03493..76a8660bfa 100644 --- a/src/argilla/server/apis/v0/handlers/text2text.py +++ b/src/argilla/server/apis/v0/handlers/text2text.py @@ -75,10 +75,10 @@ async def bulk_records( common_params: CommonTaskHandlerDependencies = Depends(), service: Text2TextService = Depends(Text2TextService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> BulkResponse: task = task_type - workspace = current_user.check_workspace(common_params.workspace) + workspace = common_params.workspace try: dataset = datasets.find_by_name( current_user, @@ -120,7 +120,7 @@ def search_records( pagination: RequestPagination = Depends(), service: Text2TextService = Depends(Text2TextService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> Text2TextSearchResults: search = search or Text2TextSearchRequest() query = search.query or Text2TextQuery() @@ -188,7 +188,7 @@ async def stream_data( limit: Optional[int] = Query(None, description="Limit loaded records", gt=0), service: Text2TextService = Depends(Text2TextService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), id_from: Optional[str] = None, ) -> StreamingResponse: """ diff --git a/src/argilla/server/apis/v0/handlers/text_classification.py b/src/argilla/server/apis/v0/handlers/text_classification.py index fb53afb55f..a28d4ba5f9 100644 --- a/src/argilla/server/apis/v0/handlers/text_classification.py +++ b/src/argilla/server/apis/v0/handlers/text_classification.py @@ -93,10 +93,10 @@ async def bulk_records( service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> BulkResponse: task = task_type - workspace = current_user.check_workspace(common_params.workspace) + workspace = common_params.workspace try: dataset = datasets.find_by_name( current_user, @@ -142,7 +142,7 @@ def search_records( pagination: RequestPagination = Depends(), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> TextClassificationSearchResults: """ Searches data from dataset @@ -240,7 +240,7 @@ async def stream_data( limit: Optional[int] = Query(None, description="Limit loaded records", gt=0), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> StreamingResponse: """ Creates a data stream over dataset records @@ -302,7 +302,7 @@ async def list_labeling_rules( common_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), service: TextClassificationService = Depends(TextClassificationService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> List[LabelingRule]: dataset = datasets.find_by_name( user=current_user, @@ -329,7 +329,7 @@ async def create_rule( common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> LabelingRule: dataset = datasets.find_by_name( user=current_user, @@ -365,7 +365,7 @@ async def compute_rule_metrics( common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> LabelingRuleMetricsSummary: dataset = datasets.find_by_name( user=current_user, @@ -391,7 +391,7 @@ async def compute_dataset_rules_metrics( common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> DatasetLabelingRulesMetricsSummary: dataset = datasets.find_by_name( user=current_user, @@ -416,7 +416,7 @@ async def delete_labeling_rule( common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> None: dataset = datasets.find_by_name( user=current_user, @@ -443,7 +443,7 @@ async def get_rule( common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> LabelingRule: dataset = datasets.find_by_name( user=current_user, @@ -471,7 +471,7 @@ async def update_rule( common_params: CommonTaskHandlerDependencies = Depends(), service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ) -> LabelingRule: dataset = datasets.find_by_name( user=current_user, diff --git a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py index 70c6b327d5..22ee497650 100644 --- a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py +++ b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py @@ -54,7 +54,7 @@ async def get_dataset_settings( name: str = DATASET_NAME_PATH_PARAM, ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), - user: User = Security(auth.get_user, scopes=["read:dataset.settings"]), + user: User = Security(auth.get_current_user), ) -> TextClassificationSettings: found_ds = datasets.find_by_name( user=user, @@ -82,7 +82,7 @@ async def save_settings( ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), - user: User = Security(auth.get_user, scopes=["write:dataset.settings"]), + user: User = Security(auth.get_current_user), ) -> TextClassificationSettings: found_ds = datasets.find_by_name( user=user, @@ -110,7 +110,7 @@ async def delete_settings( name: str = DATASET_NAME_PATH_PARAM, ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), - user: User = Security(auth.get_user, scopes=["delete:dataset.settings"]), + user: User = Security(auth.get_current_user), ) -> None: found_ds = datasets.find_by_name( user=user, diff --git a/src/argilla/server/apis/v0/handlers/token_classification.py b/src/argilla/server/apis/v0/handlers/token_classification.py index 5b8a0e6a1f..96f106a481 100644 --- a/src/argilla/server/apis/v0/handlers/token_classification.py +++ b/src/argilla/server/apis/v0/handlers/token_classification.py @@ -85,10 +85,10 @@ async def bulk_records( service: TokenClassificationService = Depends(TokenClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user, scopes=[]), ) -> BulkResponse: task = task_type - workspace = current_user.check_workspace(common_params.workspace) + workspace = common_params.workspace try: dataset = datasets.find_by_name( current_user, @@ -141,7 +141,7 @@ def search_records( pagination: RequestPagination = Depends(), service: TokenClassificationService = Depends(TokenClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user, scopes=[]), ) -> TokenClassificationSearchResults: search = search or TokenClassificationSearchRequest() query = search.query or TokenClassificationQuery() @@ -210,7 +210,7 @@ async def stream_data( limit: Optional[int] = Query(None, description="Limit loaded records", gt=0), service: TokenClassificationService = Depends(TokenClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), id_from: Optional[str] = None, ) -> StreamingResponse: """ diff --git a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py index a5ac2411f3..96c9a632ef 100644 --- a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py +++ b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py @@ -54,7 +54,7 @@ async def get_dataset_settings( name: str = DATASET_NAME_PATH_PARAM, ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), - user: User = Security(auth.get_user, scopes=["read:dataset.settings"]), + user: User = Security(auth.get_current_user), ) -> TokenClassificationSettings: found_ds = datasets.find_by_name( user=user, @@ -82,7 +82,7 @@ async def save_settings( ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), - user: User = Security(auth.get_user, scopes=["write:dataset.settings"]), + user: User = Security(auth.get_current_user), ) -> TokenClassificationSettings: found_ds = datasets.find_by_name( user=user, @@ -110,7 +110,7 @@ async def delete_settings( name: str = DATASET_NAME_PATH_PARAM, ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), - user: User = Security(auth.get_user, scopes=["delete:dataset.settings"]), + user: User = Security(auth.get_current_user), ) -> None: found_ds = datasets.find_by_name( user=user, diff --git a/src/argilla/server/apis/v0/handlers/users.py b/src/argilla/server/apis/v0/handlers/users.py index 713cad43c7..4540b01ec0 100644 --- a/src/argilla/server/apis/v0/handlers/users.py +++ b/src/argilla/server/apis/v0/handlers/users.py @@ -20,6 +20,7 @@ from pydantic import parse_obj_as from sqlalchemy.orm import Session +from argilla.server import models from argilla.server.commons import telemetry from argilla.server.contexts import accounts from argilla.server.database import get_db @@ -32,7 +33,7 @@ @router.get("/me", response_model=User, response_model_exclude_none=True, operation_id="whoami") -async def whoami(request: Request, current_user: User = Security(auth.get_user, scopes=[])): +async def whoami(request: Request, current_user: models.User = Security(auth.get_current_user)): """ User info endpoint @@ -51,11 +52,11 @@ async def whoami(request: Request, current_user: User = Security(auth.get_user, await telemetry.track_login(request, username=current_user.username) - return current_user + return User.from_orm(current_user) @router.get("/users", response_model=List[User], response_model_exclude_none=True) -def list_users(*, db: Session = Depends(get_db), current_user: User = Security(auth.get_user, scopes=[])): +def list_users(*, db: Session = Depends(get_db), current_user: User = Security(auth.get_current_user)): authorize(current_user, UserPolicy.list) users = accounts.list_users(db) @@ -65,7 +66,7 @@ def list_users(*, db: Session = Depends(get_db), current_user: User = Security(a @router.post("/users", response_model=User, response_model_exclude_none=True) def create_user( - *, db: Session = Depends(get_db), user_create: UserCreate, current_user: User = Security(auth.get_user, scopes=[]) + *, db: Session = Depends(get_db), user_create: UserCreate, current_user: User = Security(auth.get_current_user) ): authorize(current_user, UserPolicy.create) @@ -75,9 +76,7 @@ def create_user( @router.delete("/users/{user_id}", response_model=User, response_model_exclude_none=True) -def delete_user( - *, db: Session = Depends(get_db), user_id: UUID, current_user: User = Security(auth.get_user, scopes=[]) -): +def delete_user(*, db: Session = Depends(get_db), user_id: UUID, current_user: User = Security(auth.get_current_user)): user = accounts.get_user_by_id(db, user_id) if not user: # TODO: Forcing here user_id to be an string. diff --git a/src/argilla/server/apis/v0/handlers/workspaces.py b/src/argilla/server/apis/v0/handlers/workspaces.py index 616a452b12..54bd3f7921 100644 --- a/src/argilla/server/apis/v0/handlers/workspaces.py +++ b/src/argilla/server/apis/v0/handlers/workspaces.py @@ -35,7 +35,7 @@ @router.get("/workspaces", response_model=List[Workspace], response_model_exclude_none=True) -def list_workspaces(*, db: Session = Depends(get_db), current_user: User = Security(auth.get_user, scopes=[])): +def list_workspaces(*, db: Session = Depends(get_db), current_user: User = Security(auth.get_current_user)): authorize(current_user, WorkspacePolicy.list) workspaces = accounts.list_workspaces(db) @@ -48,7 +48,7 @@ def create_workspace( *, db: Session = Depends(get_db), workspace_create: WorkspaceCreate, - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ): authorize(current_user, WorkspacePolicy.create) @@ -62,7 +62,7 @@ def create_workspace( # any dataset then we can delete them. # @router.delete("/workspaces/{workspace_id}", response_model=Workspace, response_model_exclude_none=True) def delete_workspace( - *, db: Session = Depends(get_db), workspace_id: UUID, current_user: User = Security(auth.get_user, scopes=[]) + *, db: Session = Depends(get_db), workspace_id: UUID, current_user: User = Security(auth.get_current_user) ): workspace = accounts.get_workspace_by_id(db, workspace_id) if not workspace: @@ -77,7 +77,7 @@ def delete_workspace( @router.get("/workspaces/{workspace_id}/users", response_model=List[User], response_model_exclude_none=True) def list_workspace_users( - *, db: Session = Depends(get_db), workspace_id: UUID, current_user: User = Security(auth.get_user, scopes=[]) + *, db: Session = Depends(get_db), workspace_id: UUID, current_user: User = Security(auth.get_current_user) ): authorize(current_user, WorkspaceUserPolicy.list) @@ -94,7 +94,7 @@ def create_workspace_user( db: Session = Depends(get_db), workspace_id: UUID, user_id: UUID, - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ): authorize(current_user, WorkspaceUserPolicy.create) @@ -117,7 +117,7 @@ def delete_workspace_user( db: Session = Depends(get_db), workspace_id: UUID, user_id: UUID, - current_user: User = Security(auth.get_user, scopes=[]), + current_user: User = Security(auth.get_current_user), ): workspace_user = accounts.get_workspace_user_by_workspace_id_and_user_id(db, workspace_id, user_id) if not workspace_user: diff --git a/src/argilla/server/services/storage/service.py b/src/argilla/server/services/storage/service.py index 9fecae36d6..07c0dbaf97 100644 --- a/src/argilla/server/services/storage/service.py +++ b/src/argilla/server/services/storage/service.py @@ -23,7 +23,7 @@ from argilla.server.daos.backend.base import WrongLogDataError from argilla.server.daos.records import DatasetRecordsDAO from argilla.server.errors import BulkDataError, ForbiddenOperationError -from argilla.server.security.model import User +from argilla.server.models import User from argilla.server.services.datasets import ServiceDataset from argilla.server.services.search.model import ServiceBaseRecordsQuery from argilla.server.services.tasks.commons import ServiceRecord @@ -92,7 +92,7 @@ async def delete_records( status=TaskStatus.discarded, ) else: - if not user.is_superuser() and user.username != dataset.created_by: + if not user.is_admin and user.username != dataset.created_by: raise ForbiddenOperationError( "You don't have the necessary permissions to delete records on this dataset. " "Only dataset creators or administrators can delete datasets" From 15aea3abefccad4aa151b85671f247daaa699829 Mon Sep 17 00:00:00 2001 From: frascuchon Date: Fri, 10 Mar 2023 16:51:29 +0100 Subject: [PATCH 07/13] Remove auth.get_user method --- .../security/auth_provider/local/provider.py | 29 +------------------ tests/server/security/test_provider.py | 4 +-- 2 files changed, 3 insertions(+), 30 deletions(-) diff --git a/src/argilla/server/security/auth_provider/local/provider.py b/src/argilla/server/security/auth_provider/local/provider.py index 84ef0669a4..e00161cabe 100644 --- a/src/argilla/server/security/auth_provider/local/provider.py +++ b/src/argilla/server/security/auth_provider/local/provider.py @@ -28,6 +28,7 @@ from argilla.server.contexts import accounts from argilla.server.database import get_db from argilla.server.errors import UnauthorizedError +from argilla.server.models import User from argilla.server.security.auth_provider.base import ( AuthProvider, api_key_header, @@ -143,7 +144,6 @@ def fetch_token_user(self, db: Session, token: str) -> Optional[User]: ) username: str = payload.get("sub") if username: - # return self.users.get_user(username=username) return accounts.get_user_by_username(db, username) except JWTError: return None @@ -169,33 +169,6 @@ def get_current_user( return user - async def get_user( - self, - security_scopes: SecurityScopes, - db: Session = Depends(get_db), - api_key: Optional[str] = Depends(api_key_header), - old_api_key: Optional[str] = Depends(old_api_key_header), - token: Optional[str] = Depends(_oauth2_scheme), - ) -> User: - """ - Fetches the user for a given token - - Parameters - ---------- - api_key: - The apikey header info if provided - old_api_key: - Same as api key but for old clients - token: - The login token. - fastapi injects this param from request - Returns - ------- - - """ - user = self.get_current_user(security_scopes, db, api_key, old_api_key, token) - return User.from_orm(user) - def create_local_auth_provider(): settings = Settings() diff --git a/tests/server/security/test_provider.py b/tests/server/security/test_provider.py index 69b944861a..7fe2962c08 100644 --- a/tests/server/security/test_provider.py +++ b/tests/server/security/test_provider.py @@ -28,7 +28,7 @@ @pytest.mark.asyncio async def test_get_user_via_token(db: Session, argilla_user): access_token = localAuth._create_access_token(username="argilla") - user = await localAuth.get_user( + user = await localAuth.get_current_user( security_scopes=security_Scopes, db=db, token=access_token, @@ -40,7 +40,7 @@ async def test_get_user_via_token(db: Session, argilla_user): @pytest.mark.asyncio async def test_get_user_via_api_key(db: Session, argilla_user): - user = await localAuth.get_user(security_scopes=security_Scopes, db=db, api_key=DEFAULT_API_KEY, token=None) + user = await localAuth.get_current_user(security_scopes=security_Scopes, db=db, api_key=DEFAULT_API_KEY, token=None) assert user.username == "argilla" From c00b6ace4ec3a382dfcf48fbc82b5643bd6d2b97 Mon Sep 17 00:00:00 2001 From: frascuchon Date: Fri, 10 Mar 2023 16:56:46 +0100 Subject: [PATCH 08/13] Integrate datasets service with new logic --- src/argilla/server/services/datasets.py | 158 +++++++++++++++--------- 1 file changed, 97 insertions(+), 61 deletions(-) diff --git a/src/argilla/server/services/datasets.py b/src/argilla/server/services/datasets.py index da48e21d0d..5fce1283a9 100644 --- a/src/argilla/server/services/datasets.py +++ b/src/argilla/server/services/datasets.py @@ -17,7 +17,10 @@ from typing import Any, Dict, List, Optional, Type, TypeVar, cast from fastapi import Depends +from sqlalchemy.orm import Session +from argilla.server import database +from argilla.server.contexts import accounts from argilla.server.daos.datasets import BaseDatasetSettingsDB, DatasetsDAO from argilla.server.daos.models.datasets import BaseDatasetDB from argilla.server.errors import ( @@ -26,6 +29,9 @@ ForbiddenOperationError, WrongTaskError, ) +from argilla.server.models import User as UserModel +from argilla.server.models import Workspace +from argilla.server.policies import DatasetPolicy, is_authorized from argilla.server.schemas.datasets import CreateDatasetRequest, Dataset from argilla.server.security.model import User @@ -46,16 +52,24 @@ class DatasetsService: _INSTANCE: "DatasetsService" = None @classmethod - def get_instance(cls, dao: DatasetsDAO = Depends(DatasetsDAO.get_instance)) -> "DatasetsService": - if not cls._INSTANCE: - cls._INSTANCE = cls(dao) - return cls._INSTANCE + def get_instance( + cls, db: Session = Depends(database.get_db), dao: DatasetsDAO = Depends(DatasetsDAO.get_instance) + ) -> "DatasetsService": + return cls(db, dao) - def __init__(self, dao: DatasetsDAO): + def __init__(self, db: Session, dao: DatasetsDAO): + self._db = db self.__dao__ = dao - def create_dataset(self, user: User, dataset: CreateDatasetRequest) -> BaseDatasetDB: - dataset.workspace = user.check_workspace(dataset.workspace) + def create_dataset(self, user: UserModel, dataset: CreateDatasetRequest) -> BaseDatasetDB: + if not accounts.get_workspace_by_name(self._db, workspace_name=dataset.workspace): + raise EntityNotFoundError(name=dataset.workspace, type=Workspace) + + if not is_authorized(user, DatasetPolicy.create): + raise ForbiddenOperationError( + "You don't have the necessary permissions to create datasets. Only administrators can create datasets" + ) + try: self.find_by_name(user=user, name=dataset.name, task=dataset.task, workspace=dataset.workspace) raise EntityAlreadyExistsError(name=dataset.name, type=ServiceDataset, workspace=dataset.workspace) @@ -74,114 +88,136 @@ def create_dataset(self, user: User, dataset: CreateDatasetRequest) -> BaseDatas def find_by_name( self, - user: User, + user: UserModel, name: str, workspace: str, as_dataset_class: Type[ServiceDataset] = ServiceBaseDataset, task: Optional[str] = None, ) -> ServiceDataset: - workspace = user.check_workspace(workspace) - found_ds = self.__dao__.find_by_name(name=name, workspace=workspace, as_dataset_class=as_dataset_class) - if found_ds is None: + found_dataset = self.__dao__.find_by_name(name=name, workspace=workspace, as_dataset_class=as_dataset_class) + + if found_dataset is None: raise EntityNotFoundError(name=name, type=ServiceDataset) - elif task and found_ds.task != task: + + elif task and found_dataset.task != task: raise WrongTaskError(detail=f"Provided task {task} cannot be applied to dataset") - else: - return cast(ServiceDataset, found_ds) - def delete(self, user: User, dataset: ServiceDataset): - dataset = self.find_by_name(user=user, name=dataset.name, workspace=dataset.workspace, task=dataset.task) + if not is_authorized(user, DatasetPolicy.get(found_dataset)): + raise ForbiddenOperationError("You don't have the necessary permissions to get this dataset.") + + return cast(ServiceDataset, found_dataset) - if user.is_superuser() or user.username == dataset.created_by: - self.__dao__.delete_dataset(dataset) - else: + def delete(self, user: UserModel, dataset: ServiceDataset): + if not is_authorized(user, DatasetPolicy.delete(dataset)): raise ForbiddenOperationError( "You don't have the necessary permissions to delete this dataset. " "Only dataset creators or administrators can delete datasets" ) + self.__dao__.delete_dataset(dataset) def update( self, - user: User, + user: UserModel, dataset: ServiceDataset, tags: Dict[str, str], metadata: Dict[str, Any], ) -> Dataset: found = self.find_by_name(user=user, name=dataset.name, task=dataset.task, workspace=dataset.workspace) + if not is_authorized(user, DatasetPolicy.update(found)): + raise ForbiddenOperationError( + "You don't have the necessary permissions to update this dataset. " + "Only dataset creators or administrators can update datasets" + ) + dataset.tags = {**found.tags, **(tags or {})} dataset.metadata = {**found.metadata, **(metadata or {})} updated = found.copy(update={**dataset.dict(by_alias=True), "last_updated": datetime.utcnow()}) + return self.__dao__.update_dataset(updated) def list( self, - user: User, + user: UserModel, workspaces: Optional[List[str]], task2dataset_map: Dict[str, Type[ServiceDataset]] = None, ) -> List[ServiceDataset]: - workspaces = user.check_workspaces(workspaces) - return self.__dao__.list_datasets(workspaces=workspaces, task2dataset_map=task2dataset_map) + if not is_authorized(user, DatasetPolicy.list): + raise ForbiddenOperationError("You don't have the necessary permissions to list datasets.") + + accessible_workspace_names = [ + ws.name for ws in (accounts.list_workspaces(self._db) if user.is_admin else user.workspaces) + ] - def close(self, user: User, dataset: ServiceDataset): - found = self.find_by_name(user=user, name=dataset.name, workspace=dataset.workspace) - self.__dao__.close(found) + if workspaces: + for ws in workspaces: + if ws not in accessible_workspace_names: + raise EntityNotFoundError(name=ws, type=Workspace) + workspace_names = workspaces + else: # no workspaces + workspace_names = accessible_workspace_names - def open(self, user: User, dataset: ServiceDataset): - found = self.find_by_name(user=user, name=dataset.name, workspace=dataset.workspace) - self.__dao__.open(found) + return self.__dao__.list_datasets(workspaces=workspace_names, task2dataset_map=task2dataset_map) + + def close(self, user: UserModel, dataset: ServiceDataset): + if not is_authorized(user, DatasetPolicy.close(dataset)): + raise ForbiddenOperationError( + "You don't have the necessary permissions to close this dataset. " + "Only dataset creators or administrators can close datasets" + ) + self.__dao__.close(dataset) + + def open(self, user: UserModel, dataset: ServiceDataset): + if not is_authorized(user, DatasetPolicy.open(dataset)): + raise ForbiddenOperationError( + "You don't have the necessary permissions to open this dataset. " + "Only dataset creators or administrators can open datasets" + ) + self.__dao__.open(dataset) def copy_dataset( self, - user: User, + user: UserModel, dataset: ServiceDataset, copy_name: str, copy_workspace: Optional[str] = None, copy_tags: Dict[str, Any] = None, copy_metadata: Dict[str, Any] = None, ) -> ServiceDataset: - dataset_workspace = copy_workspace or dataset.workspace - dataset_workspace = user.check_workspace(dataset_workspace) + target_workspace_name = copy_workspace or dataset.workspace - self._validate_create_dataset( - name=copy_name, - workspace=dataset_workspace, - user=user, - ) + target_workspace = accounts.get_workspace_by_name(self._db, target_workspace_name) + if not target_workspace: + raise EntityNotFoundError(name=target_workspace_name, type=Workspace) - copy_dataset = dataset.copy() - copy_dataset.name = copy_name - copy_dataset.workspace = dataset_workspace + if self.__dao__.find_by_name_and_workspace(name=copy_name, workspace=target_workspace_name): + raise EntityAlreadyExistsError(name=copy_name, workspace=target_workspace_name, type=Dataset) + + if not is_authorized(user, DatasetPolicy.copy(dataset)): + raise ForbiddenOperationError( + "You don't have the necessary permissions to copy this dataset. " + "Only dataset creators or administrators can copy datasets" + ) + + dataset_copy = dataset.copy() + dataset_copy.name = copy_name + dataset_copy.workspace = target_workspace_name date_now = datetime.utcnow() - copy_dataset.created_at = date_now - copy_dataset.last_updated = date_now - copy_dataset.tags = {**copy_dataset.tags, **(copy_tags or {})} - copy_dataset.metadata = { - **copy_dataset.metadata, + dataset_copy.created_at = date_now + dataset_copy.last_updated = date_now + dataset_copy.tags = {**dataset_copy.tags, **(copy_tags or {})} + dataset_copy.metadata = { + **dataset_copy.metadata, **(copy_metadata or {}), "source_workspace": dataset.workspace, "copied_from": dataset.name, } - self.__dao__.copy( - source=dataset, - target=copy_dataset, - ) + self.__dao__.copy(source=dataset, target=dataset_copy) - return copy_dataset - - def _validate_create_dataset(self, name: str, workspace: str, user: User): - try: - found = self.find_by_name(user=user, name=name, workspace=workspace) - raise EntityAlreadyExistsError( - name=found.name, - type=found.__class__, - workspace=workspace, - ) - except (EntityNotFoundError, ForbiddenOperationError): - pass + return dataset_copy async def get_settings( self, From 87520b867f40246fc0baad06f41abe8669de614a Mon Sep 17 00:00:00 2001 From: frascuchon Date: Fri, 10 Mar 2023 16:57:10 +0100 Subject: [PATCH 09/13] Fixing tests --- tests/server/datasets/test_dao.py | 16 +++++++++------- tests/server/security/test_provider.py | 10 +++------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/tests/server/datasets/test_dao.py b/tests/server/datasets/test_dao.py index 6f2a734009..4565f17dd8 100644 --- a/tests/server/datasets/test_dao.py +++ b/tests/server/datasets/test_dao.py @@ -37,15 +37,16 @@ def test_retrieve_ownered_dataset_for_no_owner_user(): def test_list_datasets_by_task(): dataset = "test_list_datasets_by_task" + workspace_name = "other" - all_datasets = dao.list_datasets() + all_datasets = dao.list_datasets(workspaces=[workspace_name]) for ds in all_datasets: dao.delete_dataset(ds) created_text = dao.create_dataset( BaseDatasetDB( name=dataset + "_text", - workspace="other", + workspace=workspace_name, task=TaskType.text_classification, ), ) @@ -53,19 +54,20 @@ def test_list_datasets_by_task(): created_token = dao.create_dataset( BaseDatasetDB( name=dataset + "_token", - workspace="other", + workspace=workspace_name, task=TaskType.token_classification, ), ) - datasets = dao.list_datasets( - task2dataset_map={created_text.task: BaseDatasetDB}, - ) + assert len(dao.list_datasets()) == 0 + assert len(dao.list_datasets(workspaces=[workspace_name])) == 2 + + datasets = dao.list_datasets(workspaces=[workspace_name], task2dataset_map={created_text.task: BaseDatasetDB}) assert len(datasets) == 1 assert datasets[0].name == created_text.name - datasets = dao.list_datasets(task2dataset_map={created_token.task: BaseDatasetDB}) + datasets = dao.list_datasets(workspaces=[workspace_name], task2dataset_map={created_token.task: BaseDatasetDB}) assert len(datasets) == 1 assert datasets[0].name == created_token.name diff --git a/tests/server/security/test_provider.py b/tests/server/security/test_provider.py index 7fe2962c08..8832520b8a 100644 --- a/tests/server/security/test_provider.py +++ b/tests/server/security/test_provider.py @@ -28,19 +28,15 @@ @pytest.mark.asyncio async def test_get_user_via_token(db: Session, argilla_user): access_token = localAuth._create_access_token(username="argilla") - user = await localAuth.get_current_user( - security_scopes=security_Scopes, - db=db, - token=access_token, - api_key=None, - old_api_key=None, + user = localAuth.get_current_user( + security_scopes=security_Scopes, db=db, token=access_token, api_key=None, old_api_key=None ) assert user.username == "argilla" @pytest.mark.asyncio async def test_get_user_via_api_key(db: Session, argilla_user): - user = await localAuth.get_current_user(security_scopes=security_Scopes, db=db, api_key=DEFAULT_API_KEY, token=None) + user = localAuth.get_current_user(security_scopes=security_Scopes, db=db, api_key=DEFAULT_API_KEY, token=None) assert user.username == "argilla" From dd1843ce4c4e9c3d6cd9ba29efb3776e742c5ff2 Mon Sep 17 00:00:00 2001 From: frascuchon Date: Fri, 10 Mar 2023 17:15:47 +0100 Subject: [PATCH 10/13] Rename UserModel to User --- src/argilla/server/services/datasets.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/argilla/server/services/datasets.py b/src/argilla/server/services/datasets.py index 5fce1283a9..e863fd45ce 100644 --- a/src/argilla/server/services/datasets.py +++ b/src/argilla/server/services/datasets.py @@ -29,11 +29,9 @@ ForbiddenOperationError, WrongTaskError, ) -from argilla.server.models import User as UserModel -from argilla.server.models import Workspace +from argilla.server.models import User, Workspace from argilla.server.policies import DatasetPolicy, is_authorized from argilla.server.schemas.datasets import CreateDatasetRequest, Dataset -from argilla.server.security.model import User class ServiceBaseDataset(BaseDatasetDB): @@ -61,7 +59,7 @@ def __init__(self, db: Session, dao: DatasetsDAO): self._db = db self.__dao__ = dao - def create_dataset(self, user: UserModel, dataset: CreateDatasetRequest) -> BaseDatasetDB: + def create_dataset(self, user: User, dataset: CreateDatasetRequest) -> BaseDatasetDB: if not accounts.get_workspace_by_name(self._db, workspace_name=dataset.workspace): raise EntityNotFoundError(name=dataset.workspace, type=Workspace) @@ -88,7 +86,7 @@ def create_dataset(self, user: UserModel, dataset: CreateDatasetRequest) -> Base def find_by_name( self, - user: UserModel, + user: User, name: str, workspace: str, as_dataset_class: Type[ServiceDataset] = ServiceBaseDataset, @@ -107,7 +105,7 @@ def find_by_name( return cast(ServiceDataset, found_dataset) - def delete(self, user: UserModel, dataset: ServiceDataset): + def delete(self, user: User, dataset: ServiceDataset): if not is_authorized(user, DatasetPolicy.delete(dataset)): raise ForbiddenOperationError( "You don't have the necessary permissions to delete this dataset. " @@ -117,7 +115,7 @@ def delete(self, user: UserModel, dataset: ServiceDataset): def update( self, - user: UserModel, + user: User, dataset: ServiceDataset, tags: Dict[str, str], metadata: Dict[str, Any], @@ -138,7 +136,7 @@ def update( def list( self, - user: UserModel, + user: User, workspaces: Optional[List[str]], task2dataset_map: Dict[str, Type[ServiceDataset]] = None, ) -> List[ServiceDataset]: @@ -159,7 +157,7 @@ def list( return self.__dao__.list_datasets(workspaces=workspace_names, task2dataset_map=task2dataset_map) - def close(self, user: UserModel, dataset: ServiceDataset): + def close(self, user: User, dataset: ServiceDataset): if not is_authorized(user, DatasetPolicy.close(dataset)): raise ForbiddenOperationError( "You don't have the necessary permissions to close this dataset. " @@ -167,7 +165,7 @@ def close(self, user: UserModel, dataset: ServiceDataset): ) self.__dao__.close(dataset) - def open(self, user: UserModel, dataset: ServiceDataset): + def open(self, user: User, dataset: ServiceDataset): if not is_authorized(user, DatasetPolicy.open(dataset)): raise ForbiddenOperationError( "You don't have the necessary permissions to open this dataset. " @@ -177,7 +175,7 @@ def open(self, user: UserModel, dataset: ServiceDataset): def copy_dataset( self, - user: UserModel, + user: User, dataset: ServiceDataset, copy_name: str, copy_workspace: Optional[str] = None, From 248aa096b112e8d7bd6bdd6ec14557d0f32c3dc9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Mar 2023 16:22:21 +0000 Subject: [PATCH 11/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/argilla/server/apis/v0/handlers/text2text.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/argilla/server/apis/v0/handlers/text2text.py b/src/argilla/server/apis/v0/handlers/text2text.py index 29b90234be..8354a024f4 100644 --- a/src/argilla/server/apis/v0/handlers/text2text.py +++ b/src/argilla/server/apis/v0/handlers/text2text.py @@ -143,7 +143,6 @@ def search_records( aggregations=Text2TextSearchAggregations.parse_obj(result.metrics) if result.metrics else None, ) - metrics.configure_router( router, cfg=TasksFactory.get_task_by_task_type(task_type), From 707078f05cac5229ed6b1ad1a9562bf46905d079 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 10 Mar 2023 18:00:22 +0100 Subject: [PATCH 12/13] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Francisco Calvo --- src/argilla/server/apis/v0/handlers/datasets.py | 4 ++-- .../v0/handlers/text_classification_dataset_settings.py | 4 ++-- .../v0/handlers/token_classification_dataset_settings.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/argilla/server/apis/v0/handlers/datasets.py b/src/argilla/server/apis/v0/handlers/datasets.py index 9f5be790ba..cced02c4da 100644 --- a/src/argilla/server/apis/v0/handlers/datasets.py +++ b/src/argilla/server/apis/v0/handlers/datasets.py @@ -67,10 +67,10 @@ async def create_dataset( request: CreateDatasetRequest = Body(..., description="The request dataset info"), ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), - user: User = Security(auth.get_current_user), + current_user: User = Security(auth.get_current_user), ) -> Dataset: request.workspace = request.workspace or ws_params.workspace - dataset = datasets.create_dataset(user=user, dataset=request) + dataset = datasets.create_dataset(user=current_user, dataset=request) return Dataset.from_orm(dataset) diff --git a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py index 22ee497650..b2245792ec 100644 --- a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py +++ b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py @@ -54,7 +54,7 @@ async def get_dataset_settings( name: str = DATASET_NAME_PATH_PARAM, ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), - user: User = Security(auth.get_current_user), + current_user: User = Security(auth.get_current_user), ) -> TextClassificationSettings: found_ds = datasets.find_by_name( user=user, @@ -82,7 +82,7 @@ async def save_settings( ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), - user: User = Security(auth.get_current_user), + current_user: User = Security(auth.get_current_user), ) -> TextClassificationSettings: found_ds = datasets.find_by_name( user=user, diff --git a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py index 96c9a632ef..1bd3fb870d 100644 --- a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py +++ b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py @@ -54,7 +54,7 @@ async def get_dataset_settings( name: str = DATASET_NAME_PATH_PARAM, ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), - user: User = Security(auth.get_current_user), + current_user: User = Security(auth.get_current_user), ) -> TokenClassificationSettings: found_ds = datasets.find_by_name( user=user, @@ -82,7 +82,7 @@ async def save_settings( ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), - user: User = Security(auth.get_current_user), + current_user: User = Security(auth.get_current_user), ) -> TokenClassificationSettings: found_ds = datasets.find_by_name( user=user, @@ -110,7 +110,7 @@ async def delete_settings( name: str = DATASET_NAME_PATH_PARAM, ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), - user: User = Security(auth.get_current_user), + current_user: User = Security(auth.get_current_user), ) -> None: found_ds = datasets.find_by_name( user=user, From 165852c1de511b0c08d4777d722fa8a60eabbdc7 Mon Sep 17 00:00:00 2001 From: frascuchon Date: Fri, 10 Mar 2023 21:28:57 +0100 Subject: [PATCH 13/13] Resolve missing variable refs --- .../text_classification_dataset_settings.py | 10 +++++----- .../token_classification_dataset_settings.py | 14 +++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py index b2245792ec..b3a78f2b23 100644 --- a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py +++ b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py @@ -57,13 +57,13 @@ async def get_dataset_settings( current_user: User = Security(auth.get_current_user), ) -> TextClassificationSettings: found_ds = datasets.find_by_name( - user=user, + user=current_user, name=name, workspace=ws_params.workspace, task=task, ) - settings = await datasets.get_settings(user=user, dataset=found_ds, class_type=__svc_settings_class__) + settings = await datasets.get_settings(user=current_user, dataset=found_ds, class_type=__svc_settings_class__) return TextClassificationSettings.parse_obj(settings) @deprecate_endpoint( @@ -85,14 +85,14 @@ async def save_settings( current_user: User = Security(auth.get_current_user), ) -> TextClassificationSettings: found_ds = datasets.find_by_name( - user=user, + user=current_user, name=name, task=task, workspace=ws_params.workspace, ) - await validator.validate_dataset_settings(user=user, dataset=found_ds, settings=request) + await validator.validate_dataset_settings(user=current_user, dataset=found_ds, settings=request) settings = await datasets.save_settings( - user=user, + user=current_user, dataset=found_ds, settings=__svc_settings_class__.parse_obj(request.dict()), ) diff --git a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py index 1bd3fb870d..cc9ed8500d 100644 --- a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py +++ b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py @@ -57,13 +57,13 @@ async def get_dataset_settings( current_user: User = Security(auth.get_current_user), ) -> TokenClassificationSettings: found_ds = datasets.find_by_name( - user=user, + user=current_user, name=name, workspace=ws_params.workspace, task=task, ) - settings = await datasets.get_settings(user=user, dataset=found_ds, class_type=__svc_settings_class__) + settings = await datasets.get_settings(user=current_user, dataset=found_ds, class_type=__svc_settings_class__) return TokenClassificationSettings.parse_obj(settings) @deprecate_endpoint( @@ -85,14 +85,14 @@ async def save_settings( current_user: User = Security(auth.get_current_user), ) -> TokenClassificationSettings: found_ds = datasets.find_by_name( - user=user, + user=current_user, name=name, task=task, workspace=ws_params.workspace, ) - await validator.validate_dataset_settings(user=user, dataset=found_ds, settings=request) + await validator.validate_dataset_settings(user=current_user, dataset=found_ds, settings=request) settings = await datasets.save_settings( - user=user, + user=current_user, dataset=found_ds, settings=__svc_settings_class__.parse_obj(request.dict()), ) @@ -113,13 +113,13 @@ async def delete_settings( current_user: User = Security(auth.get_current_user), ) -> None: found_ds = datasets.find_by_name( - user=user, + user=current_user, name=name, task=task, workspace=ws_params.workspace, ) await datasets.delete_settings( - user=user, + user=current_user, dataset=found_ds, )