Skip to content

Commit

Permalink
Use chebyshev, canberra, hellinger and minkowski distance metrics (ra…
Browse files Browse the repository at this point in the history
…pidsai#3990)

This PR relies on RAFT PR rapidsai/raft#276 which adds these distance metrics support.

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Corey J. Nolet (https://github.com/cjnolet)
  - AJ Schmidt (https://github.com/ajschmidt8)

URL: rapidsai#3990
  • Loading branch information
mdoijade authored Jul 8, 2021
1 parent 232625e commit a52672e
Show file tree
Hide file tree
Showing 29 changed files with 944 additions and 115 deletions.
1 change: 1 addition & 0 deletions ci/checks/style.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cd $WORKSPACE
export GIT_DESCRIBE_TAG=`git describe --tags`
export MINOR_VERSION=`echo $GIT_DESCRIBE_TAG | grep -o -E '([0-9]+\.[0-9]+)'`
conda install "ucx-py=0.21.*" "ucx-proc=*=gpu"
conda install -c conda-forge clang=8.0.1 clang-tools=8.0.1

# Run flake8 and get results/return code
FLAKE=`flake8 --config=python/setup.cfg`
Expand Down
8 changes: 8 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,13 @@ if(BUILD_CUML_CPP_LIBRARY)
src/metrics/kl_divergence.cu
src/metrics/mutual_info_score.cu
src/metrics/pairwise_distance.cu
src/metrics/pairwise_distance_canberra.cu
src/metrics/pairwise_distance_chebyshev.cu
src/metrics/pairwise_distance_cosine.cu
src/metrics/pairwise_distance_euclidean.cu
src/metrics/pairwise_distance_hellinger.cu
src/metrics/pairwise_distance_l1.cu
src/metrics/pairwise_distance_minkowski.cu
src/metrics/r2_score.cu
src/metrics/rand_index.cu
src/metrics/silhouette_score.cu
Expand Down Expand Up @@ -323,6 +330,7 @@ if(BUILD_CUML_CPP_LIBRARY)
$<BUILD_INTERFACE:$<$<BOOL:${ENABLE_CUMLPRIMS_MG}>:${cumlprims_mg_INCLUDE_DIRS}>>
PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/src>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/src/metrics>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/src_prims>
$<$<OR:$<BOOL:${BUILD_CUML_STD_COMMS}>,$<BOOL:${BUILD_CUML_MPI_COMMS}>>:${NCCL_INCLUDE_DIRS}>
$<$<BOOL:${BUILD_CUML_MPI_COMMS}>:${MPI_CXX_INCLUDE_PATH}>
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/cuml/metrics/metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,12 @@ float accuracy_score_py(const raft::handle_t &handle, const int *predictions,
* @param metric the distance metric to use for the calculation
* @param isRowMajor specifies whether the x and y data pointers are row (C
* type array) or col (F type array) major
* @param metric_arg the value of `p` for Minkowski (l-p) distances.
*/
void pairwise_distance(const raft::handle_t &handle, const double *x,
const double *y, double *dist, int m, int n, int k,
raft::distance::DistanceType metric,
bool isRowMajor = true);
bool isRowMajor = true, double metric_arg = 2.0);

/**
* @brief Calculates the ij pairwise distances between two input arrays of float type
Expand All @@ -320,11 +321,12 @@ void pairwise_distance(const raft::handle_t &handle, const double *x,
* @param metric the distance metric to use for the calculation
* @param isRowMajor specifies whether the x and y data pointers are row (C
* type array) or col (F type array) major
* @param metric_arg the value of `p` for Minkowski (l-p) distances.
*/
void pairwise_distance(const raft::handle_t &handle, const float *x,
const float *y, float *dist, int m, int n, int k,
raft::distance::DistanceType metric,
bool isRowMajor = true);
bool isRowMajor = true, float metric_arg = 2.0f);

void pairwiseDistance_sparse(const raft::handle_t &handle, double *x, double *y,
double *dist, int x_nrows, int y_nrows, int n_cols,
Expand Down
12 changes: 2 additions & 10 deletions cpp/src/hierarchy/pw_dist_graph.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

#include <common/allocatorAdapter.hpp>

#include <raft/distance/distance.cuh>
#include <cuml/metrics/metrics.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>
Expand Down Expand Up @@ -70,7 +70,6 @@ template <typename value_idx, typename value_t>
void pairwise_distances(const raft::handle_t &handle, const value_t *X,
size_t m, size_t n, raft::distance::DistanceType metric,
value_idx *indptr, value_idx *indices, value_t *data) {
auto d_alloc = handle.get_device_allocator();
auto stream = handle.get_stream();
auto exec_policy = rmm::exec_policy(stream);

Expand All @@ -83,16 +82,10 @@ void pairwise_distances(const raft::handle_t &handle, const value_t *X,

raft::update_device(indptr + m, &nnz, 1, stream);

// TODO: Keeping raft device buffer here for now until our
// dense pairwise distances API is finished being refactored
raft::mr::device::buffer<char> workspace(d_alloc, stream, (size_t)0);

// TODO: It would ultimately be nice if the MST could accept
// dense inputs directly so we don't need to double the memory
// usage to hand it a sparse array here.
raft::distance::pairwise_distance<value_t, value_idx>(
X, X, data, m, m, n, workspace, metric, stream);

ML::Metrics::pairwise_distance(handle, X, X, data, m, m, n, metric);
// self-loops get max distance
auto transform_in = thrust::make_zip_iterator(
thrust::make_tuple(thrust::make_counting_iterator(0), data));
Expand Down Expand Up @@ -120,7 +113,6 @@ struct distance_graph_impl<raft::hierarchy::LinkageDistance::PAIRWISE,
rmm::device_uvector<value_idx> &indptr,
rmm::device_uvector<value_idx> &indices,
rmm::device_uvector<value_t> &data, int c) {
auto d_alloc = handle.get_device_allocator();
auto stream = handle.get_stream();

size_t nnz = m * m;
Expand Down
9 changes: 5 additions & 4 deletions cpp/src/kmeans/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
*/
#pragma once

#include <cuml/metrics/metrics.hpp>
#include <linalg/reduce_cols_by_key.cuh>
#include <linalg/reduce_rows_by_key.cuh>
#include <matrix/gather.cuh>
#include <raft/distance/distance.cuh>

#include <raft/distance/fused_l2_nn.cuh>
#include <raft/linalg/binary_op.cuh>
Expand Down Expand Up @@ -256,9 +256,10 @@ void pairwise_distance(const raft::handle_t &handle,

ASSERT(X.getSize(1) == centroids.getSize(1),
"# features in dataset and centroids are different (must be same)");
raft::distance::pairwise_distance<DataT, IndexT>(
X.data(), centroids.data(), pairwiseDistance.data(), n_samples, n_clusters,
n_features, workspace, metric, stream);

ML::Metrics::pairwise_distance(handle, X.data(), centroids.data(),
pairwiseDistance.data(), n_samples, n_clusters,
n_features, metric);
}

// Calculates a <key, value> pair for every sample in input 'X' where key is an
Expand Down
97 changes: 81 additions & 16 deletions cpp/src/metrics/pairwise_distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,97 @@
#include <raft/distance/distance.cuh>
#include <raft/handle.hpp>
#include <raft/sparse/distance/distance.cuh>
#include "pairwise_distance_canberra.cuh"
#include "pairwise_distance_chebyshev.cuh"
#include "pairwise_distance_cosine.cuh"
#include "pairwise_distance_euclidean.cuh"
#include "pairwise_distance_hellinger.cuh"
#include "pairwise_distance_l1.cuh"
#include "pairwise_distance_minkowski.cuh"

namespace ML {

namespace Metrics {
void pairwise_distance(const raft::handle_t &handle, const double *x,
const double *y, double *dist, int m, int n, int k,
raft::distance::DistanceType metric, bool isRowMajor) {
//Allocate workspace
raft::mr::device::buffer<char> workspace(handle.get_device_allocator(),
handle.get_stream(), 1);

//Call the distance function
raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric,
handle.get_stream(), isRowMajor);
raft::distance::DistanceType metric, bool isRowMajor,
double metric_arg) {
switch (metric) {
case raft::distance::DistanceType::L2Expanded:
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2Unexpanded:
case raft::distance::DistanceType::L2SqrtUnexpanded:
pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric,
isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CosineExpanded:
pairwise_distance_cosine(handle, x, y, dist, m, n, k, metric, isRowMajor,
metric_arg);
break;
case raft::distance::DistanceType::L1:
pairwise_distance_l1(handle, x, y, dist, m, n, k, metric, isRowMajor,
metric_arg);
break;
case raft::distance::DistanceType::Linf:
pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, metric,
isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HellingerExpanded:
pairwise_distance_hellinger(handle, x, y, dist, m, n, k, metric,
isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::LpUnexpanded:
pairwise_distance_minkowski(handle, x, y, dist, m, n, k, metric,
isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::Canberra:
pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric,
isRowMajor, metric_arg);
break;
default:
THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
};
}

void pairwise_distance(const raft::handle_t &handle, const float *x,
const float *y, float *dist, int m, int n, int k,
raft::distance::DistanceType metric, bool isRowMajor) {
//Allocate workspace
raft::mr::device::buffer<char> workspace(handle.get_device_allocator(),
handle.get_stream(), 1);

//Call the distance function
raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric,
handle.get_stream(), isRowMajor);
raft::distance::DistanceType metric, bool isRowMajor,
float metric_arg) {
switch (metric) {
case raft::distance::DistanceType::L2Expanded:
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2Unexpanded:
case raft::distance::DistanceType::L2SqrtUnexpanded:
pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric,
isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CosineExpanded:
pairwise_distance_cosine(handle, x, y, dist, m, n, k, metric, isRowMajor,
metric_arg);
break;
case raft::distance::DistanceType::L1:
pairwise_distance_l1(handle, x, y, dist, m, n, k, metric, isRowMajor,
metric_arg);
break;
case raft::distance::DistanceType::Linf:
pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, metric,
isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HellingerExpanded:
pairwise_distance_hellinger(handle, x, y, dist, m, n, k, metric,
isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::LpUnexpanded:
pairwise_distance_minkowski(handle, x, y, dist, m, n, k, metric,
isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::Canberra:
pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric,
isRowMajor, metric_arg);
break;
default:
THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
};
}

template <typename value_idx = int, typename value_t = float>
Expand Down
74 changes: 74 additions & 0 deletions cpp/src/metrics/pairwise_distance_canberra.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@

/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

//#include <cuml/metrics/metrics.hpp>
#include <raft/distance/distance.cuh>
#include <raft/handle.hpp>

namespace ML {

namespace Metrics {
void pairwise_distance_canberra(const raft::handle_t &handle, const double *x,
const double *y, double *dist, int m, int n,
int k, raft::distance::DistanceType metric,
bool isRowMajor, double metric_arg) {
//Allocate workspace
raft::mr::device::buffer<char> workspace(handle.get_device_allocator(),
handle.get_stream(), 1);

//Call the distance function
/* raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric,
handle.get_stream(), isRowMajor,
metric_arg);*/

switch (metric) {
case raft::distance::DistanceType::Canberra:
raft::distance::pairwise_distance_impl<
double, int, raft::distance::DistanceType::Canberra>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
break;
default:
THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
}

void pairwise_distance_canberra(const raft::handle_t &handle, const float *x,
const float *y, float *dist, int m, int n,
int k, raft::distance::DistanceType metric,
bool isRowMajor, float metric_arg) {
//Allocate workspace
raft::mr::device::buffer<char> workspace(handle.get_device_allocator(),
handle.get_stream(), 1);

//Call the distance function
/* raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric,
handle.get_stream(), isRowMajor,
metric_arg);*/

switch (metric) {
case raft::distance::DistanceType::Canberra:
raft::distance::pairwise_distance_impl<
float, int, raft::distance::DistanceType::Canberra>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
break;
default:
THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
}

} // namespace Metrics
} // namespace ML
37 changes: 37 additions & 0 deletions cpp/src/metrics/pairwise_distance_canberra.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/distance/distance.cuh>
#include <raft/handle.hpp>

namespace ML {

namespace Metrics {
void pairwise_distance_canberra(const raft::handle_t &handle, const double *x,
const double *y, double *dist, int m, int n,
int k, raft::distance::DistanceType metric,
bool isRowMajor, double metric_arg);

void pairwise_distance_canberra(const raft::handle_t &handle, const float *x,
const float *y, float *dist, int m, int n,
int k, raft::distance::DistanceType metric,
bool isRowMajor, float metric_arg);

} // namespace Metrics
} // namespace ML
Loading

0 comments on commit a52672e

Please sign in to comment.