Skip to content

Commit

Permalink
Merge branch 'main' of github.com:rapidsai/legateboost into exp
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Jul 19, 2023
2 parents 6cc534e + e0e0966 commit b8201bd
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 159 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/github-actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
python scripts/generate-conda-envs.py --python 3.10 --ctk 11.8 --os linux --compilers --openmpi;
mamba env create -n legate -f environment-test-linux-py3.10-cuda11.8-compilers-openmpi-ucx.yaml;
mamba activate legate;
mamba install -y -c conda-forge openmpi ucx rust scikit-learn build;
mamba install -y 'gcc<=10.0.0' build scikit-learn;
# workaround issue https://github.com/nv-legate/legate.core/issues/789
mkdir -p _skbuild/linux-x86_64-3.10/cmake-build
./install.py --network ucx --cuda --arch RAPIDS --verbose;
Expand Down Expand Up @@ -119,6 +119,7 @@ jobs:
path: docs/build
- name: Run cpu tests
run: |
pip install hypothesis
legate --module pytest legateboost/test -sv --durations=0
- name: Run cpu multi-node tests
run: |
Expand All @@ -127,7 +128,7 @@ jobs:
if: ${{ false }} # disabled due to issue #5
- name: Run gpu tests
run: |
legate --gpus 1 --module pytest legateboost/test -sv --durations=0 -k 'not sklearn'
legate --gpus 1 --fbmem 16000 --sysmem 16000 --module pytest legateboost/test -sv --durations=0 -k 'not sklearn'
deploy:
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ legateboost/install_info.py
src/legate_library.cc
src/legate_library.h
legate_prof*
build
.hypothesis
2 changes: 1 addition & 1 deletion legateboost/input_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def check_array(x: Any) -> cn.ndarray:
if cn.iscomplexobj(x):
raise ValueError("Complex data not supported.")

if not cn.isfinite(x).all():
if np.issubdtype(x.dtype, np.floating) and not cn.isfinite(x).all():
raise ValueError("Input contains NaN or inf")

x = cn.array(x, copy=False)
Expand Down
206 changes: 102 additions & 104 deletions legateboost/legateboost.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import math
import warnings
from enum import IntEnum
from typing import Any, Tuple, Union
Expand Down Expand Up @@ -72,46 +73,117 @@ def left_child(self, id: int) -> int:
def right_child(self, id: int) -> int:
return id * 2 + 2

def num_procs_to_use(self, num_rows: int) -> int:
min_rows_per_worker = 10
available_procs = len(get_legate_runtime().machine)
return min(available_procs, int(math.ceil(num_rows / min_rows_per_worker)))

def __init__(
self,
leaf_value: cn.ndarray,
feature: cn.ndarray,
split_value: cn.ndarray,
gain: cn.ndarray,
hessian: cn.ndarray,
X: cn.ndarray,
g: cn.ndarray,
h: cn.ndarray,
learning_rate: float,
max_depth: int,
random_state: np.random.RandomState,
) -> None:
"""Initialise from existing storage."""
self.leaf_value = leaf_value
assert leaf_value.dtype == cn.float64
self.feature = feature
assert feature.dtype == cn.int32
self.split_value = split_value
assert split_value.dtype == cn.float64
self.hessian = hessian
assert hessian.dtype == cn.float64
self.gain = gain
assert gain.dtype == cn.float64
# choose possible splits
split_proposals = X[
random_state.randint(0, X.shape[0], max_depth)
] # may not be efficient, maybe write new task
num_features = X.shape[1]
num_outputs = g.shape[1]
n_rows = X.shape[0]
num_procs = self.num_procs_to_use(n_rows)
use_gpu = get_legate_runtime().machine.preferred_kind == 1
rows_per_tile = int(cn.ceil(n_rows / num_procs))

task = user_context.create_manual_task(
LegateBoostOpCode.BUILD_TREE, launch_domain=Rect((num_procs, 1))
)

# Defining a projection function (even the identity) prevents legate
# from trying to assign empty tiles to workers
# in the case where the number of tiles is less than the launch grid
def proj(x: Tuple[int, int]) -> Tuple[int, int]:
return (x[0], 0) # everything crashes if this is lambda x: x ????

# inputs
task.add_scalar_arg(learning_rate, types.float64)
task.add_scalar_arg(max_depth, types.int32)
task.add_scalar_arg(random_state.randint(0, 2**32), types.uint64)

task.add_input(
partition_if_not_future(X, (rows_per_tile, num_features)), proj=proj
)
task.add_input(
partition_if_not_future(g, (rows_per_tile, num_outputs)), proj=proj
)
task.add_input(
partition_if_not_future(h, (rows_per_tile, num_outputs)), proj=proj
)
task.add_input(_get_store(split_proposals))

# outputs
# force 1d arrays to be 2d otherwise we get the dreaded assert proj_id == 0
max_nodes = 2 ** (max_depth + 1)
leaf_value = user_context.create_store(types.float64, (max_nodes, num_outputs))
feature = user_context.create_store(types.int32, (max_nodes, 1))
split_value = user_context.create_store(types.float64, (max_nodes, 1))
gain = user_context.create_store(types.float64, (max_nodes, 1))
hessian = user_context.create_store(types.float64, (max_nodes, num_outputs))

