Skip to content

Commit

Permalink
LinearRegression: add support for multiple targets (rapidsai#4988)
Browse files Browse the repository at this point in the history
LinearRegression did not have support for target vectors with multiple columns previously. This PR adds support.

Authors:
  - Allard Hendriksen (https://github.com/ahendriksen)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#4988
  • Loading branch information
Allard Hendriksen authored Nov 16, 2022
1 parent 1592c16 commit 99e427f
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 8 deletions.
22 changes: 22 additions & 0 deletions python/cuml/linear_model/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ctypes
import cuml.internals
import numpy as np
import cupy as cp
import warnings

from numba import cuda
Expand All @@ -35,6 +36,7 @@ from cuml.common.mixins import RegressorMixin
from cuml.common.doc_utils import generate_docstring
from pylibraft.common.handle cimport handle_t
from cuml.common import input_to_cuml_array
from cuml.common.input_utils import input_to_cupy_array

cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM":

Expand Down Expand Up @@ -66,6 +68,26 @@ class LinearPredictMixin:
Predicts `y` values for `X`.

"""
if self.coef_ is None:
raise ValueError(
"LinearModel.predict() cannot be called before fit(). "
"Please fit the model first."
)
if len(self.coef_.shape) == 2 and self.coef_.shape[1] > 1:
# Handle multi-target prediction in Python.
coef_cp = input_to_cupy_array(self.coef_).array
X_cp = input_to_cupy_array(
X,
check_dtype=self.dtype,
convert_to_dtype=(self.dtype if convert_dtype else None),
check_cols=self.n_cols
).array
intercept_cp = input_to_cupy_array(self.intercept_).array
preds_cp = X_cp @ coef_cp + intercept_cp
# preds = input_to_cuml_array(preds_cp).array # TODO:remove
return preds_cp

# Handle single-target prediction in C++
X_m, n_rows, n_cols, dtype = \
input_to_cuml_array(X, check_dtype=self.dtype,
convert_to_dtype=(self.dtype if convert_dtype
Expand Down
122 changes: 120 additions & 2 deletions python/cuml/linear_model/linear_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import ctypes
import numpy as np
import cupy as cp
import warnings

from numba import cuda
Expand All @@ -37,6 +38,7 @@ from pylibraft.common.handle cimport handle_t
from pylibraft.common.handle import Handle
from cuml.common import input_to_cuml_array
from cuml.common.mixins import FMajorInputTagMixin
from cuml.common.input_utils import input_to_cupy_array

cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM":

Expand Down Expand Up @@ -65,6 +67,55 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM":
double *sample_weight) except +


def divide_non_zero(x1, x2):
# Value chosen to be consistent with the RAFT implementation in
# linalg/detail/lstsq.cuh
eps = 1e-10

# Do not divide by values of x2 that are smaller than eps
mask = abs(x2) < eps
x2[mask] = 1.

return x1 / x2


def fit_multi_target(X, y, fit_intercept=True, sample_weight=None):
assert X.ndim == 2
assert y.ndim == 2

x_rows, x_cols = X.shape
if x_cols == 0:
raise ValueError(
"Number of columns cannot be less than one"
)
if x_rows < 2:
raise ValueError(
"Number of rows cannot be less than two"
)

if fit_intercept:
# Add column containg ones to fit intercept.
nrow, ncol = X.shape
X_wide = cp.empty_like(X, shape=(nrow, ncol + 1))
X_wide[:, :ncol] = X
X_wide[:, ncol] = 1.
X = X_wide

if sample_weight is not None:
sample_weight = cp.sqrt(sample_weight)
X = sample_weight[:, None] * X
y = sample_weight[:, None] * y

u, s, vh = cp.linalg.svd(X, full_matrices=False)

params = vh.T @ divide_non_zero(u.T @ y, s[:, None])

coef = params[:-1] if fit_intercept else params
intercept = params[-1] if fit_intercept else None

return coef, intercept


class LinearRegression(Base,
RegressorMixin,
LinearPredictMixin,
Expand Down Expand Up @@ -239,11 +290,11 @@ class LinearRegression(Base,
input_to_cuml_array(X, check_dtype=[np.float32, np.float64])
X_ptr = X_m.ptr

y_m, _, _, _ = \
y_m, _, y_cols, _ = \
input_to_cuml_array(y, check_dtype=self.dtype,
convert_to_dtype=(self.dtype if convert_dtype
else None),
check_rows=n_rows, check_cols=1)
check_rows=n_rows)
y_ptr = y_m.ptr

if sample_weight is not None:
Expand All @@ -270,6 +321,14 @@ class LinearRegression(Base,
"column currently.", UserWarning)
self.algo = 0

if 1 < y_cols:
if sample_weight is None:
sample_weight_m = None

return self._fit_multi_target(
X_m, y_m, convert_dtype, sample_weight_m
)

self.coef_ = CumlArray.zeros(self.n_cols, dtype=self.dtype)
cdef uintptr_t coef_ptr = self.coef_.ptr

Expand Down Expand Up @@ -316,6 +375,61 @@ class LinearRegression(Base,

return self

def _fit_multi_target(self, X, y, convert_dtype=True, sample_weight=None):
# In the cuml C++ layer, there is no support yet for multi-target
# regression, i.e., a y vector with multiple columns.
# We implement the regression in Python here.

if self.algo != 0:
warnings.warn("Changing solver to 'svd' as this is the " +
"only solver that support multiple targets " +
"currently.", UserWarning)
self.algo = 0
if self.normalize:
raise ValueError(
"The normalize option is not supported when `y` has "
"multiple columns."
)

X_cupy = input_to_cupy_array(
X,
convert_to_dtype=(self.dtype if convert_dtype else None),
).array
y_cupy, _, y_cols, _ = input_to_cupy_array(
y,
convert_to_dtype=(self.dtype if convert_dtype else None),
)
if sample_weight is None:
sample_weight_cupy = None
else:
sample_weight_cupy = input_to_cupy_array(
sample_weight,
convert_to_dtype=(self.dtype if convert_dtype else None),
).array
coef, intercept = fit_multi_target(
X_cupy,
y_cupy,
fit_intercept=self.fit_intercept,
sample_weight=sample_weight_cupy
)
self.coef_, _, _, _ = input_to_cuml_array(
coef,
check_dtype=self.dtype,
check_rows=self.n_cols,
check_cols=y_cols
)
if self.fit_intercept:
self.intercept_, _, _, _ = input_to_cuml_array(
intercept,
check_dtype=self.dtype,
check_rows=y_cols,
check_cols=1
)
else:
self.intercept_ = CumlArray.zeros(y_cols, dtype=self.dtype)

return self

def _predict(self, X, convert_dtype=True) -> CumlArray:
self.dtype = self.coef_.dtype
self.n_cols = self.coef_.shape[0]
Expand All @@ -329,3 +443,7 @@ class LinearRegression(Base,

def get_attributes_names(self):
return ['coef_', 'intercept_']

@staticmethod
def _more_static_tags():
return {"multioutput": True}
24 changes: 18 additions & 6 deletions python/cuml/tests/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def cuml_compatible_dataset(X_train, X_test, y_train, _=None):
)


@pytest.mark.parametrize("ntargets", [1, 2])
@pytest.mark.parametrize("datatype", [np.float32, np.float64])
@pytest.mark.parametrize("algorithm", ["eig", "svd"])
@pytest.mark.parametrize(
Expand All @@ -137,16 +138,19 @@ def cuml_compatible_dataset(X_train, X_test, y_train, _=None):
stress_param([1000, 500])
],
)
def test_linear_regression_model(datatype, algorithm, nrows, column_info):

def test_linear_regression_model(
datatype, algorithm, nrows, column_info, ntargets
):
if algorithm == "svd" and nrows > 46340:
pytest.skip("svd solver is not supported for the data that has more"
"than 46340 rows or columns if you are using CUDA version"
"10.x")
if 1 < ntargets and algorithm != "svd":
pytest.skip("The multi-target fit only supports using the svd solver.")

ncols, n_info = column_info
X_train, X_test, y_train, y_test = make_regression_dataset(
datatype, nrows, ncols, n_info
datatype, nrows, ncols, n_info, n_targets=ntargets
)

# Initialization of cuML's linear regression model
Expand All @@ -168,6 +172,7 @@ def test_linear_regression_model(datatype, algorithm, nrows, column_info):
assert array_equal(skols_predict, cuols_predict, 1e-1, with_sign=True)


@pytest.mark.parametrize("ntargets", [1, 2])
@pytest.mark.parametrize("datatype", [np.float32, np.float64])
@pytest.mark.parametrize("algorithm", ["eig", "svd", "qr", "svd-qr"])
@pytest.mark.parametrize(
Expand All @@ -180,13 +185,20 @@ def test_linear_regression_model(datatype, algorithm, nrows, column_info):
(False, False, "uniform"),
]
)
def test_weighted_linear_regression(datatype, algorithm, fit_intercept,
normalize, distribution):
def test_weighted_linear_regression(
ntargets, datatype, algorithm, fit_intercept, normalize, distribution
):
nrows, ncols, n_info = 1000, 20, 10
max_weight = 10
noise = 20

if 1 < ntargets and normalize:
pytest.skip("The multi-target fit does not support normalization.")
if 1 < ntargets and algorithm != "svd":
pytest.skip("The multi-target fit only supports using the svd solver.")

X_train, X_test, y_train, y_test = make_regression_dataset(
datatype, nrows, ncols, n_info, noise=noise
datatype, nrows, ncols, n_info, noise=noise, n_targets=ntargets
)

# set weight per sample to be from 1 to max_weight
Expand Down

0 comments on commit 99e427f

Please sign in to comment.