Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Expose silhouette score in Python #3164

Merged
merged 12 commits into from
Nov 24, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# cuML 0.17.0 (Date TBD)

## New Features
- PR #3164: Expose silhouette score in Python
- PR #2659: Add initial max inner product sparse knn
- PR #2836: Refactor UMAP to accept sparse inputs

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/cuml/metrics/metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ double rand_index(const raft::handle_t &handle, double *y, double *y_hat,
*/
double silhouette_score(const raft::handle_t &handle, double *y, int nRows,
int nCols, int *labels, int nLabels, double *silScores,
int metric);
raft::distance::DistanceType metric);
/**
* Calculates the "adjusted rand index"
*
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/metrics/silhouette_score.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* limitations under the License.
*/

#include <raft/linalg/distance_type.h>
#include <cuml/metrics/metrics.hpp>
#include <metrics/silhouette_score.cuh>

Expand All @@ -23,7 +24,7 @@ namespace ML {
namespace Metrics {
double silhouette_score(const raft::handle_t &handle, double *y, int nRows,
int nCols, int *labels, int nLabels, double *silScores,
int metric) {
raft::distance::DistanceType metric) {
return MLCommon::Metrics::silhouette_score<double, int>(
y, nRows, nCols, labels, nLabels, silScores, handle.get_device_allocator(),
handle.get_stream(), metric);
Expand Down
5 changes: 4 additions & 1 deletion cpp/src_prims/metrics/silhouette_score.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <math.h>
#include <raft/cudart_utils.h>
#include <raft/linalg/distance_type.h>
#include <algorithm>
#include <common/device_buffer.hpp>
#include <cub/cub.cuh>
Expand Down Expand Up @@ -176,7 +177,9 @@ template <typename DataT, typename LabelT>
DataT silhouette_score(DataT *X_in, int nRows, int nCols, LabelT *labels,
int nLabels, DataT *silhouette_scorePerSample,
std::shared_ptr<MLCommon::deviceAllocator> allocator,
cudaStream_t stream, int metric = 4) {
cudaStream_t stream,
raft::distance::DistanceType metric =
raft::distance::DistanceType::EucUnexpandedL2) {
ASSERT(nLabels >= 2 && nLabels <= (nRows - 1),
"silhouette Score not defined for the given number of labels!");

Expand Down
19 changes: 12 additions & 7 deletions cpp/test/prims/silhouette_score.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
#include <gtest/gtest.h>
#include <raft/cudart_utils.h>
#include <raft/linalg/distance_type.h>
#include <algorithm>
#include <cuml/common/cuml_allocator.hpp>
#include <iostream>
Expand All @@ -30,7 +31,7 @@ struct silhouetteScoreParam {
int nRows;
int nCols;
int nLabels;
int metric;
raft::distance::DistanceType metric;
double tolerance;
};

Expand Down Expand Up @@ -79,9 +80,9 @@ class silhouetteScoreTest
double *h_distanceMatrix =
(double *)malloc(nRows * nRows * sizeof(double *));

MLCommon::Distance::pairwise_distance(
d_X, d_X, d_distanceMatrix.data(), nRows, nRows, nCols, workspace,
static_cast<raft::distance::DistanceType>(params.metric), stream);
MLCommon::Distance::pairwise_distance(d_X, d_X, d_distanceMatrix.data(),
nRows, nRows, nCols, workspace,
params.metric, stream);

CUDA_CHECK(cudaStreamSynchronize(stream));

Expand Down Expand Up @@ -189,9 +190,13 @@ class silhouetteScoreTest

//setting test parameter values
const std::vector<silhouetteScoreParam> inputs = {
{4, 2, 3, 0, 0.00001}, {4, 2, 2, 5, 0.00001}, {8, 8, 3, 4, 0.00001},
{11, 2, 5, 0, 0.00001}, {40, 2, 8, 0, 0.00001}, {12, 7, 3, 2, 0.00001},
{7, 5, 5, 3, 0.00001}};
{4, 2, 3, raft::distance::DistanceType::EucExpandedL2, 0.00001},
{4, 2, 2, raft::distance::DistanceType::EucUnexpandedL2Sqrt, 0.00001},
{8, 8, 3, raft::distance::DistanceType::EucUnexpandedL2, 0.00001},
{11, 2, 5, raft::distance::DistanceType::EucExpandedL2, 0.00001},
{40, 2, 8, raft::distance::DistanceType::EucExpandedL2, 0.00001},
{12, 7, 3, raft::distance::DistanceType::EucExpandedCosine, 0.00001},
{7, 5, 5, raft::distance::DistanceType::EucUnexpandedL1, 0.00001}};

//writing the test suite
typedef silhouetteScoreTest<int, double> silhouetteScoreTestClass;
Expand Down
4 changes: 4 additions & 0 deletions python/cuml/metrics/cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@
from cuml.metrics.cluster.mutual_info_score import \
cython_mutual_info_score as mutual_info_score
from cuml.metrics.cluster.entropy import cython_entropy as entropy
from cuml.metrics.cluster.silhouette_score import \
cython_silhouette_score as silhouette_score
from cuml.metrics.cluster.silhouette_score import \
cython_silhouette_samples as silhouette_samples
181 changes: 181 additions & 0 deletions python/cuml/metrics/cluster/silhouette_score.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import cupy as cp
import numpy as np

from libc.stdint cimport uintptr_t

from cuml.common import input_to_cuml_array
from cuml.metrics.pairwise_distances import _determine_metric
from cuml.raft.common.handle cimport handle_t
from cuml.raft.common.handle import Handle
from cuml.metrics.distance_type cimport DistanceType


cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics":
double silhouette_score(
const handle_t &handle,
double *y,
int n_rows,
int n_cols,
int *labels,
int n_labels,
double *sil_scores,
DistanceType metric) except +


def _silhouette_coeff(
X, labels, metric='euclidean', sil_scores=None, handle=None):
"""Function wrapped by silhouette_score and silhouette_samples to compute
silhouette coefficients

Parameters
----------
X : array-like, shape = (n_samples, n_features)
The feature vectors for all samples.
labels : array-like, shape = (n_samples,)
The assigned cluster labels for each sample.
metric : string
A string representation of the distance metric to use for evaluating
the silhouette schore. Available options are "cityblock", "cosine",
"euclidean", "l1", "l2", "manhattan", and "sqeuclidean".
sil_scores : array_like, shape = (1, n_samples), dtype='float64'
An optional array in which to store the silhouette score for each
sample.
handle : cuml.Handle
Specifies the cuml.handle that holds internal CUDA state for
computations in this model. Most importantly, this specifies the CUDA
stream that will be used for the model's computations, so users can
run different models concurrently in different streams by creating
handles in several streams.
If it is None, a new one is created.
"""
handle = Handle() if handle is None else handle
cdef handle_t *handle_ = <handle_t*> <size_t> handle.getHandle()

data, n_rows, n_cols, _ = input_to_cuml_array(
X,
order='C',
convert_to_dtype=np.float64
)

labels, _, _, _ = input_to_cuml_array(
labels,
order='C',
convert_to_dtype=np.int32
)

n_labels = cp.unique(
labels.to_output(output_type='cupy', output_dtype='int')
).shape[0]

cdef uintptr_t scores_ptr
if sil_scores is None:
scores_ptr = <uintptr_t> NULL
else:
sil_scores = input_to_cuml_array(
sil_scores,
check_dtype=np.float64)[0]

scores_ptr = sil_scores.ptr

metric = _determine_metric(metric)

return silhouette_score(handle_[0],
<double*> <uintptr_t> data.ptr,
n_rows,
n_cols,
<int*> <uintptr_t> labels.ptr,
n_labels,
<double*> scores_ptr,
metric)


def cython_silhouette_score(
X,
labels,
metric='euclidean',
handle=None):
"""Calculate the mean silhouette coefficient for the provided data

Given a set of cluster labels for every sample in the provided data,
compute the mean intra-cluster distance (a) and the mean nearest-cluster
distance (b) for each sample. The silhouette coefficient for a sample is
then (b - a) / max(a, b).

Parameters
----------
X : array-like, shape = (n_samples, n_features)
The feature vectors for all samples.
labels : array-like, shape = (n_samples,)
The assigned cluster labels for each sample.
metric : string
A string representation of the distance metric to use for evaluating
the silhouette schore. Available options are "cityblock", "cosine",
"euclidean", "l1", "l2", "manhattan", and "sqeuclidean".
handle : cuml.Handle
Specifies the cuml.handle that holds internal CUDA state for
computations in this model. Most importantly, this specifies the CUDA
stream that will be used for the model's computations, so users can
run different models concurrently in different streams by creating
handles in several streams.
If it is None, a new one is created.
"""

return _silhouette_coeff(
X, labels, metric=metric, handle=handle
)


def cython_silhouette_samples(
X,
labels,
metric='euclidean',
handle=None):
"""Calculate the silhouette coefficient for each sample in the provided data

Given a set of cluster labels for every sample in the provided data,
compute the mean intra-cluster distance (a) and the mean nearest-cluster
distance (b) for each sample. The silhouette coefficient for a sample is
then (b - a) / max(a, b).

Parameters
----------
X : array-like, shape = (n_samples, n_features)
The feature vectors for all samples.
labels : array-like, shape = (n_samples,)
The assigned cluster labels for each sample.
metric : string
A string representation of the distance metric to use for evaluating
the silhouette schore. Available options are "cityblock", "cosine",
"euclidean", "l1", "l2", "manhattan", and "sqeuclidean".
handle : cuml.Handle
Specifies the cuml.handle that holds internal CUDA state for
computations in this model. Most importantly, this specifies the CUDA
stream that will be used for the model's computations, so users can
run different models concurrently in different streams by creating
handles in several streams.
If it is None, a new one is created.
"""

sil_scores = cp.empty((X.shape[0],), dtype='float64')

_silhouette_coeff(
X, labels, metric=metric, sil_scores=sil_scores, handle=handle
)

return sil_scores
9 changes: 9 additions & 0 deletions python/cuml/metrics/distance_type.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
cdef extern from "raft/linalg/distance_type.h" namespace "raft::distance":

ctypedef enum DistanceType:
EucExpandedL2 "raft::distance::DistanceType::EucExpandedL2"
EucExpandedL2Sqrt "raft::distance::DistanceType::EucExpandedL2Sqrt"
EucExpandedCosine "raft::distance::DistanceType::EucExpandedCosine"
EucUnexpandedL1 "raft::distance::DistanceType::EucUnexpandedL1"
EucUnexpandedL2 "raft::distance::DistanceType::EucUnexpandedL2"
EucUnexpandedL2Sqrt "raft::distance::DistanceType::EucUnexpandedL2Sqrt"
11 changes: 1 addition & 10 deletions python/cuml/metrics/pairwise_distances.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,7 @@ import cuml.internals
from cuml.common.base import _determine_stateless_output_type
from cuml.common import (input_to_cuml_array, CumlArray, logger)
from cuml.metrics.cluster.utils import prepare_cluster_metric_inputs

cdef extern from "raft/linalg/distance_type.h" namespace "raft::distance":

cdef enum DistanceType:
EucExpandedL2 "raft::distance::DistanceType::EucExpandedL2"
EucExpandedL2Sqrt "raft::distance::DistanceType::EucExpandedL2Sqrt"
EucExpandedCosine "raft::distance::DistanceType::EucExpandedCosine"
EucUnexpandedL1 "raft::distance::DistanceType::EucUnexpandedL1"
EucUnexpandedL2 "raft::distance::DistanceType::EucUnexpandedL2"
EucUnexpandedL2Sqrt "raft::distance::DistanceType::EucUnexpandedL2Sqrt"
from cuml.metrics.distance_type cimport DistanceType

cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics":
void pairwise_distance(const handle_t &handle, const double *x,
Expand Down
Loading