diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0fd92e4..500aa66 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,13 +14,25 @@ defaults: jobs: linux-unittests: - name: "Unit tests - Python ${{ matrix.PYTHON_VERSION }}" + name: >- + Unit tests - Python ${{ matrix.PYTHON_VERSION }} + ${{ matrix.SKLEARN_VERSION }}${{ matrix.LGBM_VERSION }} timeout-minutes: 15 runs-on: ubuntu-latest strategy: fail-fast: false matrix: - PYTHON_VERSION: ['3.8', '3.9', '3.10', '3.11'] + include: + - { PYTHON_VERSION: '3.8', SKLEARN_VERSION: 'scikit-learn=1.1', LGBM_VERSION: '' } + - { PYTHON_VERSION: '3.8', SKLEARN_VERSION: 'scikit-learn=1.2', LGBM_VERSION: '' } +# - { PYTHON_VERSION: '3.8', SKLEARN_VERSION: 'scikit-learn=1.3', LGBM_VERSION: '' } + - { PYTHON_VERSION: '3.8', SKLEARN_VERSION: '', LGBM_VERSION: 'lightgbm=3.2' } + - { PYTHON_VERSION: '3.8', SKLEARN_VERSION: '', LGBM_VERSION: 'lightgbm=3.3' } + - { PYTHON_VERSION: '3.8', SKLEARN_VERSION: '', LGBM_VERSION: 'lightgbm=4.0' } + - { PYTHON_VERSION: '3.8', SKLEARN_VERSION: '', LGBM_VERSION: '' } + - { PYTHON_VERSION: '3.9', SKLEARN_VERSION: '', LGBM_VERSION: '' } + - { PYTHON_VERSION: '3.10', SKLEARN_VERSION: '', LGBM_VERSION: '' } + - { PYTHON_VERSION: '3.11', SKLEARN_VERSION: '', LGBM_VERSION: '' } steps: - uses: actions/checkout@v3 - name: Set up conda env @@ -32,12 +44,16 @@ jobs: python=${{ matrix.PYTHON_VERSION }} pytest-md pytest-emoji + ${{ matrix.SKLEARN_VERSION }}${{ matrix.LGBM_VERSION }} - name: Install repository run: python -m pip install --no-build-isolation --no-deps --disable-pip-version-check -e . - name: Run unittests uses: pavelzw/pytest-action@v2 with: - report-title: "Unit tests Linux - Python ${{ matrix.PYTHON_VERSION }}" + custom-arguments: ${{ matrix.SKLEARN_VERSION != '' && '-k sklearn' || '' }}${{ matrix.LGBM_VERSION != '' && ' -k lgbm' || '' }} + report-title: >- + Unit tests - Python ${{ matrix.PYTHON_VERSION }} + ${{ matrix.SKLEARN_VERSION }}${{ matrix.LGBM_VERSION }} linux-unittests-pixi: name: "Unit tests Pixi - Python ${{ matrix.PYTHON_VERSION }}" diff --git a/environment-deprecated.yml b/environment-deprecated.yml index d8f685a..b6d283c 100644 --- a/environment-deprecated.yml +++ b/environment-deprecated.yml @@ -4,7 +4,7 @@ channels: - conda-forge - nodefaults dependencies: - - lightgbm <4.0 + - lightgbm >=3.2,<4.1 - numpy - python>=3.8 - pre-commit diff --git a/pixi.toml b/pixi.toml index 74e0cff..f0b0f3d 100644 --- a/pixi.toml +++ b/pixi.toml @@ -13,7 +13,7 @@ lint = "pre-commit run --all" [dependencies] python = ">=3.8" pip = "*" -lightgbm = "<4.0" +lightgbm = ">=3.2,<4.1" numpy = "*" pre-commit = "*" pandas = "*" diff --git a/pyproject.toml b/pyproject.toml index eac76b8..ba6014f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ [project.optional-dependencies] lightgbm = [ - "lightgbm <4.0", + "lightgbm >=3.2,<4.1", ] scikit-learn = [ "scikit-learn <1.3.0", diff --git a/slim_trees/lgbm_booster.py b/slim_trees/lgbm_booster.py index bf1ab74..3147983 100644 --- a/slim_trees/lgbm_booster.py +++ b/slim_trees/lgbm_booster.py @@ -7,6 +7,7 @@ from typing import Any, BinaryIO, List, Tuple import numpy as np +from packaging.version import Version from slim_trees import __version__ as slim_trees_version from slim_trees.compression_utils import ( @@ -17,7 +18,10 @@ from slim_trees.utils import check_version try: + from lightgbm import __version__ as _lightgbm_version from lightgbm.basic import Booster + + lightgbm_version = Version(_lightgbm_version) except ImportError: print("LightGBM does not seem to be installed.") sys.exit(os.EX_CONFIG) @@ -57,21 +61,30 @@ def _booster_unpickle(reconstructor, args, compressed_state): return booster +_handle_key_name = ( + "_handle" if lightgbm_version.major == 4 else "handle" # noqa[PLR2004] +) + + def _compress_booster_state(state: dict): """ For a given state dictionary, store data in a structured format that can then be saved to disk in a way that can be compressed. """ assert type(state) == dict - compressed_state = {k: v for k, v in state.items() if k != "handle"} - compressed_state["compressed_handle"] = _compress_booster_handle(state["handle"]) + compressed_state = {k: v for k, v in state.items() if k != _handle_key_name} + compressed_state["compressed_handle"] = _compress_booster_handle( + state[_handle_key_name] + ) return compressed_state def _decompress_booster_state(compressed_state: dict): assert type(compressed_state) == dict state = {k: v for k, v in compressed_state.items() if k != "compressed_handle"} - state["handle"] = _decompress_booster_handle(compressed_state["compressed_handle"]) + state[_handle_key_name] = _decompress_booster_handle( + compressed_state["compressed_handle"] + ) return state @@ -121,8 +134,10 @@ def parse(str_list, dtype): def _compress_booster_handle(model_string: str) -> Tuple[str, List[dict], str]: - if not model_string.startswith("tree\nversion=v3"): - raise ValueError("Only v3 is supported for the booster string format.") + if not model_string.startswith(f"tree\nversion=v{lightgbm_version.major}"): + raise ValueError( + f"Only v{lightgbm_version.major} is supported for the booster string format." + ) front_str_match = re.search(FRONT_STRING_REGEX, model_string) if front_str_match is None: