Skip to content

Commit

Permalink
Add version warnings (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelzw authored Mar 17, 2023
1 parent 9dac053 commit 5694c78
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 14 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ dependencies:
- pre-commit
- pandas
- scikit-learn
- pytest
- pytest>=7.0
- setuptools_scm
- tbump
13 changes: 11 additions & 2 deletions slim_trees/lgbm_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

import numpy as np

from slim_trees import __version__ as slim_trees_version
from slim_trees.compression_utils import (
compress_half_int_float_array,
decompress_half_int_float_array,
)
from slim_trees.utils import check_version

try:
from lightgbm.basic import Booster
Expand All @@ -30,12 +32,19 @@ def _booster_pickle(booster: Booster):
assert isinstance(booster, Booster)
reconstructor, args, state = booster.__reduce__()
compressed_state = _compress_booster_state(state)
return _booster_unpickle, (reconstructor, args, compressed_state)
return _booster_unpickle, (
reconstructor,
args,
(slim_trees_version, compressed_state),
)


def _booster_unpickle(reconstructor, args, compressed_state):
version, state = compressed_state
check_version(version)

booster = reconstructor(*args)
decompressed_state = _decompress_booster_state(compressed_state)
decompressed_state = _decompress_booster_state(state)
booster.__setstate__(decompressed_state)
return booster

Expand Down
9 changes: 7 additions & 2 deletions slim_trees/sklearn_tree.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import sys

from slim_trees import __version__ as slim_trees_version
from slim_trees.compression_utils import (
compress_half_int_float_array,
decompress_half_int_float_array,
)
from slim_trees.utils import check_version

try:
from sklearn.tree._tree import Tree
Expand All @@ -30,10 +32,13 @@ def _tree_pickle(tree):
assert isinstance(tree, Tree)
reconstructor, args, state = tree.__reduce__()
compressed_state = _compress_tree_state(state)
return _tree_unpickle, (reconstructor, args, compressed_state)
return _tree_unpickle, (reconstructor, args, (slim_trees_version, compressed_state))


def _tree_unpickle(reconstructor, args, state):
def _tree_unpickle(reconstructor, args, compressed_state):
version, state = compressed_state
check_version(version)

tree = reconstructor(*args)
decompressed_state = _decompress_tree_state(state)
tree.__setstate__(decompressed_state)
Expand Down
11 changes: 11 additions & 0 deletions slim_trees/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import warnings

from slim_trees import __version__ as slim_trees_version


def check_version(version: str):
if version != slim_trees_version:
warnings.warn(
f"Version mismatch: slim_trees version {slim_trees_version} "
f"does not match version {version} of the model."
)
27 changes: 20 additions & 7 deletions tests/test_lgbm_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
import numpy as np
import pytest
from lightgbm import LGBMRegressor
from test_util import get_dump_times, get_load_times
from util import (
assert_version_pickle,
assert_version_unpickle,
get_dump_times,
get_load_times,
)

from slim_trees import dump_lgbm_compressed
from slim_trees.lgbm_booster import _booster_pickle
from slim_trees.pickling import dump_compressed, load_compressed


Expand All @@ -27,8 +33,7 @@ def test_compresed_predictions(diabetes_toy_df, lgbm_regressor, tmp_path):


def test_compressed_size(diabetes_toy_df, lgbm_regressor, tmp_path):
X, y = diabetes_toy_df
lgbm_regressor.fit(X, y)
lgbm_regressor.fit(*diabetes_toy_df)

model_path_compressed = tmp_path / "model_compressed.pickle.lzma"
model_path = tmp_path / "model.pickle.lzma"
Expand All @@ -41,8 +46,7 @@ def test_compressed_size(diabetes_toy_df, lgbm_regressor, tmp_path):

@pytest.mark.parametrize("compression_method", ["no", "lzma", "gzip", "bz2"])
def test_dump_times(diabetes_toy_df, lgbm_regressor, tmp_path, compression_method):
X, y = diabetes_toy_df
lgbm_regressor.fit(X, y)
lgbm_regressor.fit(*diabetes_toy_df)
factor = 7 if compression_method == "no" else 4

dump_time_compressed, dump_time_uncompressed = get_dump_times(
Expand All @@ -53,8 +57,7 @@ def test_dump_times(diabetes_toy_df, lgbm_regressor, tmp_path, compression_metho

@pytest.mark.parametrize("compression_method", ["no", "lzma", "gzip", "bz2"])
def test_load_times(diabetes_toy_df, lgbm_regressor, tmp_path, compression_method):
X, y = diabetes_toy_df
lgbm_regressor.fit(X, y)
lgbm_regressor.fit(*diabetes_toy_df)

load_time_compressed, load_time_uncompressed = get_load_times(
lgbm_regressor, dump_lgbm_compressed, tmp_path, compression_method
Expand All @@ -63,4 +66,14 @@ def test_load_times(diabetes_toy_df, lgbm_regressor, tmp_path, compression_metho
assert load_time_compressed < factor * load_time_uncompressed


def test_tree_version_pickle(diabetes_toy_df, lgbm_regressor):
lgbm_regressor.fit(*diabetes_toy_df)
assert_version_pickle(_booster_pickle, lgbm_regressor.booster_)


def test_tree_version_unpickle(diabetes_toy_df, lgbm_regressor):
lgbm_regressor.fit(*diabetes_toy_df)
assert_version_unpickle(_booster_pickle, lgbm_regressor.booster_)


# todo add tests for large models
20 changes: 18 additions & 2 deletions tests/test_sklearn_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@
import pytest
from sklearn.ensemble import RandomForestRegressor
from sklearn.tree import DecisionTreeRegressor
from test_util import get_dump_times, get_load_times
from util import (
assert_version_pickle,
assert_version_unpickle,
get_dump_times,
get_load_times,
)

from slim_trees import dump_sklearn_compressed
from slim_trees.pickling import dump_compressed, load_compressed
from slim_trees.sklearn_tree import _tree_pickle


@pytest.fixture
Expand All @@ -27,7 +33,7 @@ def decision_tree_regressor(rng):

def test_compressed_predictions(diabetes_toy_df, random_forest_regressor, tmp_path):
X, y = diabetes_toy_df
random_forest_regressor.fit(X, y)
random_forest_regressor.fit(*diabetes_toy_df)

model_path = tmp_path / "model_compressed.pickle.lzma"
dump_sklearn_compressed(random_forest_regressor, model_path)
Expand Down Expand Up @@ -110,4 +116,14 @@ def test_load_times(
assert load_time_compressed < factor * load_time_uncompressed


def test_tree_version_pickle(diabetes_toy_df, decision_tree_regressor):
decision_tree_regressor.fit(*diabetes_toy_df)
assert_version_pickle(_tree_pickle, decision_tree_regressor.tree_)


def test_tree_version_unpickle(diabetes_toy_df, decision_tree_regressor):
decision_tree_regressor.fit(*diabetes_toy_df)
assert_version_unpickle(_tree_pickle, decision_tree_regressor.tree_)


# todo add tests for large models
24 changes: 24 additions & 0 deletions tests/test_util.py → tests/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import timeit
import warnings

import pytest

from slim_trees import __version__ as slim_trees_version
from slim_trees import dump_compressed
from slim_trees.pickling import load_compressed

Expand Down Expand Up @@ -30,3 +34,23 @@ def get_load_times(model, dump_lib_compressed, tmp_path, method):
lambda: load_compressed(model_path, method), number=5
)
return load_time_compressed, load_time_uncompressed


def assert_version_pickle(pickle_function, element):
_, (_, _, (version, _)) = pickle_function(element)
assert slim_trees_version == version


def assert_version_unpickle(pickle_function, element):
_unpickle_function, (
reconstructor,
args,
(version, compressed_state),
) = pickle_function(element)
with warnings.catch_warnings():
warnings.simplefilter("error")
_unpickle_function(reconstructor, args, (version, compressed_state))
with pytest.warns() as record:
_unpickle_function(reconstructor, args, ("0.0.0", compressed_state))
assert len(record) == 1
assert "version mismatch" in str(record[0].message).lower()

0 comments on commit 5694c78

Please sign in to comment.