Skip to content

Commit

Permalink
Split ivf_flat_search and interleaved_scan
Browse files Browse the repository at this point in the history
Also reduce specialization on veclen from 1, 2, .. 16/sizeof(T) to 1,
16/sizeof(T).
  • Loading branch information
ahendriksen committed Apr 18, 2023
1 parent 4d26ca9 commit 2aebe4b
Show file tree
Hide file tree
Showing 20 changed files with 1,591 additions and 1,110 deletions.
22 changes: 15 additions & 7 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,14 @@ if(RAFT_COMPILE_LIBRARY)
src/distance/specializations/fused_l2_nn_double_int64.cu
src/distance/specializations/fused_l2_nn_float_int.cu
src/distance/specializations/fused_l2_nn_float_int64.cu
src/matrix/specializations/detail/select_k_float_uint32_t.cu
src/matrix/specializations/detail/select_k_float_int64_t.cu
src/matrix/specializations/detail/select_k_half_uint32_t.cu
src/matrix/specializations/detail/select_k_half_int64_t.cu
src/matrix/detail/select_k_float_uint32_t.cu
src/matrix/detail/select_k_float_uint64_t.cu
src/matrix/detail/select_k_half_uint32_t.cu
src/matrix/detail/select_k_half_uint64_t.cu
# src/matrix/specializations/detail/select_k_float_uint32_t.cu
# src/matrix/specializations/detail/select_k_float_int64_t.cu
# src/matrix/specializations/detail/select_k_half_uint32_t.cu
# src/matrix/specializations/detail/select_k_half_int64_t.cu
src/neighbors/ivfpq_build.cu
src/neighbors/ivfpq_deserialize.cu
src/neighbors/ivfpq_serialize.cu
Expand Down Expand Up @@ -395,6 +399,10 @@ if(RAFT_COMPILE_LIBRARY)
src/random/rmat_rectangular_generator_int64_double.cu
src/random/rmat_rectangular_generator_int_float.cu
src/random/rmat_rectangular_generator_int64_float.cu
src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu
src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu
src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu
src/neighbors/detail/ivf_flat_search.cu
# src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu
# src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu
# src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu
Expand All @@ -414,9 +422,9 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/specializations/ivfflat_extend_float_int64_t.cu
src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu
src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu
src/neighbors/specializations/ivfflat_search_float_int64_t.cu
src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu
src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu
# src/neighbors/specializations/ivfflat_search_float_int64_t.cu
# src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu
# src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu
src/neighbors/ivfpq_build.cu
src/neighbors/ivfpq_deserialize.cu
src/neighbors/ivfpq_serialize.cu
Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/core/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <stddef.h>

#include <raft/core/detail/macros.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/mdspan_types.hpp>
Expand Down
98 changes: 98 additions & 0 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cstdint>
#include <cuda_fp16.h>
#include <raft/util/raft_explicit.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

#ifdef RAFT_EXPLICIT_INSTANTIATE

namespace raft::matrix::detail {

/**
* Select k smallest or largest key/values from each row in the input data.
*
* If you think of the input data `in_val` as a row-major matrix with `len` columns and
* `batch_size` rows, then this function selects `k` smallest/largest values in each row and fills
* in the row-major matrix `out_val` of size (batch_size, k).
*
* @tparam T
* the type of the keys (what is being compared).
* @tparam IdxT
* the index type (what is being selected together with the keys).
*
* @param[in] in_val
* contiguous device array of inputs of size (len * batch_size);
* these are compared and selected.
* @param[in] in_idx
* contiguous device array of inputs of size (len * batch_size);
* typically, these are indices of the corresponding in_val.
* @param batch_size
* number of input rows, i.e. the batch size.
* @param len
* length of a single input array (row); also sometimes referred as n_cols.
* Invariant: len >= k.
* @param k
* the number of outputs to select in each input row.
* @param[out] out_val
* contiguous device array of outputs of size (k * batch_size);
* the k smallest/largest values from each row of the `in_val`.
* @param[out] out_idx
* contiguous device array of outputs of size (k * batch_size);
* the payload selected together with `out_val`.
* @param select_min
* whether to select k smallest (true) or largest (false) keys.
* @param stream
* @param mr an optional memory resource to use across the calls (you can provide a large enough
* memory pool here to avoid memory allocations within the call).
*/
template <typename T, typename IdxT>
void select_k(const T* in_val,
const IdxT* in_idx,
size_t batch_size,
size_t len,
int k,
T* out_val,
IdxT* out_idx,
bool select_min,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT;
} // namespace raft::matrix::detail

#endif // RAFT_EXPLICIT_INSTANTIATE

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
extern template void raft::matrix::detail::select_k(const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
rmm::cuda_stream_view stream, \
rmm::mr::device_memory_resource* mr)

instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
instantiate_raft_matrix_detail_select_k(float, int64_t);
instantiate_raft_matrix_detail_select_k(float, uint32_t);

#undef instantiate_raft_matrix_detail_select_k
25 changes: 25 additions & 0 deletions cpp/include/raft/matrix/detail/select_k.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#if !defined(RAFT_EXPLICIT_INSTANTIATE)
#include "select_k-inl.cuh"
#endif

#ifdef RAFT_COMPILED
#include "select_k-ext.cuh"
#endif
2 changes: 1 addition & 1 deletion cpp/include/raft/matrix/detail/select_warpsort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <functional>
#include <type_traits>

#include <rmm/device_vector.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

/*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cstdint> // uintX_t
#include <raft/neighbors/ivf_flat_types.hpp> // index
#include <raft/spatial/knn/detail/ann_utils.cuh> // TODO: consider remove
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <rmm/cuda_stream_view.hpp> // rmm:cuda_stream_view

#ifdef RAFT_EXPLICIT_INSTANTIATE

namespace raft::neighbors::ivf_flat::detail {

using namespace raft::spatial::knn::detail; // NOLINT

/**
* @brief Configure and launch an appropriate template instance of the interleaved scan kernel.
*
* @tparam T value type
* @tparam AccT accumulated type
* @tparam IdxT type of the indices
*
* @param index previously built ivf-flat index
* @param[in] queries device pointer to the query vectors [batch_size, dim]
* @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes]
* @param n_queries batch size
* @param metric type of the measured distance
* @param n_probes number of nearest clusters to query
* @param k number of nearest neighbors.
* NB: the maximum value of `k` is limited statically by `kMaxCapacity`.
* @param select_min whether to select nearest (true) or furthest (false) points w.r.t. the given
* metric.
* @param[out] neighbors device pointer to the result indices for each query and cluster
* [batch_size, grid_dim_x, k]
* @param[out] distances device pointer to the result distances for each query and cluster
* [batch_size, grid_dim_x, k]
* @param[inout] grid_dim_x number of blocks launched across all n_probes clusters;
* (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes)
* @param stream
*/
template <typename T, typename AccT, typename IdxT>
void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& index,
const T* queries,
const uint32_t* coarse_query_results,
const uint32_t n_queries,
const raft::distance::DistanceType metric,
const uint32_t n_probes,
const uint32_t k,
const bool select_min,
IdxT* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream) RAFT_EXPLICIT;

} // namespace raft::neighbors::ivf_flat::detail

#endif // RAFT_EXPLICIT_INSTANTIATE

#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \
extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan<T, AccT, IdxT>( \
const raft::neighbors::ivf_flat::index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const raft::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const bool select_min, \
IdxT* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream)

instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t);

#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan
Loading

0 comments on commit 2aebe4b

Please sign in to comment.