diff --git a/README.md b/README.md index 8dd4098..b9ed7ed 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,8 @@ pixi add slim-trees Using `slim-trees` does not affect your training pipeline. Simply call `dump_sklearn_compressed` or `dump_lgbm_compressed` to save your model. -> ⚠️ `slim-trees` does not save all the data that would be saved by `sklearn`: +> [!WARNING] +> `slim-trees` does not save all the data that would be saved by `sklearn`: > only the parameters that are relevant for inference are saved. If you want to save the full model including > `impurity` etc. for analytic purposes, we suggest saving both the original using `pickle.dump` for analytics > and the slimmed down version using `slim-trees` for production. diff --git a/pixi.toml b/pixi.toml index dd8c5f7..8a72341 100644 --- a/pixi.toml +++ b/pixi.toml @@ -1,6 +1,6 @@ [project] # TODO: move to pyproject.toml once pixi supports it name = "slim-trees" -version = "0.2.3" +version = "0.2.4" description = "A python package for efficient pickling of ML models." authors = ["Pavel Zwerschke "] channels = ["conda-forge"] diff --git a/pyproject.toml b/pyproject.toml index 2f9f317..f7989d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "slim-trees" description = "A python package for efficient pickling of ML models." -version = "0.2.3" +version = "0.2.4" readme = "README.md" license = "MIT" requires-python = ">=3.8" diff --git a/slim_trees/lgbm_booster.py b/slim_trees/lgbm_booster.py index 30d3ec4..9da232c 100644 --- a/slim_trees/lgbm_booster.py +++ b/slim_trees/lgbm_booster.py @@ -198,8 +198,9 @@ def _compress_booster_handle(model_string: str) -> Tuple[str, List[dict], str]: return front_str, trees, back_str -def _validate_tree_structure(tree: dict) -> bool: - return type(tree) == dict and tree.keys() == { +def _validate_tree_structure(tree: dict): + assert isinstance(tree, dict) + if tree.keys() != { "num_leaves", "num_cat", "split_feature", @@ -210,7 +211,11 @@ def _validate_tree_structure(tree: dict) -> bool: "leaf_value", "is_linear", "shrinkage", - } + }: + raise ValueError( + "Invalid tree structure. Do you use an unsupported LightGBM version or try to load a " + "model that was pickled with a different version of LightGBM?" + ) def _decompress_booster_handle(compressed_state: Tuple[str, List[dict], str]) -> str: diff --git a/slim_trees/sklearn_tree.py b/slim_trees/sklearn_tree.py index 8bb456f..3c643aa 100644 --- a/slim_trees/sklearn_tree.py +++ b/slim_trees/sklearn_tree.py @@ -129,7 +129,7 @@ def _decompress_tree_state(state: dict): """ assert isinstance(state, dict) # TODO: make prettier once python 3.8 is not supported anymore - assert state.keys() == { + if state.keys() != { *{ "max_depth", "node_count", @@ -141,7 +141,11 @@ def _decompress_tree_state(state: dict): "values", }, *({"missing_go_to_left"} if sklearn_version >= Version("1.3") else set()), - } + }: + raise ValueError( + "Invalid tree structure. Do you use an unsupported scikit-learn version " + "or try to load a model that was pickled with a different version of scikit-learn?" + ) n_nodes = state["node_count"] children_left = np.zeros(n_nodes, dtype=np.int64)