Skip to content

Commit

Permalink
assert GPU CPU intercept_ equal when fit_intercept is false
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed Sep 18, 2023
1 parent 9b1ee68 commit 18acacf
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
9 changes: 8 additions & 1 deletion python/cuml/solvers/qn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -930,8 +930,15 @@ class QN(Base,

if self.fit_intercept:
self.intercept_ = self._coef_[-1]
else:
return

_num_classes_dim, _ = self.coef_.shape
_num_classes = self.get_num_classes(_num_classes_dim)

if _num_classes == 2:
self.intercept_ = CumlArray.zeros(shape=1)
else:
self.intercept_ = CumlArray.zeros(shape=_num_classes)

def get_param_names(self):
return super().get_param_names() + \
Expand Down
3 changes: 3 additions & 0 deletions python/cuml/tests/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,9 @@ def test_logistic_regression(
)
assert len(np.unique(cu_preds)) == len(np.unique(y_test))

if fit_intercept is False:
assert np.array_equal(culog.intercept_, sklog.intercept_)


@given(
dtype=floating_dtypes(sizes=(32, 64)),
Expand Down

0 comments on commit 18acacf

Please sign in to comment.