From 18acacf6247c211c97930864b2f4b1b3e076f09f Mon Sep 17 00:00:00 2001 From: jinfeng Date: Wed, 30 Aug 2023 15:27:58 -0700 Subject: [PATCH] assert GPU CPU intercept_ equal when fit_intercept is false --- python/cuml/solvers/qn.pyx | 9 ++++++++- python/cuml/tests/test_linear_model.py | 3 +++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/cuml/solvers/qn.pyx b/python/cuml/solvers/qn.pyx index fcc96dac18..52bf22c654 100644 --- a/python/cuml/solvers/qn.pyx +++ b/python/cuml/solvers/qn.pyx @@ -930,8 +930,15 @@ class QN(Base, if self.fit_intercept: self.intercept_ = self._coef_[-1] - else: + return + + _num_classes_dim, _ = self.coef_.shape + _num_classes = self.get_num_classes(_num_classes_dim) + + if _num_classes == 2: self.intercept_ = CumlArray.zeros(shape=1) + else: + self.intercept_ = CumlArray.zeros(shape=_num_classes) def get_param_names(self): return super().get_param_names() + \ diff --git a/python/cuml/tests/test_linear_model.py b/python/cuml/tests/test_linear_model.py index 0dc9fdeae7..31e29b2dc2 100644 --- a/python/cuml/tests/test_linear_model.py +++ b/python/cuml/tests/test_linear_model.py @@ -561,6 +561,9 @@ def test_logistic_regression( ) assert len(np.unique(cu_preds)) == len(np.unique(y_test)) + if fit_intercept is False: + assert np.array_equal(culog.intercept_, sklog.intercept_) + @given( dtype=floating_dtypes(sizes=(32, 64)),