diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f67cae6ff1..5ba8658505 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,10 +25,14 @@ repos: - id: black additional_dependencies: ["click==8.0.4"] - - repo: https://github.com/pycqa/isort - rev: 5.12.0 + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.0.244 hooks: - - id: isort + # Simulate isort via (the much faster) ruff + - id: ruff + args: + - --select=I + - --fix - repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook rev: v9.4.0 diff --git a/pyproject.toml b/pyproject.toml index 5052f1d547..f28608d42e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,3 +108,34 @@ exclude_lines = [ [tool.isort] profile = "black" + +[tool.ruff] +# Ignore line length violations +ignore = ["E501"] + +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", +] + +[tool.ruff.per-file-ignores] +# Ignore imported but unused; +"__init__.py" = ["F401"] diff --git a/scripts/load_data.py b/scripts/load_data.py index 0295872157..06405d7650 100644 --- a/scripts/load_data.py +++ b/scripts/load_data.py @@ -15,12 +15,11 @@ import sys import time +import argilla as rg import pandas as pd import requests -from datasets import load_dataset - -import argilla as rg from argilla.labeling.text_classification import Rule, add_rules +from datasets import load_dataset class LoadDatasets: diff --git a/src/argilla/__init__.py b/src/argilla/__init__.py index c9fb4f6787..1dbd654c9f 100644 --- a/src/argilla/__init__.py +++ b/src/argilla/__init__.py @@ -47,12 +47,10 @@ read_datasets, read_pandas, ) - from argilla.client.models import ( - TextGenerationRecord, # TODO Remove TextGenerationRecord - ) from argilla.client.models import ( Text2TextRecord, TextClassificationRecord, + TextGenerationRecord, # TODO Remove TextGenerationRecord TokenAttributions, TokenClassificationRecord, ) diff --git a/src/argilla/client/apis/datasets.py b/src/argilla/client/apis/datasets.py index 55079c7a95..5fff021362 100644 --- a/src/argilla/client/apis/datasets.py +++ b/src/argilla/client/apis/datasets.py @@ -15,7 +15,7 @@ import warnings from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from pydantic import BaseModel, Field diff --git a/src/argilla/client/apis/search.py b/src/argilla/client/apis/search.py index f8a1b66a98..07de6bb71e 100644 --- a/src/argilla/client/apis/search.py +++ b/src/argilla/client/apis/search.py @@ -13,7 +13,7 @@ # limitations under the License. import dataclasses -from typing import List, Optional, Union +from typing import List, Optional from argilla.client.apis import AbstractApi from argilla.client.models import Record diff --git a/src/argilla/client/datasets.py b/src/argilla/client/datasets.py index 944e7d93f2..952a3a96ef 100644 --- a/src/argilla/client/datasets.py +++ b/src/argilla/client/datasets.py @@ -16,7 +16,7 @@ import logging import random import uuid -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import pandas as pd from pkg_resources import parse_version @@ -32,6 +32,11 @@ from argilla.client.sdk.datasets.models import TaskType from argilla.utils.span_utils import SpanUtils +if TYPE_CHECKING: + import datasets + import pandas + import spacy + _LOGGER = logging.getLogger(__name__) @@ -60,7 +65,7 @@ def _requires_spacy(func): @functools.wraps(func) def check_if_spacy_installed(*args, **kwargs): try: - import spacy + import spacy # noqa: F401 except ModuleNotFoundError: raise ModuleNotFoundError( f"'spacy' must be installed to use `{func.__name__}`" @@ -1007,7 +1012,7 @@ def from_datasets( for row in dataset: # TODO: fails with a KeyError if no tokens column is present and no mapping is indicated if not row["tokens"]: - _LOGGER.warning(f"Ignoring row with no tokens.") + _LOGGER.warning("Ignoring row with no tokens.") continue if row.get("tags"): diff --git a/src/argilla/client/sdk/client.py b/src/argilla/client/sdk/client.py index ee244c7697..be4d2983b9 100644 --- a/src/argilla/client/sdk/client.py +++ b/src/argilla/client/sdk/client.py @@ -119,7 +119,7 @@ async def inner_async(self, *args, **kwargs): try: result = await func(self, *args, **kwargs) return result - except httpx.ConnectError as err: + except httpx.ConnectError as err: # noqa: F841 return wrap_error(self.base_url) @functools.wraps(func) @@ -127,7 +127,7 @@ def inner(self, *args, **kwargs): try: result = func(self, *args, **kwargs) return result - except httpx.ConnectError as err: + except httpx.ConnectError as err: # noqa: F841 return wrap_error(self.base_url) is_coroutine = inspect.iscoroutinefunction(func) diff --git a/src/argilla/client/sdk/commons/api.py b/src/argilla/client/sdk/commons/api.py index 9ccc3eb3bb..3810aebdf1 100644 --- a/src/argilla/client/sdk/commons/api.py +++ b/src/argilla/client/sdk/commons/api.py @@ -126,7 +126,7 @@ def build_data_response( parsed_record = json.loads(r) try: parsed_response = data_type(**parsed_record) - except Exception as err: + except Exception as err: # noqa: F841 raise GenericApiError(**parsed_record) from None parsed_responses.append(parsed_response) return Response( diff --git a/src/argilla/client/sdk/text2text/api.py b/src/argilla/client/sdk/text2text/api.py index 2baab48725..5b5aca9568 100644 --- a/src/argilla/client/sdk/text2text/api.py +++ b/src/argilla/client/sdk/text2text/api.py @@ -33,7 +33,6 @@ def data( limit: Optional[int] = None, id_from: Optional[str] = None, ) -> Response[Union[List[Text2TextRecord], HTTPValidationError, ErrorMessage]]: - path = f"/api/datasets/{name}/Text2Text/data" params = build_param_dict(id_from, limit) diff --git a/src/argilla/client/sdk/text_classification/api.py b/src/argilla/client/sdk/text_classification/api.py index 8a5161fe4a..39c603411e 100644 --- a/src/argilla/client/sdk/text_classification/api.py +++ b/src/argilla/client/sdk/text_classification/api.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union import httpx diff --git a/src/argilla/client/sdk/token_classification/api.py b/src/argilla/client/sdk/token_classification/api.py index f5a95ba36d..10a91d0f59 100644 --- a/src/argilla/client/sdk/token_classification/api.py +++ b/src/argilla/client/sdk/token_classification/api.py @@ -39,7 +39,6 @@ def data( ) -> Response[ Union[List[TokenClassificationRecord], HTTPValidationError, ErrorMessage] ]: - path = f"/api/datasets/{name}/TokenClassification/data" params = build_param_dict(id_from, limit) diff --git a/src/argilla/client/sdk/users/api.py b/src/argilla/client/sdk/users/api.py index 964b6f4480..c22d704b13 100644 --- a/src/argilla/client/sdk/users/api.py +++ b/src/argilla/client/sdk/users/api.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import httpx - from argilla.client.sdk.client import AuthenticatedClient -from argilla.client.sdk.commons.errors_handler import handle_response_error from argilla.client.sdk.users.models import User diff --git a/src/argilla/labeling/text_classification/label_models.py b/src/argilla/labeling/text_classification/label_models.py index 27f2e742ea..7186c46b9e 100644 --- a/src/argilla/labeling/text_classification/label_models.py +++ b/src/argilla/labeling/text_classification/label_models.py @@ -20,7 +20,6 @@ import numpy as np from argilla import DatasetForTextClassification, TextClassificationRecord -from argilla.client.datasets import Dataset from argilla.labeling.text_classification.weak_labels import WeakLabels, WeakMultiLabels _LOGGER = logging.getLogger(__name__) @@ -368,7 +367,7 @@ def score( MissingAnnotationError: If the ``weak_labels`` do not contain annotated records. """ try: - import sklearn + import sklearn # noqa: F401 except ModuleNotFoundError: raise ModuleNotFoundError( "'sklearn' must be installed to compute the metrics! " @@ -501,7 +500,7 @@ def __init__( self, weak_labels: WeakLabels, verbose: bool = True, device: str = "cpu" ): try: - import snorkel + import snorkel # noqa: F401 except ModuleNotFoundError: raise ModuleNotFoundError( "'snorkel' must be installed to use the `Snorkel` label model! " @@ -764,8 +763,8 @@ class FlyingSquid(LabelModel): def __init__(self, weak_labels: WeakLabels, **kwargs): try: - import flyingsquid - import pgmpy + import flyingsquid # noqa: F401 + import pgmpy # noqa: F401 except ModuleNotFoundError: raise ModuleNotFoundError( "'flyingsquid' must be installed to use the `FlyingSquid` label model!" @@ -1024,7 +1023,7 @@ def score( MissingAnnotationError: If the ``weak_labels`` do not contain annotated records. """ try: - import sklearn + import sklearn # noqa: F401 except ModuleNotFoundError: raise ModuleNotFoundError( "'sklearn' must be installed to compute the metrics! " diff --git a/src/argilla/listeners/listener.py b/src/argilla/listeners/listener.py index 77e48e7896..d836bbf5e6 100644 --- a/src/argilla/listeners/listener.py +++ b/src/argilla/listeners/listener.py @@ -98,7 +98,7 @@ def catch_exceptions_decorator(job_func): def wrapper(*args, **kwargs): try: return job_func(*args, **kwargs) - except: + except: # noqa: E722 import traceback print(traceback.format_exc()) @@ -208,7 +208,7 @@ def __listener_iteration_job__(self, *args, **kwargs): self._LOGGER.debug(f"Evaluate condition with arguments: {condition_args}") if self.condition(*condition_args): - self._LOGGER.debug(f"Condition passed! Running action...") + self._LOGGER.debug("Condition passed! Running action...") return self.__run_action__(ctx, *args, **kwargs) def __compute_metrics__(self, current_api, dataset, query: str) -> Metrics: @@ -235,7 +235,7 @@ def __run_action__(self, ctx: Optional[RGListenerContext] = None, *args, **kwarg ) self._LOGGER.debug(f"Running action with arguments: {action_args}") return self.action(*args, *action_args, **kwargs) - except: + except: # noqa: E722 import traceback print(traceback.format_exc()) diff --git a/src/argilla/listeners/models.py b/src/argilla/listeners/models.py index c61f603177..5c276b9a3f 100644 --- a/src/argilla/listeners/models.py +++ b/src/argilla/listeners/models.py @@ -13,12 +13,15 @@ # limitations under the License. import dataclasses -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union from prodict import Prodict from argilla.client.models import Record +if TYPE_CHECKING: + from argilla.listeners import RGDatasetListener + @dataclasses.dataclass class Search: diff --git a/src/argilla/metrics/__init__.py b/src/argilla/metrics/__init__.py index 53ef179b65..df7a911c61 100644 --- a/src/argilla/metrics/__init__.py +++ b/src/argilla/metrics/__init__.py @@ -19,15 +19,13 @@ entity_consistency, entity_density, entity_labels, -) -from .token_classification import f1 as ner_f1 -from .token_classification import ( mention_length, token_capitalness, token_frequency, token_length, tokens_length, ) +from .token_classification import f1 as ner_f1 __all__ = [ text_length, diff --git a/src/argilla/monitoring/asgi.py b/src/argilla/monitoring/asgi.py index 4b2119be43..f861eee958 100644 --- a/src/argilla/monitoring/asgi.py +++ b/src/argilla/monitoring/asgi.py @@ -21,7 +21,7 @@ from argilla.monitoring.base import BaseMonitor try: - import starlette + import starlette # noqa: F401 except ModuleNotFoundError: raise ModuleNotFoundError( "'starlette' must be installed to use the middleware feature! " diff --git a/src/argilla/monitoring/base.py b/src/argilla/monitoring/base.py index 2e8002de46..59fc8976d4 100644 --- a/src/argilla/monitoring/base.py +++ b/src/argilla/monitoring/base.py @@ -13,7 +13,6 @@ # limitations under the License. import atexit -import dataclasses import logging import random import threading diff --git a/src/argilla/server/apis/v0/handlers/datasets.py b/src/argilla/server/apis/v0/handlers/datasets.py index 01d7d75478..0d1470f0f9 100644 --- a/src/argilla/server/apis/v0/handlers/datasets.py +++ b/src/argilla/server/apis/v0/handlers/datasets.py @@ -64,7 +64,7 @@ async def list_datasets( description="Create a new dataset", ) async def create_dataset( - request: CreateDatasetRequest = Body(..., description=f"The request dataset info"), + request: CreateDatasetRequest = Body(..., description="The request dataset info"), ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), user: User = Security(auth.get_user, scopes=["create:datasets"]), diff --git a/src/argilla/server/apis/v0/handlers/metrics.py b/src/argilla/server/apis/v0/handlers/metrics.py index 96e0364697..68b92bbb37 100644 --- a/src/argilla/server/apis/v0/handlers/metrics.py +++ b/src/argilla/server/apis/v0/handlers/metrics.py @@ -58,7 +58,7 @@ def configure_router(router: APIRouter, cfg: TaskConfig): path=base_metrics_endpoint, new_path=new_base_metrics_endpoint, router_method=router.get, - operation_id=f"get_dataset_metrics", + operation_id="get_dataset_metrics", name="get_dataset_metrics", ) def get_dataset_metrics( @@ -84,7 +84,7 @@ def get_dataset_metrics( path=base_metrics_endpoint + "/{metric}:summary", new_path=new_base_metrics_endpoint + "/{metric}:summary", router_method=router.post, - operation_id=f"metric_summary", + operation_id="metric_summary", name="metric_summary", ) def metric_summary( diff --git a/src/argilla/server/apis/v0/handlers/records_update.py b/src/argilla/server/apis/v0/handlers/records_update.py index f77d391e53..b2034265ad 100644 --- a/src/argilla/server/apis/v0/handlers/records_update.py +++ b/src/argilla/server/apis/v0/handlers/records_update.py @@ -14,17 +14,12 @@ from typing import Any, Dict, Optional, Union -from fastapi import APIRouter, Depends, Query, Security +from fastapi import APIRouter, Depends, Security from pydantic import BaseModel -from argilla.client.sdk.token_classification.models import TokenClassificationQuery -from argilla.server.apis.v0.helpers import deprecate_endpoint from argilla.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies -from argilla.server.apis.v0.models.text2text import Text2TextQuery, Text2TextRecord -from argilla.server.apis.v0.models.text_classification import ( - TextClassificationQuery, - TextClassificationRecord, -) +from argilla.server.apis.v0.models.text2text import Text2TextRecord +from argilla.server.apis.v0.models.text_classification import TextClassificationRecord from argilla.server.apis.v0.models.token_classification import TokenClassificationRecord from argilla.server.commons.config import TasksFactory from argilla.server.commons.models import TaskStatus diff --git a/src/argilla/server/apis/v0/handlers/token_classification.py b/src/argilla/server/apis/v0/handlers/token_classification.py index 940775a132..7a21310e1d 100644 --- a/src/argilla/server/apis/v0/handlers/token_classification.py +++ b/src/argilla/server/apis/v0/handlers/token_classification.py @@ -23,7 +23,6 @@ metrics, token_classification_dataset_settings, ) -from argilla.server.apis.v0.helpers import deprecate_endpoint from argilla.server.apis.v0.models.commons.model import BulkResponse from argilla.server.apis.v0.models.commons.params import ( CommonTaskHandlerDependencies, diff --git a/src/argilla/server/apis/v0/models/text2text.py b/src/argilla/server/apis/v0/models/text2text.py index e2730f46e3..5345f95746 100644 --- a/src/argilla/server/apis/v0/models/text2text.py +++ b/src/argilla/server/apis/v0/models/text2text.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime from typing import Dict, List, Optional from pydantic import BaseModel, Field, validator @@ -27,7 +26,7 @@ SortableField, ) from argilla.server.apis.v0.models.datasets import UpdateDatasetRequest -from argilla.server.commons.models import PredictionStatus, TaskType +from argilla.server.commons.models import PredictionStatus from argilla.server.services.metrics.models import CommonTasksMetrics from argilla.server.services.search.model import ( ServiceBaseRecordsQuery, diff --git a/src/argilla/server/apis/v0/models/text_classification.py b/src/argilla/server/apis/v0/models/text_classification.py index 84f1868f78..b2886fe56b 100644 --- a/src/argilla/server/apis/v0/models/text_classification.py +++ b/src/argilla/server/apis/v0/models/text_classification.py @@ -14,7 +14,7 @@ # limitations under the License. from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union from pydantic import BaseModel, Field, root_validator, validator @@ -39,14 +39,11 @@ ) from argilla.server.services.tasks.text_classification.model import ( ServiceTextClassificationDataset, -) -from argilla.server.services.tasks.text_classification.model import ( - ServiceTextClassificationQuery as _TextClassificationQuery, + TokenAttributions, ) from argilla.server.services.tasks.text_classification.model import ( TextClassificationAnnotation as _TextClassificationAnnotation, ) -from argilla.server.services.tasks.text_classification.model import TokenAttributions class UpdateLabelingRule(BaseModel): diff --git a/src/argilla/server/daos/backend/metrics/base.py b/src/argilla/server/daos/backend/metrics/base.py index 80122add68..256ffcdbf1 100644 --- a/src/argilla/server/daos/backend/metrics/base.py +++ b/src/argilla/server/daos/backend/metrics/base.py @@ -13,11 +13,14 @@ # limitations under the License. import dataclasses -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from argilla.server.daos.backend.query_helpers import aggregations from argilla.server.helpers import unflatten_dict +if TYPE_CHECKING: + from argilla.server.daos.backend.client_adapters.base import IClientAdapter + @dataclasses.dataclass class ElasticsearchMetric: diff --git a/src/argilla/server/daos/backend/search/query_builder.py b/src/argilla/server/daos/backend/search/query_builder.py index a05ed872b9..76d9e989ff 100644 --- a/src/argilla/server/daos/backend/search/query_builder.py +++ b/src/argilla/server/daos/backend/search/query_builder.py @@ -277,7 +277,7 @@ def map_2_es_sort_configuration( es_sort = [] for sortable_field in sort.sort_by or [SortableField(id="id")]: if valid_fields: - if not sortable_field.id.split(".")[0] in valid_fields: + if sortable_field.id.split(".")[0] not in valid_fields: raise AssertionError( f"Wrong sort id {sortable_field.id}. Valid values are: " f"{[str(v) for v in valid_fields]}" diff --git a/src/argilla/server/daos/models/datasets.py b/src/argilla/server/daos/models/datasets.py index 22a57e2019..027e4a7011 100644 --- a/src/argilla/server/daos/models/datasets.py +++ b/src/argilla/server/daos/models/datasets.py @@ -46,7 +46,7 @@ def id(self) -> str: """The dataset id. Compounded by owner and name""" return self.build_dataset_id(self.name, self.owner) - def dict(self, *args, **kwargs) -> "DictStrAny": + def dict(self, *args, **kwargs) -> Dict[str, Any]: """ Extends base component dict extending object properties and user defined extended fields diff --git a/src/argilla/server/daos/models/records.py b/src/argilla/server/daos/models/records.py index 43776057e6..117d39581e 100644 --- a/src/argilla/server/daos/models/records.py +++ b/src/argilla/server/daos/models/records.py @@ -225,7 +225,7 @@ def extended_fields(self) -> Dict[str, Any]: "score": self.scores, } - def dict(self, *args, **kwargs) -> "DictStrAny": + def dict(self, *args, **kwargs) -> Dict[str, Any]: """ Extends base component dict extending object properties and user defined extended fields diff --git a/src/argilla/server/errors/api_errors.py b/src/argilla/server/errors/api_errors.py index 01fb75c81f..1cc0f6eb50 100644 --- a/src/argilla/server/errors/api_errors.py +++ b/src/argilla/server/errors/api_errors.py @@ -15,7 +15,7 @@ import logging from typing import Any, Dict -from fastapi import HTTPException, Request, status +from fastapi import HTTPException, Request from fastapi.exception_handlers import http_exception_handler from pydantic import BaseModel diff --git a/src/argilla/server/errors/base_errors.py b/src/argilla/server/errors/base_errors.py index ab69d7ff7b..36ed5bd700 100644 --- a/src/argilla/server/errors/base_errors.py +++ b/src/argilla/server/errors/base_errors.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Optional, Type, Union import pydantic from starlette import status diff --git a/src/argilla/server/server.py b/src/argilla/server/server.py index 7e1ece9e26..67e146d670 100644 --- a/src/argilla/server/server.py +++ b/src/argilla/server/server.py @@ -16,7 +16,6 @@ """ This module configures the global fastapi application """ -import fileinput import glob import inspect import logging diff --git a/src/argilla/server/services/datasets.py b/src/argilla/server/services/datasets.py index 2a3bce6f9a..d67914d354 100644 --- a/src/argilla/server/services/datasets.py +++ b/src/argilla/server/services/datasets.py @@ -146,7 +146,7 @@ def delete(self, user: User, dataset: ServiceDataset): self.__dao__.delete_dataset(dataset) else: raise ForbiddenOperationError( - f"You don't have the necessary permissions to delete this dataset. " + "You don't have the necessary permissions to delete this dataset. " "Only dataset creators or administrators can delete datasets" ) diff --git a/src/argilla/server/services/storage/service.py b/src/argilla/server/services/storage/service.py index 27897efe11..703adaab55 100644 --- a/src/argilla/server/services/storage/service.py +++ b/src/argilla/server/services/storage/service.py @@ -13,7 +13,7 @@ # limitations under the License. import dataclasses -from typing import Any, Dict, List, Optional, Type +from typing import List, Optional, Type from fastapi import Depends @@ -94,7 +94,7 @@ async def delete_records( else: if not user.is_superuser() and user.username != dataset.created_by: raise ForbiddenOperationError( - f"You don't have the necessary permissions to delete records on this dataset. " + "You don't have the necessary permissions to delete records on this dataset. " "Only dataset creators or administrators can delete datasets" ) diff --git a/src/argilla/server/services/tasks/text2text/models.py b/src/argilla/server/services/tasks/text2text/models.py index bccc2b912d..f8d1e8438c 100644 --- a/src/argilla/server/services/tasks/text2text/models.py +++ b/src/argilla/server/services/tasks/text2text/models.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field @@ -21,9 +20,7 @@ from argilla.server.services.datasets import ServiceBaseDataset from argilla.server.services.search.model import ( ServiceBaseRecordsQuery, - ServiceBaseSearchResultsAggregations, ServiceScoreRange, - ServiceSearchResults, ) from argilla.server.services.tasks.commons import ( ServiceBaseAnnotation, @@ -91,4 +88,3 @@ class ServiceText2TextQuery(ServiceBaseRecordsQuery): class ServiceText2TextDataset(ServiceBaseDataset): task: TaskType = Field(default=TaskType.text2text, const=True) - pass diff --git a/src/argilla/server/services/tasks/text2text/service.py b/src/argilla/server/services/tasks/text2text/service.py index 48f5af61b3..a83e219153 100644 --- a/src/argilla/server/services/tasks/text2text/service.py +++ b/src/argilla/server/services/tasks/text2text/service.py @@ -13,12 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Type +from typing import Iterable, List, Optional from fastapi import Depends from argilla.server.commons.config import TasksFactory -from argilla.server.services.metrics.models import ServiceBaseTaskMetrics from argilla.server.services.search.model import ( ServiceSearchResults, ServiceSortableField, diff --git a/src/argilla/server/services/tasks/text_classification/model.py b/src/argilla/server/services/tasks/text_classification/model.py index 2450d9731d..a94a2632a4 100644 --- a/src/argilla/server/services/tasks/text_classification/model.py +++ b/src/argilla/server/services/tasks/text_classification/model.py @@ -99,7 +99,7 @@ def check_label_length(cls, class_label): assert 1 <= len(class_label) <= DEFAULT_MAX_KEYWORD_LENGTH, ( f"Class name '{class_label}' exceeds max length of {DEFAULT_MAX_KEYWORD_LENGTH}" if len(class_label) > DEFAULT_MAX_KEYWORD_LENGTH - else f"Class name must not be empty" + else "Class name must not be empty" ) return class_label diff --git a/src/argilla/utils/span_utils.py b/src/argilla/utils/span_utils.py index 201039cb4d..4a236c6083 100644 --- a/src/argilla/utils/span_utils.py +++ b/src/argilla/utils/span_utils.py @@ -111,7 +111,7 @@ def validate(self, spans: List[Tuple[str, int, int]]): if misaligned_spans_errors: spans = "\n".join(misaligned_spans_errors) - message += f"Following entity spans are not aligned with provided tokenization\n" + message += "Following entity spans are not aligned with provided tokenization\n" message += f"Spans:\n{spans}\n" message += f"Tokens:\n{self.tokens}" diff --git a/tests/client/apis/test_base.py b/tests/client/apis/test_base.py index 189969dbb6..1101d9849a 100644 --- a/tests/client/apis/test_base.py +++ b/tests/client/apis/test_base.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.client import api from argilla.client.apis import AbstractApi, api_compatibility from argilla.client.sdk._helpers import handle_response_error diff --git a/tests/client/conftest.py b/tests/client/conftest.py index ae01679272..ab95abdd6f 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -15,10 +15,9 @@ import datetime from typing import List -import pytest - import argilla import argilla as ar +import pytest from argilla.client.sdk.datasets.models import TaskType from argilla.client.sdk.text2text.models import ( CreationText2TextRecord, diff --git a/tests/client/functional_tests/test_record_update.py b/tests/client/functional_tests/test_record_update.py index 345149ca69..8c52cb4058 100644 --- a/tests/client/functional_tests/test_record_update.py +++ b/tests/client/functional_tests/test_record_update.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.client.api import active_api from argilla.client.sdk.commons.errors import NotFoundApiError diff --git a/tests/client/functional_tests/test_scan_raw_records.py b/tests/client/functional_tests/test_scan_raw_records.py index f58b18006a..7af6af9b6f 100644 --- a/tests/client/functional_tests/test_scan_raw_records.py +++ b/tests/client/functional_tests/test_scan_raw_records.py @@ -13,8 +13,6 @@ # limitations under the License. import pytest - -import argilla from argilla.client.api import active_api from argilla.client.sdk.token_classification.models import TokenClassificationRecord @@ -28,9 +26,8 @@ def test_scan_records( gutenberg_spacy_ner, fields, ): - import pandas as pd - import argilla as rg + import pandas as pd data = active_api().datasets.scan( name=gutenberg_spacy_ner, diff --git a/tests/client/sdk/commons/api.py b/tests/client/sdk/commons/api.py index 92eeaa9ef6..b7ca635486 100644 --- a/tests/client/sdk/commons/api.py +++ b/tests/client/sdk/commons/api.py @@ -14,8 +14,6 @@ # limitations under the License. import httpx import pytest -from httpx import Response as HttpxResponse - from argilla.client.sdk.commons.api import ( build_bulk_response, build_data_response, @@ -29,6 +27,7 @@ ValidationError, ) from argilla.client.sdk.text_classification.models import TextClassificationRecord +from httpx import Response as HttpxResponse def test_text2text_bulk(sdk_client, mocked_client, bulk_text2text_data, monkeypatch): diff --git a/tests/client/sdk/commons/test_client.py b/tests/client/sdk/commons/test_client.py index 5b1dd8bedd..cb801f12ef 100644 --- a/tests/client/sdk/commons/test_client.py +++ b/tests/client/sdk/commons/test_client.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.client.api import active_api from argilla.client.sdk.client import Client diff --git a/tests/client/sdk/conftest.py b/tests/client/sdk/conftest.py index 7186d02816..75ef26042b 100644 --- a/tests/client/sdk/conftest.py +++ b/tests/client/sdk/conftest.py @@ -16,9 +16,8 @@ from datetime import datetime from typing import Any, Dict, List -import pytest - import argilla as ar +import pytest from argilla._constants import DEFAULT_API_KEY from argilla.client.sdk.client import AuthenticatedClient from argilla.client.sdk.text2text.models import ( diff --git a/tests/client/sdk/datasets/test_api.py b/tests/client/sdk/datasets/test_api.py index ec46434227..7a5fba08d0 100644 --- a/tests/client/sdk/datasets/test_api.py +++ b/tests/client/sdk/datasets/test_api.py @@ -14,7 +14,6 @@ # limitations under the License. import httpx import pytest - from argilla._constants import DEFAULT_API_KEY from argilla.client.sdk.client import AuthenticatedClient from argilla.client.sdk.commons.errors import ( @@ -22,14 +21,8 @@ NotFoundApiError, ValidationApiError, ) -from argilla.client.sdk.commons.models import ( - ErrorMessage, - HTTPValidationError, - Response, - ValidationError, -) from argilla.client.sdk.datasets.api import _build_response, get_dataset -from argilla.client.sdk.datasets.models import Dataset, TaskType +from argilla.client.sdk.datasets.models import Dataset from argilla.client.sdk.text_classification.models import TextClassificationBulkData diff --git a/tests/client/sdk/datasets/test_models.py b/tests/client/sdk/datasets/test_models.py index a10ea4acd6..c490db0e9b 100644 --- a/tests/client/sdk/datasets/test_models.py +++ b/tests/client/sdk/datasets/test_models.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest - from argilla.client.sdk.datasets.models import Dataset, TaskType from argilla.server.apis.v0.models.datasets import Dataset as ServerDataset diff --git a/tests/client/sdk/text2text/test_models.py b/tests/client/sdk/text2text/test_models.py index 811d56f61a..57f2e70178 100644 --- a/tests/client/sdk/text2text/test_models.py +++ b/tests/client/sdk/text2text/test_models.py @@ -16,7 +16,6 @@ from datetime import datetime import pytest - from argilla.client.models import Text2TextRecord from argilla.client.sdk.text2text.models import ( CreationText2TextRecord, diff --git a/tests/client/sdk/text_classification/test_models.py b/tests/client/sdk/text_classification/test_models.py index 24bee1eed1..9bebb747c5 100644 --- a/tests/client/sdk/text_classification/test_models.py +++ b/tests/client/sdk/text_classification/test_models.py @@ -16,7 +16,6 @@ from datetime import datetime import pytest - from argilla.client.models import TextClassificationRecord, TokenAttributions from argilla.client.sdk.text_classification.models import ( ClassPrediction, diff --git a/tests/client/sdk/token_classification/test_models.py b/tests/client/sdk/token_classification/test_models.py index bf0af8e502..4e419b07d2 100644 --- a/tests/client/sdk/token_classification/test_models.py +++ b/tests/client/sdk/token_classification/test_models.py @@ -16,7 +16,6 @@ from datetime import datetime import pytest - from argilla.client.models import TokenClassificationRecord from argilla.client.sdk.token_classification.models import ( CreationTokenClassificationRecord, diff --git a/tests/client/sdk/users/test_api.py b/tests/client/sdk/users/test_api.py index 34c053ff8a..8eb8f2b3c7 100644 --- a/tests/client/sdk/users/test_api.py +++ b/tests/client/sdk/users/test_api.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.client.sdk.client import AuthenticatedClient from argilla.client.sdk.commons.errors import BaseClientError, UnauthorizedApiError from argilla.client.sdk.users.api import whoami diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 4f80a66deb..0c52482c27 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -15,14 +15,13 @@ import concurrent.futures import datetime from time import sleep -from typing import Any, Iterable, List +from typing import Any, Iterable +import argilla as ar import datasets import httpx import pandas as pd import pytest - -import argilla as ar from argilla._constants import ( _OLD_WORKSPACE_HEADER_NAME, DEFAULT_API_KEY, @@ -46,6 +45,7 @@ from argilla.server.apis.v0.models.text_classification import ( TextClassificationSearchResults, ) + from tests.helpers import SecuredClient from tests.server.test_api import create_some_data_for_text_classification diff --git a/tests/client/test_asgi.py b/tests/client/test_asgi.py index 41565b7aa8..465e20136d 100644 --- a/tests/client/test_asgi.py +++ b/tests/client/test_asgi.py @@ -16,17 +16,16 @@ import time from typing import Any, Dict -from fastapi import FastAPI -from starlette.applications import Starlette -from starlette.responses import JSONResponse, PlainTextResponse -from starlette.testclient import TestClient - import argilla from argilla.monitoring.asgi import ( ArgillaLogHTTPMiddleware, text_classification_mapper, token_classification_mapper, ) +from fastapi import FastAPI +from starlette.applications import Starlette +from starlette.responses import JSONResponse, PlainTextResponse +from starlette.testclient import TestClient def test_argilla_middleware_for_text_classification( diff --git a/tests/client/test_client_errors.py b/tests/client/test_client_errors.py index a19ff3dd12..f9f0cd4703 100644 --- a/tests/client/test_client_errors.py +++ b/tests/client/test_client_errors.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.client.sdk.commons.errors import UnauthorizedApiError diff --git a/tests/client/test_dataset.py b/tests/client/test_dataset.py index e335e419b3..83ae429af0 100644 --- a/tests/client/test_dataset.py +++ b/tests/client/test_dataset.py @@ -17,12 +17,11 @@ import sys from time import sleep +import argilla as ar import datasets import pandas as pd import pytest import spacy - -import argilla as ar from argilla.client.datasets import ( DatasetBase, DatasetForTokenClassification, diff --git a/tests/client/test_models.py b/tests/client/test_models.py index 6d5d4b052b..75d37bde11 100644 --- a/tests/client/test_models.py +++ b/tests/client/test_models.py @@ -19,15 +19,13 @@ import numpy import pandas as pd import pytest -from pydantic import ValidationError - -from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH from argilla.client.models import ( Text2TextRecord, TextClassificationRecord, TokenClassificationRecord, _Validators, ) +from pydantic import ValidationError @pytest.mark.parametrize( diff --git a/tests/conftest.py b/tests/conftest.py index a9b28296e0..33738a6db6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,6 @@ import httpx import pytest from _pytest.logging import LogCaptureFixture - from argilla.client.sdk.users import api as users_api from argilla.server.commons import telemetry @@ -23,10 +22,9 @@ from loguru import logger except ModuleNotFoundError: logger = None -from starlette.testclient import TestClient - from argilla import app from argilla.client.api import active_api +from starlette.testclient import TestClient from .helpers import SecuredClient diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 3edfdccf73..4499d25964 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - import argilla as ar +import pytest from argilla import TextClassificationSettings, TokenClassificationSettings from argilla.client import api from argilla.client.sdk.commons.errors import ForbiddenApiError diff --git a/tests/functional_tests/datasets/test_delete_records_from_datasets.py b/tests/functional_tests/datasets/test_delete_records_from_datasets.py index 5279ebb567..238c358338 100644 --- a/tests/functional_tests/datasets/test_delete_records_from_datasets.py +++ b/tests/functional_tests/datasets/test_delete_records_from_datasets.py @@ -15,7 +15,6 @@ import time import pytest - from argilla.client.sdk.commons.errors import ForbiddenApiError diff --git a/tests/functional_tests/datasets/test_update_record.py b/tests/functional_tests/datasets/test_update_record.py index 8ffd74fabc..7eb050a4c5 100644 --- a/tests/functional_tests/datasets/test_update_record.py +++ b/tests/functional_tests/datasets/test_update_record.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.server.apis.v0.models.text2text import Text2TextRecord from argilla.server.apis.v0.models.text_classification import TextClassificationRecord from argilla.server.apis.v0.models.token_classification import TokenClassificationRecord diff --git a/tests/functional_tests/search/test_search_service.py b/tests/functional_tests/search/test_search_service.py index 92eb8062bb..9fc0b9b439 100644 --- a/tests/functional_tests/search/test_search_service.py +++ b/tests/functional_tests/search/test_search_service.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - import argilla +import pytest from argilla.server.apis.v0.models.commons.model import ScoreRange from argilla.server.apis.v0.models.datasets import Dataset from argilla.server.apis.v0.models.text_classification import ( diff --git a/tests/functional_tests/test_log_for_text_classification.py b/tests/functional_tests/test_log_for_text_classification.py index 0a85ec53e8..5e2a16f8bb 100644 --- a/tests/functional_tests/test_log_for_text_classification.py +++ b/tests/functional_tests/test_log_for_text_classification.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - import argilla as ar +import pytest from argilla.client.sdk.commons.errors import ( BadRequestApiError, GenericApiError, ValidationApiError, ) from argilla.server.settings import settings -from tests.client.conftest import SUPPORTED_VECTOR_SEARCH, supported_vector_search + +from tests.client.conftest import SUPPORTED_VECTOR_SEARCH from tests.helpers import SecuredClient diff --git a/tests/functional_tests/test_log_for_token_classification.py b/tests/functional_tests/test_log_for_token_classification.py index f4bea51e2c..2df8c37c5c 100644 --- a/tests/functional_tests/test_log_for_token_classification.py +++ b/tests/functional_tests/test_log_for_token_classification.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - import argilla +import pytest from argilla import TokenClassificationRecord from argilla.client import api from argilla.client.sdk.commons.errors import NotFoundApiError from argilla.metrics import __all__ as ALL_METRICS from argilla.metrics import entity_consistency + from tests.client.conftest import SUPPORTED_VECTOR_SEARCH from tests.helpers import SecuredClient diff --git a/tests/helpers.py b/tests/helpers.py index 3999376e6f..4246a8fd45 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -14,14 +14,13 @@ from typing import List -from fastapi import FastAPI -from starlette.testclient import TestClient - from argilla._constants import API_KEY_HEADER_NAME, WORKSPACE_HEADER_NAME from argilla.client.api import active_api from argilla.server.security import auth from argilla.server.security.auth_provider.local.settings import settings from argilla.server.security.auth_provider.local.users.model import UserInDB +from fastapi import FastAPI +from starlette.testclient import TestClient class SecuredClient: diff --git a/tests/labeling/text_classification/test_label_errors.py b/tests/labeling/text_classification/test_label_errors.py index 71933a9828..f33f9d8ab6 100644 --- a/tests/labeling/text_classification/test_label_errors.py +++ b/tests/labeling/text_classification/test_label_errors.py @@ -14,11 +14,9 @@ # limitations under the License. import sys +import argilla as ar import cleanlab import pytest -from pkg_resources import parse_version - -import argilla as ar from argilla.labeling.text_classification import find_label_errors from argilla.labeling.text_classification.label_errors import ( MissingPredictionError, @@ -26,6 +24,7 @@ SortBy, _construct_s_and_psx, ) +from pkg_resources import parse_version @pytest.fixture( diff --git a/tests/labeling/text_classification/test_label_models.py b/tests/labeling/text_classification/test_label_models.py index c21bac70a2..0cde14796f 100644 --- a/tests/labeling/text_classification/test_label_models.py +++ b/tests/labeling/text_classification/test_label_models.py @@ -13,11 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -from types import SimpleNamespace import numpy as np import pytest - from argilla import TextClassificationRecord from argilla.labeling.text_classification import ( FlyingSquid, diff --git a/tests/labeling/text_classification/test_rule.py b/tests/labeling/text_classification/test_rule.py index 3c1d29a326..c46c7cdf3e 100644 --- a/tests/labeling/text_classification/test_rule.py +++ b/tests/labeling/text_classification/test_rule.py @@ -14,7 +14,6 @@ # limitations under the License. import httpx import pytest - from argilla import load from argilla.client.models import TextClassificationRecord from argilla.client.sdk.text_classification.models import ( diff --git a/tests/labeling/text_classification/test_weak_labels.py b/tests/labeling/text_classification/test_weak_labels.py index 0ad4cf6e97..8a42263881 100644 --- a/tests/labeling/text_classification/test_weak_labels.py +++ b/tests/labeling/text_classification/test_weak_labels.py @@ -18,8 +18,6 @@ import numpy as np import pandas as pd import pytest -from pandas.testing import assert_frame_equal - from argilla import TextClassificationRecord from argilla.client.sdk.text_classification.models import ( CreationTextClassificationRecord, @@ -35,6 +33,7 @@ NoRulesFoundError, WeakLabelsBase, ) +from pandas.testing import assert_frame_equal @pytest.fixture diff --git a/tests/listeners/test_listener.py b/tests/listeners/test_listener.py index 1aa6bca55a..8549036869 100644 --- a/tests/listeners/test_listener.py +++ b/tests/listeners/test_listener.py @@ -15,9 +15,8 @@ import time from typing import List -import pytest - import argilla as ar +import pytest from argilla import RGListenerContext, listener from argilla.client.models import Record diff --git a/tests/metrics/test_common_metrics.py b/tests/metrics/test_common_metrics.py index b1ac68084c..4cbb6f80f6 100644 --- a/tests/metrics/test_common_metrics.py +++ b/tests/metrics/test_common_metrics.py @@ -12,10 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pytest - import argilla import argilla as ar +import pytest from argilla.metrics.commons import keywords, records_status, text_length diff --git a/tests/metrics/test_token_classification.py b/tests/metrics/test_token_classification.py index 443b7a17c8..89ee11c961 100644 --- a/tests/metrics/test_token_classification.py +++ b/tests/metrics/test_token_classification.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - import argilla import argilla as ar +import pytest from argilla.metrics import entity_consistency from argilla.metrics.token_classification import ( Annotations, diff --git a/tests/monitoring/test_base_monitor.py b/tests/monitoring/test_base_monitor.py index 96dc9c349d..f3baa32730 100644 --- a/tests/monitoring/test_base_monitor.py +++ b/tests/monitoring/test_base_monitor.py @@ -13,7 +13,7 @@ # limitations under the License. from time import sleep -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from argilla import TextClassificationRecord from argilla.client.api import Api, active_api diff --git a/tests/monitoring/test_flair_monitoring.py b/tests/monitoring/test_flair_monitoring.py index e8cb51b6ad..18aa38201e 100644 --- a/tests/monitoring/test_flair_monitoring.py +++ b/tests/monitoring/test_flair_monitoring.py @@ -15,11 +15,10 @@ def test_flair_monitoring(mocked_client, monkeypatch): + import argilla as ar from flair.data import Sentence from flair.models import SequenceTagger - import argilla as ar - dataset = "test_flair_monitoring" model = "flair/ner-english" diff --git a/tests/monitoring/test_transformers_monitoring.py b/tests/monitoring/test_transformers_monitoring.py index ee0487c98f..225f606518 100644 --- a/tests/monitoring/test_transformers_monitoring.py +++ b/tests/monitoring/test_transformers_monitoring.py @@ -14,9 +14,8 @@ from time import sleep from typing import List, Union -import pytest - import argilla +import pytest from argilla import TextClassificationRecord diff --git a/tests/server/backend/test_query_builder.py b/tests/server/backend/test_query_builder.py index 2ceef02733..0276da15eb 100644 --- a/tests/server/backend/test_query_builder.py +++ b/tests/server/backend/test_query_builder.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.server.daos.backend.search.model import ( SortableField, SortConfig, diff --git a/tests/server/commons/test_records_dao.py b/tests/server/commons/test_records_dao.py index 2dd227267a..f86d68d084 100644 --- a/tests/server/commons/test_records_dao.py +++ b/tests/server/commons/test_records_dao.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.server.commons.models import TaskType from argilla.server.daos.backend import GenericElasticEngineBackend from argilla.server.daos.models.datasets import BaseDatasetDB diff --git a/tests/server/commons/test_settings.py b/tests/server/commons/test_settings.py index 563ebdd3ce..9380f00eba 100644 --- a/tests/server/commons/test_settings.py +++ b/tests/server/commons/test_settings.py @@ -15,9 +15,8 @@ import os import pytest -from pydantic import ValidationError - from argilla.server.settings import ApiSettings +from pydantic import ValidationError @pytest.mark.parametrize("bad_namespace", ["Badns", "bad-ns", "12-bad-ns", "@bad"]) diff --git a/tests/server/commons/test_telemetry.py b/tests/server/commons/test_telemetry.py index e892042649..f3ce903918 100644 --- a/tests/server/commons/test_telemetry.py +++ b/tests/server/commons/test_telemetry.py @@ -15,11 +15,10 @@ import uuid import pytest -from fastapi import Request - from argilla.server.commons import telemetry from argilla.server.commons.models import TaskType from argilla.server.errors import ServerError +from fastapi import Request mock_request = Request(scope={"type": "http", "headers": {}}) diff --git a/tests/server/daos/models/test_records.py b/tests/server/daos/models/test_records.py index e2542c804f..5ebac261fc 100644 --- a/tests/server/daos/models/test_records.py +++ b/tests/server/daos/models/test_records.py @@ -15,7 +15,6 @@ import warnings import pytest - from argilla.server.daos.models.records import BaseRecordInDB from argilla.server.settings import settings diff --git a/tests/server/datasets/test_api.py b/tests/server/datasets/test_api.py index 416cf16275..2d0ebebf0b 100644 --- a/tests/server/datasets/test_api.py +++ b/tests/server/datasets/test_api.py @@ -19,6 +19,7 @@ TextClassificationBulkRequest, ) from argilla.server.commons.models import TaskType + from tests.helpers import SecuredClient @@ -97,7 +98,7 @@ def test_fetch_dataset_using_workspaces(mocked_client: SecuredClient): assert response.status_code == 409, response.json() response = mocked_client.post( - f"/api/datasets", + "/api/datasets", json=request, ) diff --git a/tests/server/datasets/test_dao.py b/tests/server/datasets/test_dao.py index 2b208c73fd..962df52564 100644 --- a/tests/server/datasets/test_dao.py +++ b/tests/server/datasets/test_dao.py @@ -14,7 +14,6 @@ # limitations under the License. import pytest - from argilla.server.commons.models import TaskType from argilla.server.daos.backend import GenericElasticEngineBackend from argilla.server.daos.datasets import DatasetsDAO diff --git a/tests/server/datasets/test_model.py b/tests/server/datasets/test_model.py index b664522d78..165203cd7f 100644 --- a/tests/server/datasets/test_model.py +++ b/tests/server/datasets/test_model.py @@ -14,10 +14,9 @@ # limitations under the License. import pytest -from pydantic import ValidationError - from argilla.server.apis.v0.models.datasets import CreateDatasetRequest from argilla.server.commons.models import TaskType +from pydantic import ValidationError @pytest.mark.parametrize( diff --git a/tests/server/functional_tests/datasets/test_get_record.py b/tests/server/functional_tests/datasets/test_get_record.py index 5bc7deba30..ee95849bef 100644 --- a/tests/server/functional_tests/datasets/test_get_record.py +++ b/tests/server/functional_tests/datasets/test_get_record.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest - from argilla.server.apis.v0.models.text2text import ( Text2TextBulkRequest, Text2TextRecord, diff --git a/tests/server/info/test_api.py b/tests/server/info/test_api.py index d25b143555..41b6c23518 100644 --- a/tests/server/info/test_api.py +++ b/tests/server/info/test_api.py @@ -36,7 +36,7 @@ def test_api_status(mocked_client): assert info.version == argilla_version # Checking to not get the error dictionary service.py includes whenever something goes wrong - assert not "error" in info.elasticsearch + assert "error" not in info.elasticsearch # Checking that the first key into mem_info dictionary has a nont-none value assert "rss" in info.mem_info is not None diff --git a/tests/server/security/test_dao.py b/tests/server/security/test_dao.py index cda8ea5d1b..41e03f5bc3 100644 --- a/tests/server/security/test_dao.py +++ b/tests/server/security/test_dao.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla._constants import DEFAULT_API_KEY from argilla.server.security.auth_provider.local.users.service import create_users_dao diff --git a/tests/server/security/test_model.py b/tests/server/security/test_model.py index 4c4428ed06..757349d6a0 100644 --- a/tests/server/security/test_model.py +++ b/tests/server/security/test_model.py @@ -13,10 +13,9 @@ # limitations under the License. import pytest -from pydantic import ValidationError - from argilla.server.errors import EntityNotFoundError from argilla.server.security.model import User +from pydantic import ValidationError @pytest.mark.parametrize("email", ["my@email.com", "infra@recogn.ai"]) diff --git a/tests/server/security/test_provider.py b/tests/server/security/test_provider.py index f06b64dcf7..6d169e74dc 100644 --- a/tests/server/security/test_provider.py +++ b/tests/server/security/test_provider.py @@ -13,12 +13,11 @@ # limitations under the License. import pytest -from fastapi.security import SecurityScopes - from argilla._constants import DEFAULT_API_KEY from argilla.server.security.auth_provider.local.provider import ( create_local_auth_provider, ) +from fastapi.security import SecurityScopes localAuth = create_local_auth_provider() security_Scopes = SecurityScopes diff --git a/tests/server/security/test_service.py b/tests/server/security/test_service.py index 9bc5f315a3..8985b556a2 100644 --- a/tests/server/security/test_service.py +++ b/tests/server/security/test_service.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla._constants import DEFAULT_API_KEY from argilla.server.security.auth_provider.local.users.dao import create_users_dao from argilla.server.security.auth_provider.local.users.service import UsersService diff --git a/tests/server/test_app.py b/tests/server/test_app.py index 21d336d081..06e6ca9f05 100644 --- a/tests/server/test_app.py +++ b/tests/server/test_app.py @@ -16,7 +16,6 @@ from importlib import reload import pytest - from argilla.server import app diff --git a/tests/server/text2text/test_api.py b/tests/server/text2text/test_api.py index a5f8c36e73..c5db714f9d 100644 --- a/tests/server/text2text/test_api.py +++ b/tests/server/text2text/test_api.py @@ -14,14 +14,13 @@ from typing import List, Optional import pytest - from argilla.server.apis.v0.models.commons.model import BulkResponse from argilla.server.apis.v0.models.text2text import ( Text2TextBulkRequest, - Text2TextRecord, Text2TextRecordInputs, Text2TextSearchResults, ) + from tests.client.conftest import SUPPORTED_VECTOR_SEARCH diff --git a/tests/server/text_classification/test_api.py b/tests/server/text_classification/test_api.py index 5ffb8c5300..6791614471 100644 --- a/tests/server/text_classification/test_api.py +++ b/tests/server/text_classification/test_api.py @@ -16,7 +16,6 @@ from datetime import datetime import pytest - from argilla.server.apis.v0.models.commons.model import BulkResponse from argilla.server.apis.v0.models.datasets import Dataset from argilla.server.apis.v0.models.text_classification import ( @@ -28,6 +27,7 @@ TextClassificationSearchResults, ) from argilla.server.commons.models import PredictionStatus + from tests.client.conftest import SUPPORTED_VECTOR_SEARCH diff --git a/tests/server/text_classification/test_api_rules.py b/tests/server/text_classification/test_api_rules.py index 6c95ed80c0..464ab427af 100644 --- a/tests/server/text_classification/test_api_rules.py +++ b/tests/server/text_classification/test_api_rules.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.server.apis.v0.models.text_classification import ( CreateLabelingRule, LabelingRule, diff --git a/tests/server/text_classification/test_model.py b/tests/server/text_classification/test_model.py index 18a967a637..87e344cb02 100644 --- a/tests/server/text_classification/test_model.py +++ b/tests/server/text_classification/test_model.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -from pydantic import ValidationError - from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH from argilla.server.apis.v0.models.text_classification import ( TextClassificationAnnotation, @@ -27,6 +25,7 @@ ClassPrediction, ServiceTextClassificationRecord, ) +from pydantic import ValidationError def test_flatten_metadata(): diff --git a/tests/server/token_classification/test_api.py b/tests/server/token_classification/test_api.py index 29ee17ec24..8260300012 100644 --- a/tests/server/token_classification/test_api.py +++ b/tests/server/token_classification/test_api.py @@ -15,7 +15,6 @@ from typing import Callable import pytest - from argilla.server.apis.v0.models.commons.model import BulkResponse, SortableField from argilla.server.apis.v0.models.token_classification import ( TokenClassificationBulkRequest, @@ -24,6 +23,7 @@ TokenClassificationSearchRequest, TokenClassificationSearchResults, ) + from tests.client.conftest import SUPPORTED_VECTOR_SEARCH diff --git a/tests/server/token_classification/test_model.py b/tests/server/token_classification/test_model.py index 11701da917..8d0c12bbff 100644 --- a/tests/server/token_classification/test_model.py +++ b/tests/server/token_classification/test_model.py @@ -14,8 +14,6 @@ # limitations under the License. import pytest -from pydantic import ValidationError - from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH from argilla.server.apis.v0.models.token_classification import ( TokenClassificationAnnotation, @@ -28,6 +26,7 @@ EntitySpan, ServiceTokenClassificationRecord, ) +from pydantic import ValidationError def test_char_position(): diff --git a/tests/utils/test_span_utils.py b/tests/utils/test_span_utils.py index eb53a0b038..e3ddbdc6cf 100644 --- a/tests/utils/test_span_utils.py +++ b/tests/utils/test_span_utils.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest - from argilla.utils.span_utils import SpanUtils diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 47804bd72e..85fe24e9f5 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest - from argilla.utils import LazyargillaModule