Skip to content

Commit

Permalink
Add support for lightgbm 4 (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelzw authored Jul 27, 2023
1 parent 010b4de commit 56108c5
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 11 deletions.
22 changes: 19 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,25 @@ defaults:

jobs:
linux-unittests:
name: "Unit tests - Python ${{ matrix.PYTHON_VERSION }}"
name: >-
Unit tests - Python ${{ matrix.PYTHON_VERSION }}
${{ matrix.SKLEARN_VERSION }}${{ matrix.LGBM_VERSION }}
timeout-minutes: 15
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
PYTHON_VERSION: ['3.8', '3.9', '3.10', '3.11']
include:
- { PYTHON_VERSION: '3.8', SKLEARN_VERSION: 'scikit-learn=1.1', LGBM_VERSION: '' }
- { PYTHON_VERSION: '3.8', SKLEARN_VERSION: 'scikit-learn=1.2', LGBM_VERSION: '' }
# - { PYTHON_VERSION: '3.8', SKLEARN_VERSION: 'scikit-learn=1.3', LGBM_VERSION: '' }
- { PYTHON_VERSION: '3.8', SKLEARN_VERSION: '', LGBM_VERSION: 'lightgbm=3.2' }
- { PYTHON_VERSION: '3.8', SKLEARN_VERSION: '', LGBM_VERSION: 'lightgbm=3.3' }
- { PYTHON_VERSION: '3.8', SKLEARN_VERSION: '', LGBM_VERSION: 'lightgbm=4.0' }
- { PYTHON_VERSION: '3.8', SKLEARN_VERSION: '', LGBM_VERSION: '' }
- { PYTHON_VERSION: '3.9', SKLEARN_VERSION: '', LGBM_VERSION: '' }
- { PYTHON_VERSION: '3.10', SKLEARN_VERSION: '', LGBM_VERSION: '' }
- { PYTHON_VERSION: '3.11', SKLEARN_VERSION: '', LGBM_VERSION: '' }
steps:
- uses: actions/checkout@v3
- name: Set up conda env
Expand All @@ -32,12 +44,16 @@ jobs:
python=${{ matrix.PYTHON_VERSION }}
pytest-md
pytest-emoji
${{ matrix.SKLEARN_VERSION }}${{ matrix.LGBM_VERSION }}
- name: Install repository
run: python -m pip install --no-build-isolation --no-deps --disable-pip-version-check -e .
- name: Run unittests
uses: pavelzw/pytest-action@v2
with:
report-title: "Unit tests Linux - Python ${{ matrix.PYTHON_VERSION }}"
custom-arguments: ${{ matrix.SKLEARN_VERSION != '' && '-k sklearn' || '' }}${{ matrix.LGBM_VERSION != '' && ' -k lgbm' || '' }}
report-title: >-
Unit tests - Python ${{ matrix.PYTHON_VERSION }}
${{ matrix.SKLEARN_VERSION }}${{ matrix.LGBM_VERSION }}
linux-unittests-pixi:
name: "Unit tests Pixi - Python ${{ matrix.PYTHON_VERSION }}"
Expand Down
2 changes: 1 addition & 1 deletion environment-deprecated.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- nodefaults
dependencies:
- lightgbm <4.0
- lightgbm >=3.2,<4.1
- numpy
- python>=3.8
- pre-commit
Expand Down
2 changes: 1 addition & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ lint = "pre-commit run --all"
[dependencies]
python = ">=3.8"
pip = "*"
lightgbm = "<4.0"
lightgbm = ">=3.2,<4.1"
numpy = "*"
pre-commit = "*"
pandas = "*"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [

[project.optional-dependencies]
lightgbm = [
"lightgbm <4.0",
"lightgbm >=3.2,<4.1",
]
scikit-learn = [
"scikit-learn <1.3.0",
Expand Down
25 changes: 20 additions & 5 deletions slim_trees/lgbm_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, BinaryIO, List, Tuple

import numpy as np
from packaging.version import Version

from slim_trees import __version__ as slim_trees_version
from slim_trees.compression_utils import (
Expand All @@ -17,7 +18,10 @@
from slim_trees.utils import check_version

try:
from lightgbm import __version__ as _lightgbm_version
from lightgbm.basic import Booster

lightgbm_version = Version(_lightgbm_version)
except ImportError:
print("LightGBM does not seem to be installed.")
sys.exit(os.EX_CONFIG)
Expand Down Expand Up @@ -57,21 +61,30 @@ def _booster_unpickle(reconstructor, args, compressed_state):
return booster


_handle_key_name = (
"_handle" if lightgbm_version.major == 4 else "handle" # noqa[PLR2004]
)


def _compress_booster_state(state: dict):
"""
For a given state dictionary, store data in a structured format that can then
be saved to disk in a way that can be compressed.
"""
assert type(state) == dict
compressed_state = {k: v for k, v in state.items() if k != "handle"}
compressed_state["compressed_handle"] = _compress_booster_handle(state["handle"])
compressed_state = {k: v for k, v in state.items() if k != _handle_key_name}
compressed_state["compressed_handle"] = _compress_booster_handle(
state[_handle_key_name]
)
return compressed_state


def _decompress_booster_state(compressed_state: dict):
assert type(compressed_state) == dict
state = {k: v for k, v in compressed_state.items() if k != "compressed_handle"}
state["handle"] = _decompress_booster_handle(compressed_state["compressed_handle"])
state[_handle_key_name] = _decompress_booster_handle(
compressed_state["compressed_handle"]
)
return state


Expand Down Expand Up @@ -121,8 +134,10 @@ def parse(str_list, dtype):


def _compress_booster_handle(model_string: str) -> Tuple[str, List[dict], str]:
if not model_string.startswith("tree\nversion=v3"):
raise ValueError("Only v3 is supported for the booster string format.")
if not model_string.startswith(f"tree\nversion=v{lightgbm_version.major}"):
raise ValueError(
f"Only v{lightgbm_version.major} is supported for the booster string format."
)

front_str_match = re.search(FRONT_STRING_REGEX, model_string)
if front_str_match is None:
Expand Down

0 comments on commit 56108c5

Please sign in to comment.