Skip to content

Commit

Permalink
Add two table comparator and adapter.
Browse files Browse the repository at this point in the history
  • Loading branch information
bdice committed May 3, 2022
1 parent d67f17e commit 464ed2b
Showing 1 changed file with 159 additions and 1 deletion.
160 changes: 159 additions & 1 deletion cpp/include/cudf/table/experimental/row_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <cudf/detail/utilities/algorithm.cuh>
#include <cudf/detail/utilities/assert.cuh>
#include <cudf/detail/utilities/hash_functions.cuh>
#include <cudf/detail/utilities/strong_index.hpp>
#include <cudf/lists/list_device_view.cuh>
#include <cudf/lists/lists_column_device_view.cuh>
#include <cudf/sorting.hpp>
Expand Down Expand Up @@ -89,6 +90,7 @@ namespace lexicographic {
template <typename Nullate>
class device_row_comparator {
friend class self_comparator;
// friend class two_table_device_row_comparator_adapter<Nullate>;

/**
* @brief Construct a function object for performing a lexicographic
Expand Down Expand Up @@ -277,7 +279,7 @@ struct preprocessed_table {
* @brief Preprocess table for use with lexicographical comparison
*
* Sets up the table for use with lexicographical comparison. The resulting preprocessed table can
* be passed to the constructor of `lex::self_comparator` to avoid preprocessing again.
* be passed to the constructor of `lexicographic::self_comparator` to avoid preprocessing again.
*
* @param table The table to preprocess
* @param column_order Optional, host array the same length as a row that indicates the desired
Expand Down Expand Up @@ -427,6 +429,162 @@ class self_comparator {
std::shared_ptr<preprocessed_table> d_t;
};

template <typename Nullate>
class two_table_device_row_comparator_adapter {
friend class two_table_comparator;

public:
/**
* @brief Checks whether the row at `lhs_index` in the `lhs` table compares
* lexicographically less than the row at `rhs_index` in the `rhs` table.
*
* @param lhs_index The index of row in the `lhs` table to examine
* @param rhs_index The index of the row in the `rhs` table to examine
* @return `true` if row from the `lhs` table compares less than row in the `rhs` table
*/
__device__ bool operator()(lhs_index_type const lhs_index,
rhs_index_type const rhs_index) const noexcept
{
return comp(static_cast<cudf::size_type>(lhs_index), static_cast<cudf::size_type>(rhs_index));
}

/**
* @brief Checks whether the row at `rhs_index` in the `rhs` table compares
* lexicographically less than the row at `lhs_index` in the `lhs` table.
*
* @param rhs_index The index of row in the `rhs` table to examine
* @param lhs_index The index of the row in the `lhs` table to examine
* @return `true` if row from the `rhs` table compares less than row in the `lhs` table
*/
__device__ bool operator()(rhs_index_type const rhs_index,
lhs_index_type const lhs_index) const noexcept
{
// TODO: "not lhs < rhs" isn't quite the same as "rhs < lhs". The case of
// equality returns true for operator(rhs, lhs), while operator(lhs, rhs)
// returns false. This would have to be handled at a lower level, if it
// matters. Do we just document that this means "rhs <= lhs"?
return not comp(static_cast<cudf::size_type>(lhs_index),
static_cast<cudf::size_type>(rhs_index));
}

private:
/**
* @brief Construct a function object for performing a lexicographic
* comparison between the rows of two tables with strongly typed table index
* types.
*
* @param check_nulls Indicates if either input table contains columns with nulls.
* @param lhs The first table
* @param rhs The second table (may be the same table as `lhs`)
* @param depth Optional, device array the same length as a row that contains starting depths of
* columns if they're nested, and 0 otherwise.
* @param column_order Optional, device array the same length as a row that indicates the desired
* ascending/descending order of each column in a row. If `nullopt`, it is assumed all columns are
* sorted in ascending order.
* @param null_precedence Optional, device array the same length as a row and indicates how null
* values compare to all other for every column. If `nullopt`, then null precedence would be
* `null_order::BEFORE` for all columns.
*/
two_table_device_row_comparator_adapter(
Nullate check_nulls,
table_device_view lhs,
table_device_view rhs,
std::optional<device_span<int const>> depth = std::nullopt,
std::optional<device_span<order const>> column_order = std::nullopt,
std::optional<device_span<null_order const>> null_precedence = std::nullopt)
: comp{check_nulls, lhs, rhs, depth, column_order, null_precedence}
{
}

device_row_comparator<Nullate> comp;
};

/**
* @brief An owning object that can be used to lexicographically compare rows of two different
* tables
*
* This class takes two table_views and preprocesses certain columns to allow for lexicographical
* comparison. The preprocessed table and temporary data required for the comparison are created and
* owned by this class.
*
* Alternatively, `two_table_comparator` can be constructed from two existing
* `shared_ptr<preprocessed_table>`s when sharing the same tables among multiple comparators.
*
* This class can then provide a functor object that can used on the device.
* The object of this class must outlive the usage of the device functor.
*/
class two_table_comparator {
public:
/**
* @brief Construct an owning object for performing a lexicographic comparison between rows of
* two different tables.
*
* The left and right table are expected to have the same number of columns
* and data types for each column.
*
* @param left The left table to compare
* @param right The right table to compare
* @param column_order Optional, host array the same length as a row that indicates the desired
* ascending/descending order of each column in a row. If empty, it is assumed all columns are
* sorted in ascending order.
* @param null_precedence Optional, device array the same length as a row and indicates how null
* values compare to all other for every column. If empty, then null precedence would be
* `null_order::BEFORE` for all columns.
* @param stream The stream to construct this object on. Not the stream that will be used for
* comparisons using this object.
*/
two_table_comparator(table_view const& left,
table_view const& right,
host_span<order const> column_order = {},
host_span<null_order const> null_precedence = {},
rmm::cuda_stream_view stream = rmm::cuda_stream_default)
: d_left_table{preprocessed_table::create(left, column_order, null_precedence, stream)},
d_right_table{preprocessed_table::create(right, column_order, null_precedence, stream)}
{
}

/**
* @brief Construct an owning object for performing a lexicographic comparison between two rows of
* the same preprocessed table.
*
* This constructor allows independently constructing a `preprocessed_table` and sharing it among
* multiple comparators.
*
* @param left A table preprocessed for lexicographic comparison
* @param right A table preprocessed for lexicographic comparison
*/
two_table_comparator(std::shared_ptr<preprocessed_table> left,
std::shared_ptr<preprocessed_table> right)
: d_left_table{std::move(left)}, d_right_table{std::move(right)}
{
}

/**
* @brief Return the binary operator for comparing rows in the table.
*
* Returns a binary callable, `F`, with signature `bool F(lhs_index_type, rhs_index_type)`.
*
* `F(i,j)` returns true if and only if row `i` of the left table compares
* lexicographically less than row `j` of the right table.
*
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
*/
template <typename Nullate>
two_table_device_row_comparator_adapter<Nullate> device_comparator(Nullate nullate = {}) const
{
return two_table_device_row_comparator_adapter<Nullate>(nullate,
*d_left_table,
*d_right_table,
d_left_table->depths(),
d_left_table->column_order(),
d_left_table->null_precedence());
}

private:
std::shared_ptr<preprocessed_table> d_left_table;
std::shared_ptr<preprocessed_table> d_right_table;
};

} // namespace lexicographic

namespace hash {
Expand Down

0 comments on commit 464ed2b

Please sign in to comment.