Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sklearn 1.3.0 #74

Merged
merged 13 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
include:
- { PYTHON_VERSION: '3.8', SKLEARN_VERSION: 'scikit-learn=1.1', LGBM_VERSION: '' }
- { PYTHON_VERSION: '3.8', SKLEARN_VERSION: 'scikit-learn=1.2', LGBM_VERSION: '' }
# - { PYTHON_VERSION: '3.8', SKLEARN_VERSION: 'scikit-learn=1.3', LGBM_VERSION: '' }
- { PYTHON_VERSION: '3.8', SKLEARN_VERSION: 'scikit-learn=1.3', LGBM_VERSION: '' }
- { PYTHON_VERSION: '3.8', SKLEARN_VERSION: '', LGBM_VERSION: 'lightgbm=3.2' }
- { PYTHON_VERSION: '3.8', SKLEARN_VERSION: '', LGBM_VERSION: 'lightgbm=3.3' }
- { PYTHON_VERSION: '3.8', SKLEARN_VERSION: '', LGBM_VERSION: 'lightgbm=4.0' }
Expand Down
6 changes: 3 additions & 3 deletions environment-deprecated.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ channels:
dependencies:
- lightgbm >=3.2,<4.1
- numpy
- python>=3.8
- python >=3.8
- pre-commit
- pandas
- scikit-learn <1.3.0
- pytest>=7.0
- scikit-learn >=1.1.0,<1.4
- pytest >=7.0
- hatchling
4 changes: 2 additions & 2 deletions pixi.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project] # TODO: move to pyproject.toml once pixi supports it
name = "slim-trees"
version = "0.2.2"
version = "0.2.3"
description = "A python package for efficient pickling of ML models."
authors = ["Pavel Zwerschke <[email protected]>"]
channels = ["conda-forge"]
Expand All @@ -17,7 +17,7 @@ lightgbm = ">=3.2,<4.1"
numpy = "*"
pre-commit = "*"
pandas = "*"
scikit-learn = "<1.3.0"
scikit-learn = ">=1.1.0,<1.4"
hatch = "*"
pytest = ">=7.0"
pytest-md = "*"
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "slim-trees"
description = "A python package for efficient pickling of ML models."
version = "0.2.2"
version = "0.2.3"
readme = "README.md"
license = "MIT"
requires-python = ">=3.8"
Expand All @@ -28,7 +28,7 @@ lightgbm = [
"lightgbm >=3.2,<4.1",
]
scikit-learn = [
"scikit-learn <1.3.0",
"scikit-learn >=1.1.0,<1.4",
]

[project.urls]
Expand Down
64 changes: 48 additions & 16 deletions slim_trees/sklearn_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import sys

from packaging.version import Version

