Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate datasets crud endpoints #2510

Merged
Merged
59 changes: 24 additions & 35 deletions src/argilla/server/apis/v0/handlers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,18 @@

from fastapi import APIRouter, Body, Depends, Security
from pydantic import parse_obj_as
from sqlalchemy.orm import Session

from argilla.server import database
from argilla.server.apis.v0.helpers import deprecate_endpoint
from argilla.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies
from argilla.server.contexts import accounts
from argilla.server.daos.datasets import DatasetsDAO
from argilla.server.errors import EntityNotFoundError, ForbiddenOperationError
from argilla.server.policies import DatasetPolicy, is_authorized
from argilla.server.errors import EntityNotFoundError
from argilla.server.models import User
from argilla.server.schemas.datasets import (
CopyDatasetRequest,
CreateDatasetRequest,
Dataset,
UpdateDatasetRequest,
)
from argilla.server.security import auth
from argilla.server.security.model import User, Workspace
from argilla.server.services.datasets import DatasetsService

router = APIRouter(tags=["datasets"], prefix="/datasets")
Expand All @@ -50,7 +45,7 @@
async def list_datasets(
request_deps: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> List[Dataset]:
datasets = service.list(
user=current_user,
Expand All @@ -72,10 +67,10 @@ async def create_dataset(
request: CreateDatasetRequest = Body(..., description="The request dataset info"),
ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["create:datasets"]),
current_user: User = Security(auth.get_current_user),
) -> Dataset:
request.workspace = request.workspace or ws_params.workspace
dataset = datasets.create_dataset(user=user, dataset=request)
dataset = datasets.create_dataset(user=current_user, dataset=request)

return Dataset.from_orm(dataset)

Expand All @@ -90,7 +85,7 @@ def get_dataset(
name: str,
ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> Dataset:
return Dataset.from_orm(
service.find_by_name(
Expand All @@ -112,7 +107,7 @@ def update_dataset(
request: UpdateDatasetRequest,
ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> Dataset:
found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace)

Expand All @@ -132,28 +127,22 @@ def update_dataset(
)
def delete_dataset(
name: str,
request_params: CommonTaskHandlerDependencies = Depends(),
db: Session = Depends(database.get_db),
datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance),
user: User = Security(auth.get_current_user, scopes=[]),
ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_current_user),
):
workspace_name = request_params.workspace

workspace = accounts.get_workspace_by_name(db, workspace_name=workspace_name)
if not workspace:
raise EntityNotFoundError(name=workspace_name, type=Workspace)

dataset = datasets.find_by_name_and_workspace(name=name, workspace=workspace.name)
if not dataset:
return

if not is_authorized(user, DatasetPolicy.delete(dataset)):
raise ForbiddenOperationError(
"You don't have the necessary permissions to delete this dataset. "
"Only dataset creators or administrators can delete datasets"
try:
found_ds = service.find_by_name(
user=current_user,
name=name,
workspace=ds_params.workspace,
)

datasets.delete_dataset(dataset)
service.delete(
user=current_user,
dataset=found_ds,
)
except EntityNotFoundError:
pass


@router.put(
Expand All @@ -164,7 +153,7 @@ def close_dataset(
name: str,
ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
):
found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace)
service.close(user=current_user, dataset=found_ds)
Expand All @@ -178,7 +167,7 @@ def open_dataset(
name: str,
ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
):
found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace)
service.open(user=current_user, dataset=found_ds)
Expand All @@ -195,7 +184,7 @@ def copy_dataset(
copy_request: CopyDatasetRequest,
ds_params: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> Dataset:
found = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace)
dataset = service.copy_dataset(
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/server/apis/v0/handlers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def configure_router(router: APIRouter, cfg: TaskConfig):
def get_dataset_metrics(
name: str,
request_deps: CommonTaskHandlerDependencies = Depends(),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
) -> List[MetricInfo]:
dataset = datasets.find_by_name(
Expand Down Expand Up @@ -101,7 +101,7 @@ def metric_summary(
query: cfg.query,
metric_params: MetricSummaryParams = Depends(),
request_deps: CommonTaskHandlerDependencies = Depends(),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user, scopes=[]),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
metrics: MetricsService = Depends(MetricsService.get_instance),
):
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/server/apis/v0/handlers/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def get_dataset_record(
request_deps: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
search: SearchRecordsService = Depends(SearchRecordsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> RecordType:
found = service.find_by_name(
user=current_user,
Expand Down Expand Up @@ -85,7 +85,7 @@ async def delete_dataset_records(
request_deps: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
storage: RecordsStorageService = Depends(RecordsStorageService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
):
found = service.find_by_name(
user=current_user,
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/apis/v0/handlers/records_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def scan_dataset_records(
request_deps: CommonTaskHandlerDependencies = Depends(),
service: DatasetsService = Depends(DatasetsService.get_instance),
engine: GenericElasticEngineBackend = Depends(GenericElasticEngineBackend.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
):
found = service.find_by_name(user=current_user, name=name, workspace=request_deps.workspace)

Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/apis/v0/handlers/records_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def partial_update_dataset_record(
service: DatasetsService = Depends(DatasetsService.get_instance),
search: SearchRecordsService = Depends(SearchRecordsService.get_instance),
storage: RecordsStorageService = Depends(RecordsStorageService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user, scopes=[]),
) -> RecordType:
dataset = service.find_by_name(
user=current_user,
Expand Down
6 changes: 3 additions & 3 deletions src/argilla/server/apis/v0/handlers/text2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ async def bulk_records(
common_params: CommonTaskHandlerDependencies = Depends(),
service: Text2TextService = Depends(Text2TextService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> BulkResponse:
task = task_type
workspace = current_user.check_workspace(common_params.workspace)
workspace = common_params.workspace
try:
dataset = datasets.find_by_name(
current_user,
Expand Down Expand Up @@ -118,7 +118,7 @@ def search_records(
pagination: RequestPagination = Depends(),
service: Text2TextService = Depends(Text2TextService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> Text2TextSearchResults:
search = search or Text2TextSearchRequest()
query = search.query or Text2TextQuery()
Expand Down
20 changes: 10 additions & 10 deletions src/argilla/server/apis/v0/handlers/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ async def bulk_records(
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> BulkResponse:
task = task_type
workspace = current_user.check_workspace(common_params.workspace)
workspace = common_params.workspace
try:
dataset = datasets.find_by_name(
current_user,
Expand Down Expand Up @@ -140,7 +140,7 @@ def search_records(
pagination: RequestPagination = Depends(),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> TextClassificationSearchResults:
"""
Searches data from dataset
Expand Down Expand Up @@ -208,7 +208,7 @@ async def list_labeling_rules(
common_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> List[LabelingRule]:
dataset = datasets.find_by_name(
user=current_user,
Expand All @@ -235,7 +235,7 @@ async def create_rule(
common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> LabelingRule:
dataset = datasets.find_by_name(
user=current_user,
Expand Down Expand Up @@ -271,7 +271,7 @@ async def compute_rule_metrics(
common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> LabelingRuleMetricsSummary:
dataset = datasets.find_by_name(
user=current_user,
Expand All @@ -297,7 +297,7 @@ async def compute_dataset_rules_metrics(
common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> DatasetLabelingRulesMetricsSummary:
dataset = datasets.find_by_name(
user=current_user,
Expand All @@ -322,7 +322,7 @@ async def delete_labeling_rule(
common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> None:
dataset = datasets.find_by_name(
user=current_user,
Expand All @@ -349,7 +349,7 @@ async def get_rule(
common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> LabelingRule:
dataset = datasets.find_by_name(
user=current_user,
Expand Down Expand Up @@ -377,7 +377,7 @@ async def update_rule(
common_params: CommonTaskHandlerDependencies = Depends(),
service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user),
) -> LabelingRule:
dataset = datasets.find_by_name(
user=current_user,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ async def get_dataset_settings(
name: str = DATASET_NAME_PATH_PARAM,
ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["read:dataset.settings"]),
current_user: User = Security(auth.get_current_user),
) -> TextClassificationSettings:
found_ds = datasets.find_by_name(
user=user,
user=current_user,
name=name,
workspace=ws_params.workspace,
task=task,
)

settings = await datasets.get_settings(user=user, dataset=found_ds, class_type=__svc_settings_class__)
settings = await datasets.get_settings(user=current_user, dataset=found_ds, class_type=__svc_settings_class__)
return TextClassificationSettings.parse_obj(settings)

@deprecate_endpoint(
Expand All @@ -82,17 +82,17 @@ async def save_settings(
ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
user: User = Security(auth.get_user, scopes=["write:dataset.settings"]),
current_user: User = Security(auth.get_current_user),
) -> TextClassificationSettings:
found_ds = datasets.find_by_name(
user=user,
user=current_user,
name=name,
task=task,
workspace=ws_params.workspace,
)
await validator.validate_dataset_settings(user=user, dataset=found_ds, settings=request)
await validator.validate_dataset_settings(user=current_user, dataset=found_ds, settings=request)
settings = await datasets.save_settings(
user=user,
user=current_user,
dataset=found_ds,
settings=__svc_settings_class__.parse_obj(request.dict()),
)
Expand All @@ -110,7 +110,7 @@ async def delete_settings(
name: str = DATASET_NAME_PATH_PARAM,
ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["delete:dataset.settings"]),
user: User = Security(auth.get_current_user),
) -> None:
found_ds = datasets.find_by_name(
user=user,
Expand Down
6 changes: 3 additions & 3 deletions src/argilla/server/apis/v0/handlers/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ async def bulk_records(
service: TokenClassificationService = Depends(TokenClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user, scopes=[]),
) -> BulkResponse:
task = task_type
workspace = current_user.check_workspace(common_params.workspace)
workspace = common_params.workspace
try:
dataset = datasets.find_by_name(
current_user,
Expand Down Expand Up @@ -139,7 +139,7 @@ def search_records(
pagination: RequestPagination = Depends(),
service: TokenClassificationService = Depends(TokenClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
current_user: User = Security(auth.get_current_user, scopes=[]),
) -> TokenClassificationSearchResults:
search = search or TokenClassificationSearchRequest()
query = search.query or TokenClassificationQuery()
Expand Down
Loading