Skip to content

Commit

Permalink
Integrate datasets crud endpoints (#2510)
Browse files Browse the repository at this point in the history
This PR integrates new DB logic and policies with the current
implementation.

- The main changes are applied to the dataset service class
- The `auth.get_current_user` returns a `models.User` instance
- The ONLY change in endpoints is using this new `auth.get_current_user`
method

In a separate PR the pydantic User class will be cleaned and normalized

---------

Co-authored-by: José Francisco Calvo <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: José Francisco Calvo <[email protected]>
  • Loading branch information
4 people authored Mar 13, 2023
1 parent 302dada commit 83a11e9
Show file tree
Hide file tree
Showing 19 changed files with 220 additions and 193 deletions.
59 changes: 24 additions & 35 deletions src/argilla/server/apis/v0/handlers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,18 @@

from fastapi import APIRouter, Body, Depends, Security
from pydantic import parse_obj_as
from sqlalchemy.orm import Session

from argilla.server import database
from argilla.server.apis.v0.helpers import deprecate_endpoint
from argilla.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies
from argilla.server.contexts import accounts
from argilla.server.daos.datasets import DatasetsDAO
from argilla.server.errors import EntityNotFoundError, ForbiddenOperationError
from argilla.server.policies import DatasetPolicy, is_authorized
from argilla.server.errors import EntityNotFoundError
from argilla.server.models import User
from argilla.server.schemas.datasets import (
CopyDatasetRequest,
CreateDatasetRequest,
Dataset,
UpdateDatasetRequest,
)
from argilla.server.security import auth
from argilla.server.security.model import User, Workspace
from argilla.server.services.datasets import DatasetsService

router = APIRouter(tags=["datasets"], prefix="/datasets")
Expand All @@ -50,7 +45,7 @@
async def list_datasets(
request_deps: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> List[Dataset]:
datasets = service.list(
user=current_user,
Expand All @@ -72,10 +67,10 @@ async def create_dataset(
request: CreateDatasetRequest = Body(..., description="The request dataset info"),
ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["create:datasets"]),
current_user: User = Security(auth.get_current_user),
) -> Dataset:
request.workspace = request.workspace or ws_params.workspace
dataset = datasets.create_dataset(user=user, dataset=request)
dataset = datasets.create_dataset(user=current_user, dataset=request)

return Dataset.from_orm(dataset)

Expand All @@ -90,7 +85,7 @@ def get_dataset(
name: str,
ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> Dataset:
return Dataset.from_orm(
service.find_by_name(
Expand All @@ -112,7 +107,7 @@ def update_dataset(
request: UpdateDatasetRequest,
ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> Dataset:
found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace)

Expand All @@ -132,28 +127,22 @@ def update_dataset(
)
def delete_dataset(
name: str,
request_params: CommonTaskHandlerDependencies = Depends(),
db: Session = Depends(database.get_db),
datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance),
user: User = Security(auth.get_current_user, scopes=[]),
ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_current_user),
):
workspace_name = request_params.workspace

workspace = accounts.get_workspace_by_name(db, workspace_name=workspace_name)
if not workspace:
raise EntityNotFoundError(name=workspace_name, type=Workspace)

dataset = datasets.find_by_name_and_workspace(name=name, workspace=workspace.name)
if not dataset:
return

if not is_authorized(user, DatasetPolicy.delete(dataset)):
raise ForbiddenOperationError(
"You don't have the necessary permissions to delete this dataset. "
"Only dataset creators or administrators can delete datasets"
try:
found_ds = service.find_by_name(
user=current_user,
name=name,
workspace=ds_params.workspace,
)

datasets.delete_dataset(dataset)
service.delete(
user=current_user,
dataset=found_ds,
)
except EntityNotFoundError:
pass


@router.put(
Expand All @@ -164,7 +153,7 @@ def close_dataset(
name: str,
ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
):
found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace)
service.close(user=current_user, dataset=found_ds)
Expand All @@ -178,7 +167,7 @@ def open_dataset(
name: str,
ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
):
found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace)
service.open(user=current_user, dataset=found_ds)
Expand All @@ -195,7 +184,7 @@ def copy_dataset(
copy_request: CopyDatasetRequest,
ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> Dataset:
found = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace)
dataset = service.copy_dataset(
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/server/apis/v0/handlers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def configure_router(router: APIRouter, cfg: TaskConfig):
def get_dataset_metrics(
name: str,
request_deps: CommonTaskHandlerDependencies = Depends(),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
) -> List[MetricInfo]:
dataset = datasets.find_by_name(
Expand Down Expand Up @@ -101,7 +101,7 @@ def metric_summary(
query: cfg.query,
metric_params: MetricSummaryParams = Depends(),
request_deps: CommonTaskHandlerDependencies = Depends(),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user, scopes=[]),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
metrics: MetricsService = Depends(MetricsService.get_instance),
):
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/server/apis/v0/handlers/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def get_dataset_record(
request_deps: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
search: SearchRecordsService = Depends(SearchRecordsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> RecordType:
found = service.find_by_name(
user=current_user,
Expand Down Expand Up @@ -85,7 +85,7 @@ async def delete_dataset_records(
request_deps: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
storage: RecordsStorageService = Depends(RecordsStorageService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
):
found = service.find_by_name(
user=current_user,
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/apis/v0/handlers/records_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def scan_dataset_records(
request_deps: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
engine: GenericElasticEngineBackend = Depends(GenericElasticEngineBackend.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
):
found = service.find_by_name(user=current_user, name=name, workspace=request_deps.workspace)

Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/apis/v0/handlers/records_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def partial_update_dataset_record(
service: DatasetsService = Depends(DatasetsService.get_instance),
search: SearchRecordsService = Depends(SearchRecordsService.get_instance),
storage: RecordsStorageService = Depends(RecordsStorageService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user, scopes=[]),
) -> RecordType:
dataset = service.find_by_name(
user=current_user,
Expand Down
6 changes: 3 additions & 3 deletions src/argilla/server/apis/v0/handlers/text2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ async def bulk_records(
common_params: CommonTaskHandlerDependencies = Depends(),
service: Text2TextService = Depends(Text2TextService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> BulkResponse:
task = task_type
workspace = current_user.check_workspace(common_params.workspace)
workspace = common_params.workspace
try:
dataset = datasets.find_by_name(
current_user,
Expand Down Expand Up @@ -118,7 +118,7 @@ def search_records(
pagination: RequestPagination = Depends(),
service: Text2TextService = Depends(Text2TextService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> Text2TextSearchResults:
search = search or Text2TextSearchRequest()
query = search.query or Text2TextQuery()
Expand Down
20 changes: 10 additions & 10 deletions src/argilla/server/apis/v0/handlers/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ async def bulk_records(
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> BulkResponse:
task = task_type
workspace = current_user.check_workspace(common_params.workspace)
workspace = common_params.workspace
try:
dataset = datasets.find_by_name(
current_user,
Expand Down Expand Up @@ -140,7 +140,7 @@ def search_records(
pagination: RequestPagination = Depends(),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> TextClassificationSearchResults:
"""
Searches data from dataset
Expand Down Expand Up @@ -208,7 +208,7 @@ async def list_labeling_rules(
common_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> List[LabelingRule]:
dataset = datasets.find_by_name(
user=current_user,
Expand All @@ -235,7 +235,7 @@ async def create_rule(
common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> LabelingRule:
dataset = datasets.find_by_name(
user=current_user,
Expand Down Expand Up @@ -271,7 +271,7 @@ async def compute_rule_metrics(
common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> LabelingRuleMetricsSummary:
dataset = datasets.find_by_name(
user=current_user,
Expand All @@ -297,7 +297,7 @@ async def compute_dataset_rules_metrics(
common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> DatasetLabelingRulesMetricsSummary:
dataset = datasets.find_by_name(
user=current_user,
Expand All @@ -322,7 +322,7 @@ async def delete_labeling_rule(
common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> None:
dataset = datasets.find_by_name(
user=current_user,
Expand All @@ -349,7 +349,7 @@ async def get_rule(
common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> LabelingRule:
dataset = datasets.find_by_name(
user=current_user,
Expand Down Expand Up @@ -377,7 +377,7 @@ async def update_rule(
common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> LabelingRule:
dataset = datasets.find_by_name(
user=current_user,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ async def get_dataset_settings(
name: str = DATASET_NAME_PATH_PARAM,
ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["read:dataset.settings"]),
current_user: User = Security(auth.get_current_user),
) -> TextClassificationSettings:
found_ds = datasets.find_by_name(
user=user,
user=current_user,
name=name,
workspace=ws_params.workspace,
task=task,
)

settings = await datasets.get_settings(user=user, dataset=found_ds, class_type=__svc_settings_class__)
settings = await datasets.get_settings(user=current_user, dataset=found_ds, class_type=__svc_settings_class__)
return TextClassificationSettings.parse_obj(settings)

@deprecate_endpoint(
Expand All @@ -82,17 +82,17 @@ async def save_settings(
ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
user: User = Security(auth.get_user, scopes=["write:dataset.settings"]),
current_user: User = Security(auth.get_current_user),
) -> TextClassificationSettings:
found_ds = datasets.find_by_name(
user=user,
user=current_user,
name=name,
task=task,
workspace=ws_params.workspace,
)
await validator.validate_dataset_settings(user=user, dataset=found_ds, settings=request)
await validator.validate_dataset_settings(user=current_user, dataset=found_ds, settings=request)
settings = await datasets.save_settings(
user=user,
user=current_user,
dataset=found_ds,
settings=__svc_settings_class__.parse_obj(request.dict()),
)
Expand All @@ -110,7 +110,7 @@ async def delete_settings(
name: str = DATASET_NAME_PATH_PARAM,
ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["delete:dataset.settings"]),
user: User = Security(auth.get_current_user),
) -> None:
found_ds = datasets.find_by_name(
user=user,
Expand Down
6 changes: 3 additions & 3 deletions src/argilla/server/apis/v0/handlers/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ async def bulk_records(
service: TokenClassificationService = Depends(TokenClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user, scopes=[]),
) -> BulkResponse:
task = task_type
workspace = current_user.check_workspace(common_params.workspace)
workspace = common_params.workspace
try:
dataset = datasets.find_by_name(
current_user,
Expand Down Expand Up @@ -139,7 +139,7 @@ def search_records(
pagination: RequestPagination = Depends(),
service: TokenClassificationService = Depends(TokenClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user, scopes=[]),
) -> TokenClassificationSearchResults:
search = search or TokenClassificationSearchRequest()
query = search.query or TokenClassificationQuery()
Expand Down
Loading

0 comments on commit 83a11e9

Please sign in to comment.