Skip to content

Commit

Permalink
fix: Chipping away at ruff lints (#1303)
Browse files Browse the repository at this point in the history
* fix: Chipping away at ruff lints

* fix: return lockfile path

* Update openml/config.py

* Update openml/runs/functions.py

* Update openml/tasks/functions.py

* Update openml/tasks/split.py

* Update openml/utils.py

* Update openml/utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update openml/config.py

* Update openml/testing.py

* Update openml/utils.py

* Update openml/config.py

* Update openml/utils.py

* Update openml/utils.py

* add concurrency to workflow calls

* adjust docstring

* adjust docstring

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Lennart Purucker <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Lennart Purucker <[email protected]>
  • Loading branch information
4 people committed Jan 18, 2024
1 parent ee1231b commit b217a3f
Show file tree
Hide file tree
Showing 14 changed files with 446 additions and 345 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ name: Tests

on: [push, pull_request]

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
test:
name: (${{ matrix.os }}, Py${{ matrix.python-version }}, sk${{ matrix.scikit-learn }}, sk-only:${{ matrix.sklearn-only }})
Expand Down
3 changes: 2 additions & 1 deletion openml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,5 @@ def populate_cache(task_ids=None, dataset_ids=None, flow_ids=None, run_ids=None)
]

# Load the scikit-learn extension by default
import openml.extensions.sklearn # noqa: F401
# TODO(eddiebergman): Not sure why this is at the bottom of the file
import openml.extensions.sklearn # noqa: E402, F401
103 changes: 52 additions & 51 deletions openml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
from io import StringIO
from pathlib import Path
from typing import Dict, Union, cast
from typing_extensions import Literal
from urllib.parse import urlparse

logger = logging.getLogger(__name__)
openml_logger = logging.getLogger("openml")
console_handler = None
file_handler = None # type: Optional[logging.Handler]
console_handler: logging.StreamHandler | None = None
file_handler: logging.handlers.RotatingFileHandler | None = None


def _create_log_handlers(create_file_handler: bool = True) -> None:
def _create_log_handlers(create_file_handler: bool = True) -> None: # noqa: FBT
"""Creates but does not attach the log handlers."""
global console_handler, file_handler
global console_handler, file_handler # noqa: PLW0603
if console_handler is not None or file_handler is not None:
logger.debug("Requested to create log handlers, but they are already created.")
return
Expand All @@ -35,7 +36,7 @@ def _create_log_handlers(create_file_handler: bool = True) -> None:

