Skip to content

Commit

Permalink
improve error messages (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelzw authored Aug 17, 2023
1 parent d5bdad4 commit 9932b30
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 8 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ pixi add slim-trees
Using `slim-trees` does not affect your training pipeline.
Simply call `dump_sklearn_compressed` or `dump_lgbm_compressed` to save your model.

> ⚠️ `slim-trees` does not save all the data that would be saved by `sklearn`:
> [!WARNING]
> `slim-trees` does not save all the data that would be saved by `sklearn`:
> only the parameters that are relevant for inference are saved. If you want to save the full model including
> `impurity` etc. for analytic purposes, we suggest saving both the original using `pickle.dump` for analytics
> and the slimmed down version using `slim-trees` for production.
Expand Down
2 changes: 1 addition & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project] # TODO: move to pyproject.toml once pixi supports it
name = "slim-trees"
version = "0.2.3"
version = "0.2.4"
description = "A python package for efficient pickling of ML models."
authors = ["Pavel Zwerschke <[email protected]>"]
channels = ["conda-forge"]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "slim-trees"
description = "A python package for efficient pickling of ML models."
version = "0.2.3"
version = "0.2.4"
readme = "README.md"
license = "MIT"
requires-python = ">=3.8"
Expand Down
11 changes: 8 additions & 3 deletions slim_trees/lgbm_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ def _compress_booster_handle(model_string: str) -> Tuple[str, List[dict], str]:
return front_str, trees, back_str


def _validate_tree_structure(tree: dict) -> bool:
return type(tree) == dict and tree.keys() == {
def _validate_tree_structure(tree: dict):
assert isinstance(tree, dict)
if tree.keys() != {
"num_leaves",
"num_cat",
"split_feature",
Expand All @@ -210,7 +211,11 @@ def _validate_tree_structure(tree: dict) -> bool:
"leaf_value",
"is_linear",
"shrinkage",
}
}:
raise ValueError(
"Invalid tree structure. Do you use an unsupported LightGBM version or try to load a "
"model that was pickled with a different version of LightGBM?"
)


def _decompress_booster_handle(compressed_state: Tuple[str, List[dict], str]) -> str:
Expand Down
8 changes: 6 additions & 2 deletions slim_trees/sklearn_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _decompress_tree_state(state: dict):
"""
assert isinstance(state, dict)
# TODO: make prettier once python 3.8 is not supported anymore
assert state.keys() == {
if state.keys() != {
*{
"max_depth",
"node_count",
Expand All @@ -141,7 +141,11 @@ def _decompress_tree_state(state: dict):
"values",
},
*({"missing_go_to_left"} if sklearn_version >= Version("1.3") else set()),
}
}:
raise ValueError(
"Invalid tree structure. Do you use an unsupported scikit-learn version "
"or try to load a model that was pickled with a different version of scikit-learn?"
)
n_nodes = state["node_count"]

children_left = np.zeros(n_nodes, dtype=np.int64)
Expand Down

0 comments on commit 9932b30

Please sign in to comment.