From 949c2a39795b25275ece7fceccf8cba7ce59d138 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 9 Nov 2022 18:52:02 +0100 Subject: [PATCH 1/6] LinearRegression: add support for multiple targets --- python/cuml/linear_model/base.pyx | 17 +++ .../cuml/linear_model/linear_regression.pyx | 108 +++++++++++++++++- python/cuml/tests/test_linear_model.py | 21 +++- 3 files changed, 139 insertions(+), 7 deletions(-) diff --git a/python/cuml/linear_model/base.pyx b/python/cuml/linear_model/base.pyx index 8878e56e0b..71bd79ed79 100644 --- a/python/cuml/linear_model/base.pyx +++ b/python/cuml/linear_model/base.pyx @@ -19,6 +19,7 @@ import ctypes import cuml.internals import numpy as np +import cupy as cp import warnings from numba import cuda @@ -35,6 +36,7 @@ from cuml.common.mixins import RegressorMixin from cuml.common.doc_utils import generate_docstring from pylibraft.common.handle cimport handle_t from cuml.common import input_to_cuml_array +from cuml.common.input_utils import input_to_cupy_array cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM": @@ -66,6 +68,21 @@ class LinearPredictMixin: Predicts `y` values for `X`. """ + coef_cp, n_feat, n_targets, _ = input_to_cupy_array(self.coef_) + if 1 < n_targets: + # Handle multi-target prediction in Python. + X_cp = input_to_cupy_array( + X, + check_dtype=self.dtype, + convert_to_dtype=(self.dtype if convert_dtype else None), + check_cols=self.n_cols + ).array + intercept_cp = input_to_cupy_array(self.intercept_).array + preds_cp = X_cp @ coef_cp + intercept_cp + preds = input_to_cuml_array(preds_cp).array + return preds + + # Handle single-target prediction in C++ X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index 28ef228cc0..4c9f8d37aa 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -18,6 +18,7 @@ import ctypes import numpy as np +import cupy as cp import warnings from numba import cuda @@ -37,6 +38,7 @@ from pylibraft.common.handle cimport handle_t from pylibraft.common.handle import Handle from cuml.common import input_to_cuml_array from cuml.common.mixins import FMajorInputTagMixin +from cuml.common.input_utils import input_to_cupy_array cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM": @@ -65,6 +67,48 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM": double *sample_weight) except + +def divide_non_zero(x1, x2): + # Value chosen to be consistent with the RAFT implementation in + # linalg/detail/lstsq.cuh + eps = 1e-10 + + # Do not divide by values of x2 that are smaller than eps + mask = abs(x2) < eps + x2[mask] = 1. + + return x1 / x2 + + +def fit_multi_target(X, y, fit_intercept=True, sample_weight=None): + assert X.ndim == 2 + assert y.ndim == 2 + + assert X.shape[1] > 0, "Number of columns cannot be less than one" + assert X.shape[0] > 1, "Number of rows cannot be less than two" + + if fit_intercept: + # Add column containg ones to fit intercept. + nrow, ncol = X.shape + X_wide = cp.empty_like(X, shape=(nrow, ncol + 1)) + X_wide[:, :ncol] = X + X_wide[:, ncol] = 1. + X = X_wide + + if sample_weight is not None: + sample_weight = cp.sqrt(sample_weight) + X = sample_weight[:, None] * X + y = sample_weight[:, None] * y + + u, s, vh = cp.linalg.svd(X, full_matrices=False) + + params = vh.T @ divide_non_zero(u.T @ y, s[:, None]) + + coef = params[:-1] if fit_intercept else params + intercept = params[-1] if fit_intercept else None + + return coef, intercept + + class LinearRegression(Base, RegressorMixin, LinearPredictMixin, @@ -239,11 +283,11 @@ class LinearRegression(Base, input_to_cuml_array(X, check_dtype=[np.float32, np.float64]) X_ptr = X_m.ptr - y_m, _, _, _ = \ + y_m, _, y_cols, _ = \ input_to_cuml_array(y, check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype else None), - check_rows=n_rows, check_cols=1) + check_rows=n_rows) y_ptr = y_m.ptr if sample_weight is not None: @@ -270,6 +314,66 @@ class LinearRegression(Base, "column currently.", UserWarning) self.algo = 0 + if 1 < y_cols: + # In the cuml C++ layer, there is no support yet for multi-target + # regression, i.e., a y vector with multiple columns. + # We implement the regression in Python here. + + if self.algo != 0: + warnings.warn("Changing solver to 'svd' as this is the " + + "only solver that support multiple targets " + + "currently.", UserWarning) + self.algo = 0 + if self.normalize: + raise ValueError( + "The normalize option is not supported when `y` has " + "multiple columns." + ) + + X_cupy = input_to_cupy_array( + X, + convert_to_dtype=(self.dtype if convert_dtype else None), + ).array + y_cupy = input_to_cupy_array( + y, + convert_to_dtype=(self.dtype if convert_dtype else None), + ).array + if sample_weight is None: + sample_weight_cupy = None + else: + sample_weight_cupy = input_to_cupy_array( + sample_weight, + convert_to_dtype=(self.dtype if convert_dtype else None), + ).array + coef, intercept = fit_multi_target( + X_cupy, + y_cupy, + fit_intercept=self.fit_intercept, + sample_weight=sample_weight_cupy + ) + self.coef_, _, _, _ = input_to_cuml_array( + coef, + check_dtype=self.dtype, + check_rows=self.n_cols, + check_cols=y_cols + ) + if self.fit_intercept: + self.intercept_, _, _, _ = input_to_cuml_array( + intercept, + check_dtype=self.dtype, + check_rows=y_cols, + check_cols=1 + ) + else: + self.intercept_ = CumlArray.zeros(y_cols, dtype=self.dtype) + + del X_m + del y_m + if sample_weight is not None: + del sample_weight_m + + return self + self.coef_ = CumlArray.zeros(self.n_cols, dtype=self.dtype) cdef uintptr_t coef_ptr = self.coef_.ptr diff --git a/python/cuml/tests/test_linear_model.py b/python/cuml/tests/test_linear_model.py index 03a21dfada..09bbb30f4b 100644 --- a/python/cuml/tests/test_linear_model.py +++ b/python/cuml/tests/test_linear_model.py @@ -124,6 +124,9 @@ def cuml_compatible_dataset(X_train, X_test, y_train, _=None): ) +@pytest.mark.parametrize( + "ntargets", [unit_param(2), quality_param(100), stress_param(1000)] +) @pytest.mark.parametrize("datatype", [np.float32, np.float64]) @pytest.mark.parametrize("algorithm", ["eig", "svd"]) @pytest.mark.parametrize( @@ -137,16 +140,18 @@ def cuml_compatible_dataset(X_train, X_test, y_train, _=None): stress_param([1000, 500]) ], ) -def test_linear_regression_model(datatype, algorithm, nrows, column_info): +def test_linear_regression_model(datatype, algorithm, nrows, column_info, ntargets): if algorithm == "svd" and nrows > 46340: pytest.skip("svd solver is not supported for the data that has more" "than 46340 rows or columns if you are using CUDA version" "10.x") + if 1 < ntargets and algorithm != "svd": + pytest.skip("The multi-target fit only supports using the svd solver.") ncols, n_info = column_info X_train, X_test, y_train, y_test = make_regression_dataset( - datatype, nrows, ncols, n_info + datatype, nrows, ncols, n_info, n_targets=ntargets ) # Initialization of cuML's linear regression model @@ -167,7 +172,7 @@ def test_linear_regression_model(datatype, algorithm, nrows, column_info): assert array_equal(skols_predict, cuols_predict, 1e-1, with_sign=True) - +@pytest.mark.parametrize("ntargets", [unit_param(2)]) @pytest.mark.parametrize("datatype", [np.float32, np.float64]) @pytest.mark.parametrize("algorithm", ["eig", "svd", "qr", "svd-qr"]) @pytest.mark.parametrize( @@ -180,13 +185,19 @@ def test_linear_regression_model(datatype, algorithm, nrows, column_info): (False, False, "uniform"), ] ) -def test_weighted_linear_regression(datatype, algorithm, fit_intercept, +def test_weighted_linear_regression(ntargets, datatype, algorithm, fit_intercept, normalize, distribution): nrows, ncols, n_info = 1000, 20, 10 max_weight = 10 noise = 20 + + if 1 < ntargets and normalize: + pytest.skip("The multi-target fit does not support normalization.") + if 1 < ntargets and algorithm != "svd": + pytest.skip("The multi-target fit only supports using the svd solver.") + X_train, X_test, y_train, y_test = make_regression_dataset( - datatype, nrows, ncols, n_info, noise=noise + datatype, nrows, ncols, n_info, noise=noise, n_targets=ntargets ) # set weight per sample to be from 1 to max_weight From ae6293ab2e858bee84e8f45294f7b2a9fcf5b62e Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 10 Nov 2022 14:00:41 +0100 Subject: [PATCH 2/6] Process review comments --- .../cuml/linear_model/linear_regression.pyx | 120 ++++++++++-------- python/cuml/tests/test_linear_model.py | 9 +- 2 files changed, 70 insertions(+), 59 deletions(-) diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index 4c9f8d37aa..f6e616d1d4 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -83,8 +83,15 @@ def fit_multi_target(X, y, fit_intercept=True, sample_weight=None): assert X.ndim == 2 assert y.ndim == 2 - assert X.shape[1] > 0, "Number of columns cannot be less than one" - assert X.shape[0] > 1, "Number of rows cannot be less than two" + x_rows, x_cols = X.shape + if x_cols == 0: + raise ValueError( + "Number of columns cannot be less than one" + ) + if x_rows < 2: + raise ValueError( + "Number of rows cannot be less than two" + ) if fit_intercept: # Add column containg ones to fit intercept. @@ -315,64 +322,12 @@ class LinearRegression(Base, self.algo = 0 if 1 < y_cols: - # In the cuml C++ layer, there is no support yet for multi-target - # regression, i.e., a y vector with multiple columns. - # We implement the regression in Python here. - - if self.algo != 0: - warnings.warn("Changing solver to 'svd' as this is the " + - "only solver that support multiple targets " + - "currently.", UserWarning) - self.algo = 0 - if self.normalize: - raise ValueError( - "The normalize option is not supported when `y` has " - "multiple columns." - ) - - X_cupy = input_to_cupy_array( - X, - convert_to_dtype=(self.dtype if convert_dtype else None), - ).array - y_cupy = input_to_cupy_array( - y, - convert_to_dtype=(self.dtype if convert_dtype else None), - ).array - if sample_weight is None: - sample_weight_cupy = None - else: - sample_weight_cupy = input_to_cupy_array( - sample_weight, - convert_to_dtype=(self.dtype if convert_dtype else None), - ).array - coef, intercept = fit_multi_target( - X_cupy, - y_cupy, - fit_intercept=self.fit_intercept, - sample_weight=sample_weight_cupy - ) - self.coef_, _, _, _ = input_to_cuml_array( - coef, - check_dtype=self.dtype, - check_rows=self.n_cols, - check_cols=y_cols - ) - if self.fit_intercept: - self.intercept_, _, _, _ = input_to_cuml_array( - intercept, - check_dtype=self.dtype, - check_rows=y_cols, - check_cols=1 - ) - else: - self.intercept_ = CumlArray.zeros(y_cols, dtype=self.dtype) - del X_m del y_m if sample_weight is not None: del sample_weight_m - return self + return self._fit_multi_target(X, y, convert_dtype, sample_weight) self.coef_ = CumlArray.zeros(self.n_cols, dtype=self.dtype) cdef uintptr_t coef_ptr = self.coef_.ptr @@ -420,6 +375,61 @@ class LinearRegression(Base, return self + def _fit_multi_target(self, X, y, convert_dtype=True, sample_weight=None): + # In the cuml C++ layer, there is no support yet for multi-target + # regression, i.e., a y vector with multiple columns. + # We implement the regression in Python here. + + if self.algo != 0: + warnings.warn("Changing solver to 'svd' as this is the " + + "only solver that support multiple targets " + + "currently.", UserWarning) + self.algo = 0 + if self.normalize: + raise ValueError( + "The normalize option is not supported when `y` has " + "multiple columns." + ) + + X_cupy = input_to_cupy_array( + X, + convert_to_dtype=(self.dtype if convert_dtype else None), + ).array + y_cupy, _, y_cols, _ = input_to_cupy_array( + y, + convert_to_dtype=(self.dtype if convert_dtype else None), + ) + if sample_weight is None: + sample_weight_cupy = None + else: + sample_weight_cupy = input_to_cupy_array( + sample_weight, + convert_to_dtype=(self.dtype if convert_dtype else None), + ).array + coef, intercept = fit_multi_target( + X_cupy, + y_cupy, + fit_intercept=self.fit_intercept, + sample_weight=sample_weight_cupy + ) + self.coef_, _, _, _ = input_to_cuml_array( + coef, + check_dtype=self.dtype, + check_rows=self.n_cols, + check_cols=y_cols + ) + if self.fit_intercept: + self.intercept_, _, _, _ = input_to_cuml_array( + intercept, + check_dtype=self.dtype, + check_rows=y_cols, + check_cols=1 + ) + else: + self.intercept_ = CumlArray.zeros(y_cols, dtype=self.dtype) + + return self + def _predict(self, X, convert_dtype=True) -> CumlArray: self.dtype = self.coef_.dtype self.n_cols = self.coef_.shape[0] diff --git a/python/cuml/tests/test_linear_model.py b/python/cuml/tests/test_linear_model.py index 09bbb30f4b..7f5ec0bc44 100644 --- a/python/cuml/tests/test_linear_model.py +++ b/python/cuml/tests/test_linear_model.py @@ -125,7 +125,7 @@ def cuml_compatible_dataset(X_train, X_test, y_train, _=None): @pytest.mark.parametrize( - "ntargets", [unit_param(2), quality_param(100), stress_param(1000)] + "ntargets", [unit_param(1), unit_param(2), quality_param(100), stress_param(1000)] ) @pytest.mark.parametrize("datatype", [np.float32, np.float64]) @pytest.mark.parametrize("algorithm", ["eig", "svd"]) @@ -172,7 +172,7 @@ def test_linear_regression_model(datatype, algorithm, nrows, column_info, ntarge assert array_equal(skols_predict, cuols_predict, 1e-1, with_sign=True) -@pytest.mark.parametrize("ntargets", [unit_param(2)]) +@pytest.mark.parametrize("ntargets", [1, 2]) @pytest.mark.parametrize("datatype", [np.float32, np.float64]) @pytest.mark.parametrize("algorithm", ["eig", "svd", "qr", "svd-qr"]) @pytest.mark.parametrize( @@ -185,8 +185,9 @@ def test_linear_regression_model(datatype, algorithm, nrows, column_info, ntarge (False, False, "uniform"), ] ) -def test_weighted_linear_regression(ntargets, datatype, algorithm, fit_intercept, - normalize, distribution): +def test_weighted_linear_regression( + ntargets, datatype, algorithm, fit_intercept, normalize, distribution +): nrows, ncols, n_info = 1000, 20, 10 max_weight = 10 noise = 20 From b077a329fb7e5ab5368f36671f8b20b53754af4e Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 10 Nov 2022 19:01:52 +0100 Subject: [PATCH 3/6] Fix style --- python/cuml/tests/test_linear_model.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/cuml/tests/test_linear_model.py b/python/cuml/tests/test_linear_model.py index 7f5ec0bc44..8b44684669 100644 --- a/python/cuml/tests/test_linear_model.py +++ b/python/cuml/tests/test_linear_model.py @@ -125,7 +125,8 @@ def cuml_compatible_dataset(X_train, X_test, y_train, _=None): @pytest.mark.parametrize( - "ntargets", [unit_param(1), unit_param(2), quality_param(100), stress_param(1000)] + "ntargets", + [unit_param(1), unit_param(2), quality_param(100), stress_param(1000)] ) @pytest.mark.parametrize("datatype", [np.float32, np.float64]) @pytest.mark.parametrize("algorithm", ["eig", "svd"]) @@ -140,8 +141,9 @@ def cuml_compatible_dataset(X_train, X_test, y_train, _=None): stress_param([1000, 500]) ], ) -def test_linear_regression_model(datatype, algorithm, nrows, column_info, ntargets): - +def test_linear_regression_model( + datatype, algorithm, nrows, column_info, ntargets +): if algorithm == "svd" and nrows > 46340: pytest.skip("svd solver is not supported for the data that has more" "than 46340 rows or columns if you are using CUDA version" @@ -172,6 +174,7 @@ def test_linear_regression_model(datatype, algorithm, nrows, column_info, ntarge assert array_equal(skols_predict, cuols_predict, 1e-1, with_sign=True) + @pytest.mark.parametrize("ntargets", [1, 2]) @pytest.mark.parametrize("datatype", [np.float32, np.float64]) @pytest.mark.parametrize("algorithm", ["eig", "svd", "qr", "svd-qr"]) From 677a08be814ebb260d5d192d52dfa4dfaebc2f98 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 11 Nov 2022 10:56:17 +0100 Subject: [PATCH 4/6] Remove (too large) stress test --- python/cuml/tests/test_linear_model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/cuml/tests/test_linear_model.py b/python/cuml/tests/test_linear_model.py index 8b44684669..77d809a9d7 100644 --- a/python/cuml/tests/test_linear_model.py +++ b/python/cuml/tests/test_linear_model.py @@ -124,10 +124,7 @@ def cuml_compatible_dataset(X_train, X_test, y_train, _=None): ) -@pytest.mark.parametrize( - "ntargets", - [unit_param(1), unit_param(2), quality_param(100), stress_param(1000)] -) +@pytest.mark.parametrize("ntargets", [1, 2]) @pytest.mark.parametrize("datatype", [np.float32, np.float64]) @pytest.mark.parametrize("algorithm", ["eig", "svd"]) @pytest.mark.parametrize( From 37261b30db4b16d3f6d6b26094d1fd6a29edf96b Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 15 Nov 2022 16:09:26 +0100 Subject: [PATCH 5/6] Implement code review comments --- python/cuml/linear_model/base.pyx | 13 +++++++++---- python/cuml/linear_model/linear_regression.pyx | 12 +++++++----- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/python/cuml/linear_model/base.pyx b/python/cuml/linear_model/base.pyx index 71bd79ed79..4e43a2ee29 100644 --- a/python/cuml/linear_model/base.pyx +++ b/python/cuml/linear_model/base.pyx @@ -68,9 +68,14 @@ class LinearPredictMixin: Predicts `y` values for `X`. """ - coef_cp, n_feat, n_targets, _ = input_to_cupy_array(self.coef_) - if 1 < n_targets: + if self.coef_ is None: + raise ValueError( + "LinearModel.predict() cannot be called before fit(). " + "Please fit the model first." + ) + if len(self.coef_.shape) == 2 and self.coef_.shape[1] > 1: # Handle multi-target prediction in Python. + coef_cp = input_to_cupy_array(self.coef_).array X_cp = input_to_cupy_array( X, check_dtype=self.dtype, @@ -79,8 +84,8 @@ class LinearPredictMixin: ).array intercept_cp = input_to_cupy_array(self.intercept_).array preds_cp = X_cp @ coef_cp + intercept_cp - preds = input_to_cuml_array(preds_cp).array - return preds + # preds = input_to_cuml_array(preds_cp).array # TODO:remove + return preds_cp # Handle single-target prediction in C++ X_m, n_rows, n_cols, dtype = \ diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index f6e616d1d4..5da18cba70 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -322,12 +322,10 @@ class LinearRegression(Base, self.algo = 0 if 1 < y_cols: - del X_m - del y_m - if sample_weight is not None: - del sample_weight_m + if sample_weight is None: + sample_weight_m = None - return self._fit_multi_target(X, y, convert_dtype, sample_weight) + return self._fit_multi_target(X_m, y_m, convert_dtype, sample_weight_m) self.coef_ = CumlArray.zeros(self.n_cols, dtype=self.dtype) cdef uintptr_t coef_ptr = self.coef_.ptr @@ -443,3 +441,7 @@ class LinearRegression(Base, def get_attributes_names(self): return ['coef_', 'intercept_'] + + @staticmethod + def _more_static_tags(): + return {"multioutput": True} From 9d640178548d23c951c5ab3a3bd9635a94c80024 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 16 Nov 2022 09:52:55 +0100 Subject: [PATCH 6/6] Fix style --- python/cuml/linear_model/linear_regression.pyx | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index 5da18cba70..725e4ee79e 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -323,9 +323,11 @@ class LinearRegression(Base, if 1 < y_cols: if sample_weight is None: - sample_weight_m = None + sample_weight_m = None - return self._fit_multi_target(X_m, y_m, convert_dtype, sample_weight_m) + return self._fit_multi_target( + X_m, y_m, convert_dtype, sample_weight_m + ) self.coef_ = CumlArray.zeros(self.n_cols, dtype=self.dtype) cdef uintptr_t coef_ptr = self.coef_.ptr