Skip to content

Commit

Permalink
New Feature StratifiedKFold (rapidsai#3109)
Browse files Browse the repository at this point in the history
Add equivalent of [sklearn's StratifiedKFold](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedKFold.html) to `cuml`.

Authors:
  - Jiwei Liu (https://github.com/daxiongshu)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#3109
  • Loading branch information
daxiongshu authored Sep 22, 2022
1 parent 5498d95 commit b7b1a67
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 2 deletions.
5 changes: 3 additions & 2 deletions python/cuml/model_selection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
#

from cuml.model_selection._split import train_test_split
from cuml.model_selection._split import StratifiedKFold
from cuml.common.import_utils import has_sklearn

if has_sklearn():
Expand All @@ -27,4 +28,4 @@
with CUDA-based data and cuML estimators, but all of the underlying code
is due to the scikit-learn developers.\n\n""" + GridSearchCV.__doc__

__all__ = ['train_test_split', 'GridSearchCV']
__all__ = ['train_test_split', 'GridSearchCV', 'StratifiedKFold']
103 changes: 103 additions & 0 deletions python/cuml/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np

from cuml.common.memory_utils import _strides_to_order
from cuml.common import input_to_cuml_array
from numba import cuda
from typing import Union

Expand Down Expand Up @@ -471,3 +472,105 @@ def train_test_split(X,
return X_train, X_test, y_train, y_test
else:
return X_train, X_test


class StratifiedKFold:
"""
A cudf based implementation of Stratified K-Folds cross-validator.
Provides train/test indices to split data into stratified K folds.
The percentage of samples for each class are maintained in each
fold.
Parameters
----------
n_splits : int, default=5
Number of folds. Must be at least 2.
shuffle : boolean, default=False
Whether to shuffle each class's samples before splitting.
random_state : int (default=None)
Random seed
Examples
--------
Splitting X,y into stratified K folds
.. code-block:: python
import cupy
X = cupy.random.rand(12,10)
y = cupy.arange(12)%4
kf = StratifiedKFold(n_splits=3)
for tr,te in kf.split(X,y):
print(tr, te)
Output:
.. code-block:: python
[ 4 5 6 7 8 9 10 11] [0 1 2 3]
[ 0 1 2 3 8 9 10 11] [4 5 6 7]
[0 1 2 3 4 5 6 7] [ 8 9 10 11]
"""

def __init__(self, n_splits=5, shuffle=False, random_state=None):
if n_splits < 2 or not isinstance(n_splits, int):
raise ValueError(
f'n_splits {n_splits} is not a integer at least 2')

if random_state is not None and not isinstance(random_state, int):
raise ValueError(f'random_state {random_state} is not an integer')

self.n_splits = n_splits
self.shuffle = shuffle
self.seed = random_state

def get_n_splits(self, X=None, y=None):
return self.n_splits

def split(self, x, y):
if len(x) != len(y):
raise ValueError('Expecting same length of x and y')
y = input_to_cuml_array(y).array.to_output('cupy')
if len(cp.unique(y)) < 2:
raise ValueError(
'number of unique classes cannot be less than 2')
df = cudf.DataFrame()
ids = cp.arange(y.shape[0])

if self.shuffle:
cp.random.seed(self.seed)
cp.random.shuffle(ids)
y = y[ids]

df['y'] = y
df['ids'] = ids
grpby = df.groupby(['y'])

dg = grpby.agg({'y': 'count'})
col = dg.columns[0]
msg = f'n_splits={self.n_splits} cannot be greater ' + \
'than the number of members in each class.'
if self.n_splits > dg[col].min():
raise ValueError(msg)

def get_order_in_group(y, ids, order):
for i in range(cuda.threadIdx.x, len(y), cuda.blockDim.x):
order[i] = i

got = grpby.apply_grouped(get_order_in_group, incols=['y', 'ids'],
outcols={'order': 'int32'},
tpb=64)
got = got.sort_values('ids')

for i in range(self.n_splits):
mask = got['order'] % self.n_splits == i
train = got.loc[~mask, 'ids'].values
test = got.loc[mask, 'ids'].values
if len(test) == 0:
break
yield train, test

def _check_array_shape(self, y):
if y is None:
raise ValueError("Expecting 1D array, got None")
elif hasattr(y, 'shape') and len(y.shape) > 1 and y.shape[1] > 1:
raise ValueError(f"Expecting 1D array, got {y.shape}")
else:
pass
65 changes: 65 additions & 0 deletions python/cuml/test/test_stratified_kfold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import cudf
import cupy as cp
import pytest

from cuml.model_selection import StratifiedKFold


def get_x_y(n_samples, n_classes):
X = cudf.DataFrame({"x": range(n_samples)})
y = cp.arange(n_samples) % n_classes
cp.random.shuffle(y)
y = cudf.Series(y)
return X, y


@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("n_splits", [5, 10])
@pytest.mark.parametrize("n_samples", [10000])
@pytest.mark.parametrize("n_classes", [2, 10])
def test_split_dataframe(n_samples, n_classes, n_splits, shuffle):
X, y = get_x_y(n_samples, n_classes)

kf = StratifiedKFold(n_splits=n_splits, shuffle=shuffle)
for train_index, test_index in kf.split(X, y):
assert len(train_index)+len(test_index) == n_samples
assert len(train_index) == len(test_index)*(n_splits-1)
for i in range(n_classes):
ratio_tr = (y[train_index] == i).sum() / len(train_index)
ratio_te = (y[test_index] == i).sum() / len(test_index)
assert ratio_tr == ratio_te


def test_num_classes_check():
X, y = get_x_y(n_samples=1000, n_classes=1)
kf = StratifiedKFold(n_splits=5)
err_msg = "number of unique classes cannot be less than 2"
with pytest.raises(ValueError, match=err_msg):
for train_index, test_index in kf.split(X, y):
pass


@pytest.mark.parametrize("n_splits", [0, 1])
def test_invalid_folds(n_splits):
X, y = get_x_y(n_samples=1000, n_classes=2)

err_msg = f'n_splits {n_splits} is not a integer at least 2'
with pytest.raises(ValueError, match=err_msg):
kf = StratifiedKFold(n_splits=n_splits)
for train_index, test_index in kf.split(X, y):
break

0 comments on commit b7b1a67

Please sign in to comment.