Skip to content

Commit

Permalink
Restore the functionality of RF score() (#3685)
Browse files Browse the repository at this point in the history
Closes #3682

In #3609, I unintentionally broke the function `score()` in the random forest. This PR restores the functionality. In addition, I added `score()` to one of the unit tests to ensure that `score()` does not break again.

Authors:
  - Philip Hyunsu Cho (https://github.com/hcho3)

Approvers:
  - John Zedlewski (https://github.com/JohnZed)

URL: #3685
  • Loading branch information
hcho3 authored Apr 1, 2021
1 parent 4ca603d commit 66931b9
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ class RandomForestClassifier(BaseRandomForestModel,
convert_to_dtype=(np.int32 if convert_dtype
else False))
y_ptr = y_m.ptr
preds = self.predict(X, output_class=True,
preds = self.predict(X,
threshold=threshold, algo=algo,
convert_dtype=convert_dtype,
predict_model=predict_model,
Expand Down
2 changes: 2 additions & 0 deletions python/cuml/test/test_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,8 @@ def test_rf_classification_sparse(small_clf, datatype,
algo=algo)
fil_preds = np.reshape(fil_preds, np.shape(y_test))
fil_acc = accuracy_score(y_test, fil_preds)
np.testing.assert_almost_equal(fil_acc,
cuml_model.score(X_test, y_test))

fil_model = cuml_model.convert_to_fil_model()

Expand Down

0 comments on commit 66931b9

Please sign in to comment.