Skip to content

Commit

Permalink
Make dtype casts safe (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelzw authored Mar 20, 2023
1 parent 7967b9e commit e85e0fd
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 4 deletions.
6 changes: 6 additions & 0 deletions slim_trees/compression_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import numpy as np


def safe_cast(arr: np.array, dtype):
if np.can_cast(arr.max(), dtype) and np.can_cast(arr.min(), dtype):
return arr.astype(dtype)
raise ValueError(f"Cannot cast array to {dtype}.")


def _is_in_neighborhood_of_int(arr, iinfo, eps=1e-12):
"""
Checks if the numbers are around an integer.
Expand Down
5 changes: 5 additions & 0 deletions slim_trees/lgbm_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from slim_trees.compression_utils import (
compress_half_int_float_array,
decompress_half_int_float_array,
safe_cast,
)
from slim_trees.utils import check_version

Expand Down Expand Up @@ -102,6 +103,10 @@ def _extract_feature(feature_line):
feats_map = dict(_extract_feature(fl) for fl in features)

def parse(str_list, dtype):
if np.can_cast(dtype, np.int64):
int64_array = np.array(str_list, dtype=np.int64)
return safe_cast(int64_array, dtype)
assert np.can_cast(dtype, np.float64)
return np.array(str_list, dtype=dtype)

split_feature_dtype = np.int16
Expand Down
10 changes: 6 additions & 4 deletions slim_trees/sklearn_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from slim_trees.compression_utils import (
compress_half_int_float_array,
decompress_half_int_float_array,
safe_cast,
)
from slim_trees.utils import check_version

Expand Down Expand Up @@ -70,11 +71,12 @@ def _compress_tree_state(state: dict):
assert np.array_equal(is_leaf, children_right == -1)

# feature, threshold and children are irrelevant when leaf
children_left = children_left[~is_leaf].astype(dtype_child)
children_right = children_right[~is_leaf].astype(dtype_child)
features = nodes["feature"][~is_leaf].astype(dtype_feature)

children_left = safe_cast(children_left[~is_leaf], dtype_child)
children_right = safe_cast(children_right[~is_leaf], dtype_child)
features = safe_cast(nodes["feature"][~is_leaf], dtype_feature)
# value is irrelevant when node not a leaf
values = state["values"][is_leaf].astype(dtype_value)
values = safe_cast(state["values"][is_leaf], dtype_value)
# do lossless compression for thresholds by downcasting half ints (e.g. 5.5, 10.5, ...) to int8
thresholds = nodes["threshold"][~is_leaf].astype(dtype_threshold)
thresholds = compress_half_int_float_array(thresholds)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_compression_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
import numpy as np
import pytest

from slim_trees.compression_utils import (
_is_in_neighborhood_of_int,
compress_half_int_float_array,
decompress_half_int_float_array,
safe_cast,
)


@pytest.mark.parametrize(
"arr,dtype",
[
(np.array([1, 2, 3, 4, 5]), np.int8),
(np.array([1, 2, 3, 4, 5]), np.uint8),
(np.array([1, 2, 3, 4, 5]), np.int16),
(np.array([200]), np.uint16),
],
)
def test_safe_cast(arr, dtype):
safe_cast(arr, dtype)


@pytest.mark.parametrize(
"arr,dtype",
[
(np.array([1, 2, 3, 555555]), np.int16),
(np.array([-1, 4, 6]), np.uint32),
],
)
def test_safe_cast_error(arr, dtype):
with pytest.raises(ValueError):
safe_cast(arr, dtype)


def test_compress_half_int_float_array():
a1 = np.array([0, 1, 2.5, np.pi, -np.pi, 1e5, 35.5, 2.50000000001])
state = compress_half_int_float_array(a1)
Expand Down

0 comments on commit e85e0fd

Please sign in to comment.