# All outputs belong to a single tile on worker 0
task.add_output(
leaf_value.partition_by_tiling((max_nodes, num_outputs)), proj=proj
)
task.add_output(feature.partition_by_tiling((max_nodes, 1)), proj=proj)
task.add_output(split_value.partition_by_tiling((max_nodes, 1)), proj=proj)
task.add_output(gain.partition_by_tiling((max_nodes, 1)), proj=proj)
task.add_output(
hessian.partition_by_tiling((max_nodes, num_outputs)), proj=proj
)

if num_procs > 1:
if use_gpu:
task.add_nccl_communicator()
else:
task.add_cpu_communicator()

task.execute()

self.leaf_value = cn.array(leaf_value, copy=False)
self.feature = cn.array(feature, copy=False).squeeze()
self.split_value = cn.array(split_value, copy=False).squeeze()
self.gain = cn.array(gain, copy=False).squeeze()
self.hessian = cn.array(hessian, copy=False)

def predict(self, X: cn.ndarray) -> cn.ndarray:
n_rows = X.shape[0]
n_features = X.shape[1]
n_outputs = self.leaf_value.shape[1]
num_procs = len(get_legate_runtime().machine)
# dont launch more tasks than rows
num_procs = min(num_procs, n_rows)
num_procs = self.num_procs_to_use(n_rows)
rows_per_tile = int(cn.ceil(n_rows / num_procs))
task = user_context.create_manual_task(
LegateBoostOpCode.PREDICT, Rect((num_procs, 1))
)
task.add_input(partition_if_not_future(X, (rows_per_tile, n_features)))

def proj(x: Tuple[int, int]) -> Tuple[int, int]:
return (x[0], 0)

task.add_input(
partition_if_not_future(X, (rows_per_tile, n_features)), proj=proj
)

# broadcast the tree structure
task.add_input(_get_store(self.leaf_value))
task.add_input(_get_store(self.feature))
task.add_input(_get_store(self.split_value))

pred = user_context.create_store(types.float64, (n_rows, n_outputs))
task.add_output(partition_if_not_future(pred, (rows_per_tile, n_outputs)))
task.add_output(
partition_if_not_future(pred, (rows_per_tile, n_outputs)), proj=proj
)
task.execute()
return cn.array(pred, copy=False)

Expand Down Expand Up @@ -167,81 +239,6 @@ def _get_store(input: Any) -> Store:
return store


def build_tree_native(
X: cn.ndarray,
g: cn.ndarray,
h: cn.ndarray,
learning_rate: float,
max_depth: int,
random_state: np.random.RandomState,
) -> TreeStructure:

# choose possible splits
split_proposals = X[
random_state.randint(0, X.shape[0], max_depth)
] # may not be efficient, maybe write new task
num_features = X.shape[1]
num_outputs = g.shape[1]
n_rows = X.shape[0]
num_procs = len(get_legate_runtime().machine)
use_gpu = get_legate_runtime().machine.preferred_kind == 1
# dont launch more tasks than rows
num_procs = min(num_procs, n_rows)
rows_per_tile = int(cn.ceil(n_rows / num_procs))

task = user_context.create_manual_task(
LegateBoostOpCode.BUILD_TREE, launch_domain=Rect((num_procs, 1))
)

# inputs
task.add_scalar_arg(learning_rate, types.float64)
task.add_scalar_arg(max_depth, types.int32)
task.add_scalar_arg(random_state.randint(0, 2**32), types.uint64)

task.add_input(partition_if_not_future(X, (rows_per_tile, num_features)))
task.add_input(partition_if_not_future(g, (rows_per_tile, num_outputs)))
task.add_input(partition_if_not_future(h, (rows_per_tile, num_outputs)))
task.add_input(_get_store(split_proposals))

# outputs
# force 1d arrays to be 2d otherwise we get the dreaded assert proj_id == 0
max_nodes = 2 ** (max_depth + 1)
leaf_value = user_context.create_store(types.float64, (max_nodes, num_outputs))
feature = user_context.create_store(types.int32, (max_nodes, 1))
split_value = user_context.create_store(types.float64, (max_nodes, 1))
gain = user_context.create_store(types.float64, (max_nodes, 1))
hessian = user_context.create_store(types.float64, (max_nodes, num_outputs))

# All outputs belong to a single tile on worker 0
# Defining a projection function (even the identity) prevents legate
# from trying to assign empty tiles to workers
# in the case where the number of tiles is less than the launch grid
def proj(x: Tuple[int, int]) -> Tuple[int, int]:
return (x[0], 0) # everything crashes if this is lambda x: x ????

task.add_output(leaf_value.partition_by_tiling((max_nodes, num_outputs)), proj=proj)
task.add_output(feature.partition_by_tiling((max_nodes, 1)), proj=proj)
task.add_output(split_value.partition_by_tiling((max_nodes, 1)), proj=proj)
task.add_output(gain.partition_by_tiling((max_nodes, 1)), proj=proj)
task.add_output(hessian.partition_by_tiling((max_nodes, num_outputs)), proj=proj)

