Skip to content

Commit

Permalink
Allow for custom unpickler in load(s)_compressed (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanSorgQC authored Jul 18, 2023
1 parent 7b19c57 commit 056ea90
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 13 deletions.
2 changes: 1 addition & 1 deletion environment-deprecated.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- nodefaults
dependencies:
- lightgbm
- lightgbm <4.0
- numpy
- python>=3.8
- pre-commit
Expand Down
4 changes: 2 additions & 2 deletions pixi.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project] # TODO: move to pyproject.toml once pixi supports it
name = "slim-trees"
version = "0.2.1"
version = "0.2.2"
description = "A python package for efficient pickling of ML models."
authors = ["Pavel Zwerschke <[email protected]>"]
channels = ["conda-forge"]
Expand All @@ -13,7 +13,7 @@ lint = "pre-commit run --all"
[dependencies]
python = ">=3.8"
pip = "*"
lightgbm = "*"
lightgbm = "<4.0"
numpy = "*"
pre-commit = "*"
pandas = "*"
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "slim-trees"
description = "A python package for efficient pickling of ML models."
version = "0.2.1"
version = "0.2.2"
readme = "README.md"
license = "MIT"
requires-python = ">=3.8"
Expand All @@ -25,7 +25,7 @@ dependencies = [

[project.optional-dependencies]
lightgbm = [
"lightgbm",
"lightgbm <4.0",
]
scikit-learn = [
"scikit-learn <1.3.0",
Expand Down
24 changes: 20 additions & 4 deletions slim_trees/pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def __init__(self):
self.open = open
self.compress = lambda data: data

@staticmethod
def decompress(data):
return data


def _get_compression_from_path(path: Union[str, pathlib.Path]) -> str:
compressions = {
Expand Down Expand Up @@ -120,7 +124,9 @@ def dumps_compressed(


def load_compressed(
path: Union[str, pathlib.Path], compression: Optional[Union[str, dict]] = None
path: Union[str, pathlib.Path],
compression: Optional[Union[str, dict]] = None,
unpickler_class: type = pickle.Unpickler,
) -> Any:
"""
Loads a compressed model.
Expand All @@ -129,29 +135,39 @@ def load_compressed(
set to the compression method and other key-value pairs which are forwarded
to open() of the compression library.
Inspired by the pandas.to_csv interface.
:param unpickler_class: custom unpickler class derived from pickle.Unpickler.
This is useful to restrict possible imports or to allow unpickling
when required module or function names have been refactored.
"""
compression_method, kwargs = _unpack_compression_args(compression, path)
with _get_compression_library(compression_method).open(
path, mode="rb", **kwargs
) as file:
return pickle.load(file)
return unpickler_class(file).load()


def loads_compressed(data: bytes, compression: Optional[Union[str, dict]] = None):
def loads_compressed(
data: bytes,
compression: Optional[Union[str, dict]] = None,
unpickler_class: type = pickle.Unpickler,
) -> Any:
"""
Loads a compressed model.
:param data: bytes containing the pickled object.
:param compression: the compression method used. Either a string or a dict with key 'method'
set to the compression method and other key-value pairs which are forwarded
to open() of the compression library. Defaults to 'no' compression.
Inspired by the pandas.to_csv interface.
:param unpickler_class: custom unpickler class derived from pickle.Unpickler.
This is useful to restrict possible imports or to allow unpickling
when required module or function names have been refactored.
"""
if compression is None:
compression = "no"

compression_method, kwargs = _unpack_compression_args(compression, None)
data_uncompressed = _get_compression_library(compression_method).decompress(data)
return pickle.loads(data_uncompressed)
return unpickler_class(io.BytesIO(data_uncompressed)).load()


def get_pickled_size(
Expand Down
25 changes: 23 additions & 2 deletions tests/test_lgbm_compression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pickle

import numpy as np
import pytest
Expand All @@ -10,9 +11,9 @@
get_load_times,
)

from slim_trees import dump_lgbm_compressed
from slim_trees import dump_lgbm_compressed, dumps_lgbm_compressed
from slim_trees.lgbm_booster import _booster_pickle
from slim_trees.pickling import dump_compressed, load_compressed
from slim_trees.pickling import dump_compressed, load_compressed, loads_compressed


@pytest.fixture
Expand Down Expand Up @@ -84,4 +85,24 @@ def test_tree_version_unpickle(diabetes_toy_df, lgbm_regressor):
assert_version_unpickle(_booster_pickle, lgbm_regressor.booster_)


class _TestUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if module.startswith("lightgbm"):
raise ImportError(f"Module '{module}' not allowed in this test")
return super().find_class(module, name)


def test_load_compressed_custom_unpickler(tmp_path, lgbm_regressor):
model_path = tmp_path / "model_compressed.pickle.lzma"
dump_lgbm_compressed(lgbm_regressor, model_path)
with pytest.raises(ImportError, match="lightgbm.*not allowed"):
load_compressed(model_path, unpickler_class=_TestUnpickler)


def test_loads_compressed_custom_unpickler(lgbm_regressor):
compressed = dumps_lgbm_compressed(lgbm_regressor)
with pytest.raises(ImportError, match="lightgbm.*not allowed"):
loads_compressed(compressed, unpickler_class=_TestUnpickler)


# todo add tests for large models
25 changes: 23 additions & 2 deletions tests/test_sklearn_compression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pickle

import numpy as np
import pytest
Expand All @@ -11,8 +12,8 @@
get_load_times,
)

from slim_trees import dump_sklearn_compressed
from slim_trees.pickling import dump_compressed, load_compressed
from slim_trees import dump_sklearn_compressed, dumps_sklearn_compressed
from slim_trees.pickling import dump_compressed, load_compressed, loads_compressed
from slim_trees.sklearn_tree import _tree_pickle


Expand Down Expand Up @@ -126,4 +127,24 @@ def test_tree_version_unpickle(diabetes_toy_df, decision_tree_regressor):
assert_version_unpickle(_tree_pickle, decision_tree_regressor.tree_)


class _TestUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if module.startswith("sklearn"):
raise ImportError(f"Module '{module}' not allowed in this test")
return super().find_class(module, name)


def test_load_compressed_custom_unpickler(tmp_path, random_forest_regressor):
model_path = tmp_path / "model_compressed.pickle.lzma"
dump_sklearn_compressed(random_forest_regressor, model_path)
with pytest.raises(ImportError, match="sklearn.*not allowed"):
load_compressed(model_path, unpickler_class=_TestUnpickler)


def test_loads_compressed_custom_unpickler(random_forest_regressor):
compressed = dumps_sklearn_compressed(random_forest_regressor)
with pytest.raises(ImportError, match="sklearn.*not allowed"):
loads_compressed(compressed, unpickler_class=_TestUnpickler)


# todo add tests for large models

0 comments on commit 056ea90

Please sign in to comment.