Skip to content

Commit

Permalink
Remove unused imports, apply some noqa's, add TYPE_CHECKING
Browse files Browse the repository at this point in the history
I.e. manually going through ruff errors to improve code quality somewhat
  • Loading branch information
tomaarsen committed Feb 9, 2023
1 parent 497420e commit f219acb
Show file tree
Hide file tree
Showing 41 changed files with 59 additions and 83 deletions.
2 changes: 1 addition & 1 deletion src/argilla/client/apis/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union

from pydantic import BaseModel, Field

Expand Down
2 changes: 1 addition & 1 deletion src/argilla/client/apis/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import dataclasses
from typing import List, Optional, Union
from typing import List, Optional

from argilla.client.apis import AbstractApi
from argilla.client.models import Record
Expand Down
11 changes: 8 additions & 3 deletions src/argilla/client/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import functools
import logging
import random
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union, TYPE_CHECKING

import pandas as pd
from pkg_resources import parse_version
Expand All @@ -31,6 +31,11 @@
from argilla.client.sdk.datasets.models import TaskType
from argilla.utils.span_utils import SpanUtils

if TYPE_CHECKING:
import datasets
import spacy
import pandas

_LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -59,7 +64,7 @@ def _requires_spacy(func):
@functools.wraps(func)
def check_if_spacy_installed(*args, **kwargs):
try:
import spacy
import spacy # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
f"'spacy' must be installed to use `{func.__name__}`"
Expand Down Expand Up @@ -992,7 +997,7 @@ def from_datasets(
for row in dataset:
# TODO: fails with a KeyError if no tokens column is present and no mapping is indicated
if not row["tokens"]:
_LOGGER.warning(f"Ignoring row with no tokens.")
_LOGGER.warning("Ignoring row with no tokens.")
continue

if row.get("tags"):
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/client/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ async def inner_async(self, *args, **kwargs):
try:
result = await func(self, *args, **kwargs)
return result
except httpx.ConnectError as err:
except httpx.ConnectError as err: # noqa: F841
return wrap_error(self.base_url)

@functools.wraps(func)
def inner(self, *args, **kwargs):
try:
result = func(self, *args, **kwargs)
return result
except httpx.ConnectError as err:
except httpx.ConnectError as err: # noqa: F841
return wrap_error(self.base_url)

is_coroutine = inspect.iscoroutinefunction(func)
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/client/sdk/commons/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def build_data_response(
parsed_record = json.loads(r)
try:
parsed_response = data_type(**parsed_record)
except Exception as err:
except Exception as err: # noqa: F841
raise GenericApiError(**parsed_record) from None
parsed_responses.append(parsed_response)
return Response(
Expand Down
10 changes: 2 additions & 8 deletions src/argilla/client/sdk/text_classification/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Union
from typing import List, Union

import httpx

from argilla.client.sdk._helpers import build_typed_response
from argilla.client.sdk.client import AuthenticatedClient
from argilla.client.sdk.commons.api import (
build_data_response,
build_list_response,
build_param_dict,
)
from argilla.client.sdk.commons.api import build_list_response
from argilla.client.sdk.commons.models import (
ErrorMessage,
HTTPValidationError,
Expand All @@ -31,8 +27,6 @@
from argilla.client.sdk.text_classification.models import (
LabelingRule,
LabelingRuleMetricsSummary,
TextClassificationQuery,
TextClassificationRecord,
)


Expand Down
3 changes: 0 additions & 3 deletions src/argilla/client/sdk/users/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import httpx

from argilla.client.sdk.client import AuthenticatedClient
from argilla.client.sdk.commons.errors_handler import handle_response_error
from argilla.client.sdk.users.models import User


Expand Down
11 changes: 5 additions & 6 deletions src/argilla/labeling/text_classification/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import numpy as np

from argilla import DatasetForTextClassification, TextClassificationRecord
from argilla.client.datasets import Dataset
from argilla.labeling.text_classification.weak_labels import WeakLabels, WeakMultiLabels

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -368,7 +367,7 @@ def score(
MissingAnnotationError: If the ``weak_labels`` do not contain annotated records.
"""
try:
import sklearn
import sklearn # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'sklearn' must be installed to compute the metrics! "
Expand Down Expand Up @@ -501,7 +500,7 @@ def __init__(
self, weak_labels: WeakLabels, verbose: bool = True, device: str = "cpu"
):
try:
import snorkel
import snorkel # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'snorkel' must be installed to use the `Snorkel` label model! "
Expand Down Expand Up @@ -764,8 +763,8 @@ class FlyingSquid(LabelModel):

def __init__(self, weak_labels: WeakLabels, **kwargs):
try:
import flyingsquid
import pgmpy
import flyingsquid # noqa: F401
import pgmpy # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'flyingsquid' must be installed to use the `FlyingSquid` label model!"
Expand Down Expand Up @@ -1024,7 +1023,7 @@ def score(
MissingAnnotationError: If the ``weak_labels`` do not contain annotated records.
"""
try:
import sklearn
import sklearn # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'sklearn' must be installed to compute the metrics! "
Expand Down
6 changes: 3 additions & 3 deletions src/argilla/listeners/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def catch_exceptions_decorator(job_func):
def wrapper(*args, **kwargs):
try:
return job_func(*args, **kwargs)
except:
except: # noqa: E722
import traceback

print(traceback.format_exc())
Expand Down Expand Up @@ -208,7 +208,7 @@ def __listener_iteration_job__(self, *args, **kwargs):

self._LOGGER.debug(f"Evaluate condition with arguments: {condition_args}")
if self.condition(*condition_args):
self._LOGGER.debug(f"Condition passed! Running action...")
self._LOGGER.debug("Condition passed! Running action...")
return self.__run_action__(ctx, *args, **kwargs)

def __compute_metrics__(self, current_api, dataset, query: str) -> Metrics:
Expand All @@ -235,7 +235,7 @@ def __run_action__(self, ctx: Optional[RGListenerContext] = None, *args, **kwarg
)
self._LOGGER.debug(f"Running action with arguments: {action_args}")
return self.action(*args, *action_args, **kwargs)
except:
except: # noqa: E722
import traceback

print(traceback.format_exc())
Expand Down
6 changes: 5 additions & 1 deletion src/argilla/listeners/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
# limitations under the License.

import dataclasses
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union, TYPE_CHECKING

from prodict import Prodict

from argilla.client.models import Record


if TYPE_CHECKING:
from argilla.listeners import RGDatasetListener


@dataclasses.dataclass
class Search:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/monitoring/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from argilla.monitoring.base import BaseMonitor

try:
import starlette
import starlette # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'starlette' must be installed to use the middleware feature! "
Expand Down
1 change: 0 additions & 1 deletion src/argilla/monitoring/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import atexit
import dataclasses
import logging
import random
import threading
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/apis/v0/handlers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def list_datasets(
description="Create a new dataset",
)
async def create_dataset(
request: CreateDatasetRequest = Body(..., description=f"The request dataset info"),
request: CreateDatasetRequest = Body(..., description="The request dataset info"),
ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["create:datasets"]),
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/server/apis/v0/handlers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def configure_router(router: APIRouter, cfg: TaskConfig):
path=base_metrics_endpoint,
new_path=new_base_metrics_endpoint,
router_method=router.get,
operation_id=f"get_dataset_metrics",
operation_id="get_dataset_metrics",
name="get_dataset_metrics",
)
def get_dataset_metrics(
Expand All @@ -84,7 +84,7 @@ def get_dataset_metrics(
path=base_metrics_endpoint + "/{metric}:summary",
new_path=new_base_metrics_endpoint + "/{metric}:summary",
router_method=router.post,
operation_id=f"metric_summary",
operation_id="metric_summary",
name="metric_summary",
)
def metric_summary(
Expand Down
11 changes: 3 additions & 8 deletions src/argilla/server/apis/v0/handlers/records_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,12 @@

from typing import Any, Dict, Optional, Union

from fastapi import APIRouter, Depends, Query, Security
from fastapi import APIRouter, Depends, Security
from pydantic import BaseModel

from argilla.client.sdk.token_classification.models import TokenClassificationQuery
from argilla.server.apis.v0.helpers import deprecate_endpoint
from argilla.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies
from argilla.server.apis.v0.models.text2text import Text2TextQuery, Text2TextRecord
from argilla.server.apis.v0.models.text_classification import (
TextClassificationQuery,
TextClassificationRecord,
)
from argilla.server.apis.v0.models.text2text import Text2TextRecord
from argilla.server.apis.v0.models.text_classification import TextClassificationRecord
from argilla.server.apis.v0.models.token_classification import TokenClassificationRecord
from argilla.server.commons.config import TasksFactory
from argilla.server.commons.models import TaskStatus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
metrics,
token_classification_dataset_settings,
)
from argilla.server.apis.v0.helpers import deprecate_endpoint
from argilla.server.apis.v0.models.commons.model import BulkResponse
from argilla.server.apis.v0.models.commons.params import (
CommonTaskHandlerDependencies,
Expand Down
3 changes: 1 addition & 2 deletions src/argilla/server/apis/v0/models/text2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
from typing import Dict, List, Optional

from pydantic import BaseModel, Field, validator
Expand All @@ -27,7 +26,7 @@
SortableField,
)
from argilla.server.apis.v0.models.datasets import UpdateDatasetRequest
from argilla.server.commons.models import PredictionStatus, TaskType
from argilla.server.commons.models import PredictionStatus
from argilla.server.services.metrics.models import CommonTasksMetrics
from argilla.server.services.search.model import (
ServiceBaseRecordsQuery,
Expand Down
5 changes: 1 addition & 4 deletions src/argilla/server/apis/v0/models/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union

from pydantic import BaseModel, Field, root_validator, validator

Expand All @@ -40,9 +40,6 @@
from argilla.server.services.tasks.text_classification.model import (
ServiceTextClassificationDataset,
)
from argilla.server.services.tasks.text_classification.model import (
ServiceTextClassificationQuery as _TextClassificationQuery,
)
from argilla.server.services.tasks.text_classification.model import (
TextClassificationAnnotation as _TextClassificationAnnotation,
)
Expand Down
6 changes: 5 additions & 1 deletion src/argilla/server/daos/backend/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
# limitations under the License.

import dataclasses
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING

from argilla.server.daos.backend.query_helpers import aggregations
from argilla.server.helpers import unflatten_dict


if TYPE_CHECKING:
from argilla.server.daos.backend.client_adapters.base import IClientAdapter


@dataclasses.dataclass
class ElasticsearchMetric:
id: str
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/daos/backend/search/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def map_2_es_sort_configuration(
es_sort = []
for sortable_field in sort.sort_by or [SortableField(id="id")]:
if valid_fields:
if not sortable_field.id.split(".")[0] in valid_fields:
if sortable_field.id.split(".")[0] not in valid_fields:
raise AssertionError(
f"Wrong sort id {sortable_field.id}. Valid values are: "
f"{[str(v) for v in valid_fields]}"
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/daos/models/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def id(self) -> str:
"""The dataset id. Compounded by owner and name"""
return self.build_dataset_id(self.name, self.owner)

def dict(self, *args, **kwargs) -> "DictStrAny":
def dict(self, *args, **kwargs) -> Dict[str, Any]:
"""
Extends base component dict extending object properties
and user defined extended fields
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/daos/models/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def extended_fields(self) -> Dict[str, Any]:
"score": self.scores,
}

def dict(self, *args, **kwargs) -> "DictStrAny":
def dict(self, *args, **kwargs) -> Dict[str, Any]:
"""
Extends base component dict extending object properties
and user defined extended fields
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/errors/api_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
from typing import Any, Dict

from fastapi import HTTPException, Request, status
from fastapi import HTTPException, Request
from fastapi.exception_handlers import http_exception_handler
from pydantic import BaseModel

Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/errors/base_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional, Type, Union
from typing import Any, Optional, Type, Union

import pydantic
from starlette import status
Expand Down
1 change: 0 additions & 1 deletion src/argilla/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""
This module configures the global fastapi application
"""
import fileinput
import glob
import inspect
import logging
Expand Down
Loading

0 comments on commit f219acb

Please sign in to comment.