Skip to content

Commit

Permalink
ci: Replace isort by ruff in pre-commit (#2325)
Browse files Browse the repository at this point in the history
Hello!

# Pull Request overview
* Replaced `isort` by `ruff` in `pre-commit`
  * This has fixed the import sorting throughout the repo.
* Manually went through `ruff` errors and fixed a bunch:
  * Unused imports
  * Added `if TYPE_CHECKING:` statements
* Placed `# noqa: <errcode>` where the warned-about statement is the
desired behaviour
* Add basic `ruff` configuration to `pyproject.toml`

## Details
This PR focuses on replacing `isort` by `ruff` in `pre-commit`. The
motivation for this is that:
* `isort` frequently breaks. I have experienced 2 separate occasions in
the last few months alone where the latest `isort` release has broken my
CI runs in NLTK and SetFit.
* `isort` is no longer supported for Python 3.7, whereas Argilla still
supports 3.7 for now.
* `ruff` is absurdly fast, I actually can't believe how quick it is.

This PR consists of 3 commits at this time, and I would advise looking
at them commit-by-commit rather than at the PR as a whole. I'll also
explain each commit individually.

## [Add ruff basic
configuration](497420e)
I've added basic configuration for
[`ruff`](https://github.com/charliermarsh/ruff), a very efficient
linter. I recommend the following commands:
```
# Get all [F]ailures and [E]rrors
ruff .

# See all import sort errors
ruff . --select I

# Fix all import sort errors
ruff . --select I --fix
```

## [Remove unused imports, apply some noqa's, add
TYPE_CHECKING](f219acb)
The unused imports speaks for itself.

As for the `noqa`'s, `ruff` (like most linters) respect the `# noqa` (no
quality assurance) keyword. I've used the keyword in various locations
where linters would warn, but the behaviour is actually correct. As a
result, the output of `ruff .` now only points to questionable code.

Lastly, I added `TYPE_CHECKING` in some locations. If type hints hint at
objects that do not need to be imported during run-time, then it's
common to type hint like `arr: "numpy.ndarray"`. However, IDE's won't
understand what `arr` is. Python has implemented `TYPE_CHECKING` which
can be used to conditionally import code *only* when type checking. As a
result, the code block is not actually executed in practice, but the
inclusion of it allows for IDEs to better support development.
See an example here:
```python
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    import numpy

def func(arr: "numpy.ndarray") -> None:
    ...
```

## [Replace isort with ruff in
CI](e05f30e)
I've replaced `isort` (which was both [broken for
5.11.*](PyCQA/isort#2077) and [does not work
for Python 3.8 in
5.12.*](https://github.com/PyCQA/isort/releases/tag/5.12.0)) with `ruff`
in the CI, using both `--select I` to only select `isort` warnings and
`--fix` to immediately fix the warnings.

Then I ran `pre-commit run --all` to fix the ~67 outstanding issues in
the repository.

---

**Type of change**

- [x] Refactor (change restructuring the codebase without changing
functionality)

**How Has This Been Tested**

I verified that the behaviour did not change using `pytest tests`.

**Checklist**

- [x] I have merged the original branch into my forked branch
- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I added comments to my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works

- Tom Aarsen

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tomaarsen and pre-commit-ci[bot] authored Feb 14, 2023
1 parent ff4e79f commit 5f0627c
Show file tree
Hide file tree
Showing 97 changed files with 139 additions and 182 deletions.
10 changes: 7 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ repos:
- id: black
additional_dependencies: ["click==8.0.4"]

- repo: https://github.com/pycqa/isort
rev: 5.12.0
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.244
hooks:
- id: isort
# Simulate isort via (the much faster) ruff
- id: ruff
args:
- --select=I
- --fix

- repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook
rev: v9.4.0
Expand Down
31 changes: 31 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,34 @@ exclude_lines = [

[tool.isort]
profile = "black"

[tool.ruff]
# Ignore line length violations
ignore = ["E501"]

# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".hg",
".mypy_cache",
".nox",
".pants.d",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
]

[tool.ruff.per-file-ignores]
# Ignore imported but unused;
"__init__.py" = ["F401"]
5 changes: 2 additions & 3 deletions scripts/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
import sys
import time

import argilla as rg
import pandas as pd
import requests
from datasets import load_dataset

import argilla as rg
from argilla.labeling.text_classification import Rule, add_rules
from datasets import load_dataset


class LoadDatasets:
Expand Down
4 changes: 1 addition & 3 deletions src/argilla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,10 @@
read_datasets,
read_pandas,
)
from argilla.client.models import (
TextGenerationRecord, # TODO Remove TextGenerationRecord
)
from argilla.client.models import (
Text2TextRecord,
TextClassificationRecord,
TextGenerationRecord, # TODO Remove TextGenerationRecord
TokenAttributions,
TokenClassificationRecord,
)
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/client/apis/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union

