From e0e096687cae0faccf0f0d1a31344abc0177c4ba Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 19 Jul 2023 12:20:54 +0200 Subject: [PATCH] Add property based testing (#10) --- .github/workflows/github-actions.yml | 5 +- .gitignore | 2 +- legateboost/input_validation.py | 2 +- legateboost/legateboost.py | 206 +++++++++++------------ legateboost/test/test_estimator.py | 58 +------ legateboost/test/test_with_hypothesis.py | 190 +++++++++++++++++++++ legateboost/test/utils.py | 26 +++ 7 files changed, 330 insertions(+), 159 deletions(-) create mode 100644 legateboost/test/test_with_hypothesis.py create mode 100644 legateboost/test/utils.py diff --git a/.github/workflows/github-actions.yml b/.github/workflows/github-actions.yml index 33f9e011..2483169c 100644 --- a/.github/workflows/github-actions.yml +++ b/.github/workflows/github-actions.yml @@ -47,7 +47,7 @@ jobs: python scripts/generate-conda-envs.py --python 3.10 --ctk 11.8 --os linux --compilers --openmpi; mamba env create -n legate -f environment-test-linux-py3.10-cuda11.8-compilers-openmpi-ucx.yaml; mamba activate legate; - mamba install -y -c conda-forge openmpi ucx rust scikit-learn build; + mamba install -y 'gcc<=10.0.0' build scikit-learn; # workaround issue https://github.com/nv-legate/legate.core/issues/789 mkdir -p _skbuild/linux-x86_64-3.10/cmake-build ./install.py --network ucx --cuda --arch RAPIDS --verbose; @@ -119,6 +119,7 @@ jobs: path: docs/build - name: Run cpu tests run: | + pip install hypothesis legate --module pytest legateboost/test -sv --durations=0 - name: Run cpu multi-node tests run: | @@ -127,7 +128,7 @@ jobs: if: ${{ false }} # disabled due to issue #5 - name: Run gpu tests run: | - legate --gpus 1 --module pytest legateboost/test -sv --durations=0 -k 'not sklearn' + legate --gpus 1 --fbmem 16000 --sysmem 16000 --module pytest legateboost/test -sv --durations=0 -k 'not sklearn' deploy: diff --git a/.gitignore b/.gitignore index a37a0389..8e897ece 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,4 @@ legateboost/install_info.py src/legate_library.cc src/legate_library.h legate_prof* -build +.hypothesis diff --git a/legateboost/input_validation.py b/legateboost/input_validation.py index a04780ed..7aeef0fc 100644 --- a/legateboost/input_validation.py +++ b/legateboost/input_validation.py @@ -47,7 +47,7 @@ def check_array(x: Any) -> cn.ndarray: if cn.iscomplexobj(x): raise ValueError("Complex data not supported.") - if not cn.isfinite(x).all(): + if np.issubdtype(x.dtype, np.floating) and not cn.isfinite(x).all(): raise ValueError("Input contains NaN or inf") x = cn.array(x, copy=False) diff --git a/legateboost/legateboost.py b/legateboost/legateboost.py index 29144752..b3842519 100644 --- a/legateboost/legateboost.py +++ b/legateboost/legateboost.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math import warnings from enum import IntEnum from typing import Any, Tuple, Union @@ -71,38 +72,107 @@ def left_child(self, id: int) -> int: def right_child(self, id: int) -> int: return id * 2 + 2 + def num_procs_to_use(self, num_rows: int) -> int: + min_rows_per_worker = 10 + available_procs = len(get_legate_runtime().machine) + return min(available_procs, int(math.ceil(num_rows / min_rows_per_worker))) + def __init__( self, - leaf_value: cn.ndarray, - feature: cn.ndarray, - split_value: cn.ndarray, - gain: cn.ndarray, - hessian: cn.ndarray, + X: cn.ndarray, + g: cn.ndarray, + h: cn.ndarray, + learning_rate: float, + max_depth: int, + random_state: np.random.RandomState, ) -> None: - """Initialise from existing storage.""" - self.leaf_value = leaf_value - assert leaf_value.dtype == cn.float64 - self.feature = feature - assert feature.dtype == cn.int32 - self.split_value = split_value - assert split_value.dtype == cn.float64 - self.hessian = hessian - assert hessian.dtype == cn.float64 - self.gain = gain - assert gain.dtype == cn.float64 + # choose possible splits + split_proposals = X[ + random_state.randint(0, X.shape[0], max_depth) + ] # may not be efficient, maybe write new task + num_features = X.shape[1] + num_outputs = g.shape[1] + n_rows = X.shape[0] + num_procs = self.num_procs_to_use(n_rows) + use_gpu = get_legate_runtime().machine.preferred_kind == 1 + rows_per_tile = int(cn.ceil(n_rows / num_procs)) + + task = user_context.create_manual_task( + LegateBoostOpCode.BUILD_TREE, launch_domain=Rect((num_procs, 1)) + ) + + # Defining a projection function (even the identity) prevents legate + # from trying to assign empty tiles to workers + # in the case where the number of tiles is less than the launch grid + def proj(x: Tuple[int, int]) -> Tuple[int, int]: + return (x[0], 0) # everything crashes if this is lambda x: x ???? + + # inputs + task.add_scalar_arg(learning_rate, types.float64) + task.add_scalar_arg(max_depth, types.int32) + task.add_scalar_arg(random_state.randint(0, 2**32), types.uint64) + + task.add_input( + partition_if_not_future(X, (rows_per_tile, num_features)), proj=proj + ) + task.add_input( + partition_if_not_future(g, (rows_per_tile, num_outputs)), proj=proj + ) + task.add_input( + partition_if_not_future(h, (rows_per_tile, num_outputs)), proj=proj + ) + task.add_input(_get_store(split_proposals)) + + # outputs + # force 1d arrays to be 2d otherwise we get the dreaded assert proj_id == 0 + max_nodes = 2 ** (max_depth + 1) + leaf_value = user_context.create_store(types.float64, (max_nodes, num_outputs)) + feature = user_context.create_store(types.int32, (max_nodes, 1)) + split_value = user_context.create_store(types.float64, (max_nodes, 1)) + gain = user_context.create_store(types.float64, (max_nodes, 1)) + hessian = user_context.create_store(types.float64, (max_nodes, num_outputs)) + + # All outputs belong to a single tile on worker 0 + task.add_output( + leaf_value.partition_by_tiling((max_nodes, num_outputs)), proj=proj + ) + task.add_output(feature.partition_by_tiling((max_nodes, 1)), proj=proj) + task.add_output(split_value.partition_by_tiling((max_nodes, 1)), proj=proj) + task.add_output(gain.partition_by_tiling((max_nodes, 1)), proj=proj) + task.add_output( + hessian.partition_by_tiling((max_nodes, num_outputs)), proj=proj + ) + + if num_procs > 1: + if use_gpu: + task.add_nccl_communicator() + else: + task.add_cpu_communicator() + + task.execute() + + self.leaf_value = cn.array(leaf_value, copy=False) + self.feature = cn.array(feature, copy=False).squeeze() + self.split_value = cn.array(split_value, copy=False).squeeze() + self.gain = cn.array(gain, copy=False).squeeze() + self.hessian = cn.array(hessian, copy=False) def predict(self, X: cn.ndarray) -> cn.ndarray: n_rows = X.shape[0] n_features = X.shape[1] n_outputs = self.leaf_value.shape[1] - num_procs = len(get_legate_runtime().machine) - # dont launch more tasks than rows - num_procs = min(num_procs, n_rows) + num_procs = self.num_procs_to_use(n_rows) rows_per_tile = int(cn.ceil(n_rows / num_procs)) task = user_context.create_manual_task( LegateBoostOpCode.PREDICT, Rect((num_procs, 1)) ) - task.add_input(partition_if_not_future(X, (rows_per_tile, n_features))) + + def proj(x: Tuple[int, int]) -> Tuple[int, int]: + return (x[0], 0) + + task.add_input( + partition_if_not_future(X, (rows_per_tile, n_features)), proj=proj + ) # broadcast the tree structure task.add_input(_get_store(self.leaf_value)) @@ -110,7 +180,9 @@ def predict(self, X: cn.ndarray) -> cn.ndarray: task.add_input(_get_store(self.split_value)) pred = user_context.create_store(types.float64, (n_rows, n_outputs)) - task.add_output(partition_if_not_future(pred, (rows_per_tile, n_outputs))) + task.add_output( + partition_if_not_future(pred, (rows_per_tile, n_outputs)), proj=proj + ) task.execute() return cn.array(pred, copy=False) @@ -166,81 +238,6 @@ def _get_store(input: Any) -> Store: return store -def build_tree_native( - X: cn.ndarray, - g: cn.ndarray, - h: cn.ndarray, - learning_rate: float, - max_depth: int, - random_state: np.random.RandomState, -) -> TreeStructure: - - # choose possible splits - split_proposals = X[ - random_state.randint(0, X.shape[0], max_depth) - ] # may not be efficient, maybe write new task - num_features = X.shape[1] - num_outputs = g.shape[1] - n_rows = X.shape[0] - num_procs = len(get_legate_runtime().machine) - use_gpu = get_legate_runtime().machine.preferred_kind == 1 - # dont launch more tasks than rows - num_procs = min(num_procs, n_rows) - rows_per_tile = int(cn.ceil(n_rows / num_procs)) - - task = user_context.create_manual_task( - LegateBoostOpCode.BUILD_TREE, launch_domain=Rect((num_procs, 1)) - ) - - # inputs - task.add_scalar_arg(learning_rate, types.float64) - task.add_scalar_arg(max_depth, types.int32) - task.add_scalar_arg(random_state.randint(0, 2**32), types.uint64) - - task.add_input(partition_if_not_future(X, (rows_per_tile, num_features))) - task.add_input(partition_if_not_future(g, (rows_per_tile, num_outputs))) - task.add_input(partition_if_not_future(h, (rows_per_tile, num_outputs))) - task.add_input(_get_store(split_proposals)) - - # outputs - # force 1d arrays to be 2d otherwise we get the dreaded assert proj_id == 0 - max_nodes = 2 ** (max_depth + 1) - leaf_value = user_context.create_store(types.float64, (max_nodes, num_outputs)) - feature = user_context.create_store(types.int32, (max_nodes, 1)) - split_value = user_context.create_store(types.float64, (max_nodes, 1)) - gain = user_context.create_store(types.float64, (max_nodes, 1)) - hessian = user_context.create_store(types.float64, (max_nodes, num_outputs)) - - # All outputs belong to a single tile on worker 0 - # Defining a projection function (even the identity) prevents legate - # from trying to assign empty tiles to workers - # in the case where the number of tiles is less than the launch grid - def proj(x: Tuple[int, int]) -> Tuple[int, int]: - return (x[0], 0) # everything crashes if this is lambda x: x ???? - - task.add_output(leaf_value.partition_by_tiling((max_nodes, num_outputs)), proj=proj) - task.add_output(feature.partition_by_tiling((max_nodes, 1)), proj=proj) - task.add_output(split_value.partition_by_tiling((max_nodes, 1)), proj=proj) - task.add_output(gain.partition_by_tiling((max_nodes, 1)), proj=proj) - task.add_output(hessian.partition_by_tiling((max_nodes, num_outputs)), proj=proj) - - if num_procs > 1: - if use_gpu: - task.add_nccl_communicator() - else: - task.add_cpu_communicator() - - task.execute() - - return TreeStructure( - cn.array(leaf_value, copy=False), - cn.array(feature, copy=False).squeeze(), - cn.array(split_value, copy=False).squeeze(), - cn.array(gain, copy=False).squeeze(), - cn.array(hessian, copy=False), - ) - - class LBBase(BaseEstimator, _PickleCunumericMixin): def __init__( self, @@ -324,15 +321,16 @@ def fit( h = h * sample_weight[:, None] # build new tree - tree = build_tree_native( - X, - g, - h, - self.learning_rate, - self.max_depth, - check_random_state(self.random_state), + self.models_.append( + TreeStructure( + X, + g, + h, + self.learning_rate, + self.max_depth, + check_random_state(self.random_state), + ) ) - self.models_.append(tree) # update current predictions pred += self.models_[-1].predict(X) diff --git a/legateboost/test/test_estimator.py b/legateboost/test/test_estimator.py index dead3834..14ff2d9d 100644 --- a/legateboost/test/test_estimator.py +++ b/legateboost/test/test_estimator.py @@ -1,34 +1,12 @@ import numpy as np import pytest +import utils from sklearn.utils.estimator_checks import parametrize_with_checks import cunumeric as cn import legateboost as lb -def non_increasing(x): - return all(x >= y for x, y in zip(x, x[1:])) - - -def non_decreasing(x): - return all(x <= y for x, y in zip(x, x[1:])) - - -def sanity_check_tree_stats(trees): - for tree in trees: - # Check that we have no 0 hessian splits - split_nodes = tree.feature != -1 - assert cn.all(tree.hessian[split_nodes] > 0.0) - - # Check gain is positive - assert cn.all(tree.gain[split_nodes] > 0.0) - - # Check that hessians of leaves add up to root. - leaves = (tree.feature == -1) & (tree.hessian[:, 0] > 0.0) - leaf_sum = tree.hessian[leaves].sum(axis=0) - assert np.isclose(leaf_sum, tree.hessian[0]).all() - - @pytest.mark.parametrize("num_outputs", [1, 5]) def test_regressor(num_outputs): np.random.seed(2) @@ -39,11 +17,11 @@ def test_regressor(num_outputs): ).fit(X, y) mse = lb.MSEMetric().metric(y, model.predict(X), cn.ones(y.shape[0])) assert np.isclose(model.train_metric_[-1], mse) - assert non_increasing(model.train_metric_) + assert utils.non_increasing(model.train_metric_) # test print model.dump_trees() - sanity_check_tree_stats(model.models_) + utils.sanity_check_tree_stats(model.models_) @pytest.mark.parametrize("num_outputs", [1, 5]) @@ -58,7 +36,7 @@ def test_regressor_improving_with_depth(num_outputs): ) metrics.append(model.train_metric_[-1]) - assert non_increasing(metrics) + assert utils.non_increasing(metrics) @pytest.mark.parametrize("num_outputs", [1, 5]) @@ -140,9 +118,9 @@ def test_classifier(num_class): model = lb.LBClassifier(n_estimators=10).fit(X, y) loss = lb.LogLossMetric().metric(y, model.predict_proba(X), cn.ones(y.shape[0])) assert np.isclose(model.train_metric_[-1], loss) - assert non_increasing(model.train_metric_) + assert utils.non_increasing(model.train_metric_) assert model.score(X, y) > 0.7 - sanity_check_tree_stats(model.models_) + utils.sanity_check_tree_stats(model.models_) @pytest.mark.parametrize("num_class", [2, 5]) @@ -168,26 +146,4 @@ def test_classifier_improving_with_depth(num_class): n_estimators=2, random_state=0, max_depth=max_depth ).fit(X, y) metrics.append(model.train_metric_[-1]) - assert non_increasing(metrics) - - -def test_prediction(): - tree = lb.TreeStructure( - cn.array( - [0.0, 0.0, -0.04619769, 0.01845179, -0.01151532, 0.0, 0.0, 0.0, 0.0, 0.0] - ).reshape(10, 1), - cn.array([0, 1, -1, -1, -1, -1, -1, -1, -1, -1]).astype(cn.int32), - cn.array([0.79172504, 0.71518937, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), - cn.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), - cn.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).reshape( - 10, - 1, - ), - ) - """0:[f0<=0.79172504] yes=1 no=2 1:[f1<=0.71518937] yes=3 no=4 - 3:leaf=0.01845179 4:leaf=-0.01151532 2:leaf=-0.04619769.""" - pred = tree.predict(cn.array([[1.0, 0.0], [0.5, 0.5]])) - assert pred[0] == -0.04619769 - assert pred[1] == 0.01845179 - assert tree.predict(cn.array([[0.5, 1.0]])) == -0.01151532 - assert tree.predict(cn.array([[0.79172504, 0.71518937]])) == 0.01845179 + assert utils.non_increasing(metrics) diff --git a/legateboost/test/test_with_hypothesis.py b/legateboost/test/test_with_hypothesis.py new file mode 100644 index 00000000..4b63d308 --- /dev/null +++ b/legateboost/test/test_with_hypothesis.py @@ -0,0 +1,190 @@ +import numpy as np +import pytest +import utils +from hypothesis import HealthCheck, Verbosity, given, settings, strategies as st + +import cunumeric as cn +import legateboost as lb +from legate.core import get_legate_runtime + +np.set_printoptions(threshold=10, edgeitems=1) + +# adjust max_examples to control runtime +settings.register_profile( + "local", + max_examples=50, + deadline=None, + verbosity=Verbosity.verbose, + suppress_health_check=(HealthCheck.too_slow,), + print_blob=True, +) + +settings.load_profile("local") + + +general_model_param_strategy = st.fixed_dictionaries( + { + "n_estimators": st.integers(1, 20), + "max_depth": st.integers(1, 12), + "init": st.sampled_from([None, "average"]), + "random_state": st.integers(0, 10000), + } +) + +regression_param_strategy = st.fixed_dictionaries( + { + "objective": st.sampled_from(["squared_error"]), + "learning_rate": st.floats(0.01, 1.0), + } +) + + +@st.composite +def regression_real_dataset_strategy(draw): + from sklearn.datasets import fetch_california_housing, fetch_openml, load_diabetes + + name = draw(st.sampled_from(["california_housing", "million_songs", "diabetes"])) + if name == "california_housing": + return fetch_california_housing(return_X_y=True) + elif name == "million_songs": + return fetch_openml(name="year", version=1, return_X_y=True, as_frame=False) + elif name == "diabetes": + return load_diabetes(return_X_y=True) + + +@st.composite +def regression_generated_dataset_strategy(draw): + num_outputs = draw(st.integers(1, 5)) + num_features = draw(st.integers(1, 150)) + num_rows = draw(st.integers(1, 10000)) + np.random.seed(2) + X = cn.random.random((num_rows, num_features)) + y = cn.random.random((X.shape[0], num_outputs)) + + dtype = draw(st.sampled_from([np.float16, np.float32, np.float64])) + return X.astype(dtype), y.astype(dtype) + + +@st.composite +def regression_dataset_strategy(draw): + X, y = draw( + st.one_of( + [ + regression_generated_dataset_strategy(), + regression_real_dataset_strategy(), + ] + ) + ) + if draw(st.booleans()): + w = cn.random.random(y.shape[0]) + else: + w = None + + return X, y, w + + +@given( + general_model_param_strategy, + regression_param_strategy, + regression_dataset_strategy(), +) +def test_regressor(model_params, regression_params, regression_dataset): + X, y, w = regression_dataset + model = lb.LBRegressor(**model_params, **regression_params).fit( + X, y, sample_weight=w + ) + model.predict(X) + assert utils.non_increasing(model.train_metric_) + + utils.sanity_check_tree_stats(model.models_) + + +classification_param_strategy = st.fixed_dictionaries( + { + "objective": st.sampled_from(["log_loss"]), + # we can technically have up to learning rate 1.0, however + # some problems may not converge (e.g. multiclass classification + # with many classes) unless the learning rate is sufficiently small + "learning_rate": st.floats(0.01, 0.3), + } +) + + +@st.composite +def classification_real_dataset_strategy(draw): + from sklearn.datasets import fetch_covtype, load_breast_cancer + + name = draw(st.sampled_from(["covtype", "breast_cancer"])) + if name == "covtype": + X, y = fetch_covtype(return_X_y=True, as_frame=False) + return (X, y - 1, name) + elif name == "breast_cancer": + return (*load_breast_cancer(return_X_y=True, as_frame=False), name) + + +@st.composite +def classification_generated_dataset_strategy(draw): + num_classes = draw(st.integers(2, 5)) + num_features = draw(st.integers(1, 150)) + num_rows = draw(st.integers(num_classes, 10000)) + np.random.seed(3) + X = cn.random.random((num_rows, num_features)) + y = cn.random.randint(0, num_classes, size=X.shape[0]) + + # ensure we have at least one of each class + y[:num_classes] = np.arange(num_classes) + + X_dtype = draw(st.sampled_from([np.float16, np.float32, np.float64])) + y_dtype = draw( + st.sampled_from( + [np.int8, np.uint16, np.int32, np.int64, np.float32, np.float64] + ) + ) + + return ( + X.astype(X_dtype), + y.astype(y_dtype), + "Generated: num_classes: {}, num_features: {}, num_rows: {}".format( + num_classes, num_features, num_rows + ), + ) + + +@st.composite +def classification_dataset_strategy(draw): + X, y, name = draw( + st.one_of( + [ + classification_generated_dataset_strategy(), + classification_real_dataset_strategy(), + ] + ) + ) + if draw(st.booleans()): + w = cn.random.random(y.shape[0]) + else: + w = None + + return X, y, w, name + + +@given( + general_model_param_strategy, + classification_param_strategy, + classification_dataset_strategy(), +) +@pytest.mark.skipif( + get_legate_runtime().machine.preferred_kind == 1, + reason="Fails with V100 GPU, see issue #14", +) +def test_classifier(model_params, classification_params, classification_dataset): + X, y, w, name = classification_dataset + model = lb.LBClassifier(**model_params, **classification_params).fit( + X, y, sample_weight=w + ) + model.predict(X) + model.predict_proba(X) + model.predict_raw(X) + assert utils.non_increasing(model.train_metric_) + + utils.sanity_check_tree_stats(model.models_) diff --git a/legateboost/test/utils.py b/legateboost/test/utils.py new file mode 100644 index 00000000..ceaa471c --- /dev/null +++ b/legateboost/test/utils.py @@ -0,0 +1,26 @@ +import numpy as np + +import cunumeric as cn + + +def non_increasing(x, tol=1e-5): + return all(x - y > -tol for x, y in zip(x, x[1:])) + + +def non_decreasing(x): + return all(x <= y for x, y in zip(x, x[1:])) + + +def sanity_check_tree_stats(trees): + for tree in trees: + # Check that we have no 0 hessian splits + split_nodes = tree.feature != -1 + assert cn.all(tree.hessian[split_nodes] > 0.0) + + # Check gain is positive + assert cn.all(tree.gain[split_nodes] > 0.0) + + # Check that hessians of leaves add up to root. + leaves = (tree.feature == -1) & (tree.hessian[:, 0] > 0.0) + leaf_sum = tree.hessian[leaves].sum(axis=0) + assert np.isclose(leaf_sum, tree.hessian[0]).all()