Skip to content

Commit

Permalink
Expose sparse distances via semiring to Python API (#3516)
Browse files Browse the repository at this point in the history
Closes #3478.

Authors:
  - Micka (https://github.com/lowener)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #3516
  • Loading branch information
lowener authored Apr 1, 2021
1 parent fb088d9 commit 4a946e0
Show file tree
Hide file tree
Showing 7 changed files with 564 additions and 42 deletions.
2 changes: 1 addition & 1 deletion cpp/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ else(DEFINED ENV{RAFT_PATH})

ExternalProject_Add(raft
GIT_REPOSITORY https://github.com/rapidsai/raft.git
GIT_TAG a57cf7df757b24230454e442c83f8491f97a4843
GIT_TAG d1fd927bc4ec67bfd765620b5fa93f17c54cfa70
PREFIX ${RAFT_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand Down
13 changes: 13 additions & 0 deletions cpp/include/cuml/metrics/metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,5 +321,18 @@ void pairwise_distance(const raft::handle_t &handle, const float *x,
raft::distance::DistanceType metric,
bool isRowMajor = true);

void pairwiseDistance_sparse(const raft::handle_t &handle, double *x, double *y,
double *dist, int x_nrows, int y_nrows, int n_cols,
int x_nnz, int y_nnz, int *x_indptr, int *y_indptr,
int *x_indices, int *y_indices,
raft::distance::DistanceType metric,
float metric_arg);
void pairwiseDistance_sparse(const raft::handle_t &handle, float *x, float *y,
float *dist, int x_nrows, int y_nrows, int n_cols,
int x_nnz, int y_nnz, int *x_indptr, int *y_indptr,
int *x_indices, int *y_indices,
raft::distance::DistanceType metric,
float metric_arg);

} // namespace Metrics
} // namespace ML
58 changes: 57 additions & 1 deletion cpp/src/metrics/pairwise_distance.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,8 +15,10 @@
* limitations under the License.
*/

#include <raft/sparse/distance/common.h>
#include <cuml/metrics/metrics.hpp>
#include <metrics/pairwise_distance.cuh>
#include <raft/sparse/distance/distance.cuh>

namespace ML {

Expand All @@ -37,5 +39,59 @@ void pairwise_distance(const raft::handle_t &handle, const float *x,
handle.get_stream(), isRowMajor);
}

template <typename value_idx = int, typename value_t = float>
void pairwiseDistance_sparse(const raft::handle_t &handle, value_t *x,
value_t *y, value_t *dist, value_idx x_nrows,
value_idx y_nrows, value_idx n_cols,
value_idx x_nnz, value_idx y_nnz,
value_idx *x_indptr, value_idx *y_indptr,
value_idx *x_indices, value_idx *y_indices,
raft::distance::DistanceType metric,
float metric_arg) {
raft::sparse::distance::distances_config_t<value_idx, value_t> dist_config;

dist_config.b_nrows = x_nrows;
dist_config.b_ncols = n_cols;
dist_config.b_nnz = x_nnz;
dist_config.b_indptr = x_indptr;
dist_config.b_indices = x_indices;
dist_config.b_data = x;

dist_config.a_nrows = y_nrows;
dist_config.a_ncols = n_cols;
dist_config.a_nnz = y_nnz;
dist_config.a_indptr = y_indptr;
dist_config.a_indices = y_indices;
dist_config.a_data = y;

dist_config.handle = handle.get_cusparse_handle();
dist_config.allocator = handle.get_device_allocator();
dist_config.stream = handle.get_stream();

raft::sparse::distance::pairwiseDistance(dist, dist_config, metric,
metric_arg);
}

void pairwiseDistance_sparse(const raft::handle_t &handle, float *x, float *y,
float *dist, int x_nrows, int y_nrows, int n_cols,
int x_nnz, int y_nnz, int *x_indptr, int *y_indptr,
int *x_indices, int *y_indices,
raft::distance::DistanceType metric,
float metric_arg) {
pairwiseDistance_sparse<int, float>(handle, x, y, dist, x_nrows, y_nrows,
n_cols, x_nnz, y_nnz, x_indptr, y_indptr,
x_indices, y_indices, metric, metric_arg);
}

void pairwiseDistance_sparse(const raft::handle_t &handle, double *x, double *y,
double *dist, int x_nrows, int y_nrows, int n_cols,
int x_nnz, int y_nnz, int *x_indptr, int *y_indptr,
int *x_indices, int *y_indices,
raft::distance::DistanceType metric,
float metric_arg) {
pairwiseDistance_sparse<int, double>(
handle, x, y, dist, x_nrows, y_nrows, n_cols, x_nnz, y_nnz, x_indptr,
y_indptr, x_indices, y_indices, metric, metric_arg);
}
} // namespace Metrics
} // namespace ML
6 changes: 4 additions & 2 deletions python/cuml/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
cython_mutual_info_score as mutual_info_score
from cuml.metrics.confusion_matrix import confusion_matrix
from cuml.metrics.cluster.entropy import cython_entropy as entropy
from cuml.metrics.pairwise_distances import pairwise_distances, \
PAIRWISE_DISTANCE_METRICS
from cuml.metrics.pairwise_distances import pairwise_distances
from cuml.metrics.pairwise_distances import sparse_pairwise_distances
from cuml.metrics.pairwise_distances import PAIRWISE_DISTANCE_METRICS
from cuml.metrics.pairwise_distances import PAIRWISE_DISTANCE_SPARSE_METRICS
from cuml.metrics.hinge_loss import hinge_loss
Loading

0 comments on commit 4a946e0

Please sign in to comment.