diff --git a/slim_trees/compression_utils.py b/slim_trees/compression_utils.py index 89f7042..8499760 100644 --- a/slim_trees/compression_utils.py +++ b/slim_trees/compression_utils.py @@ -1,6 +1,12 @@ import numpy as np +def safe_cast(arr: np.array, dtype): + if np.can_cast(arr.max(), dtype) and np.can_cast(arr.min(), dtype): + return arr.astype(dtype) + raise ValueError(f"Cannot cast array to {dtype}.") + + def _is_in_neighborhood_of_int(arr, iinfo, eps=1e-12): """ Checks if the numbers are around an integer. diff --git a/slim_trees/lgbm_booster.py b/slim_trees/lgbm_booster.py index 9cb3101..8847c83 100644 --- a/slim_trees/lgbm_booster.py +++ b/slim_trees/lgbm_booster.py @@ -11,6 +11,7 @@ from slim_trees.compression_utils import ( compress_half_int_float_array, decompress_half_int_float_array, + safe_cast, ) from slim_trees.utils import check_version @@ -102,6 +103,10 @@ def _extract_feature(feature_line): feats_map = dict(_extract_feature(fl) for fl in features) def parse(str_list, dtype): + if np.can_cast(dtype, np.int64): + int64_array = np.array(str_list, dtype=np.int64) + return safe_cast(int64_array, dtype) + assert np.can_cast(dtype, np.float64) return np.array(str_list, dtype=dtype) split_feature_dtype = np.int16 diff --git a/slim_trees/sklearn_tree.py b/slim_trees/sklearn_tree.py index 21c7a1b..54bf0f1 100644 --- a/slim_trees/sklearn_tree.py +++ b/slim_trees/sklearn_tree.py @@ -5,6 +5,7 @@ from slim_trees.compression_utils import ( compress_half_int_float_array, decompress_half_int_float_array, + safe_cast, ) from slim_trees.utils import check_version @@ -70,11 +71,12 @@ def _compress_tree_state(state: dict): assert np.array_equal(is_leaf, children_right == -1) # feature, threshold and children are irrelevant when leaf - children_left = children_left[~is_leaf].astype(dtype_child) - children_right = children_right[~is_leaf].astype(dtype_child) - features = nodes["feature"][~is_leaf].astype(dtype_feature) + + children_left = safe_cast(children_left[~is_leaf], dtype_child) + children_right = safe_cast(children_right[~is_leaf], dtype_child) + features = safe_cast(nodes["feature"][~is_leaf], dtype_feature) # value is irrelevant when node not a leaf - values = state["values"][is_leaf].astype(dtype_value) + values = safe_cast(state["values"][is_leaf], dtype_value) # do lossless compression for thresholds by downcasting half ints (e.g. 5.5, 10.5, ...) to int8 thresholds = nodes["threshold"][~is_leaf].astype(dtype_threshold) thresholds = compress_half_int_float_array(thresholds) diff --git a/tests/test_compression_utils.py b/tests/test_compression_utils.py index fac5da2..82bf26d 100644 --- a/tests/test_compression_utils.py +++ b/tests/test_compression_utils.py @@ -1,12 +1,39 @@ import numpy as np +import pytest from slim_trees.compression_utils import ( _is_in_neighborhood_of_int, compress_half_int_float_array, decompress_half_int_float_array, + safe_cast, ) +@pytest.mark.parametrize( + "arr,dtype", + [ + (np.array([1, 2, 3, 4, 5]), np.int8), + (np.array([1, 2, 3, 4, 5]), np.uint8), + (np.array([1, 2, 3, 4, 5]), np.int16), + (np.array([200]), np.uint16), + ], +) +def test_safe_cast(arr, dtype): + safe_cast(arr, dtype) + + +@pytest.mark.parametrize( + "arr,dtype", + [ + (np.array([1, 2, 3, 555555]), np.int16), + (np.array([-1, 4, 6]), np.uint32), + ], +) +def test_safe_cast_error(arr, dtype): + with pytest.raises(ValueError): + safe_cast(arr, dtype) + + def test_compress_half_int_float_array(): a1 = np.array([0, 1, 2.5, np.pi, -np.pi, 1e5, 35.5, 2.50000000001]) state = compress_half_int_float_array(a1)