diff --git a/environment-deprecated.yml b/environment-deprecated.yml index 24119a2..d8f685a 100644 --- a/environment-deprecated.yml +++ b/environment-deprecated.yml @@ -4,7 +4,7 @@ channels: - conda-forge - nodefaults dependencies: - - lightgbm + - lightgbm <4.0 - numpy - python>=3.8 - pre-commit diff --git a/pixi.toml b/pixi.toml index 03324a4..74e0cff 100644 --- a/pixi.toml +++ b/pixi.toml @@ -1,6 +1,6 @@ [project] # TODO: move to pyproject.toml once pixi supports it name = "slim-trees" -version = "0.2.1" +version = "0.2.2" description = "A python package for efficient pickling of ML models." authors = ["Pavel Zwerschke "] channels = ["conda-forge"] @@ -13,7 +13,7 @@ lint = "pre-commit run --all" [dependencies] python = ">=3.8" pip = "*" -lightgbm = "*" +lightgbm = "<4.0" numpy = "*" pre-commit = "*" pandas = "*" diff --git a/pyproject.toml b/pyproject.toml index e591655..eac76b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "slim-trees" description = "A python package for efficient pickling of ML models." -version = "0.2.1" +version = "0.2.2" readme = "README.md" license = "MIT" requires-python = ">=3.8" @@ -25,7 +25,7 @@ dependencies = [ [project.optional-dependencies] lightgbm = [ - "lightgbm", + "lightgbm <4.0", ] scikit-learn = [ "scikit-learn <1.3.0", diff --git a/slim_trees/pickling.py b/slim_trees/pickling.py index dfa3e91..f3af990 100644 --- a/slim_trees/pickling.py +++ b/slim_trees/pickling.py @@ -13,6 +13,10 @@ def __init__(self): self.open = open self.compress = lambda data: data + @staticmethod + def decompress(data): + return data + def _get_compression_from_path(path: Union[str, pathlib.Path]) -> str: compressions = { @@ -120,7 +124,9 @@ def dumps_compressed( def load_compressed( - path: Union[str, pathlib.Path], compression: Optional[Union[str, dict]] = None + path: Union[str, pathlib.Path], + compression: Optional[Union[str, dict]] = None, + unpickler_class: type = pickle.Unpickler, ) -> Any: """ Loads a compressed model. @@ -129,15 +135,22 @@ def load_compressed( set to the compression method and other key-value pairs which are forwarded to open() of the compression library. Inspired by the pandas.to_csv interface. + :param unpickler_class: custom unpickler class derived from pickle.Unpickler. + This is useful to restrict possible imports or to allow unpickling + when required module or function names have been refactored. """ compression_method, kwargs = _unpack_compression_args(compression, path) with _get_compression_library(compression_method).open( path, mode="rb", **kwargs ) as file: - return pickle.load(file) + return unpickler_class(file).load() -def loads_compressed(data: bytes, compression: Optional[Union[str, dict]] = None): +def loads_compressed( + data: bytes, + compression: Optional[Union[str, dict]] = None, + unpickler_class: type = pickle.Unpickler, +) -> Any: """ Loads a compressed model. :param data: bytes containing the pickled object. @@ -145,13 +158,16 @@ def loads_compressed(data: bytes, compression: Optional[Union[str, dict]] = None set to the compression method and other key-value pairs which are forwarded to open() of the compression library. Defaults to 'no' compression. Inspired by the pandas.to_csv interface. + :param unpickler_class: custom unpickler class derived from pickle.Unpickler. + This is useful to restrict possible imports or to allow unpickling + when required module or function names have been refactored. """ if compression is None: compression = "no" compression_method, kwargs = _unpack_compression_args(compression, None) data_uncompressed = _get_compression_library(compression_method).decompress(data) - return pickle.loads(data_uncompressed) + return unpickler_class(io.BytesIO(data_uncompressed)).load() def get_pickled_size( diff --git a/tests/test_lgbm_compression.py b/tests/test_lgbm_compression.py index e2d4f44..eaaf5ec 100644 --- a/tests/test_lgbm_compression.py +++ b/tests/test_lgbm_compression.py @@ -1,4 +1,5 @@ import os +import pickle import numpy as np import pytest @@ -10,9 +11,9 @@ get_load_times, ) -from slim_trees import dump_lgbm_compressed +from slim_trees import dump_lgbm_compressed, dumps_lgbm_compressed from slim_trees.lgbm_booster import _booster_pickle -from slim_trees.pickling import dump_compressed, load_compressed +from slim_trees.pickling import dump_compressed, load_compressed, loads_compressed @pytest.fixture @@ -84,4 +85,24 @@ def test_tree_version_unpickle(diabetes_toy_df, lgbm_regressor): assert_version_unpickle(_booster_pickle, lgbm_regressor.booster_) +class _TestUnpickler(pickle.Unpickler): + def find_class(self, module, name): + if module.startswith("lightgbm"): + raise ImportError(f"Module '{module}' not allowed in this test") + return super().find_class(module, name) + + +def test_load_compressed_custom_unpickler(tmp_path, lgbm_regressor): + model_path = tmp_path / "model_compressed.pickle.lzma" + dump_lgbm_compressed(lgbm_regressor, model_path) + with pytest.raises(ImportError, match="lightgbm.*not allowed"): + load_compressed(model_path, unpickler_class=_TestUnpickler) + + +def test_loads_compressed_custom_unpickler(lgbm_regressor): + compressed = dumps_lgbm_compressed(lgbm_regressor) + with pytest.raises(ImportError, match="lightgbm.*not allowed"): + loads_compressed(compressed, unpickler_class=_TestUnpickler) + + # todo add tests for large models diff --git a/tests/test_sklearn_compression.py b/tests/test_sklearn_compression.py index c3fce04..7b1d0b7 100644 --- a/tests/test_sklearn_compression.py +++ b/tests/test_sklearn_compression.py @@ -1,4 +1,5 @@ import os +import pickle import numpy as np import pytest @@ -11,8 +12,8 @@ get_load_times, ) -from slim_trees import dump_sklearn_compressed -from slim_trees.pickling import dump_compressed, load_compressed +from slim_trees import dump_sklearn_compressed, dumps_sklearn_compressed +from slim_trees.pickling import dump_compressed, load_compressed, loads_compressed from slim_trees.sklearn_tree import _tree_pickle @@ -126,4 +127,24 @@ def test_tree_version_unpickle(diabetes_toy_df, decision_tree_regressor): assert_version_unpickle(_tree_pickle, decision_tree_regressor.tree_) +class _TestUnpickler(pickle.Unpickler): + def find_class(self, module, name): + if module.startswith("sklearn"): + raise ImportError(f"Module '{module}' not allowed in this test") + return super().find_class(module, name) + + +def test_load_compressed_custom_unpickler(tmp_path, random_forest_regressor): + model_path = tmp_path / "model_compressed.pickle.lzma" + dump_sklearn_compressed(random_forest_regressor, model_path) + with pytest.raises(ImportError, match="sklearn.*not allowed"): + load_compressed(model_path, unpickler_class=_TestUnpickler) + + +def test_loads_compressed_custom_unpickler(random_forest_regressor): + compressed = dumps_sklearn_compressed(random_forest_regressor) + with pytest.raises(ImportError, match="sklearn.*not allowed"): + loads_compressed(compressed, unpickler_class=_TestUnpickler) + + # todo add tests for large models