if create_file_handler:
one_mb = 2**20
log_path = os.path.join(_root_cache_directory, "openml_python.log")
log_path = _root_cache_directory / "openml_python.log"
file_handler = logging.handlers.RotatingFileHandler(
log_path,
maxBytes=one_mb,
Expand Down Expand Up @@ -64,7 +65,7 @@ def _convert_log_levels(log_level: int) -> tuple[int, int]:

def _set_level_register_and_store(handler: logging.Handler, log_level: int) -> None:
"""Set handler log level, register it if needed, save setting to config file if specified."""
oml_level, py_level = _convert_log_levels(log_level)
_oml_level, py_level = _convert_log_levels(log_level)
handler.setLevel(py_level)

if openml_logger.level > py_level or openml_logger.level == logging.NOTSET:
Expand All @@ -76,31 +77,27 @@ def _set_level_register_and_store(handler: logging.Handler, log_level: int) -> N

def set_console_log_level(console_output_level: int) -> None:
"""Set console output to the desired level and register it with openml logger if needed."""
global console_handler
_set_level_register_and_store(cast(logging.Handler, console_handler), console_output_level)
global console_handler # noqa: PLW0602
assert console_handler is not None
_set_level_register_and_store(console_handler, console_output_level)


def set_file_log_level(file_output_level: int) -> None:
"""Set file output to the desired level and register it with openml logger if needed."""
global file_handler
_set_level_register_and_store(cast(logging.Handler, file_handler), file_output_level)
global file_handler # noqa: PLW0602
assert file_handler is not None
_set_level_register_and_store(file_handler, file_output_level)


# Default values (see also https://github.com/openml/OpenML/wiki/Client-API-Standards)
_user_path = Path("~").expanduser().absolute()
_defaults = {
"apikey": "",
"server": "https://www.openml.org/api/v1/xml",
"cachedir": (
os.environ.get(
"XDG_CACHE_HOME",
os.path.join(
"~",
".cache",
"openml",
),
)
os.environ.get("XDG_CACHE_HOME", _user_path / ".cache" / "openml")
if platform.system() == "Linux"
else os.path.join("~", ".openml")
else _user_path / ".openml"
),
"avoid_duplicate_runs": "True",
"retry_policy": "human",
Expand All @@ -124,18 +121,18 @@ def get_server_base_url() -> str:
return server.split("/api")[0]


apikey = _defaults["apikey"]
apikey: str = _defaults["apikey"]
# The current cache directory (without the server name)
_root_cache_directory = str(_defaults["cachedir"]) # so mypy knows it is a string
avoid_duplicate_runs = _defaults["avoid_duplicate_runs"] == "True"
_root_cache_directory = Path(_defaults["cachedir"])
avoid_duplicate_runs: bool = _defaults["avoid_duplicate_runs"] == "True"

retry_policy = _defaults["retry_policy"]
connection_n_retries = int(_defaults["connection_n_retries"])


def set_retry_policy(value: str, n_retries: int | None = None) -> None:
global retry_policy
global connection_n_retries
def set_retry_policy(value: Literal["human", "robot"], n_retries: int | None = None) -> None:
global retry_policy # noqa: PLW0603
global connection_n_retries # noqa: PLW0603
default_retries_by_policy = {"human": 5, "robot": 50}

if value not in default_retries_by_policy:
Expand All @@ -145,6 +142,7 @@ def set_retry_policy(value: str, n_retries: int | None = None) -> None:
)
if n_retries is not None and not isinstance(n_retries, int):
raise TypeError(f"`n_retries` must be of type `int` or `None` but is `{type(n_retries)}`.")

if isinstance(n_retries, int) and n_retries < 1:
raise ValueError(f"`n_retries` is '{n_retries}' but must be positive.")

Expand All @@ -168,8 +166,8 @@ def start_using_configuration_for_example(cls) -> None:
To configuration as was before this call is stored, and can be recovered
by using the `stop_use_example_configuration` method.
"""
global server
global apikey
global server # noqa: PLW0603
global apikey # noqa: PLW0603

if cls._start_last_called and server == cls._test_server and apikey == cls._test_apikey:
# Method is called more than once in a row without modifying the server or apikey.
Expand All @@ -186,6 +184,7 @@ def start_using_configuration_for_example(cls) -> None:
warnings.warn(
f"Switching to the test server {server} to not upload results to the live server. "
"Using the test server may result in reduced performance of the API!",
stacklevel=2,
)

@classmethod
Expand All @@ -199,8 +198,8 @@ def stop_using_configuration_for_example(cls) -> None:
"`start_use_example_configuration` must be called first.",
)

global server
global apikey
global server # noqa: PLW0603
global apikey # noqa: PLW0603

server = cast(str, cls._last_used_server)
apikey = cast(str, cls._last_used_key)
Expand All @@ -213,7 +212,7 @@ def determine_config_file_path() -> Path:
else:
config_dir = Path("~") / ".openml"
# Still use os.path.expanduser to trigger the mock in the unit test
config_dir = Path(os.path.expanduser(config_dir))
config_dir = Path(config_dir).expanduser().resolve()
return config_dir / "config"


Expand All @@ -226,18 +225,18 @@ def _setup(config: dict[str, str | int | bool] | None = None) -> None:
openml.config.server = SOMESERVER
We could also make it a property but that's less clear.
"""
global apikey
global server
global _root_cache_directory
global avoid_duplicate_runs
global apikey # noqa: PLW0603
global server # noqa: PLW0603
global _root_cache_directory # noqa: PLW0603
global avoid_duplicate_runs # noqa: PLW0603

config_file = determine_config_file_path()
config_dir = config_file.parent

# read config file, create directory for config file
if not os.path.exists(config_dir):
if not config_dir.exists():
try:
os.makedirs(config_dir, exist_ok=True)
config_dir.mkdir(exist_ok=True, parents=True)
cache_exists = True
except PermissionError:
cache_exists = False
Expand All @@ -250,20 +249,20 @@ def _setup(config: dict[str, str | int | bool] | None = None) -> None:

avoid_duplicate_runs = bool(config.get("avoid_duplicate_runs"))

apikey = cast(str, config["apikey"])
server = cast(str, config["server"])
short_cache_dir = cast(str, config["cachedir"])
apikey = str(config["apikey"])
server = str(config["server"])
short_cache_dir = Path(config["cachedir"])

tmp_n_retries = config["connection_n_retries"]
n_retries = int(tmp_n_retries) if tmp_n_retries is not None else None

set_retry_policy(cast(str, config["retry_policy"]), n_retries)
set_retry_policy(config["retry_policy"], n_retries)

_root_cache_directory = os.path.expanduser(short_cache_dir)
_root_cache_directory = short_cache_dir.expanduser().resolve()
# create the cache subdirectory
if not os.path.exists(_root_cache_directory):
if not _root_cache_directory.exists():
try:
os.makedirs(_root_cache_directory, exist_ok=True)
_root_cache_directory.mkdir(exist_ok=True, parents=True)
except PermissionError:
openml_logger.warning(
"No permission to create openml cache directory at %s! This can result in "
Expand All @@ -288,7 +287,7 @@ def set_field_in_config_file(field: str, value: str) -> None:
globals()[field] = value
config_file = determine_config_file_path()
config = _parse_config(str(config_file))
with open(config_file, "w") as fh:
with config_file.open("w") as fh:
for f in _defaults:
# We can't blindly set all values based on globals() because when the user
# sets it through config.FIELD it should not be stored to file.
Expand All @@ -303,14 +302,15 @@ def set_field_in_config_file(field: str, value: str) -> None:

def _parse_config(config_file: str | Path) -> dict[str, str]:
"""Parse the config file, set up defaults."""
config_file = Path(config_file)
config = configparser.RawConfigParser(defaults=_defaults)

# The ConfigParser requires a [SECTION_HEADER], which we do not expect in our config file.
# Cheat the ConfigParser module by adding a fake section header
config_file_ = StringIO()
config_file_.write("[FAKE_SECTION]\n")
try:
with open(config_file) as fh:
with config_file.open("w") as fh:
for line in fh:
config_file_.write(line)
except FileNotFoundError:
Expand All @@ -326,13 +326,14 @@ def get_config_as_dict() -> dict[str, str | int | bool]:
config = {} # type: Dict[str, Union[str, int, bool]]
config["apikey"] = apikey
config["server"] = server
config["cachedir"] = _root_cache_directory
config["cachedir"] = str(_root_cache_directory)
config["avoid_duplicate_runs"] = avoid_duplicate_runs
config["connection_n_retries"] = connection_n_retries
config["retry_policy"] = retry_policy
return config


# NOTE: For backwards compatibility, we keep the `str`
def get_cache_directory() -> str:
"""Get the current cache directory.
Expand All @@ -354,11 +355,11 @@ def get_cache_directory() -> str:
"""
url_suffix = urlparse(server).netloc
reversed_url_suffix = os.sep.join(url_suffix.split(".")[::-1])
return os.path.join(_root_cache_directory, reversed_url_suffix)
reversed_url_suffix = os.sep.join(url_suffix.split(".")[::-1]) # noqa: PTH118
return os.path.join(_root_cache_directory, reversed_url_suffix) # noqa: PTH118


def set_root_cache_directory(root_cache_directory: str) -> None:
def set_root_cache_directory(root_cache_directory: str | Path) -> None:
"""Set module-wide base cache directory.
Sets the root cache directory, wherin the cache directories are
Expand All @@ -377,8 +378,8 @@ def set_root_cache_directory(root_cache_directory: str) -> None:
--------
get_cache_directory
"""
global _root_cache_directory
_root_cache_directory = root_cache_directory
global _root_cache_directory # noqa: PLW0603
_root_cache_directory = Path(root_cache_directory)


start_using_configuration_for_example = (
Expand Down
10 changes: 8 additions & 2 deletions openml/extensions/sklearn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
# License: BSD 3-Clause
from __future__ import annotations

from typing import TYPE_CHECKING

from openml.extensions import register_extension

from .extension import SklearnExtension

if TYPE_CHECKING:
import pandas as pd

__all__ = ["SklearnExtension"]

register_extension(SklearnExtension)


def cont(X):
def cont(X: pd.DataFrame) -> pd.Series:
"""Returns True for all non-categorical columns, False for the rest.
This is a helper function for OpenML datasets encoded as DataFrames simplifying the handling
Expand All @@ -23,7 +29,7 @@ def cont(X):
return X.dtypes != "category"


def cat(X):
def cat(X: pd.DataFrame) -> pd.Series:
"""Returns True for all categorical columns, False for the rest.
This is a helper function for OpenML datasets encoded as DataFrames simplifying the handling
Expand Down
Loading

0 comments on commit b217a3f

Please sign in to comment.