if num_procs > 1:
if use_gpu:
task.add_nccl_communicator()
else:
task.add_cpu_communicator()

task.execute()

return TreeStructure(
cn.array(leaf_value, copy=False),
cn.array(feature, copy=False).squeeze(),
cn.array(split_value, copy=False).squeeze(),
cn.array(gain, copy=False).squeeze(),
cn.array(hessian, copy=False),
)


class LBBase(BaseEstimator, _PickleCunumericMixin):
def __init__(
self,
Expand Down Expand Up @@ -410,15 +407,16 @@ def fit(
h = h * sample_weight[:, None]

# build new tree
tree = build_tree_native(
X,
g,
h,
self.learning_rate,
self.max_depth,
check_random_state(self.random_state),
self.models_.append(
TreeStructure(
X,
g,
h,
self.learning_rate,
self.max_depth,
check_random_state(self.random_state),
)
)
self.models_.append(tree)

# update current predictions
pred += self.models_[-1].predict(X)
Expand Down
58 changes: 7 additions & 51 deletions legateboost/test/test_estimator.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,12 @@
import numpy as np
import pytest
import utils
from sklearn.utils.estimator_checks import parametrize_with_checks

import cunumeric as cn
import legateboost as lb


def non_increasing(x):
return all(x >= y for x, y in zip(x, x[1:]))


def non_decreasing(x):
return all(x <= y for x, y in zip(x, x[1:]))


def sanity_check_tree_stats(trees):
for tree in trees:
# Check that we have no 0 hessian splits
split_nodes = tree.feature != -1
assert cn.all(tree.hessian[split_nodes] > 0.0)

# Check gain is positive
assert cn.all(tree.gain[split_nodes] > 0.0)

# Check that hessians of leaves add up to root.
leaves = (tree.feature == -1) & (tree.hessian[:, 0] > 0.0)
leaf_sum = tree.hessian[leaves].sum(axis=0)
assert np.isclose(leaf_sum, tree.hessian[0]).all()


@pytest.mark.parametrize("num_outputs", [1, 5])
def test_regressor(num_outputs):
np.random.seed(2)
Expand All @@ -40,11 +18,11 @@ def test_regressor(num_outputs):
mse = lb.MSEMetric().metric(y, model.predict(X), cn.ones(y.shape[0]))
loss = next(iter(model.train_metric_.values()))
assert np.isclose(loss[-1], mse)
assert non_increasing(loss)
assert utils.non_increasing(loss)

# test print
model.dump_trees()
sanity_check_tree_stats(model.models_)
utils.sanity_check_tree_stats(model.models_)


@pytest.mark.parametrize("num_outputs", [1, 5])
Expand All @@ -60,7 +38,7 @@ def test_regressor_improving_with_depth(num_outputs):

loss = next(iter(model.train_metric_.values()))
metrics.append(loss[-1])
assert non_increasing(metrics)
assert utils.non_increasing(metrics)


@pytest.mark.parametrize("num_outputs", [1, 5])
Expand Down Expand Up @@ -151,9 +129,9 @@ def test_classifier(num_class, objective):
loss = metric.metric(y, pred, cn.ones(y.shape[0]))
train_loss = next(iter(model.train_metric_.values()))
assert np.isclose(train_loss[-1], loss)
assert non_increasing(train_loss)
assert utils.non_increasing(train_loss)
assert model.score(X, y) > 0.7
sanity_check_tree_stats(model.models_)
utils.sanity_check_tree_stats(model.models_)


@pytest.mark.parametrize("num_class", [2, 5])
Expand Down Expand Up @@ -183,26 +161,4 @@ def test_classifier_improving_with_depth(num_class, objective):
).fit(X, y)
train_loss = next(iter(model.train_metric_.values()))
metrics.append(train_loss[-1])
assert non_increasing(metrics)


def test_prediction():
tree = lb.TreeStructure(
cn.array(
[0.0, 0.0, -0.04619769, 0.01845179, -0.01151532, 0.0, 0.0, 0.0, 0.0, 0.0]
).reshape(10, 1),
cn.array([0, 1, -1, -1, -1, -1, -1, -1, -1, -1]).astype(cn.int32),
cn.array([0.79172504, 0.71518937, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
cn.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
cn.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).reshape(
10,
1,
),
)
"""0:[f0<=0.79172504] yes=1 no=2 1:[f1<=0.71518937] yes=3 no=4
3:leaf=0.01845179 4:leaf=-0.01151532 2:leaf=-0.04619769."""
pred = tree.predict(cn.array([[1.0, 0.0], [0.5, 0.5]]))
assert pred[0] == -0.04619769
assert pred[1] == 0.01845179
assert tree.predict(cn.array([[0.5, 1.0]])) == -0.01151532
assert tree.predict(cn.array([[0.79172504, 0.71518937]])) == 0.01845179
assert utils.non_increasing(metrics)
Loading

0 comments on commit b8201bd

Please sign in to comment.