from pydantic import BaseModel, Field

Expand Down
2 changes: 1 addition & 1 deletion src/argilla/client/apis/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import dataclasses
from typing import List, Optional, Union
from typing import List, Optional

from argilla.client.apis import AbstractApi
from argilla.client.models import Record
Expand Down
11 changes: 8 additions & 3 deletions src/argilla/client/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import random
import uuid
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union

import pandas as pd
from pkg_resources import parse_version
Expand All @@ -32,6 +32,11 @@
from argilla.client.sdk.datasets.models import TaskType
from argilla.utils.span_utils import SpanUtils

if TYPE_CHECKING:
import datasets
import pandas
import spacy

_LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -60,7 +65,7 @@ def _requires_spacy(func):
@functools.wraps(func)
def check_if_spacy_installed(*args, **kwargs):
try:
import spacy
import spacy # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
f"'spacy' must be installed to use `{func.__name__}`"
Expand Down Expand Up @@ -1007,7 +1012,7 @@ def from_datasets(
for row in dataset:
# TODO: fails with a KeyError if no tokens column is present and no mapping is indicated
if not row["tokens"]:
_LOGGER.warning(f"Ignoring row with no tokens.")
_LOGGER.warning("Ignoring row with no tokens.")
continue

if row.get("tags"):
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/client/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ async def inner_async(self, *args, **kwargs):
try:
result = await func(self, *args, **kwargs)
return result
except httpx.ConnectError as err:
except httpx.ConnectError as err: # noqa: F841
return wrap_error(self.base_url)

@functools.wraps(func)
def inner(self, *args, **kwargs):
try:
result = func(self, *args, **kwargs)
return result
except httpx.ConnectError as err:
except httpx.ConnectError as err: # noqa: F841
return wrap_error(self.base_url)

is_coroutine = inspect.iscoroutinefunction(func)
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/client/sdk/commons/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def build_data_response(
parsed_record = json.loads(r)
try:
parsed_response = data_type(**parsed_record)
except Exception as err:
except Exception as err: # noqa: F841
raise GenericApiError(**parsed_record) from None
parsed_responses.append(parsed_response)
return Response(
Expand Down
1 change: 0 additions & 1 deletion src/argilla/client/sdk/text2text/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def data(
limit: Optional[int] = None,
id_from: Optional[str] = None,
) -> Response[Union[List[Text2TextRecord], HTTPValidationError, ErrorMessage]]:

path = f"/api/datasets/{name}/Text2Text/data"
params = build_param_dict(id_from, limit)

Expand Down
2 changes: 1 addition & 1 deletion src/argilla/client/sdk/text_classification/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Union
from typing import List, Optional, Union

import httpx

Expand Down
1 change: 0 additions & 1 deletion src/argilla/client/sdk/token_classification/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def data(
) -> Response[
Union[List[TokenClassificationRecord], HTTPValidationError, ErrorMessage]
]:

path = f"/api/datasets/{name}/TokenClassification/data"
params = build_param_dict(id_from, limit)

Expand Down
3 changes: 0 additions & 3 deletions src/argilla/client/sdk/users/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import httpx

from argilla.client.sdk.client import AuthenticatedClient
from argilla.client.sdk.commons.errors_handler import handle_response_error
from argilla.client.sdk.users.models import User


Expand Down
11 changes: 5 additions & 6 deletions src/argilla/labeling/text_classification/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import numpy as np

from argilla import DatasetForTextClassification, TextClassificationRecord
from argilla.client.datasets import Dataset
from argilla.labeling.text_classification.weak_labels import WeakLabels, WeakMultiLabels

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -368,7 +367,7 @@ def score(
MissingAnnotationError: If the ``weak_labels`` do not contain annotated records.
"""
try:
import sklearn
import sklearn # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'sklearn' must be installed to compute the metrics! "
Expand Down Expand Up @@ -501,7 +500,7 @@ def __init__(
self, weak_labels: WeakLabels, verbose: bool = True, device: str = "cpu"
):
try:
import snorkel
import snorkel # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'snorkel' must be installed to use the `Snorkel` label model! "
Expand Down Expand Up @@ -764,8 +763,8 @@ class FlyingSquid(LabelModel):

def __init__(self, weak_labels: WeakLabels, **kwargs):
try:
import flyingsquid
import pgmpy
import flyingsquid # noqa: F401
import pgmpy # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'flyingsquid' must be installed to use the `FlyingSquid` label model!"
Expand Down Expand Up @@ -1024,7 +1023,7 @@ def score(
MissingAnnotationError: If the ``weak_labels`` do not contain annotated records.
"""
try:
import sklearn
import sklearn # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'sklearn' must be installed to compute the metrics! "
Expand Down
6 changes: 3 additions & 3 deletions src/argilla/listeners/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def catch_exceptions_decorator(job_func):
def wrapper(*args, **kwargs):
try:
return job_func(*args, **kwargs)
except:
except: # noqa: E722
import traceback

print(traceback.format_exc())
Expand Down Expand Up @@ -208,7 +208,7 @@ def __listener_iteration_job__(self, *args, **kwargs):

self._LOGGER.debug(f"Evaluate condition with arguments: {condition_args}")
if self.condition(*condition_args):
self._LOGGER.debug(f"Condition passed! Running action...")
self._LOGGER.debug("Condition passed! Running action...")
return self.__run_action__(ctx, *args, **kwargs)

def __compute_metrics__(self, current_api, dataset, query: str) -> Metrics:
Expand All @@ -235,7 +235,7 @@ def __run_action__(self, ctx: Optional[RGListenerContext] = None, *args, **kwarg
)
self._LOGGER.debug(f"Running action with arguments: {action_args}")
return self.action(*args, *action_args, **kwargs)
except:
except: # noqa: E722
import traceback

print(traceback.format_exc())
Expand Down
5 changes: 4 additions & 1 deletion src/argilla/listeners/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
# limitations under the License.

import dataclasses
from typing import Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from prodict import Prodict

from argilla.client.models import Record

if TYPE_CHECKING:
from argilla.listeners import RGDatasetListener


@dataclasses.dataclass
class Search:
Expand Down
4 changes: 1 addition & 3 deletions src/argilla/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@
entity_consistency,
entity_density,
entity_labels,
)
from .token_classification import f1 as ner_f1
from .token_classification import (
mention_length,
token_capitalness,
token_frequency,
token_length,
tokens_length,
)
from .token_classification import f1 as ner_f1

__all__ = [
text_length,
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/monitoring/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from argilla.monitoring.base import BaseMonitor

try:
import starlette
import starlette # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'starlette' must be installed to use the middleware feature! "
Expand Down
1 change: 0 additions & 1 deletion src/argilla/monitoring/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import atexit
import dataclasses
import logging
import random
import threading
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/apis/v0/handlers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def list_datasets(
description="Create a new dataset",
)
async def create_dataset(
request: CreateDatasetRequest = Body(..., description=f"The request dataset info"),
request: CreateDatasetRequest = Body(..., description="The request dataset info"),
ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["create:datasets"]),
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/server/apis/v0/handlers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def configure_router(router: APIRouter, cfg: TaskConfig):
path=base_metrics_endpoint,
new_path=new_base_metrics_endpoint,
router_method=router.get,
operation_id=f"get_dataset_metrics",
operation_id="get_dataset_metrics",
name="get_dataset_metrics",
)
def get_dataset_metrics(
Expand All @@ -84,7 +84,7 @@ def get_dataset_metrics(
path=base_metrics_endpoint + "/{metric}:summary",
new_path=new_base_metrics_endpoint + "/{metric}:summary",
router_method=router.post,
operation_id=f"metric_summary",
operation_id="metric_summary",
name="metric_summary",
)
def metric_summary(
Expand Down
11 changes: 3 additions & 8 deletions src/argilla/server/apis/v0/handlers/records_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,12 @@

from typing import Any, Dict, Optional, Union

from fastapi import APIRouter, Depends, Query, Security
from fastapi import APIRouter, Depends, Security
from pydantic import BaseModel

from argilla.client.sdk.token_classification.models import TokenClassificationQuery
from argilla.server.apis.v0.helpers import deprecate_endpoint
from argilla.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies
from argilla.server.apis.v0.models.text2text import Text2TextQuery, Text2TextRecord
from argilla.server.apis.v0.models.text_classification import (
TextClassificationQuery,
TextClassificationRecord,
)
from argilla.server.apis.v0.models.text2text import Text2TextRecord
from argilla.server.apis.v0.models.text_classification import TextClassificationRecord
from argilla.server.apis.v0.models.token_classification import TokenClassificationRecord
from argilla.server.commons.config import TasksFactory
from argilla.server.commons.models import TaskStatus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
metrics,
token_classification_dataset_settings,
)
from argilla.server.apis.v0.helpers import deprecate_endpoint
from argilla.server.apis.v0.models.commons.model import BulkResponse
from argilla.server.apis.v0.models.commons.params import (
CommonTaskHandlerDependencies,
Expand Down
3 changes: 1 addition & 2 deletions src/argilla/server/apis/v0/models/text2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
from typing import Dict, List, Optional

from pydantic import BaseModel, Field, validator
Expand All @@ -27,7 +26,7 @@
SortableField,
)
from argilla.server.apis.v0.models.datasets import UpdateDatasetRequest
from argilla.server.commons.models import PredictionStatus, TaskType
from argilla.server.commons.models import PredictionStatus
from argilla.server.services.metrics.models import CommonTasksMetrics
from argilla.server.services.search.model import (
ServiceBaseRecordsQuery,
Expand Down
Loading

0 comments on commit 5f0627c

Please sign in to comment.