diff --git a/cpp/include/cudf/table/experimental/row_operators.cuh b/cpp/include/cudf/table/experimental/row_operators.cuh index 2ed45c71633..eb5be4287e2 100644 --- a/cpp/include/cudf/table/experimental/row_operators.cuh +++ b/cpp/include/cudf/table/experimental/row_operators.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -89,6 +90,7 @@ namespace lexicographic { template class device_row_comparator { friend class self_comparator; + // friend class two_table_device_row_comparator_adapter; /** * @brief Construct a function object for performing a lexicographic @@ -277,7 +279,7 @@ struct preprocessed_table { * @brief Preprocess table for use with lexicographical comparison * * Sets up the table for use with lexicographical comparison. The resulting preprocessed table can - * be passed to the constructor of `lex::self_comparator` to avoid preprocessing again. + * be passed to the constructor of `lexicographic::self_comparator` to avoid preprocessing again. * * @param table The table to preprocess * @param column_order Optional, host array the same length as a row that indicates the desired @@ -427,6 +429,162 @@ class self_comparator { std::shared_ptr d_t; }; +template +class two_table_device_row_comparator_adapter { + friend class two_table_comparator; + + public: + /** + * @brief Checks whether the row at `lhs_index` in the `lhs` table compares + * lexicographically less than the row at `rhs_index` in the `rhs` table. + * + * @param lhs_index The index of row in the `lhs` table to examine + * @param rhs_index The index of the row in the `rhs` table to examine + * @return `true` if row from the `lhs` table compares less than row in the `rhs` table + */ + __device__ bool operator()(lhs_index_type const lhs_index, + rhs_index_type const rhs_index) const noexcept + { + return comp(static_cast(lhs_index), static_cast(rhs_index)); + } + + /** + * @brief Checks whether the row at `rhs_index` in the `rhs` table compares + * lexicographically less than the row at `lhs_index` in the `lhs` table. + * + * @param rhs_index The index of row in the `rhs` table to examine + * @param lhs_index The index of the row in the `lhs` table to examine + * @return `true` if row from the `rhs` table compares less than row in the `lhs` table + */ + __device__ bool operator()(rhs_index_type const rhs_index, + lhs_index_type const lhs_index) const noexcept + { + // TODO: "not lhs < rhs" isn't quite the same as "rhs < lhs". The case of + // equality returns true for operator(rhs, lhs), while operator(lhs, rhs) + // returns false. This would have to be handled at a lower level, if it + // matters. Do we just document that this means "rhs <= lhs"? + return not comp(static_cast(lhs_index), + static_cast(rhs_index)); + } + + private: + /** + * @brief Construct a function object for performing a lexicographic + * comparison between the rows of two tables with strongly typed table index + * types. + * + * @param check_nulls Indicates if either input table contains columns with nulls. + * @param lhs The first table + * @param rhs The second table (may be the same table as `lhs`) + * @param depth Optional, device array the same length as a row that contains starting depths of + * columns if they're nested, and 0 otherwise. + * @param column_order Optional, device array the same length as a row that indicates the desired + * ascending/descending order of each column in a row. If `nullopt`, it is assumed all columns are + * sorted in ascending order. + * @param null_precedence Optional, device array the same length as a row and indicates how null + * values compare to all other for every column. If `nullopt`, then null precedence would be + * `null_order::BEFORE` for all columns. + */ + two_table_device_row_comparator_adapter( + Nullate check_nulls, + table_device_view lhs, + table_device_view rhs, + std::optional> depth = std::nullopt, + std::optional> column_order = std::nullopt, + std::optional> null_precedence = std::nullopt) + : comp{check_nulls, lhs, rhs, depth, column_order, null_precedence} + { + } + + device_row_comparator comp; +}; + +/** + * @brief An owning object that can be used to lexicographically compare rows of two different + * tables + * + * This class takes two table_views and preprocesses certain columns to allow for lexicographical + * comparison. The preprocessed table and temporary data required for the comparison are created and + * owned by this class. + * + * Alternatively, `two_table_comparator` can be constructed from two existing + * `shared_ptr`s when sharing the same tables among multiple comparators. + * + * This class can then provide a functor object that can used on the device. + * The object of this class must outlive the usage of the device functor. + */ +class two_table_comparator { + public: + /** + * @brief Construct an owning object for performing a lexicographic comparison between rows of + * two different tables. + * + * The left and right table are expected to have the same number of columns + * and data types for each column. + * + * @param left The left table to compare + * @param right The right table to compare + * @param column_order Optional, host array the same length as a row that indicates the desired + * ascending/descending order of each column in a row. If empty, it is assumed all columns are + * sorted in ascending order. + * @param null_precedence Optional, device array the same length as a row and indicates how null + * values compare to all other for every column. If empty, then null precedence would be + * `null_order::BEFORE` for all columns. + * @param stream The stream to construct this object on. Not the stream that will be used for + * comparisons using this object. + */ + two_table_comparator(table_view const& left, + table_view const& right, + host_span column_order = {}, + host_span null_precedence = {}, + rmm::cuda_stream_view stream = rmm::cuda_stream_default) + : d_left_table{preprocessed_table::create(left, column_order, null_precedence, stream)}, + d_right_table{preprocessed_table::create(right, column_order, null_precedence, stream)} + { + } + + /** + * @brief Construct an owning object for performing a lexicographic comparison between two rows of + * the same preprocessed table. + * + * This constructor allows independently constructing a `preprocessed_table` and sharing it among + * multiple comparators. + * + * @param left A table preprocessed for lexicographic comparison + * @param right A table preprocessed for lexicographic comparison + */ + two_table_comparator(std::shared_ptr left, + std::shared_ptr right) + : d_left_table{std::move(left)}, d_right_table{std::move(right)} + { + } + + /** + * @brief Return the binary operator for comparing rows in the table. + * + * Returns a binary callable, `F`, with signature `bool F(lhs_index_type, rhs_index_type)`. + * + * `F(i,j)` returns true if and only if row `i` of the left table compares + * lexicographically less than row `j` of the right table. + * + * @tparam Nullate A cudf::nullate type describing whether to check for nulls. + */ + template + two_table_device_row_comparator_adapter device_comparator(Nullate nullate = {}) const + { + return two_table_device_row_comparator_adapter(nullate, + *d_left_table, + *d_right_table, + d_left_table->depths(), + d_left_table->column_order(), + d_left_table->null_precedence()); + } + + private: + std::shared_ptr d_left_table; + std::shared_ptr d_right_table; +}; + } // namespace lexicographic namespace hash {