From 5694c7846db001436d72d95a15bbdba7a8ad3c17 Mon Sep 17 00:00:00 2001 From: Pavel Zwerschke Date: Fri, 17 Mar 2023 17:20:14 +0100 Subject: [PATCH] Add version warnings (#24) --- environment.yml | 2 +- slim_trees/lgbm_booster.py | 13 +++++++++++-- slim_trees/sklearn_tree.py | 9 +++++++-- slim_trees/utils.py | 11 +++++++++++ tests/test_lgbm_compression.py | 27 ++++++++++++++++++++------- tests/test_sklearn_compression.py | 20 ++++++++++++++++++-- tests/{test_util.py => util.py} | 24 ++++++++++++++++++++++++ 7 files changed, 92 insertions(+), 14 deletions(-) create mode 100644 slim_trees/utils.py rename tests/{test_util.py => util.py} (59%) diff --git a/environment.yml b/environment.yml index f3aa4e3..9b60e33 100644 --- a/environment.yml +++ b/environment.yml @@ -9,6 +9,6 @@ dependencies: - pre-commit - pandas - scikit-learn - - pytest + - pytest>=7.0 - setuptools_scm - tbump diff --git a/slim_trees/lgbm_booster.py b/slim_trees/lgbm_booster.py index 9efa15b..9cb3101 100644 --- a/slim_trees/lgbm_booster.py +++ b/slim_trees/lgbm_booster.py @@ -7,10 +7,12 @@ import numpy as np +from slim_trees import __version__ as slim_trees_version from slim_trees.compression_utils import ( compress_half_int_float_array, decompress_half_int_float_array, ) +from slim_trees.utils import check_version try: from lightgbm.basic import Booster @@ -30,12 +32,19 @@ def _booster_pickle(booster: Booster): assert isinstance(booster, Booster) reconstructor, args, state = booster.__reduce__() compressed_state = _compress_booster_state(state) - return _booster_unpickle, (reconstructor, args, compressed_state) + return _booster_unpickle, ( + reconstructor, + args, + (slim_trees_version, compressed_state), + ) def _booster_unpickle(reconstructor, args, compressed_state): + version, state = compressed_state + check_version(version) + booster = reconstructor(*args) - decompressed_state = _decompress_booster_state(compressed_state) + decompressed_state = _decompress_booster_state(state) booster.__setstate__(decompressed_state) return booster diff --git a/slim_trees/sklearn_tree.py b/slim_trees/sklearn_tree.py index a6604c8..b561eea 100644 --- a/slim_trees/sklearn_tree.py +++ b/slim_trees/sklearn_tree.py @@ -1,10 +1,12 @@ import os import sys +from slim_trees import __version__ as slim_trees_version from slim_trees.compression_utils import ( compress_half_int_float_array, decompress_half_int_float_array, ) +from slim_trees.utils import check_version try: from sklearn.tree._tree import Tree @@ -30,10 +32,13 @@ def _tree_pickle(tree): assert isinstance(tree, Tree) reconstructor, args, state = tree.__reduce__() compressed_state = _compress_tree_state(state) - return _tree_unpickle, (reconstructor, args, compressed_state) + return _tree_unpickle, (reconstructor, args, (slim_trees_version, compressed_state)) -def _tree_unpickle(reconstructor, args, state): +def _tree_unpickle(reconstructor, args, compressed_state): + version, state = compressed_state + check_version(version) + tree = reconstructor(*args) decompressed_state = _decompress_tree_state(state) tree.__setstate__(decompressed_state) diff --git a/slim_trees/utils.py b/slim_trees/utils.py new file mode 100644 index 0000000..9c4965d --- /dev/null +++ b/slim_trees/utils.py @@ -0,0 +1,11 @@ +import warnings + +from slim_trees import __version__ as slim_trees_version + + +def check_version(version: str): + if version != slim_trees_version: + warnings.warn( + f"Version mismatch: slim_trees version {slim_trees_version} " + f"does not match version {version} of the model." + ) diff --git a/tests/test_lgbm_compression.py b/tests/test_lgbm_compression.py index 5b7e339..bc5915e 100644 --- a/tests/test_lgbm_compression.py +++ b/tests/test_lgbm_compression.py @@ -3,9 +3,15 @@ import numpy as np import pytest from lightgbm import LGBMRegressor -from test_util import get_dump_times, get_load_times +from util import ( + assert_version_pickle, + assert_version_unpickle, + get_dump_times, + get_load_times, +) from slim_trees import dump_lgbm_compressed +from slim_trees.lgbm_booster import _booster_pickle from slim_trees.pickling import dump_compressed, load_compressed @@ -27,8 +33,7 @@ def test_compresed_predictions(diabetes_toy_df, lgbm_regressor, tmp_path): def test_compressed_size(diabetes_toy_df, lgbm_regressor, tmp_path): - X, y = diabetes_toy_df - lgbm_regressor.fit(X, y) + lgbm_regressor.fit(*diabetes_toy_df) model_path_compressed = tmp_path / "model_compressed.pickle.lzma" model_path = tmp_path / "model.pickle.lzma" @@ -41,8 +46,7 @@ def test_compressed_size(diabetes_toy_df, lgbm_regressor, tmp_path): @pytest.mark.parametrize("compression_method", ["no", "lzma", "gzip", "bz2"]) def test_dump_times(diabetes_toy_df, lgbm_regressor, tmp_path, compression_method): - X, y = diabetes_toy_df - lgbm_regressor.fit(X, y) + lgbm_regressor.fit(*diabetes_toy_df) factor = 7 if compression_method == "no" else 4 dump_time_compressed, dump_time_uncompressed = get_dump_times( @@ -53,8 +57,7 @@ def test_dump_times(diabetes_toy_df, lgbm_regressor, tmp_path, compression_metho @pytest.mark.parametrize("compression_method", ["no", "lzma", "gzip", "bz2"]) def test_load_times(diabetes_toy_df, lgbm_regressor, tmp_path, compression_method): - X, y = diabetes_toy_df - lgbm_regressor.fit(X, y) + lgbm_regressor.fit(*diabetes_toy_df) load_time_compressed, load_time_uncompressed = get_load_times( lgbm_regressor, dump_lgbm_compressed, tmp_path, compression_method @@ -63,4 +66,14 @@ def test_load_times(diabetes_toy_df, lgbm_regressor, tmp_path, compression_metho assert load_time_compressed < factor * load_time_uncompressed +def test_tree_version_pickle(diabetes_toy_df, lgbm_regressor): + lgbm_regressor.fit(*diabetes_toy_df) + assert_version_pickle(_booster_pickle, lgbm_regressor.booster_) + + +def test_tree_version_unpickle(diabetes_toy_df, lgbm_regressor): + lgbm_regressor.fit(*diabetes_toy_df) + assert_version_unpickle(_booster_pickle, lgbm_regressor.booster_) + + # todo add tests for large models diff --git a/tests/test_sklearn_compression.py b/tests/test_sklearn_compression.py index 32100ca..c3fce04 100644 --- a/tests/test_sklearn_compression.py +++ b/tests/test_sklearn_compression.py @@ -4,10 +4,16 @@ import pytest from sklearn.ensemble import RandomForestRegressor from sklearn.tree import DecisionTreeRegressor -from test_util import get_dump_times, get_load_times +from util import ( + assert_version_pickle, + assert_version_unpickle, + get_dump_times, + get_load_times, +) from slim_trees import dump_sklearn_compressed from slim_trees.pickling import dump_compressed, load_compressed +from slim_trees.sklearn_tree import _tree_pickle @pytest.fixture @@ -27,7 +33,7 @@ def decision_tree_regressor(rng): def test_compressed_predictions(diabetes_toy_df, random_forest_regressor, tmp_path): X, y = diabetes_toy_df - random_forest_regressor.fit(X, y) + random_forest_regressor.fit(*diabetes_toy_df) model_path = tmp_path / "model_compressed.pickle.lzma" dump_sklearn_compressed(random_forest_regressor, model_path) @@ -110,4 +116,14 @@ def test_load_times( assert load_time_compressed < factor * load_time_uncompressed +def test_tree_version_pickle(diabetes_toy_df, decision_tree_regressor): + decision_tree_regressor.fit(*diabetes_toy_df) + assert_version_pickle(_tree_pickle, decision_tree_regressor.tree_) + + +def test_tree_version_unpickle(diabetes_toy_df, decision_tree_regressor): + decision_tree_regressor.fit(*diabetes_toy_df) + assert_version_unpickle(_tree_pickle, decision_tree_regressor.tree_) + + # todo add tests for large models diff --git a/tests/test_util.py b/tests/util.py similarity index 59% rename from tests/test_util.py rename to tests/util.py index f37478f..4274c67 100644 --- a/tests/test_util.py +++ b/tests/util.py @@ -1,5 +1,9 @@ import timeit +import warnings +import pytest + +from slim_trees import __version__ as slim_trees_version from slim_trees import dump_compressed from slim_trees.pickling import load_compressed @@ -30,3 +34,23 @@ def get_load_times(model, dump_lib_compressed, tmp_path, method): lambda: load_compressed(model_path, method), number=5 ) return load_time_compressed, load_time_uncompressed + + +def assert_version_pickle(pickle_function, element): + _, (_, _, (version, _)) = pickle_function(element) + assert slim_trees_version == version + + +def assert_version_unpickle(pickle_function, element): + _unpickle_function, ( + reconstructor, + args, + (version, compressed_state), + ) = pickle_function(element) + with warnings.catch_warnings(): + warnings.simplefilter("error") + _unpickle_function(reconstructor, args, (version, compressed_state)) + with pytest.warns() as record: + _unpickle_function(reconstructor, args, ("0.0.0", compressed_state)) + assert len(record) == 1 + assert "version mismatch" in str(record[0].message).lower()