Skip to content

Commit

Permalink
Move shared lhs/rhs logic into launch_search.
Browse files Browse the repository at this point in the history
  • Loading branch information
bdice committed May 3, 2022
1 parent 1fd199d commit 18bd9f0
Showing 1 changed file with 30 additions and 53 deletions.
83 changes: 30 additions & 53 deletions cpp/src/search/search.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,37 +45,6 @@

namespace cudf {
namespace {
template <typename DataIterator,
typename ValuesIterator,
typename OutputIterator,
typename Comparator>
void launch_search(DataIterator it_data,
ValuesIterator it_vals,
size_type data_size,
size_type values_size,
OutputIterator it_output,
Comparator comp,
bool find_first,
rmm::cuda_stream_view stream)
{
if (find_first) {
thrust::lower_bound(rmm::exec_policy(stream),
it_data,
it_data + data_size,
it_vals,
it_vals + values_size,
it_output,
comp);
} else {
thrust::upper_bound(rmm::exec_policy(stream),
it_data,
it_data + data_size,
it_vals,
it_vals + values_size,
it_output,
comp);
}
}

struct make_lhs_index {
__device__ lhs_index_type operator()(size_type i) const { return static_cast<lhs_index_type>(i); }
Expand All @@ -95,6 +64,35 @@ auto make_rhs_index_counting_iterator(size_type start)
return cudf::detail::make_counting_transform_iterator(start, make_rhs_index{});
};

template <typename OutputIterator, typename Comparator>
void launch_search(size_type search_table_size,
size_type values_size,
OutputIterator it_output,
Comparator comp,
bool find_first,
rmm::cuda_stream_view stream)
{
auto const it_lhs = cudf::make_lhs_index_counting_iterator(0);
auto const it_rhs = cudf::make_rhs_index_counting_iterator(0);
if (find_first) {
thrust::lower_bound(rmm::exec_policy(stream),
it_lhs,
it_lhs + search_table_size,
it_rhs,
it_rhs + values_size,
it_output,
comp);
} else {
thrust::upper_bound(rmm::exec_policy(stream),
it_rhs,
it_rhs + search_table_size,
it_lhs,
it_lhs + values_size,
it_output,
comp);
}
}

std::unique_ptr<column> search_ordered(table_view const& t,
table_view const& values,
bool find_first,
Expand Down Expand Up @@ -135,28 +133,7 @@ std::unique_ptr<column> search_ordered(table_view const& t,
auto const has_null_elements = has_nested_nulls(lhs) or has_nested_nulls(rhs);
auto const d_comparator = comparator.device_comparator(nullate::DYNAMIC{has_null_elements});

auto const left_it = cudf::make_lhs_index_counting_iterator(0);
auto const right_it = cudf::make_rhs_index_counting_iterator(0);

if (find_first) {
launch_search(left_it,
right_it,
t.num_rows(),
values.num_rows(),
result_out,
d_comparator,
find_first,
stream);
} else {
launch_search(right_it,
left_it,
t.num_rows(),
values.num_rows(),
result_out,
d_comparator,
find_first,
stream);
}
launch_search(t.num_rows(), values.num_rows(), result_out, d_comparator, find_first, stream);
/*
// Prepare to flatten the structs column
auto const has_null_elements = has_nested_nulls(t) or has_nested_nulls(values);
Expand Down

0 comments on commit 18bd9f0

Please sign in to comment.