Skip to content

Commit

Permalink
Two-table comparators with strong index types (#10730)
Browse files Browse the repository at this point in the history
This PR resolves #10508. It introduces two-table lexicographic row comparators with strongly typed index types. Given tables `lhs` and `rhs`, the `two_table_comparator` can create a device comparator whose strongly typed call operator can compare bidirectionally: `lhs[i] < rhs[j]` and `rhs[i] < lhs[j]`. The strong typing indicates which index belongs to which table.

This PR also contains a sample implementation in `search_ordered.cu`, which implements `lower_bound` and `upper_bound` algorithms.

Authors:
  - Bradley Dice (https://github.com/bdice)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - Jake Hemstad (https://github.com/jrhemstad)

URL: #10730
  • Loading branch information
bdice authored May 18, 2022
1 parent dee435f commit 7f9d51b
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 62 deletions.
210 changes: 188 additions & 22 deletions cpp/include/cudf/table/experimental/row_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include <cudf/utilities/type_dispatcher.hpp>

#include <thrust/equal.h>
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_facade.h>
#include <thrust/logical.h>
#include <thrust/swap.h>
#include <thrust/transform_reduce.h>
Expand Down Expand Up @@ -69,6 +71,48 @@ struct dispatch_void_if_nested {
};

namespace row {

enum class lhs_index_type : size_type {};
enum class rhs_index_type : size_type {};

template <typename Index, typename Underlying = std::underlying_type_t<Index>>
struct strong_index_iterator : public thrust::iterator_facade<strong_index_iterator<Index>,
Index,
thrust::use_default,
thrust::random_access_traversal_tag,
Index,
Underlying> {
using super_t = thrust::iterator_adaptor<strong_index_iterator<Index>, Index>;

explicit constexpr strong_index_iterator(Underlying n) : begin{n} {}

friend class thrust::iterator_core_access;

private:
__device__ constexpr void increment() { ++begin; }
__device__ constexpr void decrement() { --begin; }

__device__ constexpr void advance(Underlying n) { begin += n; }

__device__ constexpr bool equal(strong_index_iterator<Index> const& other) const noexcept
{
return begin == other.begin;
}

__device__ constexpr Index dereference() const noexcept { return static_cast<Index>(begin); }

__device__ constexpr Underlying distance_to(
strong_index_iterator<Index> const& other) const noexcept
{
return other.begin - begin;
}

Underlying begin{};
};

using lhs_iterator = strong_index_iterator<lhs_index_type>;
using rhs_iterator = strong_index_iterator<rhs_index_type>;

namespace lexicographic {

/**
Expand All @@ -91,6 +135,8 @@ namespace lexicographic {
template <typename Nullate>
class device_row_comparator {
friend class self_comparator;
friend class two_table_comparator;

/**
* @brief Construct a function object for performing a lexicographic
* comparison between the rows of two tables.
Expand Down Expand Up @@ -183,9 +229,9 @@ class device_row_comparator {

template <typename Element,
CUDF_ENABLE_IF(not cudf::is_relationally_comparable<Element, Element>() and
not std::is_same_v<Element, cudf::struct_view>),
typename... Args>
__device__ cuda::std::pair<weak_ordering, int> operator()(Args...) const noexcept
not std::is_same_v<Element, cudf::struct_view>)>
__device__ cuda::std::pair<weak_ordering, int> operator()(size_type const,
size_type const) const noexcept
{
CUDF_UNREACHABLE("Attempted to compare elements of uncomparable types.");
}
Expand Down Expand Up @@ -234,12 +280,13 @@ class device_row_comparator {
* @brief Checks whether the row at `lhs_index` in the `lhs` table compares
* lexicographically less, greater, or equivalent to the row at `rhs_index` in the `rhs` table.
*
* @param lhs_index The index of row in the `lhs` table to examine
* @param lhs_index The index of the row in the `lhs` table to examine
* @param rhs_index The index of the row in the `rhs` table to examine
* @return weak ordering comparison of the row in the `lhs` table relative to the row in the `rhs`
* table
*/
__device__ weak_ordering operator()(size_type lhs_index, size_type rhs_index) const noexcept
__device__ weak_ordering operator()(size_type const lhs_index,
size_type const rhs_index) const noexcept
{
int last_null_depth = std::numeric_limits<int>::max();
for (size_type i = 0; i < _lhs.num_columns(); ++i) {
Expand Down Expand Up @@ -288,12 +335,14 @@ class device_row_comparator {
*/
template <typename Comparator, weak_ordering... values>
struct weak_ordering_comparator_impl {
__device__ bool operator()(size_type const lhs, size_type const rhs) const noexcept
template <typename LhsType, typename RhsType>
__device__ constexpr bool operator()(LhsType const lhs_index,
RhsType const rhs_index) const noexcept
{
weak_ordering const result = comparator(lhs, rhs);
weak_ordering const result = comparator(lhs_index, rhs_index);
return ((result == values) || ...);
}
Comparator comparator;
Comparator const comparator;
};

/**
Expand All @@ -302,14 +351,12 @@ struct weak_ordering_comparator_impl {
*
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
*/
template <typename Nullate>
using less_comparator =
weak_ordering_comparator_impl<device_row_comparator<Nullate>, weak_ordering::LESS>;
template <typename Comparator>
using less_comparator = weak_ordering_comparator_impl<Comparator, weak_ordering::LESS>;

template <typename Nullate>
using less_equivalent_comparator = weak_ordering_comparator_impl<device_row_comparator<Nullate>,
weak_ordering::LESS,
weak_ordering::EQUIVALENT>;
template <typename Comparator>
using less_equivalent_comparator =
weak_ordering_comparator_impl<Comparator, weak_ordering::LESS, weak_ordering::EQUIVALENT>;

struct preprocessed_table {
using table_device_view_owner =
Expand All @@ -319,7 +366,7 @@ struct preprocessed_table {
* @brief Preprocess table for use with lexicographical comparison
*
* Sets up the table for use with lexicographical comparison. The resulting preprocessed table can
* be passed to the constructor of `lex::self_comparator` to avoid preprocessing again.
* be passed to the constructor of `lexicographic::self_comparator` to avoid preprocessing again.
*
* @param table The table to preprocess
* @param column_order Optional, host array the same length as a row that indicates the desired
Expand All @@ -337,6 +384,7 @@ struct preprocessed_table {

private:
friend class self_comparator;
friend class two_table_comparator;

preprocessed_table(table_device_view_owner&& table,
rmm::device_uvector<order>&& column_order,
Expand Down Expand Up @@ -395,10 +443,10 @@ struct preprocessed_table {
}

private:
table_device_view_owner _t;
rmm::device_uvector<order> _column_order;
rmm::device_uvector<null_order> _null_precedence;
rmm::device_uvector<size_type> _depths;
table_device_view_owner const _t;
rmm::device_uvector<order> const _column_order;
rmm::device_uvector<null_order> const _null_precedence;
rmm::device_uvector<size_type> const _depths;
};

/**
Expand Down Expand Up @@ -459,16 +507,134 @@ class self_comparator {
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
*/
template <typename Nullate>
less_comparator<Nullate> device_comparator(Nullate nullate = {}) const
less_comparator<device_row_comparator<Nullate>> device_comparator(Nullate nullate = {}) const
{
return less_comparator<Nullate>{device_row_comparator<Nullate>(
return less_comparator<device_row_comparator<Nullate>>{device_row_comparator<Nullate>(
nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence())};
}

private:
std::shared_ptr<preprocessed_table> d_t;
};

template <typename Comparator>
struct strong_index_comparator_adapter {
__device__ constexpr weak_ordering operator()(lhs_index_type const lhs_index,
rhs_index_type const rhs_index) const noexcept
{
return comparator(static_cast<cudf::size_type>(lhs_index),
static_cast<cudf::size_type>(rhs_index));
}

__device__ constexpr weak_ordering operator()(rhs_index_type const rhs_index,
lhs_index_type const lhs_index) const noexcept
{
auto const left_right_ordering =
comparator(static_cast<cudf::size_type>(lhs_index), static_cast<cudf::size_type>(rhs_index));

// Invert less/greater values to reflect right to left ordering
if (left_right_ordering == weak_ordering::LESS) {
return weak_ordering::GREATER;
} else if (left_right_ordering == weak_ordering::GREATER) {
return weak_ordering::LESS;
}
return weak_ordering::EQUIVALENT;
}

Comparator const comparator;
};

/**
* @brief An owning object that can be used to lexicographically compare rows of two different
* tables
*
* This class takes two table_views and preprocesses certain columns to allow for lexicographical
* comparison. The preprocessed table and temporary data required for the comparison are created and
* owned by this class.
*
* Alternatively, `two_table_comparator` can be constructed from two existing
* `shared_ptr<preprocessed_table>`s when sharing the same tables among multiple comparators.
*
* This class can then provide a functor object that can used on the device.
* The object of this class must outlive the usage of the device functor.
*/
class two_table_comparator {
public:
/**
* @brief Construct an owning object for performing a lexicographic comparison between rows of
* two different tables.
*
* The left and right table are expected to have the same number of columns
* and data types for each column.
*
* @param left The left table to compare
* @param right The right table to compare
* @param column_order Optional, host array the same length as a row that indicates the desired
* ascending/descending order of each column in a row. If empty, it is assumed all columns are
* sorted in ascending order.
* @param null_precedence Optional, device array the same length as a row and indicates how null
* values compare to all other for every column. If empty, then null precedence would be
* `null_order::BEFORE` for all columns.
* @param stream The stream to construct this object on. Not the stream that will be used for
* comparisons using this object.
*/
two_table_comparator(table_view const& left,
table_view const& right,
host_span<order const> column_order = {},
host_span<null_order const> null_precedence = {},
rmm::cuda_stream_view stream = rmm::cuda_stream_default);

/**
* @brief Construct an owning object for performing a lexicographic comparison between two rows of
* the same preprocessed table.
*
* This constructor allows independently constructing a `preprocessed_table` and sharing it among
* multiple comparators.
*
* @param left A table preprocessed for lexicographic comparison
* @param right A table preprocessed for lexicographic comparison
*/
two_table_comparator(std::shared_ptr<preprocessed_table> left,
std::shared_ptr<preprocessed_table> right)
: d_left_table{std::move(left)}, d_right_table{std::move(right)}
{
}

/**
* @brief Return the binary operator for comparing rows in the table.
*
* Returns a binary callable, `F`, with signatures
* `bool F(lhs_index_type, rhs_index_type)` and
* `bool F(rhs_index_type, lhs_index_type)`.
*
* `F(lhs_index_type i, rhs_index_type j)` returns true if and only if row
* `i` of the left table compares lexicographically less than row `j` of the
* right table.
*
* Similarly, `F(rhs_index_type i, lhs_index_type j)` returns true if and
* only if row `i` of the right table compares lexicographically less than row
* `j` of the left table.
*
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
*/
template <typename Nullate>
less_comparator<strong_index_comparator_adapter<device_row_comparator<Nullate>>>
device_comparator(Nullate nullate = {}) const
{
return less_comparator<strong_index_comparator_adapter<device_row_comparator<Nullate>>>{
device_row_comparator<Nullate>(nullate,
*d_left_table,
*d_right_table,
d_left_table->depths(),
d_left_table->column_order(),
d_left_table->null_precedence())};
}

private:
std::shared_ptr<preprocessed_table> d_left_table;
std::shared_ptr<preprocessed_table> d_right_table;
};

} // namespace lexicographic

namespace hash {
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/cudf/table/row_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ class row_lexicographic_comparator {
* @brief Checks whether the row at `lhs_index` in the `lhs` table compares
* lexicographically less than the row at `rhs_index` in the `rhs` table.
*
* @param lhs_index The index of row in the `lhs` table to examine
* @param lhs_index The index of the row in the `lhs` table to examine
* @param rhs_index The index of the row in the `rhs` table to examine
* @return `true` if row from the `lhs` table compares less than row in the
* `rhs` table
Expand Down
1 change: 1 addition & 0 deletions cpp/src/search/contains.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <cudf/scalar/scalar_device_view.cuh>
#include <cudf/search.hpp>
#include <cudf/structs/detail/contains.hpp>
#include <cudf/table/experimental/row_operators.cuh>
#include <cudf/table/row_operators.cuh>
#include <cudf/table/table_device_view.cuh>
#include <cudf/table/table_view.hpp>
Expand Down
Loading

0 comments on commit 7f9d51b

Please sign in to comment.