from slim_trees import __version__ as slim_trees_version
from slim_trees.compression_utils import (
compress_half_int_float_array,
Expand All @@ -11,7 +13,11 @@
from slim_trees.utils import check_version

try:
from sklearn import __version__ as _sklearn_version
from sklearn.tree._tree import Tree

sklearn_version = Version(_sklearn_version)
sklearn_version_ge_130 = sklearn_version >= Version("1.3")
except ImportError:
print("scikit-learn does not seem to be installed.")
sys.exit(os.EX_CONFIG)
Expand Down Expand Up @@ -88,35 +94,53 @@ def _compress_tree_state(state: dict):
thresholds = nodes["threshold"][~is_leaf].astype(dtype_threshold)
thresholds = compress_half_int_float_array(thresholds)

if sklearn_version_ge_130:
missing_go_to_left = nodes["missing_go_to_left"][~is_leaf].astype("bool")
else:
missing_go_to_left = None

# TODO: make prettier once python 3.8 is not supported anymore
return {
"max_depth": state["max_depth"],
"node_count": state["node_count"],
"is_leaf": np.packbits(is_leaf),
"children_left": children_left,
"children_right": children_right,
"features": features,
"thresholds": thresholds,
"values": values,
**{
"max_depth": state["max_depth"],
"node_count": state["node_count"],
"is_leaf": np.packbits(is_leaf),
"children_left": children_left,
"children_right": children_right,
"features": features,
"thresholds": thresholds,
"values": values,
},
**(
{"missing_go_to_left": np.packbits(missing_go_to_left)}
if sklearn_version_ge_130
else {}
),
}


def _decompress_tree_state(state: dict):
"""
Decompresses a Tree state.
:param state: 'children_left', 'children_right', 'features', 'thresholds', 'values' as keys.
If the sklearn version is >=1.3.0, also 'missing_go_to_left' is a key.
'max_depth' and 'node_count' are passed through.
:return: dictionary with decompressed tree state.
"""
assert isinstance(state, dict)
# TODO: make prettier once python 3.8 is not supported anymore
assert state.keys() == {
"max_depth",
"node_count",
"is_leaf",
"children_left",
"children_right",
"features",
"thresholds",
"values",
*{
"max_depth",
"node_count",
"is_leaf",
"children_left",
"children_right",
"features",
"thresholds",
"values",
},
*({"missing_go_to_left"} if sklearn_version >= Version("1.3") else set()),
pavelzw marked this conversation as resolved.
Show resolved Hide resolved
}
n_nodes = state["node_count"]

Expand All @@ -126,6 +150,7 @@ def _decompress_tree_state(state: dict):
thresholds = np.zeros(n_nodes, dtype=np.float64)
# same shape as values but with all nodes instead of only the leaves
values = np.zeros((n_nodes, *state["values"].shape[1:]), dtype=np.float64)
missing_go_to_left = np.zeros(n_nodes, dtype="uint8")

is_leaf = np.unpackbits(state["is_leaf"], count=n_nodes).astype("bool")
children_left[~is_leaf] = state["children_left"]
Expand All @@ -137,6 +162,10 @@ def _decompress_tree_state(state: dict):
thresholds[~is_leaf] = decompress_half_int_float_array(state["thresholds"])
thresholds[is_leaf] = -2.0 # threshold of leaves is -2
values[is_leaf] = state["values"]
if sklearn_version_ge_130:
missing_go_to_left[~is_leaf] = np.unpackbits(
state["missing_go_to_left"], count=(~is_leaf).sum()
)

dtype = np.dtype(
[
Expand All @@ -148,12 +177,15 @@ def _decompress_tree_state(state: dict):
("n_node_samples", "<i8"),
("weighted_n_node_samples", "<f8"),
]
+ ([("missing_go_to_left", "<u1")] if sklearn_version_ge_130 else [])
)
nodes = np.zeros(n_nodes, dtype=dtype)
nodes["left_child"] = children_left
nodes["right_child"] = children_right
nodes["feature"] = features
nodes["threshold"] = thresholds
if sklearn_version_ge_130:
nodes["missing_go_to_left"] = missing_go_to_left

return {
"max_depth": state["max_depth"],
Expand Down
12 changes: 11 additions & 1 deletion tests/test_sklearn_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pytest
from packaging.version import Version
from sklearn.ensemble import RandomForestRegressor
from sklearn.tree import DecisionTreeRegressor
from util import (
Expand Down Expand Up @@ -74,6 +75,15 @@ def test_compressed_internal_structure(
np.testing.assert_allclose(
tree_dtype_reduction.value[is_leaf], tree_no_reduction.value[is_leaf]
)
from sklearn import __version__ as _sklearn_version

sklearn_version = Version(_sklearn_version)
sklearn_version_ge_130 = sklearn_version >= Version("1.3")
if sklearn_version_ge_130:
np.testing.assert_allclose(
tree_dtype_reduction.missing_go_to_left[~is_leaf],
tree_no_reduction.missing_go_to_left[~is_leaf],
)


def test_compression_size(diabetes_toy_df, random_forest_regressor, tmp_path):
Expand Down Expand Up @@ -113,7 +123,7 @@ def test_load_times(
load_time_compressed, load_time_uncompressed = get_load_times(
random_forest_regressor, dump_sklearn_compressed, tmp_path, compression_method
)
factor = 4 if compression_method == "no" else 1.5
factor = 8 if compression_method == "no" else 1.5
assert load_time_compressed < factor * load_time_uncompressed


Expand Down